From c7615e40754ce7967871580b5255aff94e09ad69 Mon Sep 17 00:00:00 2001 From: Doug Campos Date: Sat, 20 Dec 2025 18:07:48 -0500 Subject: [PATCH] fix(setup): preserve symlinks in atomicWriteFile Add ResolveForWrite helper that resolves symlinks before writing, so atomic writes go to the symlink target instead of replacing the symlink itself. --- cmd/bd/setup/utils.go | 12 +++++++-- cmd/bd/setup/utils_test.go | 39 ++++++++++++++++++++++++++++ internal/utils/path.go | 17 +++++++++++++ internal/utils/path_test.go | 51 +++++++++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 2 deletions(-) diff --git a/cmd/bd/setup/utils.go b/cmd/bd/setup/utils.go index 788a051f..4dbca094 100644 --- a/cmd/bd/setup/utils.go +++ b/cmd/bd/setup/utils.go @@ -4,12 +4,20 @@ import ( "fmt" "os" "path/filepath" + + "github.com/steveyegge/beads/internal/utils" ) // atomicWriteFile writes data to a file atomically using a unique temporary file. // This prevents race conditions when multiple processes write to the same file. +// If path is a symlink, writes to the resolved target (preserving the symlink). func atomicWriteFile(path string, data []byte) error { - dir := filepath.Dir(path) + targetPath, err := utils.ResolveForWrite(path) + if err != nil { + return fmt.Errorf("resolve path: %w", err) + } + + dir := filepath.Dir(targetPath) // Create unique temp file in same directory tmpFile, err := os.CreateTemp(dir, ".*.tmp") @@ -38,7 +46,7 @@ func atomicWriteFile(path string, data []byte) error { } // Atomic rename - if err := os.Rename(tmpPath, path); err != nil { + if err := os.Rename(tmpPath, targetPath); err != nil { _ = os.Remove(tmpPath) // Best effort cleanup return fmt.Errorf("rename temp file: %w", err) } diff --git a/cmd/bd/setup/utils_test.go b/cmd/bd/setup/utils_test.go index dc68740a..4c5eeb03 100644 --- a/cmd/bd/setup/utils_test.go +++ b/cmd/bd/setup/utils_test.go @@ -67,6 +67,45 @@ func TestAtomicWriteFile(t *testing.T) { } } +func TestAtomicWriteFile_PreservesSymlink(t *testing.T) { + tmpDir := t.TempDir() + + // Create target file + target := filepath.Join(tmpDir, "target.txt") + if err := os.WriteFile(target, []byte("original"), 0644); err != nil { + t.Fatal(err) + } + + // Create symlink + link := filepath.Join(tmpDir, "link.txt") + if err := os.Symlink(target, link); err != nil { + t.Fatal(err) + } + + // Write via symlink + if err := atomicWriteFile(link, []byte("updated")); err != nil { + t.Fatalf("atomicWriteFile failed: %v", err) + } + + // Verify symlink still exists + info, err := os.Lstat(link) + if err != nil { + t.Fatalf("failed to lstat link: %v", err) + } + if info.Mode()&os.ModeSymlink == 0 { + t.Error("symlink was replaced with regular file") + } + + // Verify target was updated + data, err := os.ReadFile(target) + if err != nil { + t.Fatalf("failed to read target: %v", err) + } + if string(data) != "updated" { + t.Errorf("target content = %q, want %q", string(data), "updated") + } +} + func TestDirExists(t *testing.T) { tmpDir := t.TempDir() diff --git a/internal/utils/path.go b/internal/utils/path.go index d03c63e0..58590bbf 100644 --- a/internal/utils/path.go +++ b/internal/utils/path.go @@ -69,6 +69,23 @@ func FindMoleculesJSONLInDir(dbDir string) string { return "" } +// ResolveForWrite returns the path to write to, resolving symlinks. +// If path is a symlink, returns the resolved target path. +// If path doesn't exist, returns path unchanged (new file). +func ResolveForWrite(path string) (string, error) { + info, err := os.Lstat(path) + if err != nil { + if os.IsNotExist(err) { + return path, nil + } + return "", err + } + if info.Mode()&os.ModeSymlink != 0 { + return filepath.EvalSymlinks(path) + } + return path, nil +} + // CanonicalizePath converts a path to its canonical form by: // 1. Converting to absolute path // 2. Resolving symlinks diff --git a/internal/utils/path_test.go b/internal/utils/path_test.go index c70b64b3..c0530c3c 100644 --- a/internal/utils/path_test.go +++ b/internal/utils/path_test.go @@ -179,3 +179,54 @@ func TestCanonicalizePathSymlink(t *testing.T) { } } } + +func TestResolveForWrite(t *testing.T) { + t.Run("regular file", func(t *testing.T) { + tmpDir := t.TempDir() + file := filepath.Join(tmpDir, "regular.txt") + if err := os.WriteFile(file, []byte("test"), 0644); err != nil { + t.Fatal(err) + } + + got, err := ResolveForWrite(file) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != file { + t.Errorf("got %q, want %q", got, file) + } + }) + + t.Run("symlink", func(t *testing.T) { + tmpDir := t.TempDir() + target := filepath.Join(tmpDir, "target.txt") + if err := os.WriteFile(target, []byte("test"), 0644); err != nil { + t.Fatal(err) + } + link := filepath.Join(tmpDir, "link.txt") + if err := os.Symlink(target, link); err != nil { + t.Fatal(err) + } + + got, err := ResolveForWrite(link) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != target { + t.Errorf("got %q, want %q", got, target) + } + }) + + t.Run("non-existent", func(t *testing.T) { + tmpDir := t.TempDir() + newFile := filepath.Join(tmpDir, "new.txt") + + got, err := ResolveForWrite(newFile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != newFile { + t.Errorf("got %q, want %q", got, newFile) + } + }) +}