From 4cbacc186c6037e58ce248715942e0a7b008f871 Mon Sep 17 00:00:00 2001 From: Alexandre Pujol Date: Wed, 19 Jun 2024 18:34:58 +0100 Subject: [PATCH] feat(aa): rule interface: replace less & equal by the compare method. - set a new alphabet order to sort AARE based string. - unify compare function for all rules - handle some special sort order, eg: base include --- pkg/aa/all.go | 8 ++--- pkg/aa/base.go | 32 +++-------------- pkg/aa/blocks.go | 9 ++--- pkg/aa/capability.go | 16 +++------ pkg/aa/change_profile.go | 22 +++++------- pkg/aa/convert.go | 40 ++++++++++++++++++++-- pkg/aa/dbus.go | 47 ++++++++++--------------- pkg/aa/file.go | 62 ++++++++++++++------------------- pkg/aa/io_uring.go | 18 ++++------ pkg/aa/mount.go | 74 +++++++++++++--------------------------- pkg/aa/mqueue.go | 22 +++++------- pkg/aa/network.go | 39 +++++++++------------ pkg/aa/pivot_root.go | 23 +++++-------- pkg/aa/preamble.go | 72 ++++++++++++++------------------------ pkg/aa/profile.go | 16 +++------ pkg/aa/ptrace.go | 19 ++++------- pkg/aa/rlimit.go | 17 ++++----- pkg/aa/rules.go | 17 +++------ pkg/aa/signal.go | 23 +++++-------- pkg/aa/template.go | 8 ++++- pkg/aa/unix.go | 50 +++++++++++---------------- pkg/aa/userns.go | 15 ++++---- 22 files changed, 250 insertions(+), 399 deletions(-) diff --git a/pkg/aa/all.go b/pkg/aa/all.go index ba23aa10e..130d25c04 100644 --- a/pkg/aa/all.go +++ b/pkg/aa/all.go @@ -16,12 +16,8 @@ func (r *All) Validate() error { return nil } -func (r *All) Less(other any) bool { - return false -} - -func (r *All) Equals(other any) bool { - return false +func (r *All) Compare(other Rule) int { + return 0 } func (r *All) String() string { diff --git a/pkg/aa/base.go b/pkg/aa/base.go index 8fdae72c3..b65e81bb6 100644 --- a/pkg/aa/base.go +++ b/pkg/aa/base.go @@ -76,26 +76,6 @@ func newRuleFromLog(log map[string]string) RuleBase { } } -func (r RuleBase) Less(other any) bool { - return false -} - -func (r RuleBase) Equals(other any) bool { - return false -} - -func (r RuleBase) String() string { - return renderTemplate(r.Kind(), r) -} - -func (r RuleBase) Constraint() constraint { - return anyKind -} - -func (r RuleBase) Kind() Kind { - return COMMENT -} - type Qualifier struct { Audit bool AccessType string @@ -109,13 +89,9 @@ func newQualifierFromLog(log map[string]string) Qualifier { return Qualifier{Audit: audit} } -func (r Qualifier) Less(other Qualifier) bool { - if r.Audit != other.Audit { - return r.Audit +func (r Qualifier) Compare(o Qualifier) int { + if r := compare(r.Audit, o.Audit); r != 0 { + return r } - return r.AccessType < other.AccessType -} - -func (r Qualifier) Equals(other Qualifier) bool { - return r.Audit == other.Audit && r.AccessType == other.AccessType + return compare(r.AccessType, o.AccessType) } diff --git a/pkg/aa/blocks.go b/pkg/aa/blocks.go index 6d1079ace..b3ce0ba77 100644 --- a/pkg/aa/blocks.go +++ b/pkg/aa/blocks.go @@ -19,14 +19,9 @@ func (r *Hat) Validate() error { return nil } -func (p *Hat) Less(other any) bool { +func (r *Hat) Compare(other Rule) int { o, _ := other.(*Hat) - return p.Name < o.Name -} - -func (p *Hat) Equals(other any) bool { - o, _ := other.(*Hat) - return p.Name == o.Name + return compare(r.Name, o.Name) } func (p *Hat) String() string { diff --git a/pkg/aa/capability.go b/pkg/aa/capability.go index 44508e0e0..285f473e5 100644 --- a/pkg/aa/capability.go +++ b/pkg/aa/capability.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" ) const CAPABILITY Kind = "capability" @@ -47,19 +46,12 @@ func (r *Capability) Validate() error { return nil } -func (r *Capability) Less(other any) bool { +func (r *Capability) Compare(other Rule) int { o, _ := other.(*Capability) - for i := 0; i < len(r.Names) && i < len(o.Names); i++ { - if r.Names[i] != o.Names[i] { - return r.Names[i] < o.Names[i] - } + if res := compare(r.Names, o.Names); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Capability) Equals(other any) bool { - o, _ := other.(*Capability) - return slices.Equal(r.Names, o.Names) && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Capability) String() string { diff --git a/pkg/aa/change_profile.go b/pkg/aa/change_profile.go index d5cc618ce..611fe2ce3 100644 --- a/pkg/aa/change_profile.go +++ b/pkg/aa/change_profile.go @@ -39,24 +39,18 @@ func (r *ChangeProfile) Validate() error { return nil } -func (r *ChangeProfile) Less(other any) bool { +func (r *ChangeProfile) Compare(other Rule) int { o, _ := other.(*ChangeProfile) - if r.ExecMode != o.ExecMode { - return r.ExecMode < o.ExecMode + if res := compare(r.ExecMode, o.ExecMode); res != 0 { + return res } - if r.Exec != o.Exec { - return r.Exec < o.Exec + if res := compare(r.Exec, o.Exec); res != 0 { + return res } - if r.ProfileName != o.ProfileName { - return r.ProfileName < o.ProfileName + if res := compare(r.ProfileName, o.ProfileName); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *ChangeProfile) Equals(other any) bool { - o, _ := other.(*ChangeProfile) - return r.ExecMode == o.ExecMode && r.Exec == o.Exec && - r.ProfileName == o.ProfileName && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *ChangeProfile) String() string { diff --git a/pkg/aa/convert.go b/pkg/aa/convert.go index b78dc00b0..0918d04ae 100644 --- a/pkg/aa/convert.go +++ b/pkg/aa/convert.go @@ -19,9 +19,43 @@ func Must[T any](v T, err error) T { return v } -// cmpFileAccess compares two access strings for file rules. +func boolToInt(b bool) int { + if b { + return 1 + } + return 0 +} + +func compare(a, b any) int { + switch a := a.(type) { + case int: + return a - b.(int) + case string: + a = strings.ToLower(a) + b := strings.ToLower(b.(string)) + if a == b { + return 0 + } + for i := 0; i < len(a) && i < len(b); i++ { + if a[i] != b[i] { + return stringWeights[a[i]] - stringWeights[b[i]] + } + } + return len(a) - len(b) + case []string: + return slices.CompareFunc(a, b.([]string), func(s1, s2 string) int { + return compare(s1, s2) + }) + case bool: + return boolToInt(a) - boolToInt(b.(bool)) + default: + panic("compare: unsupported type") + } +} + +// compareFileAccess compares two access strings for file rules. // It is aimed to be used in slices.SortFunc. -func cmpFileAccess(i, j string) int { +func compareFileAccess(i, j string) int { if slices.Contains(requirements[FILE]["access"], i) && slices.Contains(requirements[FILE]["access"], j) { return requirementsWeights[FILE]["access"][i] - requirementsWeights[FILE]["access"][j] @@ -115,6 +149,6 @@ func toAccess(kind Kind, input string) ([]string, error) { return toValues(kind, "access", input) } - slices.SortFunc(res, cmpFileAccess) + slices.SortFunc(res, compareFileAccess) return slices.Compact(res), nil } diff --git a/pkg/aa/dbus.go b/pkg/aa/dbus.go index 56edd7977..afddd3ef5 100644 --- a/pkg/aa/dbus.go +++ b/pkg/aa/dbus.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" ) const DBUS Kind = "dbus" @@ -63,43 +62,33 @@ func (r *Dbus) Validate() error { return validateValues(r.Kind(), "bus", []string{r.Bus}) } -func (r *Dbus) Less(other any) bool { +func (r *Dbus) Compare(other Rule) int { o, _ := other.(*Dbus) - for i := 0; i < len(r.Access) && i < len(o.Access); i++ { - if r.Access[i] != o.Access[i] { - return r.Access[i] < o.Access[i] - } + if res := compare(r.Access, o.Access); res != 0 { + return res } - if r.Bus != o.Bus { - return r.Bus < o.Bus + if res := compare(r.Bus, o.Bus); res != 0 { + return res } - if r.Name != o.Name { - return r.Name < o.Name + if res := compare(r.Name, o.Name); res != 0 { + return res } - if r.Path != o.Path { - return r.Path < o.Path + if res := compare(r.Path, o.Path); res != 0 { + return res } - if r.Interface != o.Interface { - return r.Interface < o.Interface + if res := compare(r.Interface, o.Interface); res != 0 { + return res } - if r.Member != o.Member { - return r.Member < o.Member + if res := compare(r.Member, o.Member); res != 0 { + return res } - if r.PeerName != o.PeerName { - return r.PeerName < o.PeerName + if res := compare(r.PeerName, o.PeerName); res != 0 { + return res } - if r.PeerLabel != o.PeerLabel { - return r.PeerLabel < o.PeerLabel + if res := compare(r.PeerLabel, o.PeerLabel); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Dbus) Equals(other any) bool { - o, _ := other.(*Dbus) - return slices.Equal(r.Access, o.Access) && r.Bus == o.Bus && r.Name == o.Name && - r.Path == o.Path && r.Interface == o.Interface && - r.Member == o.Member && r.PeerName == o.PeerName && - r.PeerLabel == o.PeerLabel && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Dbus) String() string { diff --git a/pkg/aa/file.go b/pkg/aa/file.go index d1ea214a1..66a577fe0 100644 --- a/pkg/aa/file.go +++ b/pkg/aa/file.go @@ -68,32 +68,27 @@ func (r *File) Validate() error { return nil } -func (r *File) Less(other any) bool { +func (r *File) Compare(other Rule) int { o, _ := other.(*File) + letterR := getLetterIn(fileAlphabet, r.Path) letterO := getLetterIn(fileAlphabet, o.Path) if fileWeights[letterR] != fileWeights[letterO] && letterR != "" && letterO != "" { - return fileWeights[letterR] < fileWeights[letterO] + return fileWeights[letterR] - fileWeights[letterO] } - if r.Path != o.Path { - return r.Path < o.Path + if res := compare(r.Owner, o.Owner); res != 0 { + return res } - if o.Owner != r.Owner { - return r.Owner + if res := compare(r.Path, o.Path); res != 0 { + return res } - if len(r.Access) != len(o.Access) { - return len(r.Access) < len(o.Access) + if res := compare(r.Access, o.Access); res != 0 { + return res } - if r.Target != o.Target { - return r.Target < o.Target + if res := compare(r.Target, o.Target); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *File) Equals(other any) bool { - o, _ := other.(*File) - return r.Path == o.Path && slices.Equal(r.Access, o.Access) && r.Owner == o.Owner && - r.Target == o.Target && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *File) String() string { @@ -131,27 +126,22 @@ func (r *Link) Validate() error { return nil } -func (r *Link) Less(other any) bool { +func (r *Link) Compare(other Rule) int { o, _ := other.(*Link) - if r.Path != o.Path { - return r.Path < o.Path - } - if o.Owner != r.Owner { - return r.Owner - } - if r.Target != o.Target { - return r.Target < o.Target - } - if r.Subset != o.Subset { - return r.Subset - } - return r.Qualifier.Less(o.Qualifier) -} -func (r *Link) Equals(other any) bool { - o, _ := other.(*Link) - return r.Subset == o.Subset && r.Owner == o.Owner && r.Path == o.Path && - r.Target == o.Target && r.Qualifier.Equals(o.Qualifier) + if res := compare(r.Owner, o.Owner); res != 0 { + return res + } + if res := compare(r.Path, o.Path); res != 0 { + return res + } + if res := compare(r.Target, o.Target); res != 0 { + return res + } + if res := compare(r.Subset, o.Subset); res != 0 { + return res + } + return r.Qualifier.Compare(o.Qualifier) } func (r *Link) String() string { diff --git a/pkg/aa/io_uring.go b/pkg/aa/io_uring.go index 42297a1f1..78c3b1227 100644 --- a/pkg/aa/io_uring.go +++ b/pkg/aa/io_uring.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" ) const IOURING Kind = "io_uring" @@ -40,20 +39,15 @@ func (r *IOUring) Validate() error { return nil } -func (r *IOUring) Less(other any) bool { +func (r *IOUring) Compare(other Rule) int { o, _ := other.(*IOUring) - if len(r.Access) != len(o.Access) { - return len(r.Access) < len(o.Access) + if res := compare(r.Access, o.Access); res != 0 { + return res } - if r.Label != o.Label { - return r.Label < o.Label + if res := compare(r.Label, o.Label); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *IOUring) Equals(other any) bool { - o, _ := other.(*IOUring) - return slices.Equal(r.Access, o.Access) && r.Label == o.Label && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *IOUring) String() string { diff --git a/pkg/aa/mount.go b/pkg/aa/mount.go index e131e54cf..a81401e3c 100644 --- a/pkg/aa/mount.go +++ b/pkg/aa/mount.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" ) const ( @@ -48,15 +47,11 @@ func (m MountConditions) Validate() error { return validateValues(MOUNT, "flags", m.Options) } -func (m MountConditions) Less(other MountConditions) bool { - if m.FsType != other.FsType { - return m.FsType < other.FsType +func (m MountConditions) Compare(other MountConditions) int { + if res := compare(m.FsType, other.FsType); res != 0 { + return res } - return len(m.Options) < len(other.Options) -} - -func (m MountConditions) Equals(other MountConditions) bool { - return m.FsType == other.FsType && slices.Equal(m.Options, other.Options) + return compare(m.Options, other.Options) } type Mount struct { @@ -84,25 +79,18 @@ func (r *Mount) Validate() error { return nil } -func (r *Mount) Less(other any) bool { +func (r *Mount) Compare(other Rule) int { o, _ := other.(*Mount) - if r.Source != o.Source { - return r.Source < o.Source + if res := compare(r.Source, o.Source); res != 0 { + return res } - if r.MountPoint != o.MountPoint { - return r.MountPoint < o.MountPoint + if res := compare(r.MountPoint, o.MountPoint); res != 0 { + return res } - if r.MountConditions.Equals(o.MountConditions) { - return r.MountConditions.Less(o.MountConditions) + if res := r.MountConditions.Compare(o.MountConditions); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Mount) Equals(other any) bool { - o, _ := other.(*Mount) - return r.Source == o.Source && r.MountPoint == o.MountPoint && - r.MountConditions.Equals(o.MountConditions) && - r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Mount) String() string { @@ -140,22 +128,15 @@ func (r *Umount) Validate() error { return nil } -func (r *Umount) Less(other any) bool { +func (r *Umount) Compare(other Rule) int { o, _ := other.(*Umount) - if r.MountPoint != o.MountPoint { - return r.MountPoint < o.MountPoint + if res := compare(r.MountPoint, o.MountPoint); res != 0 { + return res } - if r.MountConditions.Equals(o.MountConditions) { - return r.MountConditions.Less(o.MountConditions) + if res := r.MountConditions.Compare(o.MountConditions); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Umount) Equals(other any) bool { - o, _ := other.(*Umount) - return r.MountPoint == o.MountPoint && - r.MountConditions.Equals(o.MountConditions) && - r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Umount) String() string { @@ -193,22 +174,15 @@ func (r *Remount) Validate() error { return nil } -func (r *Remount) Less(other any) bool { +func (r *Remount) Compare(other Rule) int { o, _ := other.(*Remount) - if r.MountPoint != o.MountPoint { - return r.MountPoint < o.MountPoint + if res := compare(r.MountPoint, o.MountPoint); res != 0 { + return res } - if r.MountConditions.Equals(o.MountConditions) { - return r.MountConditions.Less(o.MountConditions) + if res := r.MountConditions.Compare(o.MountConditions); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Remount) Equals(other any) bool { - o, _ := other.(*Remount) - return r.MountPoint == o.MountPoint && - r.MountConditions.Equals(o.MountConditions) && - r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Remount) String() string { diff --git a/pkg/aa/mqueue.go b/pkg/aa/mqueue.go index 9fc5f2607..9809039c2 100644 --- a/pkg/aa/mqueue.go +++ b/pkg/aa/mqueue.go @@ -58,24 +58,18 @@ func (r *Mqueue) Validate() error { return nil } -func (r *Mqueue) Less(other any) bool { +func (r *Mqueue) Compare(other Rule) int { o, _ := other.(*Mqueue) - if len(r.Access) != len(o.Access) { - return len(r.Access) < len(o.Access) + if res := compare(r.Access, o.Access); res != 0 { + return res } - if r.Type != o.Type { - return r.Type < o.Type + if res := compare(r.Type, o.Type); res != 0 { + return res } - if r.Label != o.Label { - return r.Label < o.Label + if res := compare(r.Label, o.Label); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Mqueue) Equals(other any) bool { - o, _ := other.(*Mqueue) - return slices.Equal(r.Access, o.Access) && r.Type == o.Type && r.Label == o.Label && - r.Name == o.Name && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Mqueue) String() string { diff --git a/pkg/aa/network.go b/pkg/aa/network.go index 8d01b0bab..4970bc972 100644 --- a/pkg/aa/network.go +++ b/pkg/aa/network.go @@ -46,14 +46,14 @@ func newAddressExprFromLog(log map[string]string) AddressExpr { } } -func (r AddressExpr) Less(other AddressExpr) bool { - if r.Source != other.Source { - return r.Source < other.Source +func (r AddressExpr) Compare(other AddressExpr) int { + if res := compare(r.Source, other.Source); res != 0 { + return res } - if r.Destination != other.Destination { - return r.Destination < other.Destination + if res := compare(r.Destination, other.Destination); res != 0 { + return res } - return r.Port < other.Port + return compare(r.Port, other.Port) } func (r AddressExpr) Equals(other AddressExpr) bool { @@ -94,28 +94,21 @@ func (r *Network) Validate() error { return nil } -func (r *Network) Less(other any) bool { +func (r *Network) Compare(other Rule) int { o, _ := other.(*Network) - if r.Domain != o.Domain { - return r.Domain < o.Domain + if res := compare(r.Domain, o.Domain); res != 0 { + return res } - if r.Type != o.Type { - return r.Type < o.Type + if res := compare(r.Type, o.Type); res != 0 { + return res } - if r.Protocol != o.Protocol { - return r.Protocol < o.Protocol + if res := compare(r.Protocol, o.Protocol); res != 0 { + return res } - if r.AddressExpr.Less(o.AddressExpr) { - return r.AddressExpr.Less(o.AddressExpr) + if res := r.AddressExpr.Compare(o.AddressExpr); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Network) Equals(other any) bool { - o, _ := other.(*Network) - return r.Domain == o.Domain && r.Type == o.Type && - r.Protocol == o.Protocol && r.AddressExpr.Equals(o.AddressExpr) && - r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Network) String() string { diff --git a/pkg/aa/pivot_root.go b/pkg/aa/pivot_root.go index 93847bf65..0adeb62cb 100644 --- a/pkg/aa/pivot_root.go +++ b/pkg/aa/pivot_root.go @@ -28,25 +28,18 @@ func (r *PivotRoot) Validate() error { return nil } -func (r *PivotRoot) Less(other any) bool { +func (r *PivotRoot) Compare(other Rule) int { o, _ := other.(*PivotRoot) - if r.OldRoot != o.OldRoot { - return r.OldRoot < o.OldRoot + if res := compare(r.OldRoot, o.OldRoot); res != 0 { + return res } - if r.NewRoot != o.NewRoot { - return r.NewRoot < o.NewRoot + if res := compare(r.NewRoot, o.NewRoot); res != 0 { + return res } - if r.TargetProfile != o.TargetProfile { - return r.TargetProfile < o.TargetProfile + if res := compare(r.TargetProfile, o.TargetProfile); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *PivotRoot) Equals(other any) bool { - o, _ := other.(*PivotRoot) - return r.OldRoot == o.OldRoot && r.NewRoot == o.NewRoot && - r.TargetProfile == o.TargetProfile && - r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *PivotRoot) String() string { diff --git a/pkg/aa/preamble.go b/pkg/aa/preamble.go index d8cb58131..552b16481 100644 --- a/pkg/aa/preamble.go +++ b/pkg/aa/preamble.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" "strings" ) @@ -34,12 +33,8 @@ func (r *Comment) Validate() error { return nil } -func (r *Comment) Less(other any) bool { - return false -} - -func (r *Comment) Equals(other any) bool { - return false +func (r *Comment) Compare(other Rule) int { + return 0 } func (r *Comment) String() string { @@ -93,17 +88,12 @@ func (r *Abi) Validate() error { return nil } -func (r *Abi) Less(other any) bool { +func (r *Abi) Compare(other Rule) int { o, _ := other.(*Abi) - if r.Path != o.Path { - return r.Path < o.Path + if res := compare(r.Path, o.Path); res != 0 { + return res } - return r.IsMagic == o.IsMagic -} - -func (r *Abi) Equals(other any) bool { - o, _ := other.(*Abi) - return r.Path == o.Path && r.IsMagic == o.IsMagic + return compare(r.IsMagic, o.IsMagic) } func (r *Abi) String() string { @@ -145,17 +135,12 @@ func (r *Alias) Validate() error { return nil } -func (r Alias) Less(other any) bool { +func (r *Alias) Compare(other Rule) int { o, _ := other.(*Alias) - if r.Path != o.Path { - return r.Path < o.Path + if res := compare(r.Path, o.Path); res != 0 { + return res } - return r.RewrittenPath < o.RewrittenPath -} - -func (r Alias) Equals(other any) bool { - o, _ := other.(*Alias) - return r.Path == o.Path && r.RewrittenPath == o.RewrittenPath + return compare(r.RewrittenPath, o.RewrittenPath) } func (r *Alias) String() string { @@ -216,20 +201,22 @@ func (r *Include) Validate() error { return nil } -func (r *Include) Less(other any) bool { +func (r *Include) Compare(other Rule) int { + const base = "abstractions/base" o, _ := other.(*Include) - if r.Path == o.Path { - return r.Path < o.Path + if res := compare(r.Path, o.Path); res != 0 { + if r.Path == base { + return -1 + } + if o.Path == base { + return 1 + } + return res } - if r.IsMagic != o.IsMagic { - return r.IsMagic + if res := compare(r.IsMagic, o.IsMagic); res != 0 { + return res } - return r.IfExists -} - -func (r *Include) Equals(other any) bool { - o, _ := other.(*Include) - return r.Path == o.Path && r.IsMagic == o.IsMagic && r.IfExists == o.IfExists + return compare(r.IfExists, o.IfExists) } func (r *Include) String() string { @@ -284,17 +271,8 @@ func (r *Variable) Validate() error { return nil } -func (r *Variable) Less(other any) bool { - o, _ := other.(*Variable) - if r.Name != o.Name { - return r.Name < o.Name - } - return len(r.Values) < len(o.Values) -} - -func (r *Variable) Equals(other any) bool { - o, _ := other.(*Variable) - return r.Name == o.Name && slices.Equal(r.Values, o.Values) +func (r *Variable) Compare(other Rule) int { + return 0 } func (r *Variable) String() string { diff --git a/pkg/aa/profile.go b/pkg/aa/profile.go index 97349a456..138ce6578 100644 --- a/pkg/aa/profile.go +++ b/pkg/aa/profile.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "maps" "slices" "strings" ) @@ -96,19 +95,12 @@ func (r *Profile) Validate() error { return r.Rules.Validate() } -func (p *Profile) Less(other any) bool { +func (r *Profile) Compare(other Rule) int { o, _ := other.(*Profile) - if p.Name != o.Name { - return p.Name < o.Name + if res := compare(r.Name, o.Name); res != 0 { + return res } - return len(p.Attachments) < len(o.Attachments) -} - -func (p *Profile) Equals(other any) bool { - o, _ := other.(*Profile) - return p.Name == o.Name && slices.Equal(p.Attachments, o.Attachments) && - maps.Equal(p.Attributes, o.Attributes) && - slices.Equal(p.Flags, o.Flags) + return compare(r.Attachments, o.Attachments) } func (p *Profile) String() string { diff --git a/pkg/aa/ptrace.go b/pkg/aa/ptrace.go index 00eca5888..5276d315f 100644 --- a/pkg/aa/ptrace.go +++ b/pkg/aa/ptrace.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" ) const PTRACE Kind = "ptrace" @@ -42,21 +41,15 @@ func (r *Ptrace) Validate() error { return nil } -func (r *Ptrace) Less(other any) bool { +func (r *Ptrace) Compare(other Rule) int { o, _ := other.(*Ptrace) - if len(r.Access) != len(o.Access) { - return len(r.Access) < len(o.Access) + if res := compare(r.Access, o.Access); res != 0 { + return res } - if r.Peer != o.Peer { - return r.Peer == o.Peer + if res := compare(r.Peer, o.Peer); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Ptrace) Equals(other any) bool { - o, _ := other.(*Ptrace) - return slices.Equal(r.Access, o.Access) && r.Peer == o.Peer && - r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Ptrace) String() string { diff --git a/pkg/aa/rlimit.go b/pkg/aa/rlimit.go index ddb70710c..415f443bf 100644 --- a/pkg/aa/rlimit.go +++ b/pkg/aa/rlimit.go @@ -43,20 +43,15 @@ func (r *Rlimit) Validate() error { return nil } -func (r *Rlimit) Less(other any) bool { +func (r *Rlimit) Compare(other Rule) int { o, _ := other.(*Rlimit) - if r.Key != o.Key { - return r.Key < o.Key + if res := compare(r.Key, o.Key); res != 0 { + return res } - if r.Op != o.Op { - return r.Op < o.Op + if res := compare(r.Op, o.Op); res != 0 { + return res } - return r.Value < o.Value -} - -func (r *Rlimit) Equals(other any) bool { - o, _ := other.(*Rlimit) - return r.Key == o.Key && r.Op == o.Op && r.Value == o.Value + return compare(r.Value, o.Value) } func (r *Rlimit) String() string { diff --git a/pkg/aa/rules.go b/pkg/aa/rules.go index 7aeb9752c..cdd36fc7a 100644 --- a/pkg/aa/rules.go +++ b/pkg/aa/rules.go @@ -35,8 +35,7 @@ func (k Kind) Tok() string { // Rule generic interface for all AppArmor rules type Rule interface { Validate() error - Less(other any) bool - Equals(other any) bool + Compare(other Rule) int String() string Constraint() constraint Kind() Kind @@ -66,7 +65,7 @@ func (r Rules) Index(item Rule) int { if rule == nil { continue } - if rule.Kind() == item.Kind() && rule.Equals(item) { + if rule.Kind() == item.Kind() && rule.Compare(item) == 0 { return idx } } @@ -153,7 +152,7 @@ func (r Rules) Merge() Rules { } // If rules are identical, merge them - if r[i].Equals(r[j]) { + if r[i].Compare(r[j]) == 0 { r = r.Delete(j) j-- continue @@ -166,7 +165,7 @@ func (r Rules) Merge() Rules { fileJ := r[j].(*File) if fileI.Path == fileJ.Path { fileI.Access = append(fileI.Access, fileJ.Access...) - slices.SortFunc(fileI.Access, cmpFileAccess) + slices.SortFunc(fileI.Access, compareFileAccess) fileI.Access = slices.Compact(fileI.Access) r = r.Delete(j) j-- @@ -192,13 +191,7 @@ func (r Rules) Sort() Rules { } return ruleWeights[kindOfA] - ruleWeights[kindOfB] } - if a.Equals(b) { - return 0 - } - if a.Less(b) { - return -1 - } - return 1 + return a.Compare(b) }) return r } diff --git a/pkg/aa/signal.go b/pkg/aa/signal.go index 4e7ce91cd..53dcc3a5a 100644 --- a/pkg/aa/signal.go +++ b/pkg/aa/signal.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" ) const SIGNAL Kind = "signal" @@ -60,24 +59,18 @@ func (r *Signal) Validate() error { return nil } -func (r *Signal) Less(other any) bool { +func (r *Signal) Compare(other Rule) int { o, _ := other.(*Signal) - if len(r.Access) != len(o.Access) { - return len(r.Access) < len(o.Access) + if res := compare(r.Access, o.Access); res != 0 { + return res } - if len(r.Set) != len(o.Set) { - return len(r.Set) < len(o.Set) + if res := compare(r.Set, o.Set); res != 0 { + return res } - if r.Peer != o.Peer { - return r.Peer < o.Peer + if res := compare(r.Peer, o.Peer); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Signal) Equals(other any) bool { - o, _ := other.(*Signal) - return slices.Equal(r.Access, o.Access) && slices.Equal(r.Set, o.Set) && - r.Peer == o.Peer && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Signal) String() string { diff --git a/pkg/aa/template.go b/pkg/aa/template.go index e20388cb6..28a8d3e4d 100644 --- a/pkg/aa/template.go +++ b/pkg/aa/template.go @@ -117,6 +117,12 @@ var ( } fileWeights = generateWeights(fileAlphabet) + // The order AARE should be sorted + stringAlphabet = []byte( + "!\"#$%&'(){}[]*+,-./:;<=>?@\\^_`|~0123456789abcdefghijklmnopqrstuvwxyz", + ) + stringWeights = generateWeights(stringAlphabet) + // The order the rule values (access, type, domains, etc) should be sorted requirements = map[Kind]requirement{} requirementsWeights map[Kind]map[string]map[string]int @@ -155,7 +161,7 @@ func renderTemplate(name Kind, data any) string { return res.String() } -func generateWeights[T Kind | string](alphabet []T) map[T]int { +func generateWeights[T comparable](alphabet []T) map[T]int { res := make(map[T]int, len(alphabet)) for i, r := range alphabet { res[r] = i diff --git a/pkg/aa/unix.go b/pkg/aa/unix.go index b868459b0..41afbad6c 100644 --- a/pkg/aa/unix.go +++ b/pkg/aa/unix.go @@ -6,7 +6,6 @@ package aa import ( "fmt" - "slices" ) const UNIX Kind = "unix" @@ -58,45 +57,36 @@ func (r *Unix) Validate() error { return nil } -func (r *Unix) Less(other any) bool { +func (r *Unix) Compare(other Rule) int { o, _ := other.(*Unix) - if len(r.Access) != len(o.Access) { - return len(r.Access) < len(o.Access) + if res := compare(r.Access, o.Access); res != 0 { + return res } - if r.Type != o.Type { - return r.Type < o.Type + if res := compare(r.Type, o.Type); res != 0 { + return res } - if r.Protocol != o.Protocol { - return r.Protocol < o.Protocol + if res := compare(r.Protocol, o.Protocol); res != 0 { + return res } - if r.Address != o.Address { - return r.Address < o.Address + if res := compare(r.Address, o.Address); res != 0 { + return res } - if r.Label != o.Label { - return r.Label < o.Label + if res := compare(r.Label, o.Label); res != 0 { + return res } - if r.Attr != o.Attr { - return r.Attr < o.Attr + if res := compare(r.Attr, o.Attr); res != 0 { + return res } - if r.Opt != o.Opt { - return r.Opt < o.Opt + if res := compare(r.Opt, o.Opt); res != 0 { + return res } - if r.PeerLabel != o.PeerLabel { - return r.PeerLabel < o.PeerLabel + if res := compare(r.PeerLabel, o.PeerLabel); res != 0 { + return res } - if r.PeerAddr != o.PeerAddr { - return r.PeerAddr < o.PeerAddr + if res := compare(r.PeerAddr, o.PeerAddr); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Unix) Equals(other any) bool { - o, _ := other.(*Unix) - return slices.Equal(r.Access, o.Access) && r.Type == o.Type && - r.Protocol == o.Protocol && r.Address == o.Address && - r.Label == o.Label && r.Attr == o.Attr && r.Opt == o.Opt && - r.PeerLabel == o.PeerLabel && r.PeerAddr == o.PeerAddr && - r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Unix) String() string { diff --git a/pkg/aa/userns.go b/pkg/aa/userns.go index 4c678f3d0..8a1338550 100644 --- a/pkg/aa/userns.go +++ b/pkg/aa/userns.go @@ -4,6 +4,8 @@ package aa +import "fmt" + const USERNS Kind = "userns" type Userns struct { @@ -24,17 +26,12 @@ func (r *Userns) Validate() error { return nil } -func (r *Userns) Less(other any) bool { +func (r *Userns) Compare(other Rule) int { o, _ := other.(*Userns) - if r.Create != o.Create { - return r.Create + if res := compare(r.Create, o.Create); res != 0 { + return res } - return r.Qualifier.Less(o.Qualifier) -} - -func (r *Userns) Equals(other any) bool { - o, _ := other.(*Userns) - return r.Create == o.Create && r.Qualifier.Equals(o.Qualifier) + return r.Qualifier.Compare(o.Qualifier) } func (r *Userns) String() string {