From d0b5c4de68d967985ce2cb35ac4145e7b5c623e0 Mon Sep 17 00:00:00 2001 From: Kevin Kuehler Date: Tue, 18 Dec 2018 02:52:41 -0800 Subject: [PATCH] util: Patch CopyFs and add tests Fix CopyFs bug that resulted in a flattened destination directory. Added tests catch this bug, and ensure the data is in fact copied to the destination directory. --- util/afero.go | 20 +++++----- util/afero_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 100 insertions(+), 13 deletions(-) diff --git a/util/afero.go b/util/afero.go index 3b69dd85..5e54d7d4 100644 --- a/util/afero.go +++ b/util/afero.go @@ -22,7 +22,6 @@ import ( "io" "os" "path" - "path/filepath" "github.com/spf13/afero" ) @@ -89,19 +88,20 @@ func stringify(fs afero.Fs, name string, indent []bool) (string, error) { // currently undefined. // TODO: this should be made more rsync like and robust! func CopyFs(srcFs, dstFs afero.Fs, src, dst string, force bool) error { - if src == "" { - src = "/" - } - if dst == "" { - dst = "/" - } + src = path.Join("/", src) + dst = path.Join("/", dst) + + // TODO: clean this up with function that gets parent dir? + src = path.Clean(src) + parentDir, _ := path.Split(src) + srcFsLen := len(parentDir) + walkFn := func(name string, info os.FileInfo, err error) error { if err != nil { return err } - //perm := info.Perm() - perm := info.Mode() // TODO: is this correct? - p := path.Join(dst, filepath.Base(name)) + perm := info.Mode() // get file permissions + p := path.Join(dst, name[srcFsLen:]) if info.IsDir() { err := dstFs.Mkdir(p, perm) if os.IsExist(err) && (name == "/" || force) { diff --git a/util/afero_test.go b/util/afero_test.go index 5c36bcb6..f2c188d0 100644 --- a/util/afero_test.go +++ b/util/afero_test.go @@ -20,6 +20,7 @@ package util import ( + "bytes" "io/ioutil" "os" "testing" @@ -27,10 +28,96 @@ import ( "github.com/spf13/afero" ) -func TestCopyDiskToFs1(t *testing.T) { - if true { - return // XXX: remove me once this test passes +var dirInputs = []struct { + srcDirFull string + srcCopyRoot string + dstCopyRoot string + dstExpected string + force bool +}{ + {"/tmp/foo/bar/baz/", "/tmp/foo/", "/", "/foo/bar/baz", false}, + {"/tmp/zoo/zar/zaz/", "/tmp/zoo/zar", "/start/dir", "/start/dir/zar/zaz", false}, + {"/foo", "/foo", "/", "/foo", false}, + {"/foo", "/foo", "/", "/foo", true}, +} + +func TestCopyFs1(t *testing.T) { + for _, tt := range dirInputs { + src := afero.NewMemMapFs() + dst := afero.NewMemMapFs() + + t.Run(tt.srcDirFull, func(t *testing.T) { + err := src.MkdirAll(tt.srcDirFull, 0700) + if err != nil { + t.Errorf("could not MkdirAll %+v", err) + return + } + err = CopyFs(src, dst, tt.srcCopyRoot, tt.dstCopyRoot, tt.force) + if err != nil { + t.Errorf("error copying source %s to dest %s", tt.srcCopyRoot, tt.dstCopyRoot) + return + } + + isDir, err := afero.IsDir(dst, tt.dstExpected) + if err != nil { + t.Errorf("could not check IsDir: %+v", err) + return + } + if !isDir { + t.Errorf("expected directory tree %s to exist in dest", tt.dstExpected) + return + } + }) } +} + +func TestCopyFs2(t *testing.T) { + tree := "/foo/bar/baz/" + var files = []struct { + path string + content []byte + }{ + {"/foo/foo.txt", []byte("foo")}, + {"/foo/bar/bar.txt", []byte("bar")}, + {"/foo/bar/baz/baz.txt", []byte("baz")}, + } + + src := afero.NewMemMapFs() + dst := afero.NewMemMapFs() + + err := src.MkdirAll(tree, 0700) + if err != nil { + t.Errorf("could not MkdirAll: %+v", err) + return + } + + for _, f := range files { + err = afero.WriteFile(src, f.path, f.content, 0600) + if err != nil { + t.Errorf("could not WriteFile: %+v", err) + return + } + } + + if err = CopyFs(src, dst, "", "", false); err != nil { + t.Errorf("could not CopyFs: %+v", err) + return + } + + for _, f := range files { + content, err := afero.ReadFile(dst, f.path) + if err != nil { + t.Errorf("could not ReadFile: %+v", err) + return + } + if !bytes.Equal(content, f.content) { + t.Errorf("expected: %s, actual: %s, for file %s", string(f.content), string(content), f.path) + return + } + } +} + +func TestCopyDiskToFs1(t *testing.T) { dir, err := TestDirFull() if err != nil { t.Errorf("could not get tests directory: %+v", err)