diff --git a/internal/gitvolume/context_test.go b/internal/gitvolume/context_test.go index eab4dd6..859793a 100644 --- a/internal/gitvolume/context_test.go +++ b/internal/gitvolume/context_test.go @@ -511,7 +511,9 @@ func TestCheckStatus(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tc.setup(t, tc.vol) - defer os.RemoveAll(tc.vol.TargetPath) + t.Cleanup(func() { + require.NoError(t, os.RemoveAll(tc.vol.TargetPath)) + }) status := tc.vol.CheckStatus() assert.Equal(t, tc.wantStatus, status.Status) diff --git a/internal/gitvolume/edit_test.go b/internal/gitvolume/edit_test.go index b25c74b..cc38b34 100644 --- a/internal/gitvolume/edit_test.go +++ b/internal/gitvolume/edit_test.go @@ -33,8 +33,10 @@ func TestGitVolume_GlobalEdit(t *testing.T) { // Mock EDITOR to a script that modifies the file // We use printf to avoid portability issues with echo -n originalEditor := os.Getenv("EDITOR") - defer os.Setenv("EDITOR", originalEditor) - os.Setenv("EDITOR", "sh -c 'printf \" - edited\" >> \"$1\"' --") + require.NoError(t, os.Setenv("EDITOR", "sh -c 'printf \" - edited\" >> \"$1\"' --")) + t.Cleanup(func() { + require.NoError(t, os.Setenv("EDITOR", originalEditor)) + }) err = gv.GlobalEdit("config.txt") assert.NoError(t, err) diff --git a/internal/gitvolume/git_repro_test.go b/internal/gitvolume/git_repro_test.go index 649f203..b5991b3 100644 --- a/internal/gitvolume/git_repro_test.go +++ b/internal/gitvolume/git_repro_test.go @@ -14,7 +14,9 @@ func TestFindCommonDir_Repro(t *testing.T) { // Create a temporary directory for our test environment tmpDir, err := os.MkdirTemp("", "git-volume-repro-*") require.NoError(t, err) - defer os.RemoveAll(tmpDir) + t.Cleanup(func() { + require.NoError(t, os.RemoveAll(tmpDir)) + }) // 1. Test Case: Bare Repository + Worktree t.Run("Bare Repository with Worktree", func(t *testing.T) { diff --git a/internal/gitvolume/init.go b/internal/gitvolume/init.go index 3ddc593..9348813 100644 --- a/internal/gitvolume/init.go +++ b/internal/gitvolume/init.go @@ -27,7 +27,7 @@ func (g *GitVolume) beforeInit(state *initState) error { if err != nil { return fmt.Errorf("failed to get current directory: %w", err) } - gitRoot, err := FindWorktreeRoot(cwd) + gitRoot, err := findInitRoot(cwd) if err != nil { return fmt.Errorf("failed to find git repository root: %w", err) } @@ -35,6 +35,23 @@ func (g *GitVolume) beforeInit(state *initState) error { return nil } +// findInitRoot returns the directory where git-volume.yaml should be created. +// In a normal repository or worktree, that is the current worktree root. +// In a bare repository, that is the bare repository root itself. +func findInitRoot(startDir string) (string, error) { + worktreeRoot, err := FindWorktreeRoot(startDir) + if err == nil { + return worktreeRoot, nil + } + + isBare, bareErr := isBareRepository(startDir) + if bareErr == nil && isBare { + return findCommonDir(startDir) + } + + return "", err +} + func (g *GitVolume) init(state *initState) error { if err := os.MkdirAll(g.ctx.GlobalDir, DefaultDirPerm); err != nil { return fmt.Errorf("failed to create global directory %s: %w", g.ctx.GlobalDir, err) diff --git a/internal/gitvolume/init_test.go b/internal/gitvolume/init_test.go index c669eb1..ea4bd1e 100644 --- a/internal/gitvolume/init_test.go +++ b/internal/gitvolume/init_test.go @@ -21,7 +21,7 @@ func TestInit(t *testing.T) { // Change CWD to repo wd, err := os.Getwd() require.NoError(t, err) - defer func() { _ = os.Chdir(wd) }() // Restore CWD + t.Cleanup(func() { require.NoError(t, os.Chdir(wd)) }) // Restore CWD require.NoError(t, os.Chdir(repoDir)) // Setup GitVolume (GlobalDir will be in tmp) @@ -60,7 +60,7 @@ func TestInit_OutsideGit(t *testing.T) { // Change CWD wd, err := os.Getwd() require.NoError(t, err) - defer func() { _ = os.Chdir(wd) }() + t.Cleanup(func() { require.NoError(t, os.Chdir(wd)) }) require.NoError(t, os.Chdir(tmpDir)) gv := createTestGitVolume(tmpDir, tmpDir, filepath.Join(tmpDir, "global"), nil) @@ -72,6 +72,32 @@ func TestInit_OutsideGit(t *testing.T) { assert.Contains(t, err.Error(), "failed to find git repository root") } +func TestInit_BareRepository(t *testing.T) { + tmpDir := t.TempDir() + bareDir := filepath.Join(tmpDir, "repo.git") + + cmd := exec.Command("git", "init", "--bare", bareDir) + require.NoError(t, cmd.Run()) + + wd, err := os.Getwd() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, os.Chdir(wd)) }) + require.NoError(t, os.Chdir(bareDir)) + + globalDir := filepath.Join(tmpDir, "global") + gv := createTestGitVolume(bareDir, bareDir, globalDir, nil) + + err = gv.Init() + require.NoError(t, err) + + assert.DirExists(t, globalDir) + assert.FileExists(t, filepath.Join(bareDir, "git-volume.yaml")) + + content, err := os.ReadFile(filepath.Join(bareDir, "git-volume.yaml")) + require.NoError(t, err) + assert.Contains(t, string(content), "volumes:") +} + func TestInit_NonQuiet(t *testing.T) { tmpDir := t.TempDir() repoDir := filepath.Join(tmpDir, "repo") @@ -81,7 +107,7 @@ func TestInit_NonQuiet(t *testing.T) { wd, err := os.Getwd() require.NoError(t, err) - defer func() { _ = os.Chdir(wd) }() + t.Cleanup(func() { require.NoError(t, os.Chdir(wd)) }) require.NoError(t, os.Chdir(repoDir)) globalDir := filepath.Join(tmpDir, "global") @@ -104,7 +130,7 @@ func TestInit_NonQuiet_Error(t *testing.T) { wd, err := os.Getwd() require.NoError(t, err) - defer func() { _ = os.Chdir(wd) }() + t.Cleanup(func() { require.NoError(t, os.Chdir(wd)) }) require.NoError(t, os.Chdir(tmpDir)) gv := createTestGitVolume(tmpDir, tmpDir, filepath.Join(tmpDir, "global"), nil) diff --git a/internal/gitvolume/remove_test.go b/internal/gitvolume/remove_test.go index b8776a1..18a2391 100644 --- a/internal/gitvolume/remove_test.go +++ b/internal/gitvolume/remove_test.go @@ -61,7 +61,9 @@ func TestGlobalRemove(t *testing.T) { func TestGlobalRemove_NotInitialized(t *testing.T) { tmpDir, err := os.MkdirTemp("", "git-volume-test") require.NoError(t, err) - defer os.RemoveAll(tmpDir) + t.Cleanup(func() { + require.NoError(t, os.RemoveAll(tmpDir)) + }) // Point to a non-existent directory globalDir := filepath.Join(tmpDir, "non_existent_global")