diff --git a/internal/rpc/client.go b/internal/rpc/client.go index 1426feb5..86e83ac5 100644 --- a/internal/rpc/client.go +++ b/internal/rpc/client.go @@ -18,6 +18,7 @@ type Client struct { conn net.Conn socketPath string timeout time.Duration + dbPath string // Expected database path for validation } // TryConnect attempts to connect to the daemon socket @@ -92,21 +93,34 @@ func (c *Client) SetTimeout(timeout time.Duration) { c.timeout = timeout } +// SetDatabasePath sets the expected database path for validation +func (c *Client) SetDatabasePath(dbPath string) { + c.dbPath = dbPath +} + // Execute sends an RPC request and waits for a response func (c *Client) Execute(operation string, args interface{}) (*Response, error) { + return c.ExecuteWithCwd(operation, args, "") +} + +// ExecuteWithCwd sends an RPC request with an explicit cwd (or current dir if empty string) +func (c *Client) ExecuteWithCwd(operation string, args interface{}, cwd string) (*Response, error) { argsJSON, err := json.Marshal(args) if err != nil { return nil, fmt.Errorf("failed to marshal args: %w", err) } - // Get current working directory for database routing - cwd, _ := os.Getwd() + // Use provided cwd, or get current working directory for database routing + if cwd == "" { + cwd, _ = os.Getwd() + } req := Request{ Operation: operation, Args: argsJSON, ClientVersion: ClientVersion, Cwd: cwd, + ExpectedDB: c.dbPath, // Send expected database path for validation } reqJSON, err := json.Marshal(req) diff --git a/internal/rpc/protocol.go b/internal/rpc/protocol.go index 2ca12766..6cb1bc95 100644 --- a/internal/rpc/protocol.go +++ b/internal/rpc/protocol.go @@ -42,6 +42,7 @@ type Request struct { RequestID string `json:"request_id,omitempty"` Cwd string `json:"cwd,omitempty"` // Working directory for database discovery ClientVersion string `json:"client_version,omitempty"` // Client version for compatibility checks + ExpectedDB string `json:"expected_db,omitempty"` // Expected database path for validation (absolute) } // Response represents an RPC response from daemon to client diff --git a/internal/rpc/server.go b/internal/rpc/server.go index 91b9f24c..6a2c81bb 100644 --- a/internal/rpc/server.go +++ b/internal/rpc/server.go @@ -515,6 +515,54 @@ func (s *Server) checkVersionCompatibility(clientVersion string) error { return nil } +// validateDatabaseBinding validates that the client is connecting to the correct daemon +// Returns error if ExpectedDB is set and doesn't match the daemon's database path +func (s *Server) validateDatabaseBinding(req *Request) error { + // If client doesn't specify ExpectedDB, allow but log warning (old clients) + if req.ExpectedDB == "" { + // Log warning for audit trail + fmt.Fprintf(os.Stderr, "Warning: Client request without database binding validation (old client or missing ExpectedDB)\n") + return nil + } + + // For multi-database daemons: If a cwd is provided, verify the client expects + // the database that would be selected for that cwd + var daemonDB string + if req.Cwd != "" { + // Use the database discovery logic to find which DB would be used + dbPath := s.findDatabaseForCwd(req.Cwd) + if dbPath != "" { + daemonDB = dbPath + } else { + // No database found for cwd, will fall back to default storage + daemonDB = s.storage.Path() + } + } else { + // No cwd provided, use default storage + daemonDB = s.storage.Path() + } + + // Normalize both paths for comparison (resolve symlinks, clean paths) + expectedPath, err := filepath.EvalSymlinks(req.ExpectedDB) + if err != nil { + // If we can't resolve expected path, use it as-is + expectedPath = filepath.Clean(req.ExpectedDB) + } + daemonPath, err := filepath.EvalSymlinks(daemonDB) + if err != nil { + // If we can't resolve daemon path, use it as-is + daemonPath = filepath.Clean(daemonDB) + } + + // Compare paths + if expectedPath != daemonPath { + return fmt.Errorf("database mismatch: client expects %s but daemon serves %s. Wrong daemon connection - check socket path", + req.ExpectedDB, daemonDB) + } + + return nil +} + func (s *Server) handleRequest(req *Request) Response { // Track request timing start := time.Now() @@ -525,6 +573,17 @@ func (s *Server) handleRequest(req *Request) Response { s.metrics.RecordRequest(req.Operation, latency) }() + // Validate database binding (skip for health/metrics to allow diagnostics) + if req.Operation != OpHealth && req.Operation != OpMetrics { + if err := s.validateDatabaseBinding(req); err != nil { + s.metrics.RecordError(req.Operation) + return Response{ + Success: false, + Error: err.Error(), + } + } + } + // Check version compatibility (skip for ping/health to allow version checks) if req.Operation != OpPing && req.Operation != OpHealth { if err := s.checkVersionCompatibility(req.ClientVersion); err != nil { diff --git a/internal/storage/sqlite/sqlite.go b/internal/storage/sqlite/sqlite.go index 7d887edc..e48c8bdc 100644 --- a/internal/storage/sqlite/sqlite.go +++ b/internal/storage/sqlite/sqlite.go @@ -96,9 +96,15 @@ func New(path string) (*SQLiteStorage, error) { return nil, fmt.Errorf("failed to migrate compacted_at_commit column: %w", err) } + // Convert to absolute path for consistency + absPath, err := filepath.Abs(path) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path: %w", err) + } + return &SQLiteStorage{ db: db, - dbPath: path, + dbPath: absPath, }, nil } @@ -1862,3 +1868,8 @@ func (s *SQLiteStorage) GetIssueComments(ctx context.Context, issueID string) ([ func (s *SQLiteStorage) Close() error { return s.db.Close() } + +// Path returns the absolute path to the database file +func (s *SQLiteStorage) Path() string { + return s.dbPath +} diff --git a/internal/storage/sqlite/sqlite_test.go b/internal/storage/sqlite/sqlite_test.go index 3b418a5b..445be42e 100644 --- a/internal/storage/sqlite/sqlite_test.go +++ b/internal/storage/sqlite/sqlite_test.go @@ -1226,3 +1226,70 @@ func TestMetadataMultipleKeys(t *testing.T) { } } } + +func TestPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "beads-test-path-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Test with relative path + relPath := filepath.Join(tmpDir, "test.db") + store, err := New(relPath) + if err != nil { + t.Fatalf("failed to create storage: %v", err) + } + defer store.Close() + + // Path() should return absolute path + path := store.Path() + if !filepath.IsAbs(path) { + t.Errorf("Path() should return absolute path, got: %s", path) + } + + // Path should match the temp directory + expectedPath, _ := filepath.Abs(relPath) + if path != expectedPath { + t.Errorf("Path() returned %s, expected %s", path, expectedPath) + } +} + +func TestMultipleStorageDistinctPaths(t *testing.T) { + tmpDir1, err := os.MkdirTemp("", "beads-test-path1-*") + if err != nil { + t.Fatalf("failed to create temp dir 1: %v", err) + } + defer os.RemoveAll(tmpDir1) + + tmpDir2, err := os.MkdirTemp("", "beads-test-path2-*") + if err != nil { + t.Fatalf("failed to create temp dir 2: %v", err) + } + defer os.RemoveAll(tmpDir2) + + store1, err := New(filepath.Join(tmpDir1, "db1.db")) + if err != nil { + t.Fatalf("failed to create storage 1: %v", err) + } + defer store1.Close() + + store2, err := New(filepath.Join(tmpDir2, "db2.db")) + if err != nil { + t.Fatalf("failed to create storage 2: %v", err) + } + defer store2.Close() + + // Paths should be distinct + path1 := store1.Path() + path2 := store2.Path() + + if path1 == path2 { + t.Errorf("Multiple storage instances should have distinct paths, both returned: %s", path1) + } + + // Both should be absolute + if !filepath.IsAbs(path1) || !filepath.IsAbs(path2) { + t.Errorf("Both paths should be absolute: path1=%s, path2=%s", path1, path2) + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index e3a88099..b1a286d8 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -69,6 +69,9 @@ type Storage interface { // Lifecycle Close() error + + // Database path (for daemon validation) + Path() string } // Config holds database configuration