diff --git a/pkg/util/tools.go b/pkg/util/tools.go index 1637144e8..64b71bd00 100644 --- a/pkg/util/tools.go +++ b/pkg/util/tools.go @@ -7,6 +7,8 @@ package util import ( "encoding/hex" "regexp" + + "github.com/arduino/go-paths-helper" ) type RegexReplList []RegexRepl @@ -67,3 +69,29 @@ func RemoveDuplicate[T comparable](inlist []T) []T { } return list } + +// CopyTo recursivelly copy all files from a source path to a destination path. +func CopyTo(src *paths.Path, dst *paths.Path) error { + files, err := src.ReadDirRecursiveFiltered(nil, + paths.FilterOutDirectories(), + paths.FilterOutNames("README.md"), + ) + if err != nil { + return err + } + for _, file := range files { + destination, err := file.RelFrom(src) + if err != nil { + return err + } + destination = dst.JoinPath(destination) + if err := destination.Parent().MkdirAll(); err != nil { + return err + } + if err := file.CopyTo(destination); err != nil { + return err + } + } + return nil +} + diff --git a/pkg/util/tools_test.go b/pkg/util/tools_test.go index 9b161cd36..dce3c461e 100644 --- a/pkg/util/tools_test.go +++ b/pkg/util/tools_test.go @@ -8,6 +8,8 @@ import ( "reflect" "regexp" "testing" + + "github.com/arduino/go-paths-helper" ) func TestDecodeHexInString(t *testing.T) { @@ -108,3 +110,44 @@ func TestRegexReplList_Replace(t *testing.T) { }) } } + +func TestCopyTo(t *testing.T) { + tests := []struct { + name string + src *paths.Path + dst *paths.Path + wantErr bool + }{ + { + name: "default", + src: paths.New("../../apparmor.d/groups/_full/"), + dst: paths.New("../../.build/apparmor.d/groups/_full/"), + wantErr: false, + }, + { + name: "issue-source", + src: paths.New("../../apparmor.d/groups/nope/"), + dst: paths.New("../../.build/apparmor.d/groups/_full/"), + wantErr: true, + }, + { + name: "issue-dest-1", + src: paths.New("../../apparmor.d/groups/_full/"), + dst: paths.New("/"), + wantErr: true, + }, + { + name: "issue-dest-2", + src: paths.New("../../apparmor.d/groups/_full/"), + dst: paths.New("/_full/"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CopyTo(tt.src, tt.dst); (err != nil) != tt.wantErr { + t.Errorf("CopyTo() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}