diff --git a/cmd/aa/main.go b/cmd/aa/main.go index 5d32e9331..b0737de77 100644 --- a/cmd/aa/main.go +++ b/cmd/aa/main.go @@ -8,6 +8,9 @@ import ( "flag" "fmt" "os" + "os/exec" + "regexp" + "slices" "strings" "github.com/roddhjav/apparmor.d/pkg/aa" @@ -15,12 +18,14 @@ import ( "github.com/roddhjav/apparmor.d/pkg/paths" ) -const usage = `aa [-h] [--lint | --format | --tree] [-s] [-F file] [profiles...] +const usage = `aa [-h] [--lint | --format | --tree | --complain | --enfore] [-s] [-F file] [profiles...] Various AppArmor profiles development tools Options: -h, --help Show this help message and exit. + -e, --enforce Switch the given profile(s) to enforce mode. + -c, --complain Switch the given profile(s) to complain mode. -f, --format Format the AppArmor profiles. -l, --lint Lint the AppArmor profiles. -t, --tree Generate a tree of visited profiles. @@ -31,12 +36,19 @@ Options: // Command line options var ( - help bool - path string - systemd bool - lint bool - format bool - tree bool + help bool + path string + systemd bool + enforce bool + complain bool + lint bool + format bool + tree bool +) + +var ( + regFlags = regexp.MustCompile(`flags=\(([^)]+)\) `) + regProfileHeader = regexp.MustCompile(` {\n`) ) type kind uint8 @@ -60,6 +72,10 @@ func init() { flag.StringVar(&path, "file", "", "Set a logfile or a suffix to the default log file.") flag.BoolVar(&systemd, "s", false, "Parse systemd logs from journalctl.") flag.BoolVar(&systemd, "systemd", false, "Parse systemd logs from journalctl.") + flag.BoolVar(&enforce, "e", false, "Switch the given profile to enforce mode.") + flag.BoolVar(&enforce, "enforce", false, "Switch the given profile to enforce mode.") + flag.BoolVar(&complain, "c", false, "Switch the given profile to complain mode.") + flag.BoolVar(&complain, "complain", false, "Switch the given profile to complain mode.") } func getIndentationLevel(input string) int { @@ -111,7 +127,7 @@ func formatFile(kind kind, profile string) (string, error) { for idx, rules := range rulesByParagraph { aa.IndentationLevel = getIndentationLevel(paragraphs[idx]) rules = rules.Merge().Sort().Format() - profile = strings.ReplaceAll(profile, paragraphs[idx], rules.String()+"\n") + fmt.Printf(rules.String() + "\n") } return profile, nil } @@ -152,17 +168,95 @@ func aaFormat(files paths.PathList) error { return nil } +func aaLint(files paths.PathList) error { + for _, file := range files { + fmt.Printf("wip: %v\n", file) + } + return nil +} + +func setFlag(profile string, flag string) (string, error) { + f := aa.DefaultTunables() + if _, err := f.Parse(profile); err != nil { + return profile, err + } + + flags := f.GetDefaultProfile().Flags + switch flag { + case "enforce": + if len(flags) == 0 || slices.Contains(flags, "enforce") { + return profile, nil // Nothing to do + } + idx := slices.Index(flags, "complain") + if idx == -1 { + return profile, nil // No complain flag, nothing to do + } + flags = slices.Delete(flags, idx, idx+1) + + case "complain": + if slices.Contains(flags, "complain") { + return profile, nil // Nothing to do + } + flags = append(flags, "complain") + + default: + return profile, fmt.Errorf("unknown flag: %s", flag) + } + strFlags := " flags=(" + strings.Join(flags, ",") + ") {\n" + + // Remove all flags definition, then the new flags + profile = regFlags.ReplaceAllLiteralString(profile, "") + if len(flags) > 0 { + profile = regProfileHeader.ReplaceAllLiteralString(profile, strFlags) + } + return profile, nil +} + +func aaSetFlag(files paths.PathList, flag string) error { + for _, file := range files { + profile, err := file.ReadFileAsString() + if err != nil { + return err + } + profile, err = setFlag(profile, flag) + if err != nil { + return err + } + if err = file.WriteFile([]byte(profile)); err != nil { + return err + } + if err = reloadProfile(file); err != nil { + return err + } + } + return nil +} + func aaTree() error { return nil } +func reloadProfile(file *paths.Path) error { + cmd := exec.Command("apparmor_parser", "--replace", file.String()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("apparmor_parser failed: %w", err) + } + return nil +} + func pathsFromArgs() (paths.PathList, error) { res := paths.PathList{} for _, arg := range flag.Args() { path := paths.New(arg) switch { case !path.Exist(): - return nil, fmt.Errorf("file %s not found", path) + if aa.MagicRoot.Join(arg).Exist() { + res = append(res, aa.MagicRoot.Join(arg)) + } else { + return nil, fmt.Errorf("file %s not found", path) + } case path.IsDir(): files, err := path.ReadDirRecursiveFiltered(nil, paths.FilterOutDirectories(), @@ -190,7 +284,26 @@ func main() { var err error var files paths.PathList switch { + case enforce: + files, err = pathsFromArgs() + if err != nil { + logging.Fatal("%s", err.Error()) + } + err = aaSetFlag(files, "enforce") + + case complain: + files, err = pathsFromArgs() + if err != nil { + logging.Fatal("%s", err.Error()) + } + err = aaSetFlag(files, "complain") + case lint: + files, err = pathsFromArgs() + if err != nil { + logging.Fatal("%s", err.Error()) + } + err = aaLint(files) case format: files, err = pathsFromArgs()