diff --git a/cmd/bd/daemon_autostart.go b/cmd/bd/daemon_autostart.go index d858c7d8..4fdd0cc0 100644 --- a/cmd/bd/daemon_autostart.go +++ b/cmd/bd/daemon_autostart.go @@ -31,6 +31,19 @@ var ( daemonStartFailures int ) +var ( + executableFn = os.Executable + execCommandFn = exec.Command + openFileFn = os.OpenFile + findProcessFn = os.FindProcess + removeFileFn = os.Remove + configureDaemonProcessFn = configureDaemonProcess + waitForSocketReadinessFn = waitForSocketReadiness + startDaemonProcessFn = startDaemonProcess + isDaemonRunningFn = isDaemonRunning + sendStopSignalFn = sendStopSignal +) + // shouldAutoStartDaemon checks if daemon auto-start is enabled func shouldAutoStartDaemon() bool { // Check BEADS_NO_DAEMON first (escape hatch for single-user workflows) @@ -53,7 +66,6 @@ func shouldAutoStartDaemon() bool { return config.GetBool("auto-start-daemon") // Defaults to true } - // restartDaemonForVersionMismatch stops the old daemon and starts a new one // Returns true if restart was successful func restartDaemonForVersionMismatch() bool { @@ -67,17 +79,17 @@ func restartDaemonForVersionMismatch() bool { // Check if daemon is running and stop it forcedKill := false - if isRunning, pid := isDaemonRunning(pidFile); isRunning { + if isRunning, pid := isDaemonRunningFn(pidFile); isRunning { debug.Logf("stopping old daemon (PID %d)", pid) - process, err := os.FindProcess(pid) + process, err := findProcessFn(pid) if err != nil { debug.Logf("failed to find process: %v", err) return false } // Send stop signal - if err := sendStopSignal(process); err != nil { + if err := sendStopSignalFn(process); err != nil { debug.Logf("failed to signal daemon: %v", err) return false } @@ -85,14 +97,14 @@ func restartDaemonForVersionMismatch() bool { // Wait for daemon to stop, then force kill for i := 0; i < daemonShutdownAttempts; i++ { time.Sleep(daemonShutdownPollInterval) - if isRunning, _ := isDaemonRunning(pidFile); !isRunning { + if isRunning, _ := isDaemonRunningFn(pidFile); !isRunning { debug.Logf("old daemon stopped successfully") break } } // Force kill if still running - if isRunning, _ := isDaemonRunning(pidFile); isRunning { + if isRunning, _ := isDaemonRunningFn(pidFile); isRunning { debug.Logf("force killing old daemon") _ = process.Kill() forcedKill = true @@ -101,19 +113,19 @@ func restartDaemonForVersionMismatch() bool { // Clean up stale socket and PID file after force kill or if not running if forcedKill || !isDaemonRunningQuiet(pidFile) { - _ = os.Remove(socketPath) - _ = os.Remove(pidFile) + _ = removeFileFn(socketPath) + _ = removeFileFn(pidFile) } // Start new daemon with current binary version - exe, err := os.Executable() + exe, err := executableFn() if err != nil { debug.Logf("failed to get executable path: %v", err) return false } args := []string{"daemon", "--start"} - cmd := exec.Command(exe, args...) + cmd := execCommandFn(exe, args...) cmd.Env = append(os.Environ(), "BD_DAEMON_FOREGROUND=1") // Set working directory to database directory so daemon finds correct DB @@ -121,9 +133,9 @@ func restartDaemonForVersionMismatch() bool { cmd.Dir = filepath.Dir(dbPath) } - configureDaemonProcess(cmd) + configureDaemonProcessFn(cmd) - devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + devNull, err := openFileFn(os.DevNull, os.O_RDWR, 0) if err == nil { cmd.Stdin = devNull cmd.Stdout = devNull @@ -140,7 +152,7 @@ func restartDaemonForVersionMismatch() bool { go func() { _ = cmd.Wait() }() // Wait for daemon to be ready using shared helper - if waitForSocketReadiness(socketPath, 5*time.Second) { + if waitForSocketReadinessFn(socketPath, 5*time.Second) { debug.Logf("new daemon started successfully") return true } @@ -153,7 +165,7 @@ func restartDaemonForVersionMismatch() bool { // isDaemonRunningQuiet checks if daemon is running without output func isDaemonRunningQuiet(pidFile string) bool { - isRunning, _ := isDaemonRunning(pidFile) + isRunning, _ := isDaemonRunningFn(pidFile) return isRunning } @@ -185,7 +197,7 @@ func tryAutoStartDaemon(socketPath string) bool { } socketPath = determineSocketPath(socketPath) - return startDaemonProcess(socketPath) + return startDaemonProcessFn(socketPath) } func debugLog(msg string, args ...interface{}) { @@ -269,21 +281,21 @@ func determineSocketPath(socketPath string) string { } func startDaemonProcess(socketPath string) bool { - binPath, err := os.Executable() + binPath, err := executableFn() if err != nil { binPath = os.Args[0] } args := []string{"daemon", "--start"} - cmd := exec.Command(binPath, args...) + cmd := execCommandFn(binPath, args...) setupDaemonIO(cmd) if dbPath != "" { cmd.Dir = filepath.Dir(dbPath) } - configureDaemonProcess(cmd) + configureDaemonProcessFn(cmd) if err := cmd.Start(); err != nil { recordDaemonStartFailure() debugLog("failed to start daemon: %v", err) @@ -292,7 +304,7 @@ func startDaemonProcess(socketPath string) bool { go func() { _ = cmd.Wait() }() - if waitForSocketReadiness(socketPath, 5*time.Second) { + if waitForSocketReadinessFn(socketPath, 5*time.Second) { recordDaemonStartSuccess() return true } @@ -306,7 +318,7 @@ func startDaemonProcess(socketPath string) bool { } func setupDaemonIO(cmd *exec.Cmd) { - devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + devNull, err := openFileFn(os.DevNull, os.O_RDWR, 0) if err == nil { cmd.Stdout = devNull cmd.Stderr = devNull diff --git a/cmd/bd/daemon_autostart_unit_test.go b/cmd/bd/daemon_autostart_unit_test.go new file mode 100644 index 00000000..625cedf6 --- /dev/null +++ b/cmd/bd/daemon_autostart_unit_test.go @@ -0,0 +1,331 @@ +package main + +import ( + "bytes" + "context" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/steveyegge/beads/internal/config" +) + +func tempSockDir(t *testing.T) string { + t.Helper() + + base := "/tmp" + if runtime.GOOS == windowsOS { + base = os.TempDir() + } else if _, err := os.Stat(base); err != nil { + base = os.TempDir() + } + + d, err := os.MkdirTemp(base, "bd-sock-*") + if err != nil { + t.Fatalf("MkdirTemp: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(d) }) + return d +} + +func startTestRPCServer(t *testing.T) (socketPath string, cleanup func()) { + t.Helper() + + tmpDir := tempSockDir(t) + beadsDir := filepath.Join(tmpDir, ".beads") + if err := os.MkdirAll(beadsDir, 0o750); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + socketPath = filepath.Join(beadsDir, "bd.sock") + db := filepath.Join(beadsDir, "test.db") + store := newTestStore(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + log := newTestLogger() + + server, _, err := startRPCServer(ctx, socketPath, store, tmpDir, db, log) + if err != nil { + cancel() + t.Fatalf("startRPCServer: %v", err) + } + + cleanup = func() { + cancel() + if server != nil { + _ = server.Stop() + } + } + + return socketPath, cleanup +} + +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + + old := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe: %v", err) + } + os.Stderr = w + + var buf bytes.Buffer + done := make(chan struct{}) + go func() { + _, _ = io.Copy(&buf, r) + close(done) + }() + + fn() + _ = w.Close() + os.Stderr = old + <-done + _ = r.Close() + + return buf.String() +} + +func TestDaemonAutostart_AcquireStartLock_CreatesAndCleansStale(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "bd.sock.startlock") + pid, err := readPIDFromFile(lockPath) + if err == nil || pid != 0 { + // lock doesn't exist yet; expect read to fail. + } + + if !acquireStartLock(lockPath, filepath.Join(tmpDir, "bd.sock")) { + t.Fatalf("expected acquireStartLock to succeed") + } + got, err := readPIDFromFile(lockPath) + if err != nil { + t.Fatalf("readPIDFromFile: %v", err) + } + if got != os.Getpid() { + t.Fatalf("expected lock PID %d, got %d", os.Getpid(), got) + } + + // Stale lock: dead/unreadable PID should be removed and recreated. + if err := os.WriteFile(lockPath, []byte("0\n"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + if !acquireStartLock(lockPath, filepath.Join(tmpDir, "bd.sock")) { + t.Fatalf("expected acquireStartLock to succeed on stale lock") + } + got, err = readPIDFromFile(lockPath) + if err != nil { + t.Fatalf("readPIDFromFile: %v", err) + } + if got != os.Getpid() { + t.Fatalf("expected recreated lock PID %d, got %d", os.Getpid(), got) + } +} + +func TestDaemonAutostart_SocketHealthAndReadiness(t *testing.T) { + socketPath, cleanup := startTestRPCServer(t) + defer cleanup() + + if !canDialSocket(socketPath, 500*time.Millisecond) { + t.Fatalf("expected canDialSocket to succeed") + } + if !isDaemonHealthy(socketPath) { + t.Fatalf("expected isDaemonHealthy to succeed") + } + if !waitForSocketReadiness(socketPath, 500*time.Millisecond) { + t.Fatalf("expected waitForSocketReadiness to succeed") + } + + missing := filepath.Join(tempSockDir(t), "missing.sock") + if canDialSocket(missing, 50*time.Millisecond) { + t.Fatalf("expected canDialSocket to fail") + } + if waitForSocketReadiness(missing, 200*time.Millisecond) { + t.Fatalf("expected waitForSocketReadiness to time out") + } +} + +func TestDaemonAutostart_HandleExistingSocket(t *testing.T) { + socketPath, cleanup := startTestRPCServer(t) + defer cleanup() + + if !handleExistingSocket(socketPath) { + t.Fatalf("expected handleExistingSocket true for running daemon") + } +} + +func TestDaemonAutostart_HandleExistingSocket_StaleCleansUp(t *testing.T) { + tmpDir := t.TempDir() + beadsDir := filepath.Join(tmpDir, ".beads") + if err := os.MkdirAll(beadsDir, 0o750); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + socketPath := filepath.Join(beadsDir, "bd.sock") + pidFile := filepath.Join(beadsDir, "daemon.pid") + if err := os.WriteFile(socketPath, []byte("not-a-socket"), 0o600); err != nil { + t.Fatalf("WriteFile socket: %v", err) + } + if err := os.WriteFile(pidFile, []byte("0\n"), 0o600); err != nil { + t.Fatalf("WriteFile pid: %v", err) + } + + if handleExistingSocket(socketPath) { + t.Fatalf("expected false for stale socket") + } + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Fatalf("expected socket removed") + } + if _, err := os.Stat(pidFile); !os.IsNotExist(err) { + t.Fatalf("expected pidfile removed") + } +} + +func TestDaemonAutostart_TryAutoStartDaemon_EarlyExits(t *testing.T) { + oldFailures := daemonStartFailures + oldLast := lastDaemonStartAttempt + defer func() { + daemonStartFailures = oldFailures + lastDaemonStartAttempt = oldLast + }() + + daemonStartFailures = 1 + lastDaemonStartAttempt = time.Now() + if tryAutoStartDaemon(filepath.Join(t.TempDir(), "bd.sock")) { + t.Fatalf("expected tryAutoStartDaemon to skip due to backoff") + } + + daemonStartFailures = 0 + lastDaemonStartAttempt = time.Time{} + socketPath, cleanup := startTestRPCServer(t) + defer cleanup() + if !tryAutoStartDaemon(socketPath) { + t.Fatalf("expected tryAutoStartDaemon true when daemon already healthy") + } +} + +func TestDaemonAutostart_MiscHelpers(t *testing.T) { + if determineSocketPath("/x") != "/x" { + t.Fatalf("determineSocketPath should be identity") + } + + if err := config.Initialize(); err != nil { + t.Fatalf("config.Initialize: %v", err) + } + old := config.GetDuration("flush-debounce") + defer config.Set("flush-debounce", old) + + config.Set("flush-debounce", 0) + if got := getDebounceDuration(); got != 5*time.Second { + t.Fatalf("expected default debounce 5s, got %v", got) + } + config.Set("flush-debounce", 2*time.Second) + if got := getDebounceDuration(); got != 2*time.Second { + t.Fatalf("expected debounce 2s, got %v", got) + } +} + +func TestDaemonAutostart_EmitVerboseWarning(t *testing.T) { + old := daemonStatus + defer func() { daemonStatus = old }() + + daemonStatus.SocketPath = "/tmp/bd.sock" + for _, tt := range []struct { + reason string + shouldWrite bool + }{ + {FallbackConnectFailed, true}, + {FallbackHealthFailed, true}, + {FallbackAutoStartDisabled, true}, + {FallbackAutoStartFailed, true}, + {FallbackDaemonUnsupported, true}, + {FallbackWorktreeSafety, false}, + {FallbackFlagNoDaemon, false}, + } { + t.Run(tt.reason, func(t *testing.T) { + daemonStatus.FallbackReason = tt.reason + out := captureStderr(t, emitVerboseWarning) + if tt.shouldWrite && out == "" { + t.Fatalf("expected output") + } + if !tt.shouldWrite && out != "" { + t.Fatalf("expected no output, got %q", out) + } + }) + } +} + +func TestDaemonAutostart_StartDaemonProcess_Stubbed(t *testing.T) { + oldExec := execCommandFn + oldWait := waitForSocketReadinessFn + oldCfg := configureDaemonProcessFn + defer func() { + execCommandFn = oldExec + waitForSocketReadinessFn = oldWait + configureDaemonProcessFn = oldCfg + }() + + execCommandFn = func(string, ...string) *exec.Cmd { + return exec.Command(os.Args[0], "-test.run=^$") + } + waitForSocketReadinessFn = func(string, time.Duration) bool { return true } + configureDaemonProcessFn = func(*exec.Cmd) {} + + if !startDaemonProcess(filepath.Join(t.TempDir(), "bd.sock")) { + t.Fatalf("expected startDaemonProcess true when readiness stubbed") + } +} + +func TestDaemonAutostart_RestartDaemonForVersionMismatch_Stubbed(t *testing.T) { + oldExec := execCommandFn + oldWait := waitForSocketReadinessFn + oldRun := isDaemonRunningFn + oldCfg := configureDaemonProcessFn + defer func() { + execCommandFn = oldExec + waitForSocketReadinessFn = oldWait + isDaemonRunningFn = oldRun + configureDaemonProcessFn = oldCfg + }() + + tmpDir := t.TempDir() + beadsDir := filepath.Join(tmpDir, ".beads") + if err := os.MkdirAll(beadsDir, 0o750); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + oldDB := dbPath + defer func() { dbPath = oldDB }() + dbPath = filepath.Join(beadsDir, "test.db") + + pidFile, err := getPIDFilePath() + if err != nil { + t.Fatalf("getPIDFilePath: %v", err) + } + sock := getSocketPath() + if err := os.WriteFile(pidFile, []byte("999999\n"), 0o600); err != nil { + t.Fatalf("WriteFile pid: %v", err) + } + if err := os.WriteFile(sock, []byte("stale"), 0o600); err != nil { + t.Fatalf("WriteFile sock: %v", err) + } + + execCommandFn = func(string, ...string) *exec.Cmd { + return exec.Command(os.Args[0], "-test.run=^$") + } + waitForSocketReadinessFn = func(string, time.Duration) bool { return true } + isDaemonRunningFn = func(string) (bool, int) { return false, 0 } + configureDaemonProcessFn = func(*exec.Cmd) {} + + if !restartDaemonForVersionMismatch() { + t.Fatalf("expected restartDaemonForVersionMismatch true when stubbed") + } + if _, err := os.Stat(pidFile); !os.IsNotExist(err) { + t.Fatalf("expected pidfile removed") + } + if _, err := os.Stat(sock); !os.IsNotExist(err) { + t.Fatalf("expected socket removed") + } +}