From 992cab1fa49c4a14219393a138d083e92c5833f2 Mon Sep 17 00:00:00 2001 From: Alexandre Pujol Date: Thu, 30 May 2024 12:32:30 +0100 Subject: [PATCH] feat(aa): move conversion function to its own file & add unit tests. --- pkg/aa/convert.go | 120 +++++++++++++++++++++++++++++++++++++++++ pkg/aa/convert_test.go | 92 +++++++++++++++++++++++++++++++ pkg/aa/file.go | 17 ------ pkg/aa/rules.go | 90 ------------------------------- 4 files changed, 212 insertions(+), 107 deletions(-) create mode 100644 pkg/aa/convert.go create mode 100644 pkg/aa/convert_test.go diff --git a/pkg/aa/convert.go b/pkg/aa/convert.go new file mode 100644 index 000000000..b78dc00b0 --- /dev/null +++ b/pkg/aa/convert.go @@ -0,0 +1,120 @@ +// apparmor.d - Full set of apparmor profiles +// Copyright (C) 2021-2024 Alexandre Pujol +// SPDX-License-Identifier: GPL-2.0-only + +package aa + +import ( + "fmt" + "slices" + "strings" +) + +// Must is a helper that wraps a call to a function returning (any, error) and +// panics if the error is non-nil. +func Must[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} + +// cmpFileAccess compares two access strings for file rules. +// It is aimed to be used in slices.SortFunc. +func cmpFileAccess(i, j string) int { + if slices.Contains(requirements[FILE]["access"], i) && + slices.Contains(requirements[FILE]["access"], j) { + return requirementsWeights[FILE]["access"][i] - requirementsWeights[FILE]["access"][j] + } + if slices.Contains(requirements[FILE]["transition"], i) && + slices.Contains(requirements[FILE]["transition"], j) { + return requirementsWeights[FILE]["transition"][i] - requirementsWeights[FILE]["transition"][j] + } + if slices.Contains(requirements[FILE]["access"], i) { + return -1 + } + return 1 +} + +func validateValues(kind Kind, key string, values []string) error { + for _, v := range values { + if v == "" { + continue + } + if !slices.Contains(requirements[kind][key], v) { + return fmt.Errorf("invalid mode '%s'", v) + } + } + return nil +} + +// Helper function to convert a string to a slice of rule values according to +// the rule requirements as defined in the requirements map. +func toValues(kind Kind, key string, input string) ([]string, error) { + req, ok := requirements[kind][key] + if !ok { + return nil, fmt.Errorf("unrecognized requirement '%s' for rule %s", key, kind) + } + + res := tokenToSlice(input) + for idx := range res { + res[idx] = strings.Trim(res[idx], `" `) + if res[idx] == "" { + res = slices.Delete(res, idx, idx+1) + continue + } + if !slices.Contains(req, res[idx]) { + return nil, fmt.Errorf("unrecognized %s: %s", key, res[idx]) + } + } + slices.SortFunc(res, func(i, j string) int { + return requirementsWeights[kind][key][i] - requirementsWeights[kind][key][j] + }) + return slices.Compact(res), nil +} + +// Helper function to convert an access string to a slice of access according to +// the rule requirements as defined in the requirements map. +func toAccess(kind Kind, input string) ([]string, error) { + var res []string + + switch kind { + case FILE: + raw := strings.Split(input, "") + trans := []string{} + for _, access := range raw { + if slices.Contains(requirements[FILE]["access"], access) { + res = append(res, access) + } else { + trans = append(trans, access) + } + } + + transition := strings.Join(trans, "") + if len(transition) > 0 { + if slices.Contains(requirements[FILE]["transition"], transition) { + res = append(res, transition) + } else { + return nil, fmt.Errorf("unrecognized transition: %s", transition) + } + } + + case FILE + "-log": + raw := strings.Split(input, "") + for _, access := range raw { + if slices.Contains(requirements[FILE]["access"], access) { + res = append(res, access) + } else if maskToAccess[access] != "" { + res = append(res, maskToAccess[access]) + } else { + return nil, fmt.Errorf("toAccess: unrecognized file access '%s' for %s", input, kind) + } + } + + default: + return toValues(kind, "access", input) + } + + slices.SortFunc(res, cmpFileAccess) + return slices.Compact(res), nil +} diff --git a/pkg/aa/convert_test.go b/pkg/aa/convert_test.go new file mode 100644 index 000000000..8a027ffa3 --- /dev/null +++ b/pkg/aa/convert_test.go @@ -0,0 +1,92 @@ +// apparmor.d - Full set of apparmor profiles +// Copyright (C) 2021-2024 Alexandre Pujol +// SPDX-License-Identifier: GPL-2.0-only + +package aa + +import ( + "reflect" + "testing" + + "github.com/k0kubun/pp/v3" +) + +func Test_toAccess(t *testing.T) { + tests := []struct { + name string + kind Kind + inputs []string + wants [][]string + wantsErr []bool + }{ + { + name: "empty", + kind: FILE, + inputs: []string{""}, + wants: [][]string{nil}, + wantsErr: []bool{false}, + }, + { + name: "file", + kind: FILE, + inputs: []string{ + "rPx", "rPUx", "mr", "rm", "rix", "rcx", "rCUx", "rmix", "rwlk", + "mrwkl", "", "r", "x", "w", "wr", "px", "Px", "Ux", "mrwlkPix", + }, + wants: [][]string{ + {"r", "Px"}, {"r", "PUx"}, {"m", "r"}, {"m", "r"}, {"r", "ix"}, + {"r", "cx"}, {"r", "CUx"}, {"m", "r", "ix"}, {"r", "w", "l", "k"}, + {"m", "r", "w", "l", "k"}, nil, {"r"}, {"x"}, {"w"}, {"r", "w"}, + {"px"}, {"Px"}, {"Ux"}, {"m", "r", "w", "l", "k", "Pix"}, + }, + wantsErr: []bool{ + false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, + }, + }, + { + name: "file-log", + kind: FILE + "-log", + inputs: []string{ + "mr", "rm", "x", "rwlk", "mrwkl", "r", "c", "wc", "d", "wr", + }, + wants: [][]string{ + {"m", "r"}, {"m", "r"}, {"ix"}, {"r", "w", "l", "k"}, + {"m", "r", "w", "l", "k"}, {"r"}, {"w"}, {"w"}, {"w"}, {"r", "w"}, + }, + wantsErr: []bool{ + false, false, false, false, false, false, false, false, false, false, + }, + }, + { + name: "signal", + kind: SIGNAL, + inputs: []string{"send receive rw"}, + wants: [][]string{{"rw", "send", "receive"}}, + wantsErr: []bool{false}, + }, + { + name: "ptrace", + kind: PTRACE, + inputs: []string{"readby", "tracedby", "read readby", "r w", "rw", ""}, + wants: [][]string{ + {"readby"}, {"tracedby"}, {"read", "readby"}, {"r", "w"}, {"rw"}, {}, + }, + wantsErr: []bool{false, false, false, false, false, false}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i, input := range tt.inputs { + got, err := toAccess(tt.kind, input) + if (err != nil) != tt.wantsErr[i] { + t.Errorf("toAccess() error = %v, wantErr %v", err, tt.wantsErr[i]) + return + } + if !reflect.DeepEqual(got, tt.wants[i]) { + t.Errorf("toAccess() = %v, want %v", pp.Sprint(got), pp.Sprint(tt.wants[i])) + } + } + }) + } +} diff --git a/pkg/aa/file.go b/pkg/aa/file.go index dd828951b..d1ea214a1 100644 --- a/pkg/aa/file.go +++ b/pkg/aa/file.go @@ -37,23 +37,6 @@ func isOwner(log map[string]string) bool { return false } -// cmpFileAccess compares two access strings for file rules. -// It is aimed to be used in slices.SortFunc. -func cmpFileAccess(i, j string) int { - if slices.Contains(requirements[FILE]["access"], i) && - slices.Contains(requirements[FILE]["access"], j) { - return requirementsWeights[FILE]["access"][i] - requirementsWeights[FILE]["access"][j] - } - if slices.Contains(requirements[FILE]["transition"], i) && - slices.Contains(requirements[FILE]["transition"], j) { - return requirementsWeights[FILE]["transition"][i] - requirementsWeights[FILE]["transition"][j] - } - if slices.Contains(requirements[FILE]["access"], i) { - return -1 - } - return 1 -} - type File struct { RuleBase Qualifier diff --git a/pkg/aa/rules.go b/pkg/aa/rules.go index 5d37ef322..7aeb9752c 100644 --- a/pkg/aa/rules.go +++ b/pkg/aa/rules.go @@ -5,9 +5,7 @@ package aa import ( - "fmt" "slices" - "strings" ) type requirement map[string][]string @@ -237,91 +235,3 @@ func (r Rules) Format() Rules { } return r } - -// Must is a helper that wraps a call to a function returning (any, error) and -// panics if the error is non-nil. -func Must[T any](v T, err error) T { - if err != nil { - panic(err) - } - return v -} - -func validateValues(kind Kind, key string, values []string) error { - for _, v := range values { - if v == "" { - continue - } - if !slices.Contains(requirements[kind][key], v) { - return fmt.Errorf("invalid mode '%s'", v) - } - } - return nil -} - -// Helper function to convert a string to a slice of rule values according to -// the rule requirements as defined in the requirements map. -func toValues(kind Kind, key string, input string) ([]string, error) { - req, ok := requirements[kind][key] - if !ok { - return nil, fmt.Errorf("unrecognized requirement '%s' for rule %s", key, kind) - } - - res := tokenToSlice(input) - for idx := range res { - res[idx] = strings.Trim(res[idx], `" `) - if !slices.Contains(req, res[idx]) { - return nil, fmt.Errorf("unrecognized %s: %s", key, res[idx]) - } - } - slices.SortFunc(res, func(i, j string) int { - return requirementsWeights[kind][key][i] - requirementsWeights[kind][key][j] - }) - return slices.Compact(res), nil -} - -// Helper function to convert an access string to a slice of access according to -// the rule requirements as defined in the requirements map. -func toAccess(kind Kind, input string) ([]string, error) { - var res []string - - switch kind { - case FILE: - raw := strings.Split(input, "") - trans := []string{} - for _, access := range raw { - if slices.Contains(requirements[FILE]["access"], access) { - res = append(res, access) - } else { - trans = append(trans, access) - } - } - - transition := strings.Join(trans, "") - if len(transition) > 0 { - if slices.Contains(requirements[FILE]["transition"], transition) { - res = append(res, transition) - } else { - return nil, fmt.Errorf("unrecognized transition: %s", transition) - } - } - - case FILE + "-log": - raw := strings.Split(input, "") - for _, access := range raw { - if slices.Contains(requirements[FILE]["access"], access) { - res = append(res, access) - } else if maskToAccess[access] != "" { - res = append(res, maskToAccess[access]) - } else { - return nil, fmt.Errorf("toAccess: unrecognized file access '%s'", input) - } - } - - default: - return toValues(kind, "access", input) - } - - slices.SortFunc(res, cmpFileAccess) - return slices.Compact(res), nil -}