Files
beads/internal/storage/sqlite/sqlite.go
Steve Yegge 3aeeeb752c Fix malformed ID detection to actually work (bd-54)
SQLite's CAST to INTEGER never returns NULL - it returns 0 for
invalid strings. This meant the malformed ID detection query was
completely broken and never found any malformed IDs.

The Problem:
- Query used: CAST(suffix AS INTEGER) IS NULL
- SQLite behavior: CAST('abc' AS INTEGER) = 0 (not NULL!)
- Result: Malformed IDs were never detected

The Fix:
- Check if CAST returns 0 AND suffix doesn't start with '0'
- This catches non-numeric suffixes like 'abc', 'foo123'
- Avoids false positives on legitimate IDs like 'test-0', 'test-007'

Changes:
- internal/storage/sqlite/sqlite.go:126-131
  * Updated malformed ID query logic
  * Check: CAST = 0 AND first char != '0'
  * Added third parameter for prefix (used 3 times now)

Testing:
- Created test DB with test-abc, test-1, test-foo123
- Warning correctly shows: [test-abc test-foo123] ✓
- Added test-0, test-007 (zero-prefixed IDs)
- No false positives ✓
- All existing tests pass ✓

Impact:
- Malformed IDs are now properly detected and warned about
- Helps maintain data quality
- Prevents confusion when auto-incrementing IDs

Closes bd-54

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-14 00:32:42 -07:00

528 lines
14 KiB
Go

// Package sqlite implements the storage interface using SQLite.
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
// Import SQLite driver
_ "github.com/mattn/go-sqlite3"
"github.com/steveyegge/beads/internal/types"
)
// SQLiteStorage implements the Storage interface using SQLite
type SQLiteStorage struct {
db *sql.DB
nextID int
idMu sync.Mutex // Protects nextID from concurrent access
}
// New creates a new SQLite storage backend
func New(path string) (*SQLiteStorage, error) {
// Ensure directory exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create directory: %w", err)
}
// Open database with WAL mode for better concurrency
db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=ON")
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Test connection
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
// Initialize schema
if _, err := db.Exec(schema); err != nil {
return nil, fmt.Errorf("failed to initialize schema: %w", err)
}
// Migrate existing databases to add dirty_issues table if missing
if err := migrateDirtyIssuesTable(db); err != nil {
return nil, fmt.Errorf("failed to migrate dirty_issues table: %w", err)
}
// Get next ID
nextID := getNextID(db)
return &SQLiteStorage{
db: db,
nextID: nextID,
}, nil
}
// migrateDirtyIssuesTable checks if the dirty_issues table exists and creates it if missing.
// This ensures existing databases created before the incremental export feature get migrated automatically.
func migrateDirtyIssuesTable(db *sql.DB) error {
// Check if dirty_issues table exists
var tableName string
err := db.QueryRow(`
SELECT name FROM sqlite_master
WHERE type='table' AND name='dirty_issues'
`).Scan(&tableName)
if err == sql.ErrNoRows {
// Table doesn't exist, create it
_, err := db.Exec(`
CREATE TABLE dirty_issues (
issue_id TEXT PRIMARY KEY,
marked_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (issue_id) REFERENCES issues(id) ON DELETE CASCADE
);
CREATE INDEX idx_dirty_issues_marked_at ON dirty_issues(marked_at);
`)
if err != nil {
return fmt.Errorf("failed to create dirty_issues table: %w", err)
}
// Table created successfully - no need to log, happens silently
return nil
}
if err != nil {
return fmt.Errorf("failed to check for dirty_issues table: %w", err)
}
// Table exists, no migration needed
return nil
}
// getNextID determines the next issue ID to use
func getNextID(db *sql.DB) int {
// Get prefix from config, default to "bd"
var prefix string
err := db.QueryRow("SELECT value FROM config WHERE key = 'issue_prefix'").Scan(&prefix)
if err != nil || prefix == "" {
prefix = "bd"
}
// Find the maximum numeric ID for this prefix
// Use SUBSTR to extract numeric part after prefix and hyphen, then CAST to INTEGER
// This ensures we get numerical max, not alphabetical (bd-10 > bd-9, not bd-9 > bd-10)
var maxNum sql.NullInt64
query := `
SELECT MAX(CAST(SUBSTR(id, LENGTH(?) + 2) AS INTEGER))
FROM issues
WHERE id LIKE ? || '-%'
`
err = db.QueryRow(query, prefix, prefix).Scan(&maxNum)
if err != nil || !maxNum.Valid {
return 1 // Start from 1 if table is empty or no matching IDs
}
// Check for malformed IDs (non-numeric suffixes) and warn
// SQLite's CAST returns 0 for invalid integers, never NULL
// So we detect malformed IDs by checking if CAST returns 0 AND suffix doesn't start with '0'
malformedQuery := `
SELECT id FROM issues
WHERE id LIKE ? || '-%'
AND CAST(SUBSTR(id, LENGTH(?) + 2) AS INTEGER) = 0
AND SUBSTR(id, LENGTH(?) + 2, 1) != '0'
`
rows, err := db.Query(malformedQuery, prefix, prefix, prefix)
if err == nil {
defer rows.Close()
var malformedIDs []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err == nil {
malformedIDs = append(malformedIDs, id)
}
}
if len(malformedIDs) > 0 {
fmt.Fprintf(os.Stderr, "Warning: Found %d malformed issue IDs with non-numeric suffixes: %v\n",
len(malformedIDs), malformedIDs)
fmt.Fprintf(os.Stderr, "These IDs are being ignored for ID generation. Consider fixing them.\n")
}
}
return int(maxNum.Int64) + 1
}
// CreateIssue creates a new issue
func (s *SQLiteStorage) CreateIssue(ctx context.Context, issue *types.Issue, actor string) error {
// Validate issue before creating
if err := issue.Validate(); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
// Generate ID if not set (thread-safe)
if issue.ID == "" {
s.idMu.Lock()
// Get prefix from config, default to "bd"
prefix, err := s.GetConfig(ctx, "issue_prefix")
if err != nil || prefix == "" {
prefix = "bd"
}
issue.ID = fmt.Sprintf("%s-%d", prefix, s.nextID)
s.nextID++
s.idMu.Unlock()
}
// Set timestamps
now := time.Now()
issue.CreatedAt = now
issue.UpdatedAt = now
// Start transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
// Insert issue
_, err = tx.ExecContext(ctx, `
INSERT INTO issues (
id, title, description, design, acceptance_criteria, notes,
status, priority, issue_type, assignee, estimated_minutes,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
issue.ID, issue.Title, issue.Description, issue.Design,
issue.AcceptanceCriteria, issue.Notes, issue.Status,
issue.Priority, issue.IssueType, issue.Assignee,
issue.EstimatedMinutes, issue.CreatedAt, issue.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to insert issue: %w", err)
}
// Record creation event
eventData, err := json.Marshal(issue)
if err != nil {
// Fall back to minimal description if marshaling fails
eventData = []byte(fmt.Sprintf(`{"id":"%s","title":"%s"}`, issue.ID, issue.Title))
}
eventDataStr := string(eventData)
_, err = tx.ExecContext(ctx, `
INSERT INTO events (issue_id, event_type, actor, new_value)
VALUES (?, ?, ?, ?)
`, issue.ID, types.EventCreated, actor, eventDataStr)
if err != nil {
return fmt.Errorf("failed to record event: %w", err)
}
// Mark issue as dirty for incremental export
_, err = tx.ExecContext(ctx, `
INSERT INTO dirty_issues (issue_id, marked_at)
VALUES (?, ?)
ON CONFLICT (issue_id) DO UPDATE SET marked_at = excluded.marked_at
`, issue.ID, time.Now())
if err != nil {
return fmt.Errorf("failed to mark issue dirty: %w", err)
}
return tx.Commit()
}
// GetIssue retrieves an issue by ID
func (s *SQLiteStorage) GetIssue(ctx context.Context, id string) (*types.Issue, error) {
var issue types.Issue
var closedAt sql.NullTime
var estimatedMinutes sql.NullInt64
var assignee sql.NullString
err := s.db.QueryRowContext(ctx, `
SELECT id, title, description, design, acceptance_criteria, notes,
status, priority, issue_type, assignee, estimated_minutes,
created_at, updated_at, closed_at
FROM issues
WHERE id = ?
`, id).Scan(
&issue.ID, &issue.Title, &issue.Description, &issue.Design,
&issue.AcceptanceCriteria, &issue.Notes, &issue.Status,
&issue.Priority, &issue.IssueType, &assignee, &estimatedMinutes,
&issue.CreatedAt, &issue.UpdatedAt, &closedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get issue: %w", err)
}
if closedAt.Valid {
issue.ClosedAt = &closedAt.Time
}
if estimatedMinutes.Valid {
mins := int(estimatedMinutes.Int64)
issue.EstimatedMinutes = &mins
}
if assignee.Valid {
issue.Assignee = assignee.String
}
return &issue, nil
}
// Allowed fields for update to prevent SQL injection
var allowedUpdateFields = map[string]bool{
"status": true,
"priority": true,
"title": true,
"assignee": true,
"description": true,
"design": true,
"acceptance_criteria": true,
"notes": true,
"issue_type": true,
"estimated_minutes": true,
}
// UpdateIssue updates fields on an issue
func (s *SQLiteStorage) UpdateIssue(ctx context.Context, id string, updates map[string]interface{}, actor string) error {
// Get old issue for event
oldIssue, err := s.GetIssue(ctx, id)
if err != nil {
return err
}
if oldIssue == nil {
return fmt.Errorf("issue %s not found", id)
}
// Build update query with validated field names
setClauses := []string{"updated_at = ?"}
args := []interface{}{time.Now()}
for key, value := range updates {
// Prevent SQL injection by validating field names
if !allowedUpdateFields[key] {
return fmt.Errorf("invalid field for update: %s", key)
}
// Validate field values
switch key {
case "priority":
if priority, ok := value.(int); ok {
if priority < 0 || priority > 4 {
return fmt.Errorf("priority must be between 0 and 4 (got %d)", priority)
}
}
case "status":
if status, ok := value.(string); ok {
if !types.Status(status).IsValid() {
return fmt.Errorf("invalid status: %s", status)
}
}
case "issue_type":
if issueType, ok := value.(string); ok {
if !types.IssueType(issueType).IsValid() {
return fmt.Errorf("invalid issue type: %s", issueType)
}
}
case "title":
if title, ok := value.(string); ok {
if len(title) == 0 || len(title) > 500 {
return fmt.Errorf("title must be 1-500 characters")
}
}
case "estimated_minutes":
if mins, ok := value.(int); ok {
if mins < 0 {
return fmt.Errorf("estimated_minutes cannot be negative")
}
}
}
setClauses = append(setClauses, fmt.Sprintf("%s = ?", key))
args = append(args, value)
}
args = append(args, id)
// Start transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
// Update issue
query := fmt.Sprintf("UPDATE issues SET %s WHERE id = ?", strings.Join(setClauses, ", "))
_, err = tx.ExecContext(ctx, query, args...)
if err != nil {
return fmt.Errorf("failed to update issue: %w", err)
}
// Record event
oldData, err := json.Marshal(oldIssue)
if err != nil {
// Fall back to minimal description if marshaling fails
oldData = []byte(fmt.Sprintf(`{"id":"%s"}`, id))
}
newData, err := json.Marshal(updates)
if err != nil {
// Fall back to minimal description if marshaling fails
newData = []byte(`{}`)
}
oldDataStr := string(oldData)
newDataStr := string(newData)
eventType := types.EventUpdated
if statusVal, ok := updates["status"]; ok {
if statusVal == string(types.StatusClosed) {
eventType = types.EventClosed
} else {
eventType = types.EventStatusChanged
}
}
_, err = tx.ExecContext(ctx, `
INSERT INTO events (issue_id, event_type, actor, old_value, new_value)
VALUES (?, ?, ?, ?, ?)
`, id, eventType, actor, oldDataStr, newDataStr)
if err != nil {
return fmt.Errorf("failed to record event: %w", err)
}
// Mark issue as dirty for incremental export
_, err = tx.ExecContext(ctx, `
INSERT INTO dirty_issues (issue_id, marked_at)
VALUES (?, ?)
ON CONFLICT (issue_id) DO UPDATE SET marked_at = excluded.marked_at
`, id, time.Now())
if err != nil {
return fmt.Errorf("failed to mark issue dirty: %w", err)
}
return tx.Commit()
}
// CloseIssue closes an issue with a reason
func (s *SQLiteStorage) CloseIssue(ctx context.Context, id string, reason string, actor string) error {
now := time.Now()
// Update with special event handling
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback()
_, err = tx.ExecContext(ctx, `
UPDATE issues SET status = ?, closed_at = ?, updated_at = ?
WHERE id = ?
`, types.StatusClosed, now, now, id)
if err != nil {
return fmt.Errorf("failed to close issue: %w", err)
}
_, err = tx.ExecContext(ctx, `
INSERT INTO events (issue_id, event_type, actor, comment)
VALUES (?, ?, ?, ?)
`, id, types.EventClosed, actor, reason)
if err != nil {
return fmt.Errorf("failed to record event: %w", err)
}
// Mark issue as dirty for incremental export
_, err = tx.ExecContext(ctx, `
INSERT INTO dirty_issues (issue_id, marked_at)
VALUES (?, ?)
ON CONFLICT (issue_id) DO UPDATE SET marked_at = excluded.marked_at
`, id, time.Now())
if err != nil {
return fmt.Errorf("failed to mark issue dirty: %w", err)
}
return tx.Commit()
}
// SearchIssues finds issues matching query and filters
func (s *SQLiteStorage) SearchIssues(ctx context.Context, query string, filter types.IssueFilter) ([]*types.Issue, error) {
whereClauses := []string{}
args := []interface{}{}
if query != "" {
whereClauses = append(whereClauses, "(title LIKE ? OR description LIKE ? OR id LIKE ?)")
pattern := "%" + query + "%"
args = append(args, pattern, pattern, pattern)
}
if filter.Status != nil {
whereClauses = append(whereClauses, "status = ?")
args = append(args, *filter.Status)
}
if filter.Priority != nil {
whereClauses = append(whereClauses, "priority = ?")
args = append(args, *filter.Priority)
}
if filter.IssueType != nil {
whereClauses = append(whereClauses, "issue_type = ?")
args = append(args, *filter.IssueType)
}
if filter.Assignee != nil {
whereClauses = append(whereClauses, "assignee = ?")
args = append(args, *filter.Assignee)
}
whereSQL := ""
if len(whereClauses) > 0 {
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
}
limitSQL := ""
if filter.Limit > 0 {
limitSQL = " LIMIT ?"
args = append(args, filter.Limit)
}
querySQL := fmt.Sprintf(`
SELECT id, title, description, design, acceptance_criteria, notes,
status, priority, issue_type, assignee, estimated_minutes,
created_at, updated_at, closed_at
FROM issues
%s
ORDER BY priority ASC, created_at DESC
%s
`, whereSQL, limitSQL)
rows, err := s.db.QueryContext(ctx, querySQL, args...)
if err != nil {
return nil, fmt.Errorf("failed to search issues: %w", err)
}
defer rows.Close()
return scanIssues(rows)
}
// SetConfig sets a configuration value
func (s *SQLiteStorage) SetConfig(ctx context.Context, key, value string) error {
_, err := s.db.ExecContext(ctx, `
INSERT INTO config (key, value) VALUES (?, ?)
ON CONFLICT (key) DO UPDATE SET value = excluded.value
`, key, value)
return err
}
// GetConfig gets a configuration value
func (s *SQLiteStorage) GetConfig(ctx context.Context, key string) (string, error) {
var value string
err := s.db.QueryRowContext(ctx, `SELECT value FROM config WHERE key = ?`, key).Scan(&value)
if err == sql.ErrNoRows {
return "", nil
}
return value, err
}
// Close closes the database connection
func (s *SQLiteStorage) Close() error {
return s.db.Close()
}