diff --git a/.github/workflows/zcp-golangci-lint.yml b/.github/workflows/zcp-golangci-lint.yml new file mode 120000 index 0000000..a6147f8 --- /dev/null +++ b/.github/workflows/zcp-golangci-lint.yml @@ -0,0 +1 @@ +../../go/tools/zcp/.github/workflows/zcp-golangci-lint.yml \ No newline at end of file diff --git a/.github/workflows/zcp-tests.yml b/.github/workflows/zcp-tests.yml new file mode 120000 index 0000000..1f97800 --- /dev/null +++ b/.github/workflows/zcp-tests.yml @@ -0,0 +1 @@ +../../go/tools/zcp/.github/workflows/zcp-tests.yml \ No newline at end of file diff --git a/go/tools/zcp/.github/workflows/zcp-golangci-lint.yml b/go/tools/zcp/.github/workflows/zcp-golangci-lint.yml new file mode 100644 index 0000000..cf8ed80 --- /dev/null +++ b/go/tools/zcp/.github/workflows/zcp-golangci-lint.yml @@ -0,0 +1,83 @@ +name: Format & Lint + +on: + push: + paths: &paths + - "go/tools/zcp/**/*.go" + - "go/tools/zcp/go.mod" + - "go/tools/zcp/go.sum" + - "go/tools/zcp/.golangci.yml" + - "go/tools/zcp/.github/workflows/zcp-golangci-lint.yml" + - "go/tools/zcp/.github/workflows/zcp-tests.yml" + - ".github/workflows/zcp-golangci-lint.yml" + - ".github/workflows/zcp-tests.yml" + pull_request: + paths: *paths + workflow_dispatch: + +concurrency: + group: zcp-format-lint-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + format: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup Go with module cache + uses: actions/setup-go@v6 + with: + go-version-file: go/tools/zcp/go.mod + cache: true + cache-dependency-path: | + go/tools/zcp/go.mod + go/tools/zcp/go.sum + + - name: Cache golangci-lint data + uses: actions/cache@v5 + with: + path: ~/.cache/golangci-lint + key: ${{ runner.os }}-golangci-format-${{ hashFiles('go/tools/zcp/go.mod', 'go/tools/zcp/go.sum', 'go/tools/zcp/.golangci.yml') }} + restore-keys: | + ${{ runner.os }}-golangci-format- + + - name: Check formatting via golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: latest + working-directory: go/tools/zcp + args: fmt --diff --timeout=5m ./... + skip-cache: false + + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup Go with module cache + uses: actions/setup-go@v6 + with: + go-version-file: go/tools/zcp/go.mod + cache: true + cache-dependency-path: | + go/tools/zcp/go.mod + go/tools/zcp/go.sum + + - name: Cache golangci-lint data + uses: actions/cache@v5 + with: + path: ~/.cache/golangci-lint + key: ${{ runner.os }}-golangci-${{ hashFiles('go/tools/zcp/go.mod', 'go/tools/zcp/go.sum', 'go/tools/zcp/.golangci.yml') }} + restore-keys: | + ${{ runner.os }}-golangci- + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: latest + working-directory: go/tools/zcp + args: --timeout=5m ./... + skip-cache: false diff --git a/go/tools/zcp/.github/workflows/zcp-tests.yml b/go/tools/zcp/.github/workflows/zcp-tests.yml new file mode 100644 index 0000000..3e39f8a --- /dev/null +++ b/go/tools/zcp/.github/workflows/zcp-tests.yml @@ -0,0 +1,39 @@ +name: Tests + +on: + push: + paths: &paths + - "go/tools/zcp/**/*.go" + - "go/tools/zcp/go.mod" + - "go/tools/zcp/go.sum" + - "go/tools/zcp/.github/workflows/zcp-tests.yml" + - "go/tools/zcp/.github/workflows/zcp-golangci-lint.yml" + - ".github/workflows/zcp-tests.yml" + - ".github/workflows/zcp-golangci-lint.yml" + pull_request: + paths: *paths + workflow_dispatch: + +concurrency: + group: zcp-tests-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup Go with module cache + uses: actions/setup-go@v6 + with: + go-version-file: go/tools/zcp/go.mod + cache: true + cache-dependency-path: | + go/tools/zcp/go.mod + go/tools/zcp/go.sum + + - name: Run tests + working-directory: go/tools/zcp + run: go test ./... diff --git a/go/tools/zcp/.golangci.yml b/go/tools/zcp/.golangci.yml new file mode 100644 index 0000000..481ba53 --- /dev/null +++ b/go/tools/zcp/.golangci.yml @@ -0,0 +1,20 @@ +version: "2" + +run: + timeout: 5m + tests: true + +linters: + disable-all: true + enable: + - errcheck + - govet + - ineffassign + - staticcheck + - unused + +formatters: + enable: + - gofumpt + - goimports + - golines diff --git a/go/tools/zcp/README.md b/go/tools/zcp/README.md new file mode 100644 index 0000000..69b3fe9 --- /dev/null +++ b/go/tools/zcp/README.md @@ -0,0 +1,108 @@ +# zcp + +Like `cp`, but with a progress bar. The `z` is there because I liked it and to avoid name collisions. + +Inspired by the original `gcp` utility: +https://manpages.ubuntu.com/manpages/focal/man1/gcp.1.html + +## Features + +- Copy files and directories +- Recursive directory copy (`-r`) +- Per-copy progress bar with: + - percent complete + - bytes copied / total bytes + - transfer speed + - ETA +- Optional metadata preservation (`-p` mode + mtime) +- Optional overwrite (`-f`) +- Optional verbose output (`-v`) to print created file names + +## Usage + +```bash +zcp [options] SOURCE... DEST +``` + +### Options + +- `-r`, `--recursive`: copy directories recursively +- `-f`, `--force`: overwrite destination files +- `-p`, `--preserve`: preserve mode and modification time +- `-q`, `--quiet`: disable progress output +- `-v`, `--verbose`: print created file names +- `--buffer-size`: copy buffer size in bytes (default `1048576`) + +### Examples + +Copy a single file: + +```bash +zcp movie.mkv /mnt/backup/movie.mkv +``` + +Copy a directory recursively: + +```bash +zcp -r photos /mnt/backup/ +``` + +Copy multiple sources into an existing destination directory: + +```bash +zcp -r folder_a folder_b file.txt /mnt/backup/ +``` + +Overwrite existing files: + +```bash +zcp -f large.iso /mnt/backup/large.iso +``` + +Preserve source mode + mtime: + +```bash +zcp -p -r assets ./assets-copy +``` + +Disable progress output: + +```bash +zcp -q -r logs /tmp/logs-copy +``` + +Verbose file listing: + +```bash +zcp -v -r photos /mnt/backup/ +``` + +## Build + +From this directory: + +```bash +mkdir -p bin +go build -o bin/zcp ./cmd/zcp +``` + +### Cross-compile + +Linux: + +```bash +mkdir -p bin +GOOS=linux GOARCH=amd64 go build -o bin/zcp-linux-amd64 ./cmd/zcp +``` + +Windows: + +```bash +mkdir -p bin +GOOS=windows GOARCH=amd64 go build -o bin/zcp-windows-amd64.exe ./cmd/zcp +``` + +## Notes + +- Symbolic links are currently not copied. +- For multiple sources, destination must already exist as a directory. diff --git a/go/tools/zcp/cmd/zcp/e2e_test.go b/go/tools/zcp/cmd/zcp/e2e_test.go new file mode 100644 index 0000000..62d1c68 --- /dev/null +++ b/go/tools/zcp/cmd/zcp/e2e_test.go @@ -0,0 +1,369 @@ +package main_test + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +var ( + binaryBuildOnce sync.Once + binaryPath string + binaryBuildErr error +) + +func TestCLIFlagsE2E(t *testing.T) { + t.Parallel() + + runRecursiveCase := func(t *testing.T, recursiveFlag string) { + t.Helper() + + tempDir := t.TempDir() + sourceDir := filepath.Join(tempDir, "src") + if err := os.MkdirAll(filepath.Join(sourceDir, "nested"), 0o755); err != nil { + t.Fatalf("create source directory: %v", err) + } + + expectedContents := "recursive-copy" + sourceFile := filepath.Join(sourceDir, "nested", "payload.txt") + if err := os.WriteFile(sourceFile, []byte(expectedContents), 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + + destinationDir := filepath.Join(tempDir, "dest") + stdout, stderr, err := runCLI(t, tempDir, recursiveFlag, sourceDir, destinationDir) + if err != nil { + t.Fatalf("recursive copy failed: %v (stdout=%q, stderr=%q)", err, stdout, stderr) + } + if stderr != "" { + t.Fatalf("expected empty stderr, got %q", stderr) + } + if !strings.Contains(stdout, "Copied 1 file(s)") { + t.Fatalf("expected copy summary in stdout, got %q", stdout) + } + + copiedFile := filepath.Join(destinationDir, "nested", "payload.txt") + actualBytes, err := os.ReadFile(copiedFile) + if err != nil { + t.Fatalf("read copied file: %v", err) + } + if string(actualBytes) != expectedContents { + t.Fatalf("unexpected copied contents: got %q, want %q", actualBytes, expectedContents) + } + } + + runForceCase := func(t *testing.T, forceFlag string) { + t.Helper() + + tempDir := t.TempDir() + sourceFile := filepath.Join(tempDir, "source.txt") + destinationFile := filepath.Join(tempDir, "destination.txt") + + if err := os.WriteFile(sourceFile, []byte("new-content"), 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + if err := os.WriteFile(destinationFile, []byte("old-content"), 0o644); err != nil { + t.Fatalf("write destination file: %v", err) + } + + _, stderrWithoutForce, errWithoutForce := runCLI(t, tempDir, sourceFile, destinationFile) + if errWithoutForce == nil { + t.Fatalf("expected copy to fail without force when destination exists") + } + if !strings.Contains(stderrWithoutForce, "destination file exists") { + t.Fatalf("expected overwrite error, got stderr=%q", stderrWithoutForce) + } + + stdout, stderr, err := runCLI(t, tempDir, forceFlag, sourceFile, destinationFile) + if err != nil { + t.Fatalf("force copy failed: %v (stdout=%q, stderr=%q)", err, stdout, stderr) + } + if stderr != "" { + t.Fatalf("expected empty stderr, got %q", stderr) + } + + actualBytes, err := os.ReadFile(destinationFile) + if err != nil { + t.Fatalf("read destination file: %v", err) + } + if string(actualBytes) != "new-content" { + t.Fatalf("destination was not overwritten, got %q", string(actualBytes)) + } + } + + runPreserveCase := func(t *testing.T, preserveFlag string) { + t.Helper() + + tempDir := t.TempDir() + sourceFile := filepath.Join(tempDir, "source.txt") + destinationFile := filepath.Join(tempDir, "destination.txt") + + if err := os.WriteFile(sourceFile, []byte("metadata"), 0o640); err != nil { + t.Fatalf("write source file: %v", err) + } + + expectedModTime := time.Now().Add(-3 * time.Hour).Truncate(time.Second) + if err := os.Chtimes(sourceFile, expectedModTime, expectedModTime); err != nil { + t.Fatalf("set source times: %v", err) + } + + stdout, stderr, err := runCLI(t, tempDir, preserveFlag, sourceFile, destinationFile) + if err != nil { + t.Fatalf("preserve copy failed: %v (stdout=%q, stderr=%q)", err, stdout, stderr) + } + if stderr != "" { + t.Fatalf("expected empty stderr, got %q", stderr) + } + + info, err := os.Stat(destinationFile) + if err != nil { + t.Fatalf("stat destination file: %v", err) + } + + if runtime.GOOS != "windows" && info.Mode().Perm() != 0o640 { + t.Fatalf("expected destination mode 0640, got %o", info.Mode().Perm()) + } + + diff := info.ModTime().Sub(expectedModTime) + if diff < 0 { + diff = -diff + } + if diff > time.Second { + t.Fatalf("expected destination modtime near %v, got %v", expectedModTime, info.ModTime()) + } + } + + runQuietCase := func(t *testing.T, quietFlag string) { + t.Helper() + + tempDir := t.TempDir() + sourceFile := filepath.Join(tempDir, "source.bin") + destinationFile := filepath.Join(tempDir, "destination.bin") + + if err := os.WriteFile(sourceFile, bytes.Repeat([]byte("z"), 128*1024), 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + + stdout, stderr, err := runCLI(t, tempDir, quietFlag, sourceFile, destinationFile) + if err != nil { + t.Fatalf("quiet copy failed: %v (stdout=%q, stderr=%q)", err, stdout, stderr) + } + if stderr != "" { + t.Fatalf("expected empty stderr, got %q", stderr) + } + if !strings.Contains(stdout, "Copied 1 file(s)") { + t.Fatalf("expected summary in stdout, got %q", stdout) + } + if strings.Contains(stdout, "[") { + t.Fatalf("expected no progress bar output with quiet flag, got %q", stdout) + } + } + + runVerboseCase := func(t *testing.T, verboseFlag string) { + t.Helper() + + tempDir := t.TempDir() + sourceFile := filepath.Join(tempDir, "source.txt") + destinationFile := filepath.Join(tempDir, "destination.txt") + + if err := os.WriteFile(sourceFile, []byte("verbose-output"), 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + + stdout, stderr, err := runCLI(t, tempDir, "-q", verboseFlag, sourceFile, destinationFile) + if err != nil { + t.Fatalf("verbose copy failed: %v (stdout=%q, stderr=%q)", err, stdout, stderr) + } + if stderr != "" { + t.Fatalf("expected empty stderr, got %q", stderr) + } + if !strings.Contains(stdout, "created: "+destinationFile) { + t.Fatalf("expected verbose created-file line, got %q", stdout) + } + if !strings.Contains(stdout, "Copied 1 file(s)") { + t.Fatalf("expected summary in stdout, got %q", stdout) + } + } + + type flagVariant struct { + name string + flag string + } + + type shortLongFlagSuite struct { + name string + variants []flagVariant + run func(t *testing.T, flag string) + } + + shortLongSuites := []shortLongFlagSuite{ + { + name: "recursive", + variants: []flagVariant{ + {name: "short", flag: "-r"}, + {name: "long", flag: "--recursive"}, + }, + run: runRecursiveCase, + }, + { + name: "force", + variants: []flagVariant{ + {name: "short", flag: "-f"}, + {name: "long", flag: "--force"}, + }, + run: runForceCase, + }, + { + name: "preserve", + variants: []flagVariant{ + {name: "short", flag: "-p"}, + {name: "long", flag: "--preserve"}, + }, + run: runPreserveCase, + }, + { + name: "quiet", + variants: []flagVariant{ + {name: "short", flag: "-q"}, + {name: "long", flag: "--quiet"}, + }, + run: runQuietCase, + }, + { + name: "verbose", + variants: []flagVariant{ + {name: "short", flag: "-v"}, + {name: "long", flag: "--verbose"}, + }, + run: runVerboseCase, + }, + } + + for _, suite := range shortLongSuites { + suite := suite + for _, variant := range suite.variants { + variant := variant + t.Run(fmt.Sprintf("%s_%s_flag", suite.name, variant.name), func(t *testing.T) { + t.Parallel() + suite.run(t, variant.flag) + }) + } + } + + t.Run("buffer_size_flag", func(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + sourceFile := filepath.Join(tempDir, "source.bin") + destinationFile := filepath.Join(tempDir, "destination.bin") + + expectedBytes := bytes.Repeat([]byte("a"), 257*1024) + if err := os.WriteFile(sourceFile, expectedBytes, 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + + stdout, stderr, err := runCLI(t, tempDir, "--buffer-size", "17", sourceFile, destinationFile) + if err != nil { + t.Fatalf("buffer-size copy failed: %v (stdout=%q, stderr=%q)", err, stdout, stderr) + } + if stderr != "" { + t.Fatalf("expected empty stderr, got %q", stderr) + } + + actualBytes, err := os.ReadFile(destinationFile) + if err != nil { + t.Fatalf("read destination file: %v", err) + } + if !bytes.Equal(actualBytes, expectedBytes) { + t.Fatalf("destination content mismatch after buffer-size copy") + } + }) + + t.Run("buffer_size_validation", func(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + sourceFile := filepath.Join(tempDir, "source.txt") + destinationFile := filepath.Join(tempDir, "destination.txt") + + if err := os.WriteFile(sourceFile, []byte("invalid-buffer-size"), 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + + _, stderr, err := runCLI(t, tempDir, "--buffer-size", "0", sourceFile, destinationFile) + if err == nil { + t.Fatalf("expected invalid buffer-size to fail") + } + if !strings.Contains(stderr, "buffer-size must be greater than 0") { + t.Fatalf("expected buffer-size validation error, got stderr=%q", stderr) + } + }) +} + +func runCLI(t *testing.T, workingDirectory string, args ...string) (string, string, error) { + t.Helper() + + command := exec.Command(zcpBinary(t), args...) + command.Dir = workingDirectory + + var stdout bytes.Buffer + var stderr bytes.Buffer + command.Stdout = &stdout + command.Stderr = &stderr + + err := command.Run() + return stdout.String(), stderr.String(), err +} + +func zcpBinary(t *testing.T) string { + t.Helper() + + binaryBuildOnce.Do(func() { + moduleRoot, err := moduleRootPath() + if err != nil { + binaryBuildErr = err + return + } + + tempDir, err := os.MkdirTemp("", "zcp-e2e-binary-*") + if err != nil { + binaryBuildErr = fmt.Errorf("create temp directory for binary: %w", err) + return + } + + binaryName := "zcp" + if runtime.GOOS == "windows" { + binaryName += ".exe" + } + binaryPath = filepath.Join(tempDir, binaryName) + + buildCommand := exec.Command("go", "build", "-o", binaryPath, "./cmd/zcp") + buildCommand.Dir = moduleRoot + + buildOutput, err := buildCommand.CombinedOutput() + if err != nil { + binaryBuildErr = fmt.Errorf("build zcp binary: %w\n%s", err, string(buildOutput)) + } + }) + + if binaryBuildErr != nil { + t.Fatalf("prepare zcp binary: %v", binaryBuildErr) + } + + return binaryPath +} + +func moduleRootPath() (string, error) { + _, currentFile, _, ok := runtime.Caller(0) + if !ok { + return "", fmt.Errorf("resolve current file for module path") + } + return filepath.Clean(filepath.Join(filepath.Dir(currentFile), "..", "..")), nil +} diff --git a/go/tools/zcp/cmd/zcp/main.go b/go/tools/zcp/cmd/zcp/main.go new file mode 100644 index 0000000..da72f36 --- /dev/null +++ b/go/tools/zcp/cmd/zcp/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + "os" + + "github.com/BoscoDomingo/utils/go/tools/zcp/internal/zcp" +) + +func main() { + if err := zcp.Run(os.Args[1:], os.Stdout, os.Stderr); err != nil { + fmt.Fprintf(os.Stderr, "zcp: %v\n", err) + os.Exit(1) + } +} diff --git a/go/tools/zcp/go.mod b/go/tools/zcp/go.mod new file mode 100644 index 0000000..ef0348f --- /dev/null +++ b/go/tools/zcp/go.mod @@ -0,0 +1,3 @@ +module github.com/BoscoDomingo/utils/go/tools/zcp + +go 1.26 diff --git a/go/tools/zcp/internal/zcp/app.go b/go/tools/zcp/internal/zcp/app.go new file mode 100644 index 0000000..a8cc896 --- /dev/null +++ b/go/tools/zcp/internal/zcp/app.go @@ -0,0 +1,111 @@ +package zcp + +import ( + "errors" + "flag" + "fmt" + "io" +) + +const defaultBufferSize = 1024 * 1024 + +type options struct { + recursive bool + force bool + preserve bool + quiet bool + verbose bool + bufferSize int +} + +func Run(args []string, stdout io.Writer, stderr io.Writer) error { + opts, sources, destination, err := parseArgs(args, stderr) + if err != nil { + if errors.Is(err, flag.ErrHelp) { + return nil + } + return err + } + + plan, totalBytes, err := buildCopyPlan(sources, destination, opts.recursive) + if err != nil { + return err + } + + progress := newProgressBar(totalBytes, !opts.quiet, stdout) + progress.start() + defer progress.stop() + + if err := executePlan(plan, opts, progress); err != nil { + return err + } + progress.stop() + + if opts.verbose { + for _, op := range plan { + if op.kind == operationCopyFile { + fmt.Fprintf(stdout, "created: %s\n", op.destination) + } + } + } + + fmt.Fprintf(stdout, "Copied %d file(s), %s total.\n", countFiles(plan), humanizeBytes(totalBytes)) + return nil +} + +func parseArgs(args []string, stderr io.Writer) (options, []string, string, error) { + opts := options{ + bufferSize: defaultBufferSize, + } + + fs := flag.NewFlagSet("zcp", flag.ContinueOnError) + fs.SetOutput(stderr) + + fs.BoolVar(&opts.recursive, "r", false, "copy directories recursively") + fs.BoolVar(&opts.recursive, "recursive", false, "copy directories recursively") + fs.BoolVar(&opts.force, "f", false, "overwrite destination files if they already exist") + fs.BoolVar(&opts.force, "force", false, "overwrite destination files if they already exist") + fs.BoolVar(&opts.preserve, "p", false, "preserve file mode and modification time") + fs.BoolVar(&opts.preserve, "preserve", false, "preserve file mode and modification time") + fs.BoolVar(&opts.quiet, "q", false, "disable progress output") + fs.BoolVar(&opts.quiet, "quiet", false, "disable progress output") + fs.BoolVar(&opts.verbose, "v", false, "print created file names") + fs.BoolVar(&opts.verbose, "verbose", false, "print created file names") + fs.IntVar(&opts.bufferSize, "buffer-size", defaultBufferSize, "copy buffer size in bytes") + + fs.Usage = func() { + fmt.Fprintln(stderr, "zcp: copy files and directories with a progress bar") + fmt.Fprintln(stderr) + fmt.Fprintln(stderr, "Usage:") + fmt.Fprintln(stderr, " zcp [options] SOURCE... DEST") + fmt.Fprintln(stderr) + fmt.Fprintln(stderr, "Options:") + fs.PrintDefaults() + } + + if err := fs.Parse(args); err != nil { + return options{}, nil, "", err + } + + if opts.bufferSize <= 0 { + return options{}, nil, "", fmt.Errorf("buffer-size must be greater than 0") + } + + remaining := fs.Args() + if len(remaining) < 2 { + fs.Usage() + return options{}, nil, "", fmt.Errorf("expected at least one SOURCE and one DEST") + } + + return opts, remaining[:len(remaining)-1], remaining[len(remaining)-1], nil +} + +func countFiles(plan []copyOperation) int { + count := 0 + for _, op := range plan { + if op.kind == operationCopyFile { + count++ + } + } + return count +} diff --git a/go/tools/zcp/internal/zcp/copy.go b/go/tools/zcp/internal/zcp/copy.go new file mode 100644 index 0000000..e46fa84 --- /dev/null +++ b/go/tools/zcp/internal/zcp/copy.go @@ -0,0 +1,324 @@ +package zcp + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "time" +) + +type operationType int + +const ( + operationCreateDirectory operationType = iota + operationCopyFile +) + +type copyOperation struct { + kind operationType + source string + destination string + mode fs.FileMode + modTime time.Time + size uint64 +} + +func buildCopyPlan(sources []string, destination string, recursive bool) ([]copyOperation, uint64, error) { + destInfo, destErr := os.Stat(destination) + destExists := destErr == nil + if destErr != nil && !errors.Is(destErr, os.ErrNotExist) { + return nil, 0, fmt.Errorf("stat destination %q: %w", destination, destErr) + } + + destIsDir := destExists && destInfo.IsDir() + if len(sources) > 1 && !destIsDir { + return nil, 0, fmt.Errorf("destination %q must be an existing directory when copying multiple sources", destination) + } + + plan := make([]copyOperation, 0, len(sources)) + var totalBytes uint64 + + for _, source := range sources { + source = filepath.Clean(source) + + sourceInfo, err := os.Lstat(source) + if err != nil { + return nil, 0, fmt.Errorf("stat source %q: %w", source, err) + } + + if sourceInfo.Mode()&os.ModeSymlink != 0 { + return nil, 0, fmt.Errorf("symbolic links are not supported: %q", source) + } + + target := destination + if len(sources) > 1 || destIsDir { + target = filepath.Join(destination, filepath.Base(source)) + } + + if sourceInfo.IsDir() { + if !recursive { + return nil, 0, fmt.Errorf("omitting directory %q (use -r or --recursive)", source) + } + + if destExists && !destIsDir && len(sources) == 1 { + return nil, 0, fmt.Errorf("cannot overwrite non-directory %q with directory %q", destination, source) + } + + if err := ensureDestinationOutsideSource(source, target); err != nil { + return nil, 0, err + } + + directoryOps, directoryBytes, err := collectDirectoryOperations(source, target) + if err != nil { + return nil, 0, err + } + + plan = append(plan, directoryOps...) + totalBytes += directoryBytes + continue + } + + sameFile, err := refersToSameFile(source, target) + if err != nil { + return nil, 0, err + } + if sameFile { + return nil, 0, fmt.Errorf("%q and %q are the same file", source, target) + } + + size := sourceInfo.Size() + if size < 0 { + size = 0 + } + + plan = append(plan, copyOperation{ + kind: operationCopyFile, + source: source, + destination: target, + mode: sourceInfo.Mode(), + modTime: sourceInfo.ModTime(), + size: uint64(size), + }) + totalBytes += uint64(size) + } + + return plan, totalBytes, nil +} + +func collectDirectoryOperations(sourceRoot string, destinationRoot string) ([]copyOperation, uint64, error) { + operations := make([]copyOperation, 0, 16) + var totalBytes uint64 + + err := filepath.WalkDir(sourceRoot, func(path string, entry fs.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + + if entry.Type()&os.ModeSymlink != 0 { + return fmt.Errorf("symbolic links are not supported: %q", path) + } + + entryInfo, err := entry.Info() + if err != nil { + return err + } + + relativePath, err := filepath.Rel(sourceRoot, path) + if err != nil { + return err + } + + destinationPath := destinationRoot + if relativePath != "." { + destinationPath = filepath.Join(destinationRoot, relativePath) + } + + if entry.IsDir() { + operations = append(operations, copyOperation{ + kind: operationCreateDirectory, + source: path, + destination: destinationPath, + mode: entryInfo.Mode(), + modTime: entryInfo.ModTime(), + }) + return nil + } + + size := entryInfo.Size() + if size < 0 { + size = 0 + } + + operations = append(operations, copyOperation{ + kind: operationCopyFile, + source: path, + destination: destinationPath, + mode: entryInfo.Mode(), + modTime: entryInfo.ModTime(), + size: uint64(size), + }) + totalBytes += uint64(size) + return nil + }) + if err != nil { + return nil, 0, fmt.Errorf("walk source directory %q: %w", sourceRoot, err) + } + + return operations, totalBytes, nil +} + +func ensureDestinationOutsideSource(sourceDirectory string, destinationPath string) error { + sourceAbs, err := filepath.Abs(sourceDirectory) + if err != nil { + return fmt.Errorf("resolve source path %q: %w", sourceDirectory, err) + } + + destinationAbs, err := filepath.Abs(destinationPath) + if err != nil { + return fmt.Errorf("resolve destination path %q: %w", destinationPath, err) + } + + relative, err := filepath.Rel(sourceAbs, destinationAbs) + if err != nil { + return fmt.Errorf("check destination relation: %w", err) + } + + if relative == "." || relative == "" { + return fmt.Errorf("cannot copy %q to itself", sourceDirectory) + } + + parentPrefix := ".." + string(os.PathSeparator) + if relative == ".." || strings.HasPrefix(relative, parentPrefix) { + return nil + } + + return fmt.Errorf("cannot copy directory %q into itself (%q)", sourceDirectory, destinationPath) +} + +func refersToSameFile(source string, destination string) (bool, error) { + sourceInfo, err := os.Stat(source) + if err != nil { + return false, fmt.Errorf("stat source %q: %w", source, err) + } + + destinationInfo, err := os.Stat(destination) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + return false, fmt.Errorf("stat destination %q: %w", destination, err) + } + + return os.SameFile(sourceInfo, destinationInfo), nil +} + +func executePlan(plan []copyOperation, opts options, progress *progressBar) error { + directoriesToPreserve := make([]copyOperation, 0) + + for _, op := range plan { + switch op.kind { + case operationCreateDirectory: + if err := os.MkdirAll(op.destination, op.mode.Perm()); err != nil { + return fmt.Errorf("create directory %q: %w", op.destination, err) + } + if opts.preserve { + directoriesToPreserve = append(directoriesToPreserve, op) + } + + case operationCopyFile: + if err := copyFile(op, opts, progress); err != nil { + return err + } + + default: + return fmt.Errorf("unsupported copy operation: %v", op.kind) + } + } + + if opts.preserve { + for i := len(directoriesToPreserve) - 1; i >= 0; i-- { + directory := directoriesToPreserve[i] + if err := setMetadata(directory.destination, directory.mode, directory.modTime); err != nil { + return err + } + } + } + + return nil +} + +func copyFile(op copyOperation, opts options, progress *progressBar) error { + if err := os.MkdirAll(filepath.Dir(op.destination), 0o755); err != nil { + return fmt.Errorf("create destination parent for %q: %w", op.destination, err) + } + + sourceFile, err := os.Open(op.source) + if err != nil { + return fmt.Errorf("open source file %q: %w", op.source, err) + } + defer sourceFile.Close() + + flags := os.O_CREATE | os.O_WRONLY | os.O_TRUNC + if !opts.force { + flags |= os.O_EXCL + } + + destinationFile, err := os.OpenFile(op.destination, flags, op.mode.Perm()) + if err != nil { + if errors.Is(err, os.ErrExist) { + return fmt.Errorf("destination file exists (use -f to overwrite): %q", op.destination) + } + return fmt.Errorf("open destination file %q: %w", op.destination, err) + } + + buffer := make([]byte, opts.bufferSize) + for { + readBytes, readErr := sourceFile.Read(buffer) + if readBytes > 0 { + writtenBytes, writeErr := destinationFile.Write(buffer[:readBytes]) + if writeErr != nil { + destinationFile.Close() + return fmt.Errorf("write destination file %q: %w", op.destination, writeErr) + } + if writtenBytes != readBytes { + destinationFile.Close() + return fmt.Errorf("write destination file %q: short write", op.destination) + } + progress.add(uint64(writtenBytes)) + } + + if errors.Is(readErr, io.EOF) { + break + } + if readErr != nil { + destinationFile.Close() + return fmt.Errorf("read source file %q: %w", op.source, readErr) + } + } + + if err := destinationFile.Close(); err != nil { + return fmt.Errorf("close destination file %q: %w", op.destination, err) + } + + if opts.preserve { + if err := setMetadata(op.destination, op.mode, op.modTime); err != nil { + return err + } + } + + return nil +} + +func setMetadata(path string, mode fs.FileMode, modTime time.Time) error { + if err := os.Chmod(path, mode.Perm()); err != nil { + return fmt.Errorf("set mode on %q: %w", path, err) + } + if err := os.Chtimes(path, modTime, modTime); err != nil { + return fmt.Errorf("set modification time on %q: %w", path, err) + } + return nil +} diff --git a/go/tools/zcp/internal/zcp/copy_test.go b/go/tools/zcp/internal/zcp/copy_test.go new file mode 100644 index 0000000..b3c945d --- /dev/null +++ b/go/tools/zcp/internal/zcp/copy_test.go @@ -0,0 +1,164 @@ +package zcp + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestCopyPlanAndExecution(t *testing.T) { + t.Parallel() + + t.Run("requires_recursive_for_directories", func(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + sourceDir := filepath.Join(tempDir, "source") + if err := os.MkdirAll(sourceDir, 0o755); err != nil { + t.Fatalf("mkdir source: %v", err) + } + if err := os.WriteFile(filepath.Join(sourceDir, "file.txt"), []byte("hello"), 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + + _, _, err := buildCopyPlan([]string{sourceDir}, filepath.Join(tempDir, "dest"), false) + if err == nil { + t.Fatalf("expected error for missing recursive flag") + } + if !strings.Contains(err.Error(), "use -r") { + t.Fatalf("expected recursive hint, got: %v", err) + } + }) + + t.Run("requires_directory_for_multiple_sources", func(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + first := filepath.Join(tempDir, "first.txt") + second := filepath.Join(tempDir, "second.txt") + if err := os.WriteFile(first, []byte("a"), 0o644); err != nil { + t.Fatalf("write first source: %v", err) + } + if err := os.WriteFile(second, []byte("b"), 0o644); err != nil { + t.Fatalf("write second source: %v", err) + } + + notDirectory := filepath.Join(tempDir, "dest.txt") + if err := os.WriteFile(notDirectory, []byte("existing"), 0o644); err != nil { + t.Fatalf("write destination file: %v", err) + } + + _, _, err := buildCopyPlan([]string{first, second}, notDirectory, false) + if err == nil { + t.Fatalf("expected error for multiple sources to non-directory destination") + } + }) + + t.Run("copies_nested_directory_with_preserve", func(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + sourceRoot := filepath.Join(tempDir, "source") + nestedDir := filepath.Join(sourceRoot, "nested") + if err := os.MkdirAll(nestedDir, 0o755); err != nil { + t.Fatalf("mkdir nested: %v", err) + } + + sourceFile := filepath.Join(nestedDir, "payload.txt") + expectedContents := "copy me" + if err := os.WriteFile(sourceFile, []byte(expectedContents), 0o640); err != nil { + t.Fatalf("write source file: %v", err) + } + + originalModTime := time.Now().Add(-2 * time.Hour).Truncate(time.Second) + if err := os.Chtimes(sourceFile, originalModTime, originalModTime); err != nil { + t.Fatalf("set source modtime: %v", err) + } + + destinationRoot := filepath.Join(tempDir, "destination") + plan, totalBytes, err := buildCopyPlan([]string{sourceRoot}, destinationRoot, true) + if err != nil { + t.Fatalf("build copy plan: %v", err) + } + if totalBytes == 0 { + t.Fatalf("expected non-zero total bytes") + } + + opts := options{ + recursive: true, + force: false, + preserve: true, + quiet: true, + bufferSize: 8, + } + progress := newProgressBar(totalBytes, false, io.Discard) + if err := executePlan(plan, opts, progress); err != nil { + t.Fatalf("execute plan: %v", err) + } + + destinationFile := filepath.Join(destinationRoot, "nested", "payload.txt") + actualBytes, err := os.ReadFile(destinationFile) + if err != nil { + t.Fatalf("read destination file: %v", err) + } + if string(actualBytes) != expectedContents { + t.Fatalf("unexpected destination content: %q", actualBytes) + } + + info, err := os.Stat(destinationFile) + if err != nil { + t.Fatalf("stat destination file: %v", err) + } + + if !info.ModTime().Equal(originalModTime) { + t.Fatalf("expected modtime %v, got %v", originalModTime, info.ModTime()) + } + }) + + t.Run("honors_force_flag", func(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + sourceFile := filepath.Join(tempDir, "source.txt") + destinationFile := filepath.Join(tempDir, "dest.txt") + + if err := os.WriteFile(sourceFile, []byte("new"), 0o644); err != nil { + t.Fatalf("write source file: %v", err) + } + if err := os.WriteFile(destinationFile, []byte("old"), 0o644); err != nil { + t.Fatalf("write destination file: %v", err) + } + + plan, totalBytes, err := buildCopyPlan([]string{sourceFile}, destinationFile, false) + if err != nil { + t.Fatalf("build copy plan: %v", err) + } + + noForce := options{ + force: false, + bufferSize: 4, + } + if err := executePlan(plan, noForce, newProgressBar(totalBytes, false, io.Discard)); err == nil { + t.Fatalf("expected overwrite error without -f") + } + + withForce := options{ + force: true, + bufferSize: 4, + } + if err := executePlan(plan, withForce, newProgressBar(totalBytes, false, io.Discard)); err != nil { + t.Fatalf("force overwrite failed: %v", err) + } + + actual, err := os.ReadFile(destinationFile) + if err != nil { + t.Fatalf("read destination file: %v", err) + } + if string(actual) != "new" { + t.Fatalf("expected destination content to be overwritten, got %q", string(actual)) + } + }) +} diff --git a/go/tools/zcp/internal/zcp/progress.go b/go/tools/zcp/internal/zcp/progress.go new file mode 100644 index 0000000..7a844a5 --- /dev/null +++ b/go/tools/zcp/internal/zcp/progress.go @@ -0,0 +1,234 @@ +package zcp + +import ( + "fmt" + "io" + "math" + "os" + "strings" + "sync" + "sync/atomic" + "time" +) + +type progressBar struct { + total uint64 + startedAt time.Time + completed atomic.Uint64 + enabled bool + writer io.Writer + terminal bool + stopCh chan struct{} + stopOnce sync.Once + waitGroup sync.WaitGroup + lastRender int +} + +func newProgressBar(total uint64, enabled bool, writer io.Writer) *progressBar { + bar := &progressBar{ + total: total, + enabled: enabled && total > 0, + writer: writer, + terminal: isTerminalWriter(writer), + } + return bar +} + +func (p *progressBar) start() { + if !p.enabled { + return + } + + p.stopCh = make(chan struct{}) + p.startedAt = time.Now() + p.waitGroup.Add(1) + + go func() { + defer p.waitGroup.Done() + + interval := 120 * time.Millisecond + if !p.terminal { + interval = time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.render(false) + case <-p.stopCh: + p.render(true) + return + } + } + }() +} + +func (p *progressBar) stop() { + if !p.enabled { + return + } + + p.stopOnce.Do(func() { + close(p.stopCh) + p.waitGroup.Wait() + }) +} + +func (p *progressBar) add(value uint64) { + if !p.enabled || value == 0 { + return + } + p.completed.Add(value) +} + +func (p *progressBar) render(final bool) { + done := p.completed.Load() + if done > p.total { + done = p.total + } + if final { + done = p.total + } + + elapsed := time.Since(p.startedAt) + if elapsed <= 0 { + elapsed = time.Millisecond + } + + bytesPerSecond := float64(done) / elapsed.Seconds() + line := formatProgressLine(done, p.total, bytesPerSecond) + + if p.terminal { + padding := "" + if len(line) < p.lastRender { + padding = strings.Repeat(" ", p.lastRender-len(line)) + } + fmt.Fprintf(p.writer, "\r%s%s", line, padding) + p.lastRender = len(line) + if final { + fmt.Fprint(p.writer, "\n") + } + return + } + + fmt.Fprintln(p.writer, line) +} + +func formatProgressLine(done uint64, total uint64, bytesPerSecond float64) string { + if total == 0 { + return "[==============================] 100.00% 0 B/0 B 0 B/s ETA 00:00" + } + + percentage := float64(done) / float64(total) * 100 + if percentage > 100 { + percentage = 100 + } + + eta := "00:00" + if done < total && bytesPerSecond > 0 { + remainingSeconds := float64(total-done) / bytesPerSecond + eta = formatDuration(time.Duration(remainingSeconds * float64(time.Second))) + } + + return fmt.Sprintf( + "[%s] %6.2f%% %s/%s %s/s ETA %s", + buildBar(percentage, 30), + percentage, + humanizeBytes(done), + humanizeBytes(total), + humanizeRate(bytesPerSecond), + eta, + ) +} + +func buildBar(percentage float64, width int) string { + if width <= 0 { + return "" + } + + filled := int(math.Round((percentage / 100) * float64(width))) + if filled < 0 { + filled = 0 + } + if filled > width { + filled = width + } + + switch { + case filled <= 0: + return strings.Repeat(" ", width) + case filled >= width: + return strings.Repeat("=", width) + default: + return strings.Repeat("=", filled-1) + ">" + strings.Repeat(" ", width-filled) + } +} + +func humanizeBytes(value uint64) string { + const unit = 1024.0 + if value < 1024 { + return fmt.Sprintf("%d B", value) + } + + size := float64(value) + units := []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB"} + unitIndex := 0 + for size >= unit && unitIndex < len(units)-1 { + size /= unit + unitIndex++ + } + + return fmt.Sprintf("%.1f %s", size, units[unitIndex]) +} + +func humanizeRate(bytesPerSecond float64) string { + if bytesPerSecond <= 0 { + return "0 B" + } + + units := []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB"} + value := bytesPerSecond + unitIndex := 0 + for value >= 1024 && unitIndex < len(units)-1 { + value /= 1024 + unitIndex++ + } + + if unitIndex == 0 { + return fmt.Sprintf("%.0f %s", value, units[unitIndex]) + } + return fmt.Sprintf("%.1f %s", value, units[unitIndex]) +} + +func formatDuration(duration time.Duration) string { + if duration < 0 { + duration = 0 + } + + totalSeconds := int64(duration.Round(time.Second).Seconds()) + hours := totalSeconds / 3600 + minutes := (totalSeconds % 3600) / 60 + seconds := totalSeconds % 60 + + if hours > 0 { + return fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + } + return fmt.Sprintf("%02d:%02d", minutes, seconds) +} + +func isTerminalWriter(writer io.Writer) bool { + file, ok := writer.(*os.File) + if !ok { + return false + } + + info, err := file.Stat() + if err != nil { + return false + } + + return info.Mode()&os.ModeCharDevice != 0 +} diff --git a/go/tools/zcp/internal/zcp/progress_test.go b/go/tools/zcp/internal/zcp/progress_test.go new file mode 100644 index 0000000..fb1fa41 --- /dev/null +++ b/go/tools/zcp/internal/zcp/progress_test.go @@ -0,0 +1,201 @@ +package zcp + +import ( + "bytes" + "strings" + "testing" + "time" +) + +func TestProgressFormattingHelpers(t *testing.T) { + t.Parallel() + + t.Run("build_bar", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + percentage float64 + width int + want string + }{ + { + name: "zero_percent", + percentage: 0, + width: 10, + want: " ", + }, + { + name: "half_percent", + percentage: 50, + width: 10, + want: "====> ", + }, + { + name: "full_percent", + percentage: 100, + width: 10, + want: "==========", + }, + { + name: "clamps_above_hundred", + percentage: 150, + width: 10, + want: "==========", + }, + { + name: "zero_width", + percentage: 50, + width: 0, + want: "", + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := buildBar(testCase.percentage, testCase.width) + if got != testCase.want { + t.Fatalf("buildBar(%v, %d) = %q, want %q", testCase.percentage, testCase.width, got, testCase.want) + } + }) + } + }) + + t.Run("humanize_bytes", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + value uint64 + want string + }{ + {name: "bytes", value: 512, want: "512 B"}, + {name: "kib", value: 1024, want: "1.0 KiB"}, + {name: "fractional_kib", value: 1536, want: "1.5 KiB"}, + {name: "mib", value: 1024 * 1024, want: "1.0 MiB"}, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := humanizeBytes(testCase.value) + if got != testCase.want { + t.Fatalf("humanizeBytes(%d) = %q, want %q", testCase.value, got, testCase.want) + } + }) + } + }) + + t.Run("humanize_rate", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + value float64 + want string + }{ + {name: "zero", value: 0, want: "0 B"}, + {name: "bytes", value: 256, want: "256 B"}, + {name: "kib", value: 1024, want: "1.0 KiB"}, + {name: "fractional_kib", value: 1536, want: "1.5 KiB"}, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := humanizeRate(testCase.value) + if got != testCase.want { + t.Fatalf("humanizeRate(%f) = %q, want %q", testCase.value, got, testCase.want) + } + }) + } + }) + + t.Run("format_duration", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + value time.Duration + want string + }{ + {name: "seconds", value: 59 * time.Second, want: "00:59"}, + {name: "minutes", value: 2*time.Minute + 5*time.Second, want: "02:05"}, + {name: "hours", value: time.Hour + 2*time.Minute + 3*time.Second, want: "01:02:03"}, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + got := formatDuration(testCase.value) + if got != testCase.want { + t.Fatalf("formatDuration(%v) = %q, want %q", testCase.value, got, testCase.want) + } + }) + } + }) + + t.Run("format_progress_line", func(t *testing.T) { + t.Parallel() + + line := formatProgressLine(512, 1024, 256) + expectedFragments := []string{ + "50.00%", + "512 B/1.0 KiB", + "256 B/s", + "ETA 00:02", + } + for _, fragment := range expectedFragments { + if !strings.Contains(line, fragment) { + t.Fatalf("expected fragment %q in line %q", fragment, line) + } + } + + zeroTotalLine := formatProgressLine(0, 0, 0) + if !strings.Contains(zeroTotalLine, "100.00% 0 B/0 B") { + t.Fatalf("unexpected zero-total line: %q", zeroTotalLine) + } + }) +} + +func TestProgressBarLifecycle(t *testing.T) { + t.Parallel() + + t.Run("writes_final_line_on_stop", func(t *testing.T) { + t.Parallel() + + var output bytes.Buffer + bar := newProgressBar(1024, true, &output) + bar.start() + bar.add(1024) + bar.stop() + + got := output.String() + if !strings.Contains(got, "100.00%") { + t.Fatalf("expected final progress output, got %q", got) + } + }) + + t.Run("disabled_for_zero_total", func(t *testing.T) { + t.Parallel() + + var output bytes.Buffer + bar := newProgressBar(0, true, &output) + bar.start() + bar.add(100) + bar.stop() + + if output.Len() != 0 { + t.Fatalf("expected no output for zero-total progress, got %q", output.String()) + } + }) +}