diff --git a/cmd/bd/import.go b/cmd/bd/import.go index 8731fa55..92f46b94 100644 --- a/cmd/bd/import.go +++ b/cmd/bd/import.go @@ -238,7 +238,14 @@ Behavior: } } - // Phase 5: Process dependencies + // Phase 5: Sync ID counters after importing issues with explicit IDs + // This prevents ID collisions with subsequently auto-generated issues + if err := sqliteStore.SyncAllCounters(ctx); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to sync ID counters: %v\n", err) + // Don't exit - this is not fatal, just a warning + } + + // Phase 6: Process dependencies // Do this after all issues are created to handle forward references var depsCreated, depsSkipped int for _, issue := range allIssues { diff --git a/cmd/bd/import_collision_test.go b/cmd/bd/import_collision_test.go index 177740bb..c3537521 100644 --- a/cmd/bd/import_collision_test.go +++ b/cmd/bd/import_collision_test.go @@ -968,3 +968,79 @@ func TestImportWithDependenciesInJSONL(t *testing.T) { t.Errorf("Dependency target = %s, want bd-1", deps[0].DependsOnID) } } + +func TestImportCounterSyncAfterHighID(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "bd-collision-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Logf("Warning: cleanup failed: %v", err) + } + }() + + dbPath := filepath.Join(tmpDir, "test.db") + testStore, err := sqlite.New(dbPath) + if err != nil { + t.Fatalf("Failed to create storage: %v", err) + } + defer func() { + if err := testStore.Close(); err != nil { + t.Logf("Warning: failed to close store: %v", err) + } + }() + + ctx := context.Background() + + if err := testStore.SetConfig(ctx, "issue_prefix", "bd"); err != nil { + t.Fatalf("Failed to set issue prefix: %v", err) + } + + for i := 0; i < 3; i++ { + issue := &types.Issue{ + Title: fmt.Sprintf("Auto issue %d", i+1), + Status: types.StatusOpen, + Priority: 1, + IssueType: types.TypeTask, + } + if err := testStore.CreateIssue(ctx, issue, "test"); err != nil { + t.Fatalf("Failed to create auto issue %d: %v", i+1, err) + } + } + + highIDIssue := &types.Issue{ + ID: "bd-100", + Title: "High ID issue", + Status: types.StatusOpen, + Priority: 1, + IssueType: types.TypeTask, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := testStore.CreateIssue(ctx, highIDIssue, "import"); err != nil { + t.Fatalf("Failed to import high ID issue: %v", err) + } + + // Step 4: Sync counters after import (mimics import command behavior) + if err := testStore.SyncAllCounters(ctx); err != nil { + t.Fatalf("Failed to sync counters: %v", err) + } + + // Step 5: Create another auto-generated issue + // This should get bd-101 (counter should have synced to 100), not bd-4 + newIssue := &types.Issue{ + Title: "New issue after import", + Status: types.StatusOpen, + Priority: 1, + IssueType: types.TypeTask, + } + if err := testStore.CreateIssue(ctx, newIssue, "test"); err != nil { + t.Fatalf("Failed to create new issue: %v", err) + } + + if newIssue.ID != "bd-101" { + t.Errorf("Expected new issue to get ID bd-101, got %s", newIssue.ID) + } +} diff --git a/internal/storage/sqlite/collision.go b/internal/storage/sqlite/collision.go index 0cfb7545..0ec6933c 100644 --- a/internal/storage/sqlite/collision.go +++ b/internal/storage/sqlite/collision.go @@ -232,15 +232,16 @@ func RemapCollisions(ctx context.Context, s *SQLiteStorage, collisions []*Collis for _, collision := range collisions { oldID := collision.ID - // Allocate new ID - s.idMu.Lock() + // Allocate new ID using atomic counter prefix, err := s.GetConfig(ctx, "issue_prefix") if err != nil || prefix == "" { prefix = "bd" } - newID := fmt.Sprintf("%s-%d", prefix, s.nextID) - s.nextID++ - s.idMu.Unlock() + nextID, err := s.getNextIDForPrefix(ctx, prefix) + if err != nil { + return nil, fmt.Errorf("failed to generate new ID for collision %s: %w", oldID, err) + } + newID := fmt.Sprintf("%s-%d", prefix, nextID) // Record mapping idMapping[oldID] = newID diff --git a/internal/storage/sqlite/schema.go b/internal/storage/sqlite/schema.go index dd00e1b4..f05de236 100644 --- a/internal/storage/sqlite/schema.go +++ b/internal/storage/sqlite/schema.go @@ -81,6 +81,12 @@ CREATE TABLE IF NOT EXISTS dirty_issues ( CREATE INDEX IF NOT EXISTS idx_dirty_issues_marked_at ON dirty_issues(marked_at); +-- Issue counters table (for atomic ID generation) +CREATE TABLE IF NOT EXISTS issue_counters ( + prefix TEXT PRIMARY KEY, + last_id INTEGER NOT NULL DEFAULT 0 +); + -- Ready work view CREATE VIEW IF NOT EXISTS ready_issues AS SELECT i.* diff --git a/internal/storage/sqlite/sqlite.go b/internal/storage/sqlite/sqlite.go index 8bfb0eeb..a964a09b 100644 --- a/internal/storage/sqlite/sqlite.go +++ b/internal/storage/sqlite/sqlite.go @@ -9,7 +9,6 @@ import ( "os" "path/filepath" "strings" - "sync" "time" // Import SQLite driver @@ -19,9 +18,7 @@ import ( // SQLiteStorage implements the Storage interface using SQLite type SQLiteStorage struct { - db *sql.DB - nextID int - idMu sync.Mutex // Protects nextID from concurrent access + db *sql.DB } // New creates a new SQLite storage backend @@ -53,12 +50,8 @@ func New(path string) (*SQLiteStorage, error) { return nil, fmt.Errorf("failed to migrate dirty_issues table: %w", err) } - // Get next ID - nextID := getNextID(db) - return &SQLiteStorage{ - db: db, - nextID: nextID, + db: db, }, nil } @@ -97,56 +90,42 @@ func migrateDirtyIssuesTable(db *sql.DB) error { return nil } -// getNextID determines the next issue ID to use -func getNextID(db *sql.DB) int { - // Get prefix from config, default to "bd" - var prefix string - err := db.QueryRow("SELECT value FROM config WHERE key = 'issue_prefix'").Scan(&prefix) - if err != nil || prefix == "" { - prefix = "bd" +// getNextIDForPrefix atomically generates the next ID for a given prefix +// Uses the issue_counters table for atomic, cross-process ID generation +func (s *SQLiteStorage) getNextIDForPrefix(ctx context.Context, prefix string) (int, error) { + var nextID int + err := s.db.QueryRowContext(ctx, ` + INSERT INTO issue_counters (prefix, last_id) + VALUES (?, 1) + ON CONFLICT(prefix) DO UPDATE SET + last_id = last_id + 1 + RETURNING last_id + `, prefix).Scan(&nextID) + if err != nil { + return 0, fmt.Errorf("failed to generate next ID for prefix %s: %w", prefix, err) } + return nextID, nil +} - // Find the maximum numeric ID for this prefix - // Use SUBSTR to extract numeric part after prefix and hyphen, then CAST to INTEGER - // This ensures we get numerical max, not alphabetical (bd-10 > bd-9, not bd-9 > bd-10) - var maxNum sql.NullInt64 - query := ` - SELECT MAX(CAST(SUBSTR(id, LENGTH(?) + 2) AS INTEGER)) +// SyncAllCounters synchronizes all ID counters based on existing issues in the database +// This scans all issues and updates counters to prevent ID collisions with auto-generated IDs +func (s *SQLiteStorage) SyncAllCounters(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO issue_counters (prefix, last_id) + SELECT + substr(id, 1, instr(id, '-') - 1) as prefix, + MAX(CAST(substr(id, instr(id, '-') + 1) AS INTEGER)) as max_id FROM issues - WHERE id LIKE ? || '-%' - ` - err = db.QueryRow(query, prefix, prefix).Scan(&maxNum) - if err != nil || !maxNum.Valid { - return 1 // Start from 1 if table is empty or no matching IDs + WHERE instr(id, '-') > 0 + AND substr(id, instr(id, '-') + 1) GLOB '[0-9]*' + GROUP BY prefix + ON CONFLICT(prefix) DO UPDATE SET + last_id = MAX(last_id, excluded.last_id) + `) + if err != nil { + return fmt.Errorf("failed to sync counters: %w", err) } - - // Check for malformed IDs (non-numeric suffixes) and warn - // SQLite's CAST returns 0 for invalid integers, never NULL - // So we detect malformed IDs by checking if CAST returns 0 AND suffix doesn't start with '0' - malformedQuery := ` - SELECT id FROM issues - WHERE id LIKE ? || '-%' - AND CAST(SUBSTR(id, LENGTH(?) + 2) AS INTEGER) = 0 - AND SUBSTR(id, LENGTH(?) + 2, 1) != '0' - ` - rows, err := db.Query(malformedQuery, prefix, prefix, prefix) - if err == nil { - defer rows.Close() - var malformedIDs []string - for rows.Next() { - var id string - if err := rows.Scan(&id); err == nil { - malformedIDs = append(malformedIDs, id) - } - } - if len(malformedIDs) > 0 { - fmt.Fprintf(os.Stderr, "Warning: Found %d malformed issue IDs with non-numeric suffixes: %v\n", - len(malformedIDs), malformedIDs) - fmt.Fprintf(os.Stderr, "These IDs are being ignored for ID generation. Consider fixing them.\n") - } - } - - return int(maxNum.Int64) + 1 + return nil } // CreateIssue creates a new issue @@ -156,9 +135,14 @@ func (s *SQLiteStorage) CreateIssue(ctx context.Context, issue *types.Issue, act return fmt.Errorf("validation failed: %w", err) } - // Generate ID if not set (thread-safe) + // Generate ID if not set (using atomic counter table) if issue.ID == "" { - s.idMu.Lock() + // Sync all counters first to ensure we don't collide with existing issues + // This handles the case where the database was created before this fix + // or issues were imported without syncing counters + if err := s.SyncAllCounters(ctx); err != nil { + return fmt.Errorf("failed to sync counters: %w", err) + } // Get prefix from config, default to "bd" prefix, err := s.GetConfig(ctx, "issue_prefix") @@ -166,9 +150,13 @@ func (s *SQLiteStorage) CreateIssue(ctx context.Context, issue *types.Issue, act prefix = "bd" } - issue.ID = fmt.Sprintf("%s-%d", prefix, s.nextID) - s.nextID++ - s.idMu.Unlock() + // Atomically get next ID from counter table + nextID, err := s.getNextIDForPrefix(ctx, prefix) + if err != nil { + return err + } + + issue.ID = fmt.Sprintf("%s-%d", prefix, nextID) } // Set timestamps diff --git a/internal/storage/sqlite/sqlite_test.go b/internal/storage/sqlite/sqlite_test.go index 6bfe997e..4805527f 100644 --- a/internal/storage/sqlite/sqlite_test.go +++ b/internal/storage/sqlite/sqlite_test.go @@ -359,7 +359,7 @@ func TestConcurrentIDGeneration(t *testing.T) { results := make(chan result, numIssues) - // Create issues concurrently + // Create issues concurrently (goroutines, not processes) for i := 0; i < numIssues; i++ { go func(n int) { issue := &types.Issue{ @@ -391,3 +391,92 @@ func TestConcurrentIDGeneration(t *testing.T) { t.Errorf("Expected %d unique IDs, got %d", numIssues, len(ids)) } } + +// TestMultiProcessIDGeneration tests ID generation across multiple processes +// This test simulates the real-world scenario of multiple `bd create` commands +// running in parallel, which is what triggers the race condition. +func TestMultiProcessIDGeneration(t *testing.T) { + // Create temporary directory + tmpDir, err := os.MkdirTemp("", "beads-multiprocess-test-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dbPath := filepath.Join(tmpDir, "test.db") + + // Initialize database + store, err := New(dbPath) + if err != nil { + t.Fatalf("failed to create storage: %v", err) + } + store.Close() + + // Spawn multiple processes that each open the DB and create an issue + const numProcesses = 20 + type result struct { + id string + err error + } + + results := make(chan result, numProcesses) + + for i := 0; i < numProcesses; i++ { + go func(n int) { + // Each goroutine simulates a separate process by opening a new connection + procStore, err := New(dbPath) + if err != nil { + results <- result{err: err} + return + } + defer procStore.Close() + + ctx := context.Background() + issue := &types.Issue{ + Title: "Multi-process test", + Status: types.StatusOpen, + Priority: 2, + IssueType: types.TypeTask, + } + + err = procStore.CreateIssue(ctx, issue, "test-user") + results <- result{id: issue.ID, err: err} + }(i) + } + + // Collect results + ids := make(map[string]bool) + var errors []error + + for i := 0; i < numProcesses; i++ { + res := <-results + if res.err != nil { + errors = append(errors, res.err) + continue + } + if ids[res.id] { + t.Errorf("Duplicate ID generated: %s", res.id) + } + ids[res.id] = true + } + + // With the bug, we expect UNIQUE constraint errors + if len(errors) > 0 { + t.Logf("Got %d errors (expected with current implementation):", len(errors)) + for _, err := range errors { + t.Logf(" - %v", err) + } + } + + t.Logf("Successfully created %d unique issues out of %d attempts", len(ids), numProcesses) + + // After the fix, all should succeed + if len(ids) != numProcesses { + t.Errorf("Expected %d unique IDs, got %d", numProcesses, len(ids)) + } + + if len(errors) > 0 { + t.Errorf("Expected no errors, got %d", len(errors)) + } +} +