feat(aa): rule interface: replace less & equal by the compare method.

- set a new alphabet order to sort AARE based string.
- unify compare function for all rules
- handle some special sort order, eg: base include
This commit is contained in:
Alexandre Pujol 2024-06-19 18:34:58 +01:00
parent 747292e954
commit 4cbacc186c
No known key found for this signature in database
GPG key ID: C5469996F0DF68EC
22 changed files with 250 additions and 399 deletions

View file

@ -16,12 +16,8 @@ func (r *All) Validate() error {
return nil return nil
} }
func (r *All) Less(other any) bool { func (r *All) Compare(other Rule) int {
return false return 0
}
func (r *All) Equals(other any) bool {
return false
} }
func (r *All) String() string { func (r *All) String() string {

View file

@ -76,26 +76,6 @@ func newRuleFromLog(log map[string]string) RuleBase {
} }
} }
func (r RuleBase) Less(other any) bool {
return false
}
func (r RuleBase) Equals(other any) bool {
return false
}
func (r RuleBase) String() string {
return renderTemplate(r.Kind(), r)
}
func (r RuleBase) Constraint() constraint {
return anyKind
}
func (r RuleBase) Kind() Kind {
return COMMENT
}
type Qualifier struct { type Qualifier struct {
Audit bool Audit bool
AccessType string AccessType string
@ -109,13 +89,9 @@ func newQualifierFromLog(log map[string]string) Qualifier {
return Qualifier{Audit: audit} return Qualifier{Audit: audit}
} }
func (r Qualifier) Less(other Qualifier) bool { func (r Qualifier) Compare(o Qualifier) int {
if r.Audit != other.Audit { if r := compare(r.Audit, o.Audit); r != 0 {
return r.Audit return r
} }
return r.AccessType < other.AccessType return compare(r.AccessType, o.AccessType)
}
func (r Qualifier) Equals(other Qualifier) bool {
return r.Audit == other.Audit && r.AccessType == other.AccessType
} }

View file

@ -19,14 +19,9 @@ func (r *Hat) Validate() error {
return nil return nil
} }
func (p *Hat) Less(other any) bool { func (r *Hat) Compare(other Rule) int {
o, _ := other.(*Hat) o, _ := other.(*Hat)
return p.Name < o.Name return compare(r.Name, o.Name)
}
func (p *Hat) Equals(other any) bool {
o, _ := other.(*Hat)
return p.Name == o.Name
} }
func (p *Hat) String() string { func (p *Hat) String() string {

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
) )
const CAPABILITY Kind = "capability" const CAPABILITY Kind = "capability"
@ -47,19 +46,12 @@ func (r *Capability) Validate() error {
return nil return nil
} }
func (r *Capability) Less(other any) bool { func (r *Capability) Compare(other Rule) int {
o, _ := other.(*Capability) o, _ := other.(*Capability)
for i := 0; i < len(r.Names) && i < len(o.Names); i++ { if res := compare(r.Names, o.Names); res != 0 {
if r.Names[i] != o.Names[i] { return res
return r.Names[i] < o.Names[i]
} }
} return r.Qualifier.Compare(o.Qualifier)
return r.Qualifier.Less(o.Qualifier)
}
func (r *Capability) Equals(other any) bool {
o, _ := other.(*Capability)
return slices.Equal(r.Names, o.Names) && r.Qualifier.Equals(o.Qualifier)
} }
func (r *Capability) String() string { func (r *Capability) String() string {

View file

@ -39,24 +39,18 @@ func (r *ChangeProfile) Validate() error {
return nil return nil
} }
func (r *ChangeProfile) Less(other any) bool { func (r *ChangeProfile) Compare(other Rule) int {
o, _ := other.(*ChangeProfile) o, _ := other.(*ChangeProfile)
if r.ExecMode != o.ExecMode { if res := compare(r.ExecMode, o.ExecMode); res != 0 {
return r.ExecMode < o.ExecMode return res
} }
if r.Exec != o.Exec { if res := compare(r.Exec, o.Exec); res != 0 {
return r.Exec < o.Exec return res
} }
if r.ProfileName != o.ProfileName { if res := compare(r.ProfileName, o.ProfileName); res != 0 {
return r.ProfileName < o.ProfileName return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *ChangeProfile) Equals(other any) bool {
o, _ := other.(*ChangeProfile)
return r.ExecMode == o.ExecMode && r.Exec == o.Exec &&
r.ProfileName == o.ProfileName && r.Qualifier.Equals(o.Qualifier)
} }
func (r *ChangeProfile) String() string { func (r *ChangeProfile) String() string {

View file

@ -19,9 +19,43 @@ func Must[T any](v T, err error) T {
return v return v
} }
// cmpFileAccess compares two access strings for file rules. func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
func compare(a, b any) int {
switch a := a.(type) {
case int:
return a - b.(int)
case string:
a = strings.ToLower(a)
b := strings.ToLower(b.(string))
if a == b {
return 0
}
for i := 0; i < len(a) && i < len(b); i++ {
if a[i] != b[i] {
return stringWeights[a[i]] - stringWeights[b[i]]
}
}
return len(a) - len(b)
case []string:
return slices.CompareFunc(a, b.([]string), func(s1, s2 string) int {
return compare(s1, s2)
})
case bool:
return boolToInt(a) - boolToInt(b.(bool))
default:
panic("compare: unsupported type")
}
}
// compareFileAccess compares two access strings for file rules.
// It is aimed to be used in slices.SortFunc. // It is aimed to be used in slices.SortFunc.
func cmpFileAccess(i, j string) int { func compareFileAccess(i, j string) int {
if slices.Contains(requirements[FILE]["access"], i) && if slices.Contains(requirements[FILE]["access"], i) &&
slices.Contains(requirements[FILE]["access"], j) { slices.Contains(requirements[FILE]["access"], j) {
return requirementsWeights[FILE]["access"][i] - requirementsWeights[FILE]["access"][j] return requirementsWeights[FILE]["access"][i] - requirementsWeights[FILE]["access"][j]
@ -115,6 +149,6 @@ func toAccess(kind Kind, input string) ([]string, error) {
return toValues(kind, "access", input) return toValues(kind, "access", input)
} }
slices.SortFunc(res, cmpFileAccess) slices.SortFunc(res, compareFileAccess)
return slices.Compact(res), nil return slices.Compact(res), nil
} }

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
) )
const DBUS Kind = "dbus" const DBUS Kind = "dbus"
@ -63,43 +62,33 @@ func (r *Dbus) Validate() error {
return validateValues(r.Kind(), "bus", []string{r.Bus}) return validateValues(r.Kind(), "bus", []string{r.Bus})
} }
func (r *Dbus) Less(other any) bool { func (r *Dbus) Compare(other Rule) int {
o, _ := other.(*Dbus) o, _ := other.(*Dbus)
for i := 0; i < len(r.Access) && i < len(o.Access); i++ { if res := compare(r.Access, o.Access); res != 0 {
if r.Access[i] != o.Access[i] { return res
return r.Access[i] < o.Access[i]
} }
if res := compare(r.Bus, o.Bus); res != 0 {
return res
} }
if r.Bus != o.Bus { if res := compare(r.Name, o.Name); res != 0 {
return r.Bus < o.Bus return res
} }
if r.Name != o.Name { if res := compare(r.Path, o.Path); res != 0 {
return r.Name < o.Name return res
} }
if r.Path != o.Path { if res := compare(r.Interface, o.Interface); res != 0 {
return r.Path < o.Path return res
} }
if r.Interface != o.Interface { if res := compare(r.Member, o.Member); res != 0 {
return r.Interface < o.Interface return res
} }
if r.Member != o.Member { if res := compare(r.PeerName, o.PeerName); res != 0 {
return r.Member < o.Member return res
} }
if r.PeerName != o.PeerName { if res := compare(r.PeerLabel, o.PeerLabel); res != 0 {
return r.PeerName < o.PeerName return res
} }
if r.PeerLabel != o.PeerLabel { return r.Qualifier.Compare(o.Qualifier)
return r.PeerLabel < o.PeerLabel
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Dbus) Equals(other any) bool {
o, _ := other.(*Dbus)
return slices.Equal(r.Access, o.Access) && r.Bus == o.Bus && r.Name == o.Name &&
r.Path == o.Path && r.Interface == o.Interface &&
r.Member == o.Member && r.PeerName == o.PeerName &&
r.PeerLabel == o.PeerLabel && r.Qualifier.Equals(o.Qualifier)
} }
func (r *Dbus) String() string { func (r *Dbus) String() string {

View file

@ -68,32 +68,27 @@ func (r *File) Validate() error {
return nil return nil
} }
func (r *File) Less(other any) bool { func (r *File) Compare(other Rule) int {
o, _ := other.(*File) o, _ := other.(*File)
letterR := getLetterIn(fileAlphabet, r.Path) letterR := getLetterIn(fileAlphabet, r.Path)
letterO := getLetterIn(fileAlphabet, o.Path) letterO := getLetterIn(fileAlphabet, o.Path)
if fileWeights[letterR] != fileWeights[letterO] && letterR != "" && letterO != "" { if fileWeights[letterR] != fileWeights[letterO] && letterR != "" && letterO != "" {
return fileWeights[letterR] < fileWeights[letterO] return fileWeights[letterR] - fileWeights[letterO]
} }
if r.Path != o.Path { if res := compare(r.Owner, o.Owner); res != 0 {
return r.Path < o.Path return res
} }
if o.Owner != r.Owner { if res := compare(r.Path, o.Path); res != 0 {
return r.Owner return res
} }
if len(r.Access) != len(o.Access) { if res := compare(r.Access, o.Access); res != 0 {
return len(r.Access) < len(o.Access) return res
} }
if r.Target != o.Target { if res := compare(r.Target, o.Target); res != 0 {
return r.Target < o.Target return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *File) Equals(other any) bool {
o, _ := other.(*File)
return r.Path == o.Path && slices.Equal(r.Access, o.Access) && r.Owner == o.Owner &&
r.Target == o.Target && r.Qualifier.Equals(o.Qualifier)
} }
func (r *File) String() string { func (r *File) String() string {
@ -131,27 +126,22 @@ func (r *Link) Validate() error {
return nil return nil
} }
func (r *Link) Less(other any) bool { func (r *Link) Compare(other Rule) int {
o, _ := other.(*Link) o, _ := other.(*Link)
if r.Path != o.Path {
return r.Path < o.Path
}
if o.Owner != r.Owner {
return r.Owner
}
if r.Target != o.Target {
return r.Target < o.Target
}
if r.Subset != o.Subset {
return r.Subset
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Link) Equals(other any) bool { if res := compare(r.Owner, o.Owner); res != 0 {
o, _ := other.(*Link) return res
return r.Subset == o.Subset && r.Owner == o.Owner && r.Path == o.Path && }
r.Target == o.Target && r.Qualifier.Equals(o.Qualifier) if res := compare(r.Path, o.Path); res != 0 {
return res
}
if res := compare(r.Target, o.Target); res != 0 {
return res
}
if res := compare(r.Subset, o.Subset); res != 0 {
return res
}
return r.Qualifier.Compare(o.Qualifier)
} }
func (r *Link) String() string { func (r *Link) String() string {

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
) )
const IOURING Kind = "io_uring" const IOURING Kind = "io_uring"
@ -40,20 +39,15 @@ func (r *IOUring) Validate() error {
return nil return nil
} }
func (r *IOUring) Less(other any) bool { func (r *IOUring) Compare(other Rule) int {
o, _ := other.(*IOUring) o, _ := other.(*IOUring)
if len(r.Access) != len(o.Access) { if res := compare(r.Access, o.Access); res != 0 {
return len(r.Access) < len(o.Access) return res
} }
if r.Label != o.Label { if res := compare(r.Label, o.Label); res != 0 {
return r.Label < o.Label return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *IOUring) Equals(other any) bool {
o, _ := other.(*IOUring)
return slices.Equal(r.Access, o.Access) && r.Label == o.Label && r.Qualifier.Equals(o.Qualifier)
} }
func (r *IOUring) String() string { func (r *IOUring) String() string {

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
) )
const ( const (
@ -48,15 +47,11 @@ func (m MountConditions) Validate() error {
return validateValues(MOUNT, "flags", m.Options) return validateValues(MOUNT, "flags", m.Options)
} }
func (m MountConditions) Less(other MountConditions) bool { func (m MountConditions) Compare(other MountConditions) int {
if m.FsType != other.FsType { if res := compare(m.FsType, other.FsType); res != 0 {
return m.FsType < other.FsType return res
} }
return len(m.Options) < len(other.Options) return compare(m.Options, other.Options)
}
func (m MountConditions) Equals(other MountConditions) bool {
return m.FsType == other.FsType && slices.Equal(m.Options, other.Options)
} }
type Mount struct { type Mount struct {
@ -84,25 +79,18 @@ func (r *Mount) Validate() error {
return nil return nil
} }
func (r *Mount) Less(other any) bool { func (r *Mount) Compare(other Rule) int {
o, _ := other.(*Mount) o, _ := other.(*Mount)
if r.Source != o.Source { if res := compare(r.Source, o.Source); res != 0 {
return r.Source < o.Source return res
} }
if r.MountPoint != o.MountPoint { if res := compare(r.MountPoint, o.MountPoint); res != 0 {
return r.MountPoint < o.MountPoint return res
} }
if r.MountConditions.Equals(o.MountConditions) { if res := r.MountConditions.Compare(o.MountConditions); res != 0 {
return r.MountConditions.Less(o.MountConditions) return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Mount) Equals(other any) bool {
o, _ := other.(*Mount)
return r.Source == o.Source && r.MountPoint == o.MountPoint &&
r.MountConditions.Equals(o.MountConditions) &&
r.Qualifier.Equals(o.Qualifier)
} }
func (r *Mount) String() string { func (r *Mount) String() string {
@ -140,22 +128,15 @@ func (r *Umount) Validate() error {
return nil return nil
} }
func (r *Umount) Less(other any) bool { func (r *Umount) Compare(other Rule) int {
o, _ := other.(*Umount) o, _ := other.(*Umount)
if r.MountPoint != o.MountPoint { if res := compare(r.MountPoint, o.MountPoint); res != 0 {
return r.MountPoint < o.MountPoint return res
} }
if r.MountConditions.Equals(o.MountConditions) { if res := r.MountConditions.Compare(o.MountConditions); res != 0 {
return r.MountConditions.Less(o.MountConditions) return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Umount) Equals(other any) bool {
o, _ := other.(*Umount)
return r.MountPoint == o.MountPoint &&
r.MountConditions.Equals(o.MountConditions) &&
r.Qualifier.Equals(o.Qualifier)
} }
func (r *Umount) String() string { func (r *Umount) String() string {
@ -193,22 +174,15 @@ func (r *Remount) Validate() error {
return nil return nil
} }
func (r *Remount) Less(other any) bool { func (r *Remount) Compare(other Rule) int {
o, _ := other.(*Remount) o, _ := other.(*Remount)
if r.MountPoint != o.MountPoint { if res := compare(r.MountPoint, o.MountPoint); res != 0 {
return r.MountPoint < o.MountPoint return res
} }
if r.MountConditions.Equals(o.MountConditions) { if res := r.MountConditions.Compare(o.MountConditions); res != 0 {
return r.MountConditions.Less(o.MountConditions) return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Remount) Equals(other any) bool {
o, _ := other.(*Remount)
return r.MountPoint == o.MountPoint &&
r.MountConditions.Equals(o.MountConditions) &&
r.Qualifier.Equals(o.Qualifier)
} }
func (r *Remount) String() string { func (r *Remount) String() string {

View file

@ -58,24 +58,18 @@ func (r *Mqueue) Validate() error {
return nil return nil
} }
func (r *Mqueue) Less(other any) bool { func (r *Mqueue) Compare(other Rule) int {
o, _ := other.(*Mqueue) o, _ := other.(*Mqueue)
if len(r.Access) != len(o.Access) { if res := compare(r.Access, o.Access); res != 0 {
return len(r.Access) < len(o.Access) return res
} }
if r.Type != o.Type { if res := compare(r.Type, o.Type); res != 0 {
return r.Type < o.Type return res
} }
if r.Label != o.Label { if res := compare(r.Label, o.Label); res != 0 {
return r.Label < o.Label return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Mqueue) Equals(other any) bool {
o, _ := other.(*Mqueue)
return slices.Equal(r.Access, o.Access) && r.Type == o.Type && r.Label == o.Label &&
r.Name == o.Name && r.Qualifier.Equals(o.Qualifier)
} }
func (r *Mqueue) String() string { func (r *Mqueue) String() string {

View file

@ -46,14 +46,14 @@ func newAddressExprFromLog(log map[string]string) AddressExpr {
} }
} }
func (r AddressExpr) Less(other AddressExpr) bool { func (r AddressExpr) Compare(other AddressExpr) int {
if r.Source != other.Source { if res := compare(r.Source, other.Source); res != 0 {
return r.Source < other.Source return res
} }
if r.Destination != other.Destination { if res := compare(r.Destination, other.Destination); res != 0 {
return r.Destination < other.Destination return res
} }
return r.Port < other.Port return compare(r.Port, other.Port)
} }
func (r AddressExpr) Equals(other AddressExpr) bool { func (r AddressExpr) Equals(other AddressExpr) bool {
@ -94,28 +94,21 @@ func (r *Network) Validate() error {
return nil return nil
} }
func (r *Network) Less(other any) bool { func (r *Network) Compare(other Rule) int {
o, _ := other.(*Network) o, _ := other.(*Network)
if r.Domain != o.Domain { if res := compare(r.Domain, o.Domain); res != 0 {
return r.Domain < o.Domain return res
} }
if r.Type != o.Type { if res := compare(r.Type, o.Type); res != 0 {
return r.Type < o.Type return res
} }
if r.Protocol != o.Protocol { if res := compare(r.Protocol, o.Protocol); res != 0 {
return r.Protocol < o.Protocol return res
} }
if r.AddressExpr.Less(o.AddressExpr) { if res := r.AddressExpr.Compare(o.AddressExpr); res != 0 {
return r.AddressExpr.Less(o.AddressExpr) return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Network) Equals(other any) bool {
o, _ := other.(*Network)
return r.Domain == o.Domain && r.Type == o.Type &&
r.Protocol == o.Protocol && r.AddressExpr.Equals(o.AddressExpr) &&
r.Qualifier.Equals(o.Qualifier)
} }
func (r *Network) String() string { func (r *Network) String() string {

View file

@ -28,25 +28,18 @@ func (r *PivotRoot) Validate() error {
return nil return nil
} }
func (r *PivotRoot) Less(other any) bool { func (r *PivotRoot) Compare(other Rule) int {
o, _ := other.(*PivotRoot) o, _ := other.(*PivotRoot)
if r.OldRoot != o.OldRoot { if res := compare(r.OldRoot, o.OldRoot); res != 0 {
return r.OldRoot < o.OldRoot return res
} }
if r.NewRoot != o.NewRoot { if res := compare(r.NewRoot, o.NewRoot); res != 0 {
return r.NewRoot < o.NewRoot return res
} }
if r.TargetProfile != o.TargetProfile { if res := compare(r.TargetProfile, o.TargetProfile); res != 0 {
return r.TargetProfile < o.TargetProfile return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *PivotRoot) Equals(other any) bool {
o, _ := other.(*PivotRoot)
return r.OldRoot == o.OldRoot && r.NewRoot == o.NewRoot &&
r.TargetProfile == o.TargetProfile &&
r.Qualifier.Equals(o.Qualifier)
} }
func (r *PivotRoot) String() string { func (r *PivotRoot) String() string {

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
"strings" "strings"
) )
@ -34,12 +33,8 @@ func (r *Comment) Validate() error {
return nil return nil
} }
func (r *Comment) Less(other any) bool { func (r *Comment) Compare(other Rule) int {
return false return 0
}
func (r *Comment) Equals(other any) bool {
return false
} }
func (r *Comment) String() string { func (r *Comment) String() string {
@ -93,17 +88,12 @@ func (r *Abi) Validate() error {
return nil return nil
} }
func (r *Abi) Less(other any) bool { func (r *Abi) Compare(other Rule) int {
o, _ := other.(*Abi) o, _ := other.(*Abi)
if r.Path != o.Path { if res := compare(r.Path, o.Path); res != 0 {
return r.Path < o.Path return res
} }
return r.IsMagic == o.IsMagic return compare(r.IsMagic, o.IsMagic)
}
func (r *Abi) Equals(other any) bool {
o, _ := other.(*Abi)
return r.Path == o.Path && r.IsMagic == o.IsMagic
} }
func (r *Abi) String() string { func (r *Abi) String() string {
@ -145,17 +135,12 @@ func (r *Alias) Validate() error {
return nil return nil
} }
func (r Alias) Less(other any) bool { func (r *Alias) Compare(other Rule) int {
o, _ := other.(*Alias) o, _ := other.(*Alias)
if r.Path != o.Path { if res := compare(r.Path, o.Path); res != 0 {
return r.Path < o.Path return res
} }
return r.RewrittenPath < o.RewrittenPath return compare(r.RewrittenPath, o.RewrittenPath)
}
func (r Alias) Equals(other any) bool {
o, _ := other.(*Alias)
return r.Path == o.Path && r.RewrittenPath == o.RewrittenPath
} }
func (r *Alias) String() string { func (r *Alias) String() string {
@ -216,20 +201,22 @@ func (r *Include) Validate() error {
return nil return nil
} }
func (r *Include) Less(other any) bool { func (r *Include) Compare(other Rule) int {
const base = "abstractions/base"
o, _ := other.(*Include) o, _ := other.(*Include)
if r.Path == o.Path { if res := compare(r.Path, o.Path); res != 0 {
return r.Path < o.Path if r.Path == base {
return -1
} }
if r.IsMagic != o.IsMagic { if o.Path == base {
return r.IsMagic return 1
} }
return r.IfExists return res
} }
if res := compare(r.IsMagic, o.IsMagic); res != 0 {
func (r *Include) Equals(other any) bool { return res
o, _ := other.(*Include) }
return r.Path == o.Path && r.IsMagic == o.IsMagic && r.IfExists == o.IfExists return compare(r.IfExists, o.IfExists)
} }
func (r *Include) String() string { func (r *Include) String() string {
@ -284,17 +271,8 @@ func (r *Variable) Validate() error {
return nil return nil
} }
func (r *Variable) Less(other any) bool { func (r *Variable) Compare(other Rule) int {
o, _ := other.(*Variable) return 0
if r.Name != o.Name {
return r.Name < o.Name
}
return len(r.Values) < len(o.Values)
}
func (r *Variable) Equals(other any) bool {
o, _ := other.(*Variable)
return r.Name == o.Name && slices.Equal(r.Values, o.Values)
} }
func (r *Variable) String() string { func (r *Variable) String() string {

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"maps"
"slices" "slices"
"strings" "strings"
) )
@ -96,19 +95,12 @@ func (r *Profile) Validate() error {
return r.Rules.Validate() return r.Rules.Validate()
} }
func (p *Profile) Less(other any) bool { func (r *Profile) Compare(other Rule) int {
o, _ := other.(*Profile) o, _ := other.(*Profile)
if p.Name != o.Name { if res := compare(r.Name, o.Name); res != 0 {
return p.Name < o.Name return res
} }
return len(p.Attachments) < len(o.Attachments) return compare(r.Attachments, o.Attachments)
}
func (p *Profile) Equals(other any) bool {
o, _ := other.(*Profile)
return p.Name == o.Name && slices.Equal(p.Attachments, o.Attachments) &&
maps.Equal(p.Attributes, o.Attributes) &&
slices.Equal(p.Flags, o.Flags)
} }
func (p *Profile) String() string { func (p *Profile) String() string {

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
) )
const PTRACE Kind = "ptrace" const PTRACE Kind = "ptrace"
@ -42,21 +41,15 @@ func (r *Ptrace) Validate() error {
return nil return nil
} }
func (r *Ptrace) Less(other any) bool { func (r *Ptrace) Compare(other Rule) int {
o, _ := other.(*Ptrace) o, _ := other.(*Ptrace)
if len(r.Access) != len(o.Access) { if res := compare(r.Access, o.Access); res != 0 {
return len(r.Access) < len(o.Access) return res
} }
if r.Peer != o.Peer { if res := compare(r.Peer, o.Peer); res != 0 {
return r.Peer == o.Peer return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Ptrace) Equals(other any) bool {
o, _ := other.(*Ptrace)
return slices.Equal(r.Access, o.Access) && r.Peer == o.Peer &&
r.Qualifier.Equals(o.Qualifier)
} }
func (r *Ptrace) String() string { func (r *Ptrace) String() string {

View file

@ -43,20 +43,15 @@ func (r *Rlimit) Validate() error {
return nil return nil
} }
func (r *Rlimit) Less(other any) bool { func (r *Rlimit) Compare(other Rule) int {
o, _ := other.(*Rlimit) o, _ := other.(*Rlimit)
if r.Key != o.Key { if res := compare(r.Key, o.Key); res != 0 {
return r.Key < o.Key return res
} }
if r.Op != o.Op { if res := compare(r.Op, o.Op); res != 0 {
return r.Op < o.Op return res
} }
return r.Value < o.Value return compare(r.Value, o.Value)
}
func (r *Rlimit) Equals(other any) bool {
o, _ := other.(*Rlimit)
return r.Key == o.Key && r.Op == o.Op && r.Value == o.Value
} }
func (r *Rlimit) String() string { func (r *Rlimit) String() string {

View file

@ -35,8 +35,7 @@ func (k Kind) Tok() string {
// Rule generic interface for all AppArmor rules // Rule generic interface for all AppArmor rules
type Rule interface { type Rule interface {
Validate() error Validate() error
Less(other any) bool Compare(other Rule) int
Equals(other any) bool
String() string String() string
Constraint() constraint Constraint() constraint
Kind() Kind Kind() Kind
@ -66,7 +65,7 @@ func (r Rules) Index(item Rule) int {
if rule == nil { if rule == nil {
continue continue
} }
if rule.Kind() == item.Kind() && rule.Equals(item) { if rule.Kind() == item.Kind() && rule.Compare(item) == 0 {
return idx return idx
} }
} }
@ -153,7 +152,7 @@ func (r Rules) Merge() Rules {
} }
// If rules are identical, merge them // If rules are identical, merge them
if r[i].Equals(r[j]) { if r[i].Compare(r[j]) == 0 {
r = r.Delete(j) r = r.Delete(j)
j-- j--
continue continue
@ -166,7 +165,7 @@ func (r Rules) Merge() Rules {
fileJ := r[j].(*File) fileJ := r[j].(*File)
if fileI.Path == fileJ.Path { if fileI.Path == fileJ.Path {
fileI.Access = append(fileI.Access, fileJ.Access...) fileI.Access = append(fileI.Access, fileJ.Access...)
slices.SortFunc(fileI.Access, cmpFileAccess) slices.SortFunc(fileI.Access, compareFileAccess)
fileI.Access = slices.Compact(fileI.Access) fileI.Access = slices.Compact(fileI.Access)
r = r.Delete(j) r = r.Delete(j)
j-- j--
@ -192,13 +191,7 @@ func (r Rules) Sort() Rules {
} }
return ruleWeights[kindOfA] - ruleWeights[kindOfB] return ruleWeights[kindOfA] - ruleWeights[kindOfB]
} }
if a.Equals(b) { return a.Compare(b)
return 0
}
if a.Less(b) {
return -1
}
return 1
}) })
return r return r
} }

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
) )
const SIGNAL Kind = "signal" const SIGNAL Kind = "signal"
@ -60,24 +59,18 @@ func (r *Signal) Validate() error {
return nil return nil
} }
func (r *Signal) Less(other any) bool { func (r *Signal) Compare(other Rule) int {
o, _ := other.(*Signal) o, _ := other.(*Signal)
if len(r.Access) != len(o.Access) { if res := compare(r.Access, o.Access); res != 0 {
return len(r.Access) < len(o.Access) return res
} }
if len(r.Set) != len(o.Set) { if res := compare(r.Set, o.Set); res != 0 {
return len(r.Set) < len(o.Set) return res
} }
if r.Peer != o.Peer { if res := compare(r.Peer, o.Peer); res != 0 {
return r.Peer < o.Peer return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Signal) Equals(other any) bool {
o, _ := other.(*Signal)
return slices.Equal(r.Access, o.Access) && slices.Equal(r.Set, o.Set) &&
r.Peer == o.Peer && r.Qualifier.Equals(o.Qualifier)
} }
func (r *Signal) String() string { func (r *Signal) String() string {

View file

@ -117,6 +117,12 @@ var (
} }
fileWeights = generateWeights(fileAlphabet) fileWeights = generateWeights(fileAlphabet)
// The order AARE should be sorted
stringAlphabet = []byte(
"!\"#$%&'(){}[]*+,-./:;<=>?@\\^_`|~0123456789abcdefghijklmnopqrstuvwxyz",
)
stringWeights = generateWeights(stringAlphabet)
// The order the rule values (access, type, domains, etc) should be sorted // The order the rule values (access, type, domains, etc) should be sorted
requirements = map[Kind]requirement{} requirements = map[Kind]requirement{}
requirementsWeights map[Kind]map[string]map[string]int requirementsWeights map[Kind]map[string]map[string]int
@ -155,7 +161,7 @@ func renderTemplate(name Kind, data any) string {
return res.String() return res.String()
} }
func generateWeights[T Kind | string](alphabet []T) map[T]int { func generateWeights[T comparable](alphabet []T) map[T]int {
res := make(map[T]int, len(alphabet)) res := make(map[T]int, len(alphabet))
for i, r := range alphabet { for i, r := range alphabet {
res[r] = i res[r] = i

View file

@ -6,7 +6,6 @@ package aa
import ( import (
"fmt" "fmt"
"slices"
) )
const UNIX Kind = "unix" const UNIX Kind = "unix"
@ -58,45 +57,36 @@ func (r *Unix) Validate() error {
return nil return nil
} }
func (r *Unix) Less(other any) bool { func (r *Unix) Compare(other Rule) int {
o, _ := other.(*Unix) o, _ := other.(*Unix)
if len(r.Access) != len(o.Access) { if res := compare(r.Access, o.Access); res != 0 {
return len(r.Access) < len(o.Access) return res
} }
if r.Type != o.Type { if res := compare(r.Type, o.Type); res != 0 {
return r.Type < o.Type return res
} }
if r.Protocol != o.Protocol { if res := compare(r.Protocol, o.Protocol); res != 0 {
return r.Protocol < o.Protocol return res
} }
if r.Address != o.Address { if res := compare(r.Address, o.Address); res != 0 {
return r.Address < o.Address return res
} }
if r.Label != o.Label { if res := compare(r.Label, o.Label); res != 0 {
return r.Label < o.Label return res
} }
if r.Attr != o.Attr { if res := compare(r.Attr, o.Attr); res != 0 {
return r.Attr < o.Attr return res
} }
if r.Opt != o.Opt { if res := compare(r.Opt, o.Opt); res != 0 {
return r.Opt < o.Opt return res
} }
if r.PeerLabel != o.PeerLabel { if res := compare(r.PeerLabel, o.PeerLabel); res != 0 {
return r.PeerLabel < o.PeerLabel return res
} }
if r.PeerAddr != o.PeerAddr { if res := compare(r.PeerAddr, o.PeerAddr); res != 0 {
return r.PeerAddr < o.PeerAddr return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Unix) Equals(other any) bool {
o, _ := other.(*Unix)
return slices.Equal(r.Access, o.Access) && r.Type == o.Type &&
r.Protocol == o.Protocol && r.Address == o.Address &&
r.Label == o.Label && r.Attr == o.Attr && r.Opt == o.Opt &&
r.PeerLabel == o.PeerLabel && r.PeerAddr == o.PeerAddr &&
r.Qualifier.Equals(o.Qualifier)
} }
func (r *Unix) String() string { func (r *Unix) String() string {

View file

@ -4,6 +4,8 @@
package aa package aa
import "fmt"
const USERNS Kind = "userns" const USERNS Kind = "userns"
type Userns struct { type Userns struct {
@ -24,17 +26,12 @@ func (r *Userns) Validate() error {
return nil return nil
} }
func (r *Userns) Less(other any) bool { func (r *Userns) Compare(other Rule) int {
o, _ := other.(*Userns) o, _ := other.(*Userns)
if r.Create != o.Create { if res := compare(r.Create, o.Create); res != 0 {
return r.Create return res
} }
return r.Qualifier.Less(o.Qualifier) return r.Qualifier.Compare(o.Qualifier)
}
func (r *Userns) Equals(other any) bool {
o, _ := other.(*Userns)
return r.Create == o.Create && r.Qualifier.Equals(o.Qualifier)
} }
func (r *Userns) String() string { func (r *Userns) String() string {