feat(aa): add initial profile validation structure.

This commit is contained in:
Alexandre Pujol 2024-05-25 22:36:39 +01:00
parent 2dd6046697
commit 92641e7e28
No known key found for this signature in database
GPG key ID: C5469996F0DF68EC
20 changed files with 222 additions and 2 deletions

View file

@ -12,6 +12,10 @@ type All struct {
RuleBase RuleBase
} }
func (r *All) Validate() error {
return nil
}
func (r *All) Less(other any) bool { func (r *All) Less(other any) bool {
return false return false
} }

View file

@ -49,6 +49,19 @@ func (f *AppArmorProfileFile) String() string {
return renderTemplate("apparmor", f) return renderTemplate("apparmor", f)
} }
// Validate the profile file
func (f *AppArmorProfileFile) Validate() error {
if err := f.Preamble.Validate(); err != nil {
return err
}
for _, p := range f.Profiles {
if err := p.Validate(); err != nil {
return err
}
}
return nil
}
// GetDefaultProfile ensure a profile is always present in the profile file and // GetDefaultProfile ensure a profile is always present in the profile file and
// return it, as a default profile. // return it, as a default profile.
func (f *AppArmorProfileFile) GetDefaultProfile() *Profile { func (f *AppArmorProfileFile) GetDefaultProfile() *Profile {

View file

@ -16,6 +16,10 @@ type Hat struct {
Rules Rules Rules Rules
} }
func (r *Hat) Validate() error {
return nil
}
func (p *Hat) Less(other any) bool { func (p *Hat) Less(other any) bool {
o, _ := other.(*Hat) o, _ := other.(*Hat)
return p.Name < o.Name return p.Name < o.Name

View file

@ -5,6 +5,7 @@
package aa package aa
import ( import (
"fmt"
"slices" "slices"
) )
@ -39,6 +40,13 @@ func newCapabilityFromLog(log map[string]string) Rule {
} }
} }
func (r *Capability) Validate() error {
if err := validateValues(r.Kind(), "name", r.Names); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Capability) Less(other any) bool { func (r *Capability) Less(other any) bool {
o, _ := other.(*Capability) o, _ := other.(*Capability)
for i := 0; i < len(r.Names) && i < len(o.Names); i++ { for i := 0; i < len(r.Names) && i < len(o.Names); i++ {

View file

@ -30,6 +30,13 @@ func newChangeProfileFromLog(log map[string]string) Rule {
} }
} }
func (r *ChangeProfile) Validate() error {
if err := validateValues(r.Kind(), "mode", []string{r.ExecMode}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *ChangeProfile) Less(other any) bool { func (r *ChangeProfile) Less(other any) bool {
o, _ := other.(*ChangeProfile) o, _ := other.(*ChangeProfile)
if r.ExecMode != o.ExecMode { if r.ExecMode != o.ExecMode {

View file

@ -5,6 +5,7 @@
package aa package aa
import ( import (
"fmt"
"slices" "slices"
) )
@ -55,6 +56,13 @@ func newDbusFromLog(log map[string]string) Rule {
} }
} }
func (r *Dbus) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return validateValues(r.Kind(), "bus", []string{r.Bus})
}
func (r *Dbus) Less(other any) bool { func (r *Dbus) Less(other any) bool {
o, _ := other.(*Dbus) o, _ := other.(*Dbus)
for i := 0; i < len(r.Access) && i < len(o.Access); i++ { for i := 0; i < len(r.Access) && i < len(o.Access); i++ {

View file

@ -81,6 +81,10 @@ func newFileFromLog(log map[string]string) Rule {
} }
} }
func (r *File) Validate() error {
return nil
}
func (r *File) Less(other any) bool { func (r *File) Less(other any) bool {
o, _ := other.(*File) o, _ := other.(*File)
letterR := getLetterIn(fileAlphabet, r.Path) letterR := getLetterIn(fileAlphabet, r.Path)
@ -140,6 +144,10 @@ func newLinkFromLog(log map[string]string) Rule {
} }
} }
func (r *Link) Validate() error {
return nil
}
func (r *Link) Less(other any) bool { func (r *Link) Less(other any) bool {
o, _ := other.(*Link) o, _ := other.(*Link)
if r.Path != o.Path { if r.Path != o.Path {

View file

@ -4,7 +4,10 @@
package aa package aa
import "slices" import (
"fmt"
"slices"
)
const tokIOURING = "io_uring" const tokIOURING = "io_uring"
@ -30,6 +33,13 @@ func newIOUringFromLog(log map[string]string) Rule {
} }
} }
func (r *IOUring) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *IOUring) Less(other any) bool { func (r *IOUring) Less(other any) bool {
o, _ := other.(*IOUring) o, _ := other.(*IOUring)
if len(r.Access) != len(o.Access) { if len(r.Access) != len(o.Access) {

View file

@ -42,6 +42,10 @@ func newMountConditionsFromLog(log map[string]string) MountConditions {
return MountConditions{FsType: log["fstype"]} return MountConditions{FsType: log["fstype"]}
} }
func (m MountConditions) Validate() error {
return validateValues(tokMOUNT, "flags", m.Options)
}
func (m MountConditions) Less(other MountConditions) bool { func (m MountConditions) Less(other MountConditions) bool {
if m.FsType != other.FsType { if m.FsType != other.FsType {
return m.FsType < other.FsType return m.FsType < other.FsType
@ -71,6 +75,13 @@ func newMountFromLog(log map[string]string) Rule {
} }
} }
func (r *Mount) Validate() error {
if err := r.MountConditions.Validate(); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Mount) Less(other any) bool { func (r *Mount) Less(other any) bool {
o, _ := other.(*Mount) o, _ := other.(*Mount)
if r.Source != o.Source { if r.Source != o.Source {
@ -120,6 +131,13 @@ func newUmountFromLog(log map[string]string) Rule {
} }
} }
func (r *Umount) Validate() error {
if err := r.MountConditions.Validate(); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Umount) Less(other any) bool { func (r *Umount) Less(other any) bool {
o, _ := other.(*Umount) o, _ := other.(*Umount)
if r.MountPoint != o.MountPoint { if r.MountPoint != o.MountPoint {
@ -166,6 +184,13 @@ func newRemountFromLog(log map[string]string) Rule {
} }
} }
func (r *Remount) Validate() error {
if err := r.MountConditions.Validate(); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Remount) Less(other any) bool { func (r *Remount) Less(other any) bool {
o, _ := other.(*Remount) o, _ := other.(*Remount)
if r.MountPoint != o.MountPoint { if r.MountPoint != o.MountPoint {

View file

@ -5,6 +5,7 @@
package aa package aa
import ( import (
"fmt"
"slices" "slices"
"strings" "strings"
) )
@ -47,6 +48,16 @@ func newMqueueFromLog(log map[string]string) Rule {
} }
} }
func (r *Mqueue) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "type", []string{r.Type}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Mqueue) Less(other any) bool { func (r *Mqueue) Less(other any) bool {
o, _ := other.(*Mqueue) o, _ := other.(*Mqueue)
if len(r.Access) != len(o.Access) { if len(r.Access) != len(o.Access) {

View file

@ -4,6 +4,10 @@
package aa package aa
import (
"fmt"
)
const tokNETWORK = "network" const tokNETWORK = "network"
func init() { func init() {
@ -77,6 +81,19 @@ func newNetworkFromLog(log map[string]string) Rule {
} }
} }
func (r *Network) Validate() error {
if err := validateValues(r.Kind(), "domains", []string{r.Domain}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "type", []string{r.Type}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "protocol", []string{r.Protocol}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Network) Less(other any) bool { func (r *Network) Less(other any) bool {
o, _ := other.(*Network) o, _ := other.(*Network)
if r.Domain != o.Domain { if r.Domain != o.Domain {

View file

@ -24,6 +24,10 @@ func newPivotRootFromLog(log map[string]string) Rule {
} }
} }
func (r *PivotRoot) Validate() error {
return nil
}
func (r *PivotRoot) Less(other any) bool { func (r *PivotRoot) Less(other any) bool {
o, _ := other.(*PivotRoot) o, _ := other.(*PivotRoot)
if r.OldRoot != o.OldRoot { if r.OldRoot != o.OldRoot {

View file

@ -21,6 +21,10 @@ type Comment struct {
RuleBase RuleBase
} }
func (r *Comment) Validate() error {
return nil
}
func (r *Comment) Less(other any) bool { func (r *Comment) Less(other any) bool {
return false return false
} }
@ -51,6 +55,10 @@ type Abi struct {
IsMagic bool IsMagic bool
} }
func (r *Abi) Validate() error {
return nil
}
func (r *Abi) Less(other any) bool { func (r *Abi) Less(other any) bool {
o, _ := other.(*Abi) o, _ := other.(*Abi)
if r.Path != o.Path { if r.Path != o.Path {
@ -82,6 +90,10 @@ type Alias struct {
RewrittenPath string RewrittenPath string
} }
func (r *Alias) Validate() error {
return nil
}
func (r Alias) Less(other any) bool { func (r Alias) Less(other any) bool {
o, _ := other.(*Alias) o, _ := other.(*Alias)
if r.Path != o.Path { if r.Path != o.Path {
@ -114,6 +126,10 @@ type Include struct {
IsMagic bool IsMagic bool
} }
func (r *Include) Validate() error {
return nil
}
func (r *Include) Less(other any) bool { func (r *Include) Less(other any) bool {
o, _ := other.(*Include) o, _ := other.(*Include)
if r.Path == o.Path { if r.Path == o.Path {
@ -149,6 +165,10 @@ type Variable struct {
Define bool Define bool
} }
func (r *Variable) Validate() error {
return nil
}
func (r *Variable) Less(other any) bool { func (r *Variable) Less(other any) bool {
o, _ := other.(*Variable) o, _ := other.(*Variable)
if r.Name != o.Name { if r.Name != o.Name {

View file

@ -5,6 +5,7 @@
package aa package aa
import ( import (
"fmt"
"maps" "maps"
"reflect" "reflect"
"slices" "slices"
@ -18,6 +19,17 @@ const (
tokPROFILE = "profile" tokPROFILE = "profile"
) )
func init() {
requirements[tokPROFILE] = requirement{
tokFLAGS: {
"enforce", "complain", "kill", "default_allow", "unconfined",
"prompt", "audit", "mediate_deleted", "attach_disconnected",
"attach_disconneced.path=", "chroot_relative", "debug",
"interruptible", "kill", "kill.signal=",
},
}
}
// Profile represents a single AppArmor profile. // Profile represents a single AppArmor profile.
type Profile struct { type Profile struct {
RuleBase RuleBase
@ -33,6 +45,13 @@ type Header struct {
Flags []string Flags []string
} }
func (r *Profile) Validate() error {
if err := validateValues(r.Kind(), tokFLAGS, r.Flags); err != nil {
return fmt.Errorf("profile %s: %w", r.Name, err)
}
return r.Rules.Validate()
}
func (p *Profile) Less(other any) bool { func (p *Profile) Less(other any) bool {
o, _ := other.(*Profile) o, _ := other.(*Profile)
if p.Name != o.Name { if p.Name != o.Name {

View file

@ -5,6 +5,7 @@
package aa package aa
import ( import (
"fmt"
"slices" "slices"
) )
@ -34,6 +35,13 @@ func newPtraceFromLog(log map[string]string) Rule {
} }
} }
func (r *Ptrace) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Ptrace) Less(other any) bool { func (r *Ptrace) Less(other any) bool {
o, _ := other.(*Ptrace) o, _ := other.(*Ptrace)
if len(r.Access) != len(o.Access) { if len(r.Access) != len(o.Access) {

View file

@ -35,6 +35,13 @@ func newRlimitFromLog(log map[string]string) Rule {
} }
} }
func (r *Rlimit) Validate() error {
if err := validateValues(r.Kind(), "keys", []string{r.Key}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Rlimit) Less(other any) bool { func (r *Rlimit) Less(other any) bool {
o, _ := other.(*Rlimit) o, _ := other.(*Rlimit)
if r.Key != o.Key { if r.Key != o.Key {

View file

@ -28,6 +28,7 @@ const (
// Rule generic interface for all AppArmor rules // Rule generic interface for all AppArmor rules
type Rule interface { type Rule interface {
Validate() error
Less(other any) bool Less(other any) bool
Equals(other any) bool Equals(other any) bool
String() string String() string
@ -37,6 +38,15 @@ type Rule interface {
type Rules []Rule type Rules []Rule
func (r Rules) Validate() error {
for _, rule := range r {
if err := rule.Validate(); err != nil {
return err
}
}
return nil
}
func (r Rules) String() string { func (r Rules) String() string {
return renderTemplate("rules", r) return renderTemplate("rules", r)
} }
@ -82,6 +92,18 @@ func Must[T any](v T, err error) T {
return v return v
} }
func validateValues(rule string, key string, values []string) error {
for _, v := range values {
if v == "" {
continue
}
if !slices.Contains(requirements[rule][key], v) {
return fmt.Errorf("invalid mode '%s'", v)
}
}
return nil
}
// Helper function to convert a string to a slice of rule values according to // Helper function to convert a string to a slice of rule values according to
// the rule requirements as defined in the requirements map. // the rule requirements as defined in the requirements map.
func toValues(rule string, key string, input string) ([]string, error) { func toValues(rule string, key string, input string) ([]string, error) {

View file

@ -5,6 +5,7 @@
package aa package aa
import ( import (
"fmt"
"slices" "slices"
) )
@ -49,6 +50,16 @@ func newSignalFromLog(log map[string]string) Rule {
} }
} }
func (r *Signal) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "set", r.Set); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Signal) Less(other any) bool { func (r *Signal) Less(other any) bool {
o, _ := other.(*Signal) o, _ := other.(*Signal)
if len(r.Access) != len(o.Access) { if len(r.Access) != len(o.Access) {

View file

@ -4,7 +4,10 @@
package aa package aa
import "slices" import (
"fmt"
"slices"
)
const tokUNIX = "unix" const tokUNIX = "unix"
@ -48,6 +51,13 @@ func newUnixFromLog(log map[string]string) Rule {
} }
} }
func (r *Unix) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Unix) Less(other any) bool { func (r *Unix) Less(other any) bool {
o, _ := other.(*Unix) o, _ := other.(*Unix)
if len(r.Access) != len(o.Access) { if len(r.Access) != len(o.Access) {

View file

@ -20,6 +20,10 @@ func newUsernsFromLog(log map[string]string) Rule {
} }
} }
func (r *Userns) Validate() error {
return nil
}
func (r *Userns) Less(other any) bool { func (r *Userns) Less(other any) bool {
o, _ := other.(*Userns) o, _ := other.(*Userns)
if r.Create != o.Create { if r.Create != o.Create {