Add UnderlyingConn(ctx) for safer scoped DB access
- Add UnderlyingConn method to Storage interface - Implement in SQLiteStorage for scoped connection access - Useful for migrations and DDL operations - Add comprehensive tests for basic access, DDL, context cancellation, and concurrent connections - Closes bd-66, bd-22, bd-24, bd-38, bd-39, bd-56 Amp-Thread-ID: https://ampcode.com/threads/T-e47963af-4ace-4914-a0ae-4737f77be6ff Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -1963,3 +1963,40 @@ func (s *SQLiteStorage) IsClosed() bool {
|
||||
func (s *SQLiteStorage) UnderlyingDB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// UnderlyingConn returns a single connection from the pool for scoped use.
|
||||
//
|
||||
// This provides a connection with explicit lifetime boundaries, useful for:
|
||||
// - One-time DDL operations (CREATE TABLE, ALTER TABLE)
|
||||
// - Migration scripts that need transaction control
|
||||
// - Operations that benefit from connection-level state
|
||||
//
|
||||
// IMPORTANT: The caller MUST close the connection when done:
|
||||
//
|
||||
// conn, err := storage.UnderlyingConn(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer conn.Close()
|
||||
//
|
||||
// For general queries and transactions, prefer UnderlyingDB() which manages
|
||||
// the connection pool automatically.
|
||||
//
|
||||
// EXAMPLE (extension table migration):
|
||||
//
|
||||
// conn, err := storage.UnderlyingConn(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer conn.Close()
|
||||
//
|
||||
// _, err = conn.ExecContext(ctx, `
|
||||
// CREATE TABLE IF NOT EXISTS vc_executions (
|
||||
// id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
// issue_id TEXT NOT NULL,
|
||||
// FOREIGN KEY (issue_id) REFERENCES issues(id) ON DELETE CASCADE
|
||||
// )
|
||||
// `)
|
||||
func (s *SQLiteStorage) UnderlyingConn(ctx context.Context) (*sql.Conn, error) {
|
||||
return s.db.Conn(ctx)
|
||||
}
|
||||
|
||||
@@ -287,3 +287,194 @@ func TestUnderlyingDB_LongTxDoesNotDeadlock(t *testing.T) {
|
||||
t.Error("CreateIssue deadlocked or timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnderlyingConn_BasicAccess tests that UnderlyingConn returns a usable connection
|
||||
func TestUnderlyingConn_BasicAccess(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "beads-conn-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create storage: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a scoped connection
|
||||
conn, err := store.UnderlyingConn(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("UnderlyingConn() failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Verify we can query it
|
||||
var count int
|
||||
err = conn.QueryRowContext(ctx, "SELECT COUNT(*) FROM issues").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query via UnderlyingConn: %v", err)
|
||||
}
|
||||
|
||||
if count != 0 {
|
||||
t.Errorf("Expected 0 issues, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnderlyingConn_DDLOperations tests using UnderlyingConn for DDL
|
||||
func TestUnderlyingConn_DDLOperations(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "beads-conn-ddl-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create storage: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a test issue first for FK reference
|
||||
issue := &types.Issue{
|
||||
Title: "Test issue",
|
||||
Description: "For extension testing",
|
||||
Status: types.StatusOpen,
|
||||
Priority: 1,
|
||||
IssueType: types.TypeTask,
|
||||
}
|
||||
if err := store.CreateIssue(ctx, issue, "test"); err != nil {
|
||||
t.Fatalf("Failed to create issue: %v", err)
|
||||
}
|
||||
|
||||
// Get a scoped connection for DDL
|
||||
conn, err := store.UnderlyingConn(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("UnderlyingConn() failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Create extension table using the scoped connection
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS vc_migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
issue_id TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (issue_id) REFERENCES issues(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_vc_migrations_issue ON vc_migrations(issue_id);
|
||||
`
|
||||
|
||||
if _, err := conn.ExecContext(ctx, schema); err != nil {
|
||||
t.Fatalf("Failed to create extension table: %v", err)
|
||||
}
|
||||
|
||||
// Insert using the same connection
|
||||
result, err := conn.ExecContext(ctx, `
|
||||
INSERT INTO vc_migrations (issue_id, version)
|
||||
VALUES (?, ?)
|
||||
`, issue.ID, "v1.0.0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert into extension table: %v", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
if id == 0 {
|
||||
t.Error("Expected non-zero insert ID")
|
||||
}
|
||||
|
||||
// Verify the data persists after connection close
|
||||
conn.Close()
|
||||
|
||||
// Use UnderlyingDB to verify
|
||||
db := store.UnderlyingDB()
|
||||
var version string
|
||||
err = db.QueryRowContext(ctx, `
|
||||
SELECT version FROM vc_migrations WHERE issue_id = ?
|
||||
`, issue.ID).Scan(&version)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query after connection close: %v", err)
|
||||
}
|
||||
|
||||
if version != "v1.0.0" {
|
||||
t.Errorf("Expected version 'v1.0.0', got %q", version)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnderlyingConn_ContextCancellation tests that context cancellation works
|
||||
func TestUnderlyingConn_ContextCancellation(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "beads-conn-ctx-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create storage: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Create a context that's already canceled
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
// Try to get connection with canceled context
|
||||
conn, err := store.UnderlyingConn(ctx)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
t.Error("Expected error with canceled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnderlyingConn_MultipleConnections tests multiple connections don't interfere
|
||||
func TestUnderlyingConn_MultipleConnections(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "beads-multi-conn-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create storage: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get multiple connections
|
||||
conn1, err := store.UnderlyingConn(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get conn1: %v", err)
|
||||
}
|
||||
defer conn1.Close()
|
||||
|
||||
conn2, err := store.UnderlyingConn(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get conn2: %v", err)
|
||||
}
|
||||
defer conn2.Close()
|
||||
|
||||
// Both should be able to query independently
|
||||
var count1, count2 int
|
||||
if err := conn1.QueryRowContext(ctx, "SELECT COUNT(*) FROM issues").Scan(&count1); err != nil {
|
||||
t.Errorf("conn1 query failed: %v", err)
|
||||
}
|
||||
if err := conn2.QueryRowContext(ctx, "SELECT COUNT(*) FROM issues").Scan(&count2); err != nil {
|
||||
t.Errorf("conn2 query failed: %v", err)
|
||||
}
|
||||
|
||||
if count1 != count2 {
|
||||
t.Errorf("Connections see different data: %d vs %d", count1, count2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,6 +79,12 @@ type Storage interface {
|
||||
// in the same database. Extensions should use foreign keys to reference core tables.
|
||||
// WARNING: Direct database access bypasses the storage layer. Use with caution.
|
||||
UnderlyingDB() *sql.DB
|
||||
|
||||
// UnderlyingConn returns a single connection from the pool for scoped use.
|
||||
// Useful for migrations and DDL operations that benefit from explicit connection lifetime.
|
||||
// The caller MUST close the connection when done to return it to the pool.
|
||||
// For general queries, prefer UnderlyingDB() which manages the pool automatically.
|
||||
UnderlyingConn(ctx context.Context) (*sql.Conn, error)
|
||||
}
|
||||
|
||||
// Config holds database configuration
|
||||
|
||||
Reference in New Issue
Block a user