diff --git a/script.go b/script.go index 2d39bdb..1fd40bc 100644 --- a/script.go +++ b/script.go @@ -10,6 +10,7 @@ import ( "fmt" "hash" "io" + "io/fs" "math" "net/http" "os" @@ -76,8 +77,9 @@ func File(path string) *Pipe { } // FindFiles creates a pipe listing all the files in the directory dir and its -// subdirectories recursively, one per line, like Unix find(1). If dir doesn't -// exist or can't be read, the pipe's error status will be set. +// subdirectories recursively, one per line, like Unix find(1). +// Errors are ignored unless no files are found (in which case the pipe's error +// status will be set to the last error encountered). // // Each line of the output consists of a slash-separated path, starting with // the initial directory. For example, if the directory looks like this: @@ -92,17 +94,19 @@ func File(path string) *Pipe { // test/2.txt func FindFiles(dir string) *Pipe { var paths []string - err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + var innerErr error + fs.WalkDir(os.DirFS(dir), ".", func(path string, d fs.DirEntry, err error) error { if err != nil { - return err + innerErr = err + return fs.SkipDir } - if !info.IsDir() { - paths = append(paths, path) + if !d.IsDir() { + paths = append(paths, filepath.Join(dir, path)) } return nil }) - if err != nil { - return NewPipe().WithError(err) + if innerErr != nil && len(paths) == 0 { + return NewPipe().WithError(innerErr) } return Slice(paths) } diff --git a/script_unix_test.go b/script_unix_test.go index 40a98c8..71e2fae 100644 --- a/script_unix_test.go +++ b/script_unix_test.go @@ -3,6 +3,8 @@ package script_test import ( + "os" + "path/filepath" "testing" "github.com/bitfield/script" @@ -106,6 +108,30 @@ func TestExecPipesDataToExternalCommandAndGetsExpectedOutput(t *testing.T) { } } +func TestFindFiles_DoesNotErrorWhenSubDirectoryIsNotReadable(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + restrictedDirPath := filepath.Join(tmpDir, "a_restricted_dir") + if err := os.Mkdir(restrictedDirPath, 0o000); err != nil { + t.Fatal(err) + } + fileAPath := filepath.Join(tmpDir, "file_a.txt") + if err := os.WriteFile(fileAPath, []byte("hello world!"), os.ModePerm); err != nil { + t.Fatal(err) + } + got, err := script.FindFiles(tmpDir).String() + if err != nil { + t.Fatal(err) + } + want := fileAPath + "\n" + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(want, got) { + t.Fatal(cmp.Diff(want, got)) + } +} + func ExampleExec_ok() { script.Exec("echo Hello, world!").Stdout() // Output: