diff --git a/internal/checkpoint/checkpoint_test.go b/internal/checkpoint/checkpoint_test.go new file mode 100644 index 00000000..cf49d486 --- /dev/null +++ b/internal/checkpoint/checkpoint_test.go @@ -0,0 +1,398 @@ +package checkpoint + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" +) + +func TestPath(t *testing.T) { + dir := "/some/polecat/dir" + got := Path(dir) + want := filepath.Join(dir, Filename) + if got != want { + t.Errorf("Path(%q) = %q, want %q", dir, got, want) + } +} + +func TestReadWrite(t *testing.T) { + // Create temp directory + tmpDir := t.TempDir() + + // Test reading non-existent checkpoint returns nil, nil + cp, err := Read(tmpDir) + if err != nil { + t.Fatalf("Read non-existent: unexpected error: %v", err) + } + if cp != nil { + t.Fatal("Read non-existent: expected nil checkpoint") + } + + // Create and write a checkpoint + original := &Checkpoint{ + MoleculeID: "mol-123", + CurrentStep: "step-1", + StepTitle: "Build the thing", + ModifiedFiles: []string{"file1.go", "file2.go"}, + LastCommit: "abc123", + Branch: "feature/test", + HookedBead: "gt-xyz", + Notes: "Some notes", + } + + if err := Write(tmpDir, original); err != nil { + t.Fatalf("Write: unexpected error: %v", err) + } + + // Verify file exists + path := Path(tmpDir) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Fatal("Write: checkpoint file not created") + } + + // Read it back + loaded, err := Read(tmpDir) + if err != nil { + t.Fatalf("Read: unexpected error: %v", err) + } + if loaded == nil { + t.Fatal("Read: expected non-nil checkpoint") + } + + // Verify fields + if loaded.MoleculeID != original.MoleculeID { + t.Errorf("MoleculeID = %q, want %q", loaded.MoleculeID, original.MoleculeID) + } + if loaded.CurrentStep != original.CurrentStep { + t.Errorf("CurrentStep = %q, want %q", loaded.CurrentStep, original.CurrentStep) + } + if loaded.StepTitle != original.StepTitle { + t.Errorf("StepTitle = %q, want %q", loaded.StepTitle, original.StepTitle) + } + if loaded.Branch != original.Branch { + t.Errorf("Branch = %q, want %q", loaded.Branch, original.Branch) + } + if loaded.HookedBead != original.HookedBead { + t.Errorf("HookedBead = %q, want %q", loaded.HookedBead, original.HookedBead) + } + if loaded.Notes != original.Notes { + t.Errorf("Notes = %q, want %q", loaded.Notes, original.Notes) + } + if len(loaded.ModifiedFiles) != len(original.ModifiedFiles) { + t.Errorf("ModifiedFiles len = %d, want %d", len(loaded.ModifiedFiles), len(original.ModifiedFiles)) + } + + // Verify timestamp was set + if loaded.Timestamp.IsZero() { + t.Error("Timestamp should be set by Write") + } + + // Verify SessionID was set + if loaded.SessionID == "" { + t.Error("SessionID should be set by Write") + } +} + +func TestWritePreservesTimestamp(t *testing.T) { + tmpDir := t.TempDir() + + // Create checkpoint with explicit timestamp + ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + cp := &Checkpoint{ + Timestamp: ts, + Notes: "test", + } + + if err := Write(tmpDir, cp); err != nil { + t.Fatalf("Write: %v", err) + } + + loaded, err := Read(tmpDir) + if err != nil { + t.Fatalf("Read: %v", err) + } + + if !loaded.Timestamp.Equal(ts) { + t.Errorf("Timestamp = %v, want %v", loaded.Timestamp, ts) + } +} + +func TestReadCorruptedJSON(t *testing.T) { + tmpDir := t.TempDir() + path := Path(tmpDir) + + // Write invalid JSON + if err := os.WriteFile(path, []byte("not valid json{"), 0600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + _, err := Read(tmpDir) + if err == nil { + t.Fatal("Read corrupted JSON: expected error") + } +} + +func TestRemove(t *testing.T) { + tmpDir := t.TempDir() + + // Write a checkpoint + cp := &Checkpoint{Notes: "to be removed"} + if err := Write(tmpDir, cp); err != nil { + t.Fatalf("Write: %v", err) + } + + // Verify it exists + path := Path(tmpDir) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Fatal("checkpoint should exist before Remove") + } + + // Remove it + if err := Remove(tmpDir); err != nil { + t.Fatalf("Remove: %v", err) + } + + // Verify it's gone + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatal("checkpoint should not exist after Remove") + } + + // Remove again should not error + if err := Remove(tmpDir); err != nil { + t.Fatalf("Remove non-existent: %v", err) + } +} + +func TestCapture(t *testing.T) { + // Use current directory (should be a git repo) + cwd, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + + // Find git root + gitRoot := cwd + for { + if _, err := os.Stat(filepath.Join(gitRoot, ".git")); err == nil { + break + } + parent := filepath.Dir(gitRoot) + if parent == gitRoot { + t.Skip("not in a git repository") + } + gitRoot = parent + } + + cp, err := Capture(gitRoot) + if err != nil { + t.Fatalf("Capture: %v", err) + } + + // Should have timestamp + if cp.Timestamp.IsZero() { + t.Error("Timestamp should be set") + } + + // Should have branch (we're in a git repo) + if cp.Branch == "" { + t.Error("Branch should be set in git repo") + } + + // Should have last commit + if cp.LastCommit == "" { + t.Error("LastCommit should be set in git repo") + } +} + +func TestWithMolecule(t *testing.T) { + cp := &Checkpoint{} + result := cp.WithMolecule("mol-abc", "step-1", "Do the thing") + + if result != cp { + t.Error("WithMolecule should return same checkpoint") + } + if cp.MoleculeID != "mol-abc" { + t.Errorf("MoleculeID = %q, want %q", cp.MoleculeID, "mol-abc") + } + if cp.CurrentStep != "step-1" { + t.Errorf("CurrentStep = %q, want %q", cp.CurrentStep, "step-1") + } + if cp.StepTitle != "Do the thing" { + t.Errorf("StepTitle = %q, want %q", cp.StepTitle, "Do the thing") + } +} + +func TestWithHookedBead(t *testing.T) { + cp := &Checkpoint{} + result := cp.WithHookedBead("gt-123") + + if result != cp { + t.Error("WithHookedBead should return same checkpoint") + } + if cp.HookedBead != "gt-123" { + t.Errorf("HookedBead = %q, want %q", cp.HookedBead, "gt-123") + } +} + +func TestWithNotes(t *testing.T) { + cp := &Checkpoint{} + result := cp.WithNotes("important context") + + if result != cp { + t.Error("WithNotes should return same checkpoint") + } + if cp.Notes != "important context" { + t.Errorf("Notes = %q, want %q", cp.Notes, "important context") + } +} + +func TestAge(t *testing.T) { + cp := &Checkpoint{ + Timestamp: time.Now().Add(-5 * time.Minute), + } + + age := cp.Age() + if age < 4*time.Minute || age > 6*time.Minute { + t.Errorf("Age = %v, expected ~5 minutes", age) + } +} + +func TestIsStale(t *testing.T) { + tests := []struct { + name string + age time.Duration + threshold time.Duration + want bool + }{ + {"fresh", 5 * time.Minute, 1 * time.Hour, false}, + {"stale", 2 * time.Hour, 1 * time.Hour, true}, + {"exactly threshold", 1 * time.Hour, 1 * time.Hour, true}, // timing race: by the time IsStale runs, age > threshold + {"just over threshold", 1*time.Hour + time.Second, 1 * time.Hour, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cp := &Checkpoint{ + Timestamp: time.Now().Add(-tt.age), + } + got := cp.IsStale(tt.threshold) + if got != tt.want { + t.Errorf("IsStale(%v) = %v, want %v", tt.threshold, got, tt.want) + } + }) + } +} + +func TestSummary(t *testing.T) { + tests := []struct { + name string + cp *Checkpoint + want string + }{ + { + name: "empty", + cp: &Checkpoint{}, + want: "no significant state", + }, + { + name: "molecule only", + cp: &Checkpoint{MoleculeID: "mol-123"}, + want: "molecule mol-123", + }, + { + name: "molecule with step", + cp: &Checkpoint{MoleculeID: "mol-123", CurrentStep: "step-1"}, + want: "molecule mol-123, step step-1", + }, + { + name: "hooked bead", + cp: &Checkpoint{HookedBead: "gt-abc"}, + want: "hooked: gt-abc", + }, + { + name: "modified files", + cp: &Checkpoint{ModifiedFiles: []string{"a.go", "b.go"}}, + want: "2 modified files", + }, + { + name: "branch", + cp: &Checkpoint{Branch: "feature/test"}, + want: "branch: feature/test", + }, + { + name: "full", + cp: &Checkpoint{ + MoleculeID: "mol-123", + CurrentStep: "step-1", + HookedBead: "gt-abc", + ModifiedFiles: []string{"a.go"}, + Branch: "main", + }, + want: "molecule mol-123, step step-1, hooked: gt-abc, 1 modified files, branch: main", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.cp.Summary() + if got != tt.want { + t.Errorf("Summary() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestCheckpointJSONRoundtrip(t *testing.T) { + original := &Checkpoint{ + MoleculeID: "mol-test", + CurrentStep: "step-2", + StepTitle: "Testing JSON", + ModifiedFiles: []string{"x.go", "y.go", "z.go"}, + LastCommit: "deadbeef", + Branch: "develop", + HookedBead: "gt-roundtrip", + Timestamp: time.Date(2025, 6, 15, 10, 30, 0, 0, time.UTC), + SessionID: "session-123", + Notes: "Testing round trip", + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + var loaded Checkpoint + if err := json.Unmarshal(data, &loaded); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + + if loaded.MoleculeID != original.MoleculeID { + t.Errorf("MoleculeID mismatch") + } + if loaded.CurrentStep != original.CurrentStep { + t.Errorf("CurrentStep mismatch") + } + if loaded.StepTitle != original.StepTitle { + t.Errorf("StepTitle mismatch") + } + if loaded.Branch != original.Branch { + t.Errorf("Branch mismatch") + } + if loaded.HookedBead != original.HookedBead { + t.Errorf("HookedBead mismatch") + } + if loaded.SessionID != original.SessionID { + t.Errorf("SessionID mismatch") + } + if loaded.Notes != original.Notes { + t.Errorf("Notes mismatch") + } + if !loaded.Timestamp.Equal(original.Timestamp) { + t.Errorf("Timestamp mismatch") + } + if len(loaded.ModifiedFiles) != len(original.ModifiedFiles) { + t.Errorf("ModifiedFiles length mismatch") + } +} diff --git a/internal/connection/address_test.go b/internal/connection/address_test.go index b1b07440..8bab7826 100644 --- a/internal/connection/address_test.go +++ b/internal/connection/address_test.go @@ -190,3 +190,261 @@ func TestAddressEqual(t *testing.T) { }) } } + +func TestParseAddress_EdgeCases(t *testing.T) { + tests := []struct { + name string + input string + want *Address + wantErr bool + }{ + // Malformed: empty/whitespace variations + { + name: "empty string", + input: "", + wantErr: true, + }, + { + name: "whitespace only", + input: " ", + want: &Address{Rig: " "}, + wantErr: false, // whitespace-only rig is technically parsed + }, + { + name: "just slash", + input: "/", + wantErr: true, + }, + { + name: "double slash", + input: "//", + wantErr: true, + }, + { + name: "triple slash", + input: "///", + wantErr: true, + }, + + // Malformed: leading/trailing issues + { + name: "leading slash", + input: "/polecat", + wantErr: true, + }, + { + name: "leading slash with rig", + input: "/rig/polecat", + wantErr: true, + }, + { + name: "trailing slash is broadcast", + input: "rig/", + want: &Address{Rig: "rig"}, + }, + + // Machine prefix edge cases + { + name: "colon only", + input: ":", + wantErr: true, + }, + { + name: "colon with trailing slash", + input: ":/", + wantErr: true, + }, + { + name: "empty machine with colon", + input: ":rig/polecat", + wantErr: true, + }, + { + name: "multiple colons in machine", + input: "host:8080:rig/polecat", + want: &Address{Machine: "host", Rig: "8080:rig", Polecat: "polecat"}, + }, + { + name: "colon in rig name", + input: "machine:rig:port/polecat", + want: &Address{Machine: "machine", Rig: "rig:port", Polecat: "polecat"}, + }, + + // Multiple slash handling (SplitN behavior) + { + name: "extra slashes in polecat", + input: "rig/pole/cat/extra", + want: &Address{Rig: "rig", Polecat: "pole/cat/extra"}, + }, + { + name: "many path components", + input: "a/b/c/d/e", + want: &Address{Rig: "a", Polecat: "b/c/d/e"}, + }, + + // Unicode handling + { + name: "unicode rig name", + input: "日本語/polecat", + want: &Address{Rig: "日本語", Polecat: "polecat"}, + }, + { + name: "unicode polecat name", + input: "rig/工作者", + want: &Address{Rig: "rig", Polecat: "工作者"}, + }, + { + name: "emoji in address", + input: "🔧/🐱", + want: &Address{Rig: "🔧", Polecat: "🐱"}, + }, + { + name: "unicode machine name", + input: "マシン:rig/polecat", + want: &Address{Machine: "マシン", Rig: "rig", Polecat: "polecat"}, + }, + + // Long addresses + { + name: "very long rig name", + input: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789/polecat", + want: &Address{Rig: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789", Polecat: "polecat"}, + }, + { + name: "very long polecat name", + input: "rig/abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789", + want: &Address{Rig: "rig", Polecat: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789"}, + }, + + // Special characters + { + name: "hyphen in names", + input: "my-rig/my-polecat", + want: &Address{Rig: "my-rig", Polecat: "my-polecat"}, + }, + { + name: "underscore in names", + input: "my_rig/my_polecat", + want: &Address{Rig: "my_rig", Polecat: "my_polecat"}, + }, + { + name: "dots in names", + input: "my.rig/my.polecat", + want: &Address{Rig: "my.rig", Polecat: "my.polecat"}, + }, + { + name: "mixed special chars", + input: "rig-1_v2.0/polecat-alpha_1.0", + want: &Address{Rig: "rig-1_v2.0", Polecat: "polecat-alpha_1.0"}, + }, + + // Whitespace in components + { + name: "space in rig name", + input: "my rig/polecat", + want: &Address{Rig: "my rig", Polecat: "polecat"}, + }, + { + name: "space in polecat name", + input: "rig/my polecat", + want: &Address{Rig: "rig", Polecat: "my polecat"}, + }, + { + name: "leading space in rig", + input: " rig/polecat", + want: &Address{Rig: " rig", Polecat: "polecat"}, + }, + { + name: "trailing space in polecat", + input: "rig/polecat ", + want: &Address{Rig: "rig", Polecat: "polecat "}, + }, + + // Edge case: machine with no rig after colon + { + name: "machine colon nothing", + input: "machine:", + wantErr: true, + }, + { + name: "machine colon slash", + input: "machine:/", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseAddress(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("ParseAddress(%q) expected error, got %+v", tt.input, got) + } + return + } + if err != nil { + t.Errorf("ParseAddress(%q) unexpected error: %v", tt.input, err) + return + } + if got.Machine != tt.want.Machine { + t.Errorf("Machine = %q, want %q", got.Machine, tt.want.Machine) + } + if got.Rig != tt.want.Rig { + t.Errorf("Rig = %q, want %q", got.Rig, tt.want.Rig) + } + if got.Polecat != tt.want.Polecat { + t.Errorf("Polecat = %q, want %q", got.Polecat, tt.want.Polecat) + } + }) + } +} + +func TestMustParseAddress_Panics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("MustParseAddress with empty string should panic") + } + }() + MustParseAddress("") +} + +func TestMustParseAddress_Valid(t *testing.T) { + // Should not panic + addr := MustParseAddress("rig/polecat") + if addr.Rig != "rig" || addr.Polecat != "polecat" { + t.Errorf("MustParseAddress returned wrong address: %+v", addr) + } +} + +func TestAddressRigPath(t *testing.T) { + tests := []struct { + addr *Address + want string + }{ + { + addr: &Address{Rig: "gastown", Polecat: "rictus"}, + want: "gastown/rictus", + }, + { + addr: &Address{Rig: "gastown"}, + want: "gastown/", + }, + { + addr: &Address{Machine: "vm", Rig: "gastown", Polecat: "rictus"}, + want: "gastown/rictus", + }, + { + addr: &Address{Rig: "a", Polecat: "b/c/d"}, + want: "a/b/c/d", + }, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := tt.addr.RigPath() + if got != tt.want { + t.Errorf("RigPath() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/keepalive/keepalive_test.go b/internal/keepalive/keepalive_test.go index 5fc42077..3d192412 100644 --- a/internal/keepalive/keepalive_test.go +++ b/internal/keepalive/keepalive_test.go @@ -76,3 +76,63 @@ func TestDirectoryCreation(t *testing.T) { t.Error("expected .runtime directory to be created") } } + +// Example functions demonstrate keepalive usage patterns. + +func ExampleTouchInWorkspace() { + // TouchInWorkspace signals agent activity in a specific workspace. + // This is the core function - use it when you know the workspace root. + + workspaceRoot := "/path/to/workspace" + + // Signal that "gt status" was run + TouchInWorkspace(workspaceRoot, "gt status") + + // Signal a command with arguments + TouchInWorkspace(workspaceRoot, "gt sling bd-abc123 ai-platform") + + // All errors are silently ignored (best-effort design). + // This is intentional - keepalive failures should never break commands. +} + +func ExampleRead() { + // Read retrieves the current keepalive state for a workspace. + // Returns nil if no keepalive file exists or it can't be read. + + workspaceRoot := "/path/to/workspace" + state := Read(workspaceRoot) + + if state == nil { + // No keepalive found - agent may not have run any commands yet + return + } + + // Access the last command that was run + _ = state.LastCommand // e.g., "gt status" + + // Access when the command was run + _ = state.Timestamp // time.Time in UTC +} + +func ExampleState_Age() { + // Age() returns how long ago the keepalive was updated. + // This is useful for detecting idle or stuck agents. + + workspaceRoot := "/path/to/workspace" + state := Read(workspaceRoot) + + // Age() is nil-safe - returns ~1 year for nil state + age := state.Age() + + // Check if agent was active recently (within 5 minutes) + if age < 5*time.Minute { + // Agent is active + _ = "active" + } + + // Check if agent might be stuck (no activity for 30+ minutes) + if age > 30*time.Minute { + // Agent may need attention + _ = "possibly stuck" + } +} diff --git a/internal/lock/lock_test.go b/internal/lock/lock_test.go new file mode 100644 index 00000000..006281a4 --- /dev/null +++ b/internal/lock/lock_test.go @@ -0,0 +1,665 @@ +package lock + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" +) + +func TestNew(t *testing.T) { + workerDir := "/tmp/test-worker" + l := New(workerDir) + + if l.workerDir != workerDir { + t.Errorf("workerDir = %q, want %q", l.workerDir, workerDir) + } + + expectedPath := filepath.Join(workerDir, ".runtime", "agent.lock") + if l.lockPath != expectedPath { + t.Errorf("lockPath = %q, want %q", l.lockPath, expectedPath) + } +} + +func TestLockInfo_IsStale(t *testing.T) { + tests := []struct { + name string + pid int + wantStale bool + }{ + {"current process", os.Getpid(), false}, + {"invalid pid zero", 0, true}, + {"invalid pid negative", -1, true}, + {"non-existent pid", 999999999, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := &LockInfo{PID: tt.pid} + if got := info.IsStale(); got != tt.wantStale { + t.Errorf("IsStale() = %v, want %v", got, tt.wantStale) + } + }) + } +} + +func TestLock_AcquireAndRelease(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + if err := os.MkdirAll(workerDir, 0755); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Acquire lock + err := l.Acquire("test-session") + if err != nil { + t.Fatalf("Acquire() error = %v", err) + } + + // Verify lock file exists + info, err := l.Read() + if err != nil { + t.Fatalf("Read() error = %v", err) + } + if info.PID != os.Getpid() { + t.Errorf("PID = %d, want %d", info.PID, os.Getpid()) + } + if info.SessionID != "test-session" { + t.Errorf("SessionID = %q, want %q", info.SessionID, "test-session") + } + + // Release lock + err = l.Release() + if err != nil { + t.Fatalf("Release() error = %v", err) + } + + // Verify lock file is gone + _, err = l.Read() + if err != ErrNotLocked { + t.Errorf("Read() after release: error = %v, want ErrNotLocked", err) + } +} + +func TestLock_AcquireAlreadyHeld(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + if err := os.MkdirAll(workerDir, 0755); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Acquire lock first time + if err := l.Acquire("session-1"); err != nil { + t.Fatalf("First Acquire() error = %v", err) + } + + // Re-acquire with different session should refresh + if err := l.Acquire("session-2"); err != nil { + t.Fatalf("Second Acquire() error = %v", err) + } + + // Verify session was updated + info, err := l.Read() + if err != nil { + t.Fatalf("Read() error = %v", err) + } + if info.SessionID != "session-2" { + t.Errorf("SessionID = %q, want %q", info.SessionID, "session-2") + } + + l.Release() +} + +func TestLock_AcquireStaleLock(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + runtimeDir := filepath.Join(workerDir, ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + + // Create a stale lock file with non-existent PID + staleLock := LockInfo{ + PID: 999999999, // Non-existent PID + AcquiredAt: time.Now().Add(-time.Hour), + SessionID: "dead-session", + } + data, _ := json.Marshal(staleLock) + lockPath := filepath.Join(runtimeDir, "agent.lock") + if err := os.WriteFile(lockPath, data, 0644); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Should acquire by cleaning up stale lock + if err := l.Acquire("new-session"); err != nil { + t.Fatalf("Acquire() with stale lock error = %v", err) + } + + // Verify we now own it + info, err := l.Read() + if err != nil { + t.Fatalf("Read() error = %v", err) + } + if info.PID != os.Getpid() { + t.Errorf("PID = %d, want %d", info.PID, os.Getpid()) + } + if info.SessionID != "new-session" { + t.Errorf("SessionID = %q, want %q", info.SessionID, "new-session") + } + + l.Release() +} + +func TestLock_Read(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + runtimeDir := filepath.Join(workerDir, ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Test reading non-existent lock + _, err := l.Read() + if err != ErrNotLocked { + t.Errorf("Read() non-existent: error = %v, want ErrNotLocked", err) + } + + // Test reading invalid JSON + lockPath := filepath.Join(runtimeDir, "agent.lock") + if err := os.WriteFile(lockPath, []byte("invalid json"), 0644); err != nil { + t.Fatal(err) + } + _, err = l.Read() + if err == nil { + t.Error("Read() invalid JSON: expected error, got nil") + } + + // Test reading valid lock + validLock := LockInfo{ + PID: 12345, + AcquiredAt: time.Now(), + SessionID: "test", + Hostname: "testhost", + } + data, _ := json.Marshal(validLock) + if err := os.WriteFile(lockPath, data, 0644); err != nil { + t.Fatal(err) + } + info, err := l.Read() + if err != nil { + t.Fatalf("Read() valid lock: error = %v", err) + } + if info.PID != 12345 { + t.Errorf("PID = %d, want 12345", info.PID) + } + if info.SessionID != "test" { + t.Errorf("SessionID = %q, want %q", info.SessionID, "test") + } +} + +func TestLock_Check(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + runtimeDir := filepath.Join(workerDir, ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Check when unlocked + if err := l.Check(); err != nil { + t.Errorf("Check() unlocked: error = %v, want nil", err) + } + + // Acquire and check (should pass - we hold it) + if err := l.Acquire("test"); err != nil { + t.Fatal(err) + } + if err := l.Check(); err != nil { + t.Errorf("Check() owned by us: error = %v, want nil", err) + } + l.Release() + + // Create lock owned by another process - we'll simulate this by using a + // fake "live" process via the stale lock detection mechanism. + // Since we can't reliably find another live PID we can signal on all platforms, + // we test that Check() correctly identifies our own PID vs a different PID. + // The stale lock cleanup path is tested elsewhere. + + // Test that a non-existent PID lock gets cleaned up and returns nil + staleLock := LockInfo{ + PID: 999999999, // Non-existent PID + AcquiredAt: time.Now(), + SessionID: "other-session", + } + data, _ := json.Marshal(staleLock) + lockPath := filepath.Join(runtimeDir, "agent.lock") + if err := os.WriteFile(lockPath, data, 0644); err != nil { + t.Fatal(err) + } + + // Check should clean up the stale lock and return nil + err := l.Check() + if err != nil { + t.Errorf("Check() with stale lock: error = %v, want nil (should clean up)", err) + } + + // Verify lock was cleaned up + if _, statErr := os.Stat(lockPath); !os.IsNotExist(statErr) { + t.Error("Check() should have removed stale lock file") + } +} + +func TestLock_Status(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + runtimeDir := filepath.Join(workerDir, ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Unlocked status + status := l.Status() + if status != "unlocked" { + t.Errorf("Status() unlocked = %q, want %q", status, "unlocked") + } + + // Owned by us + if err := l.Acquire("test"); err != nil { + t.Fatal(err) + } + status = l.Status() + if status != "locked (by us)" { + t.Errorf("Status() owned = %q, want %q", status, "locked (by us)") + } + l.Release() + + // Stale lock + staleLock := LockInfo{ + PID: 999999999, + AcquiredAt: time.Now(), + SessionID: "dead", + } + data, _ := json.Marshal(staleLock) + lockPath := filepath.Join(runtimeDir, "agent.lock") + if err := os.WriteFile(lockPath, data, 0644); err != nil { + t.Fatal(err) + } + status = l.Status() + expected := "stale (dead PID 999999999)" + if status != expected { + t.Errorf("Status() stale = %q, want %q", status, expected) + } + + os.Remove(lockPath) +} + +func TestLock_ForceRelease(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + if err := os.MkdirAll(workerDir, 0755); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + if err := l.Acquire("test"); err != nil { + t.Fatal(err) + } + + if err := l.ForceRelease(); err != nil { + t.Errorf("ForceRelease() error = %v", err) + } + + _, err := l.Read() + if err != ErrNotLocked { + t.Errorf("Read() after ForceRelease: error = %v, want ErrNotLocked", err) + } +} + +func TestProcessExists(t *testing.T) { + // Current process exists + if !processExists(os.Getpid()) { + t.Error("processExists(current PID) = false, want true") + } + + // Note: PID 1 (init/launchd) cannot be signaled without permission on macOS, + // so we only test our own process and invalid PIDs. + + // Invalid PIDs + if processExists(0) { + t.Error("processExists(0) = true, want false") + } + if processExists(-1) { + t.Error("processExists(-1) = true, want false") + } + if processExists(999999999) { + t.Error("processExists(999999999) = true, want false") + } +} + +func TestFindAllLocks(t *testing.T) { + tmpDir := t.TempDir() + + // Create multiple worker directories with locks + workers := []string{"worker1", "worker2", "worker3"} + for i, w := range workers { + runtimeDir := filepath.Join(tmpDir, w, ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + info := LockInfo{ + PID: i + 100, + AcquiredAt: time.Now(), + SessionID: "session-" + w, + } + data, _ := json.Marshal(info) + lockPath := filepath.Join(runtimeDir, "agent.lock") + if err := os.WriteFile(lockPath, data, 0644); err != nil { + t.Fatal(err) + } + } + + locks, err := FindAllLocks(tmpDir) + if err != nil { + t.Fatalf("FindAllLocks() error = %v", err) + } + + if len(locks) != 3 { + t.Errorf("FindAllLocks() found %d locks, want 3", len(locks)) + } + + for _, w := range workers { + workerDir := filepath.Join(tmpDir, w) + if _, ok := locks[workerDir]; !ok { + t.Errorf("FindAllLocks() missing lock for %s", w) + } + } +} + +func TestCleanStaleLocks(t *testing.T) { + // Save and restore execCommand + origExecCommand := execCommand + defer func() { execCommand = origExecCommand }() + + // Mock tmux to return no active sessions + execCommand = func(name string, args ...string) interface{ Output() ([]byte, error) } { + return &mockCmd{output: []byte("")} + } + + tmpDir := t.TempDir() + + // Create a stale lock + runtimeDir := filepath.Join(tmpDir, "stale-worker", ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + staleLock := LockInfo{ + PID: 999999999, + AcquiredAt: time.Now(), + SessionID: "dead-session", + } + data, _ := json.Marshal(staleLock) + if err := os.WriteFile(filepath.Join(runtimeDir, "agent.lock"), data, 0644); err != nil { + t.Fatal(err) + } + + // Create a live lock (current process) + liveDir := filepath.Join(tmpDir, "live-worker", ".runtime") + if err := os.MkdirAll(liveDir, 0755); err != nil { + t.Fatal(err) + } + liveLock := LockInfo{ + PID: os.Getpid(), + AcquiredAt: time.Now(), + SessionID: "live-session", + } + data, _ = json.Marshal(liveLock) + if err := os.WriteFile(filepath.Join(liveDir, "agent.lock"), data, 0644); err != nil { + t.Fatal(err) + } + + cleaned, err := CleanStaleLocks(tmpDir) + if err != nil { + t.Fatalf("CleanStaleLocks() error = %v", err) + } + + if cleaned != 1 { + t.Errorf("CleanStaleLocks() cleaned %d, want 1", cleaned) + } + + // Verify stale lock is gone + staleLockPath := filepath.Join(runtimeDir, "agent.lock") + if _, err := os.Stat(staleLockPath); !os.IsNotExist(err) { + t.Error("Stale lock file should be removed") + } + + // Verify live lock still exists + liveLockPath := filepath.Join(liveDir, "agent.lock") + if _, err := os.Stat(liveLockPath); err != nil { + t.Error("Live lock file should still exist") + } +} + +type mockCmd struct { + output []byte + err error +} + +func (m *mockCmd) Output() ([]byte, error) { + return m.output, m.err +} + +func TestGetActiveTmuxSessions(t *testing.T) { + // Save and restore execCommand + origExecCommand := execCommand + defer func() { execCommand = origExecCommand }() + + // Mock tmux output + execCommand = func(name string, args ...string) interface{ Output() ([]byte, error) } { + return &mockCmd{output: []byte("session1:$1\nsession2:$2\n")} + } + + sessions := getActiveTmuxSessions() + + // Should contain session names and IDs + expected := map[string]bool{ + "session1": true, + "session2": true, + "$1": true, + "$2": true, + "%1": true, + "%2": true, + } + + for _, s := range sessions { + if !expected[s] { + t.Errorf("Unexpected session: %s", s) + } + } +} + +func TestSplitOnColon(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + {"a:b", []string{"a", "b"}}, + {"abc", []string{"abc"}}, + {"a:b:c", []string{"a", "b:c"}}, + {":b", []string{"", "b"}}, + {"a:", []string{"a", ""}}, + } + + for _, tt := range tests { + result := splitOnColon(tt.input) + if len(result) != len(tt.expected) { + t.Errorf("splitOnColon(%q) = %v, want %v", tt.input, result, tt.expected) + continue + } + for i := range result { + if result[i] != tt.expected[i] { + t.Errorf("splitOnColon(%q)[%d] = %q, want %q", tt.input, i, result[i], tt.expected[i]) + } + } + } +} + +func TestSplitLines(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + {"a\nb\nc", []string{"a", "b", "c"}}, + {"a\r\nb\r\nc", []string{"a", "b", "c"}}, + {"single", []string{"single"}}, + {"", []string{}}, + {"a\n", []string{"a"}}, + {"a\nb", []string{"a", "b"}}, + } + + for _, tt := range tests { + result := splitLines(tt.input) + if len(result) != len(tt.expected) { + t.Errorf("splitLines(%q) = %v, want %v", tt.input, result, tt.expected) + continue + } + for i := range result { + if result[i] != tt.expected[i] { + t.Errorf("splitLines(%q)[%d] = %q, want %q", tt.input, i, result[i], tt.expected[i]) + } + } + } +} + +func TestDetectCollisions(t *testing.T) { + tmpDir := t.TempDir() + + // Create a stale lock + runtimeDir := filepath.Join(tmpDir, "stale-worker", ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + staleLock := LockInfo{ + PID: 999999999, + AcquiredAt: time.Now(), + SessionID: "dead-session", + } + data, _ := json.Marshal(staleLock) + if err := os.WriteFile(filepath.Join(runtimeDir, "agent.lock"), data, 0644); err != nil { + t.Fatal(err) + } + + // Create an orphaned lock (live PID but session not in active list) + orphanDir := filepath.Join(tmpDir, "orphan-worker", ".runtime") + if err := os.MkdirAll(orphanDir, 0755); err != nil { + t.Fatal(err) + } + orphanLock := LockInfo{ + PID: os.Getpid(), // Live PID + AcquiredAt: time.Now(), + SessionID: "orphan-session", // Not in active list + } + data, _ = json.Marshal(orphanLock) + if err := os.WriteFile(filepath.Join(orphanDir, "agent.lock"), data, 0644); err != nil { + t.Fatal(err) + } + + activeSessions := []string{"active-session-1", "active-session-2"} + collisions := DetectCollisions(tmpDir, activeSessions) + + if len(collisions) != 2 { + t.Errorf("DetectCollisions() found %d collisions, want 2: %v", len(collisions), collisions) + } + + // Verify we found both issues + foundStale := false + foundOrphan := false + for _, c := range collisions { + if contains(c, "stale lock") { + foundStale = true + } + if contains(c, "orphaned lock") { + foundOrphan = true + } + } + + if !foundStale { + t.Error("DetectCollisions() did not find stale lock") + } + if !foundOrphan { + t.Error("DetectCollisions() did not find orphaned lock") + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func TestLock_ReleaseNonExistent(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + if err := os.MkdirAll(workerDir, 0755); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Releasing a non-existent lock should not error + if err := l.Release(); err != nil { + t.Errorf("Release() non-existent: error = %v, want nil", err) + } +} + +func TestLock_CheckCleansUpStaleLock(t *testing.T) { + tmpDir := t.TempDir() + workerDir := filepath.Join(tmpDir, "worker") + runtimeDir := filepath.Join(workerDir, ".runtime") + if err := os.MkdirAll(runtimeDir, 0755); err != nil { + t.Fatal(err) + } + + // Create a stale lock + staleLock := LockInfo{ + PID: 999999999, + AcquiredAt: time.Now(), + SessionID: "dead", + } + data, _ := json.Marshal(staleLock) + lockPath := filepath.Join(runtimeDir, "agent.lock") + if err := os.WriteFile(lockPath, data, 0644); err != nil { + t.Fatal(err) + } + + l := New(workerDir) + + // Check should clean up stale lock and return nil + if err := l.Check(); err != nil { + t.Errorf("Check() with stale lock: error = %v, want nil", err) + } + + // Lock file should be removed + if _, err := os.Stat(lockPath); !os.IsNotExist(err) { + t.Error("Check() should have removed stale lock file") + } +} diff --git a/internal/util/atomic_test.go b/internal/util/atomic_test.go index 91fbe33a..a6f82929 100644 --- a/internal/util/atomic_test.go +++ b/internal/util/atomic_test.go @@ -1,8 +1,10 @@ package util import ( + "encoding/json" "os" "path/filepath" + "sync" "testing" ) @@ -86,3 +88,272 @@ func TestAtomicWriteOverwrite(t *testing.T) { t.Fatalf("Unexpected content: %s", content) } } + +func TestAtomicWriteFilePermissions(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + + // Test with specific permissions + data := []byte("test data") + if err := AtomicWriteFile(testFile, data, 0600); err != nil { + t.Fatalf("AtomicWriteFile error: %v", err) + } + + // Verify permissions (on Unix systems) + info, err := os.Stat(testFile) + if err != nil { + t.Fatalf("Stat error: %v", err) + } + // Check that owner read/write bits are set + perm := info.Mode().Perm() + if perm&0600 != 0600 { + t.Errorf("Expected owner read/write permissions, got %o", perm) + } +} + +func TestAtomicWriteFileEmpty(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "empty.txt") + + // Test writing empty data + if err := AtomicWriteFile(testFile, []byte{}, 0644); err != nil { + t.Fatalf("AtomicWriteFile error: %v", err) + } + + // Verify file exists and is empty + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + if len(content) != 0 { + t.Fatalf("Expected empty file, got %d bytes", len(content)) + } +} + +func TestAtomicWriteJSONTypes(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + data interface{} + expected string + }{ + {"string", "hello", `"hello"`}, + {"int", 42, "42"}, + {"float", 3.14, "3.14"}, + {"bool", true, "true"}, + {"null", nil, "null"}, + {"array", []int{1, 2, 3}, "[\n 1,\n 2,\n 3\n]"}, + {"nested", map[string]interface{}{"a": map[string]int{"b": 1}}, "{\n \"a\": {\n \"b\": 1\n }\n}"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testFile := filepath.Join(tmpDir, tc.name+".json") + if err := AtomicWriteJSON(testFile, tc.data); err != nil { + t.Fatalf("AtomicWriteJSON error: %v", err) + } + + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + if string(content) != tc.expected { + t.Errorf("Expected %q, got %q", tc.expected, string(content)) + } + }) + } +} + +func TestAtomicWriteJSONUnmarshallable(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "unmarshallable.json") + + // Channels cannot be marshalled to JSON + ch := make(chan int) + err := AtomicWriteJSON(testFile, ch) + if err == nil { + t.Fatal("Expected error for unmarshallable type") + } + + // Verify file was not created + if _, statErr := os.Stat(testFile); !os.IsNotExist(statErr) { + t.Fatal("File should not exist after marshal error") + } + + // Verify temp file was not left behind + tmpFile := testFile + ".tmp" + if _, statErr := os.Stat(tmpFile); !os.IsNotExist(statErr) { + t.Fatal("Temp file should not exist after marshal error") + } +} + +func TestAtomicWriteFileReadOnlyDir(t *testing.T) { + tmpDir := t.TempDir() + roDir := filepath.Join(tmpDir, "readonly") + + // Create read-only directory + if err := os.Mkdir(roDir, 0555); err != nil { + t.Fatalf("Failed to create readonly dir: %v", err) + } + defer os.Chmod(roDir, 0755) // Restore permissions for cleanup + + testFile := filepath.Join(roDir, "test.txt") + err := AtomicWriteFile(testFile, []byte("test"), 0644) + if err == nil { + t.Fatal("Expected permission error") + } + + // Verify no files were created + if _, statErr := os.Stat(testFile); !os.IsNotExist(statErr) { + t.Fatal("File should not exist after permission error") + } +} + +func TestAtomicWriteFileConcurrent(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "concurrent.txt") + + // Write initial content + if err := AtomicWriteFile(testFile, []byte("initial"), 0644); err != nil { + t.Fatalf("Initial write error: %v", err) + } + + // Concurrent writes + const numWriters = 10 + var wg sync.WaitGroup + wg.Add(numWriters) + + for i := 0; i < numWriters; i++ { + go func(n int) { + defer wg.Done() + data := []byte(string(rune('A' + n))) + // Errors are possible due to race, but file should remain valid + _ = AtomicWriteFile(testFile, data, 0644) + }(i) + } + + wg.Wait() + + // Verify file is readable and contains valid content (one of the writes won) + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + if len(content) != 1 { + t.Errorf("Expected single character, got %q", content) + } + + // Verify no temp files left behind + entries, err := os.ReadDir(tmpDir) + if err != nil { + t.Fatalf("ReadDir error: %v", err) + } + for _, e := range entries { + if filepath.Ext(e.Name()) == ".tmp" { + t.Errorf("Temp file left behind: %s", e.Name()) + } + } +} + +func TestAtomicWritePreservesOnFailure(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "preserve.txt") + + // Write initial content + initialContent := []byte("original content") + if err := AtomicWriteFile(testFile, initialContent, 0644); err != nil { + t.Fatalf("Initial write error: %v", err) + } + + // Create a subdirectory with the .tmp name to cause rename to fail + tmpFile := testFile + ".tmp" + if err := os.Mkdir(tmpFile, 0755); err != nil { + t.Fatalf("Failed to create blocking dir: %v", err) + } + + // Attempt write which should fail at rename + err := AtomicWriteFile(testFile, []byte("new content"), 0644) + if err == nil { + t.Fatal("Expected error when .tmp is a directory") + } + + // Verify original content is preserved + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + if string(content) != string(initialContent) { + t.Errorf("Original content not preserved: got %q", content) + } +} + +func TestAtomicWriteJSONStruct(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "struct.json") + + type TestStruct struct { + Name string `json:"name"` + Count int `json:"count"` + Enabled bool `json:"enabled"` + Tags []string `json:"tags"` + } + + data := TestStruct{ + Name: "test", + Count: 42, + Enabled: true, + Tags: []string{"a", "b"}, + } + + if err := AtomicWriteJSON(testFile, data); err != nil { + t.Fatalf("AtomicWriteJSON error: %v", err) + } + + // Read back and verify + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + + var result TestStruct + if err := json.Unmarshal(content, &result); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if result.Name != data.Name || result.Count != data.Count || + result.Enabled != data.Enabled || len(result.Tags) != len(data.Tags) { + t.Errorf("Data mismatch: got %+v, want %+v", result, data) + } +} + +func TestAtomicWriteFileLargeData(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "large.bin") + + // Create 1MB of data + size := 1024 * 1024 + data := make([]byte, size) + for i := range data { + data[i] = byte(i % 256) + } + + if err := AtomicWriteFile(testFile, data, 0644); err != nil { + t.Fatalf("AtomicWriteFile error: %v", err) + } + + // Verify content + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + if len(content) != size { + t.Errorf("Size mismatch: got %d, want %d", len(content), size) + } + for i := 0; i < size; i++ { + if content[i] != byte(i%256) { + t.Errorf("Content mismatch at byte %d", i) + break + } + } +}