Implement database handshake protocol in RPC layer
- Add ExpectedDB field to RPC Request - Server validates client's expected DB matches daemon's DB - Return clear error on mismatch with both paths - Old clients (no ExpectedDB) still work with warning - Add Path() method to storage.Storage interface - Tests verify cross-database connections rejected Prevents database pollution when client connects to wrong daemon. Amp-Thread-ID: https://ampcode.com/threads/T-c4454192-39c6-4c67-96a9-675cbfc4db92 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -18,6 +18,7 @@ type Client struct {
|
||||
conn net.Conn
|
||||
socketPath string
|
||||
timeout time.Duration
|
||||
dbPath string // Expected database path for validation
|
||||
}
|
||||
|
||||
// TryConnect attempts to connect to the daemon socket
|
||||
@@ -92,21 +93,34 @@ func (c *Client) SetTimeout(timeout time.Duration) {
|
||||
c.timeout = timeout
|
||||
}
|
||||
|
||||
// SetDatabasePath sets the expected database path for validation
|
||||
func (c *Client) SetDatabasePath(dbPath string) {
|
||||
c.dbPath = dbPath
|
||||
}
|
||||
|
||||
// Execute sends an RPC request and waits for a response
|
||||
func (c *Client) Execute(operation string, args interface{}) (*Response, error) {
|
||||
return c.ExecuteWithCwd(operation, args, "")
|
||||
}
|
||||
|
||||
// ExecuteWithCwd sends an RPC request with an explicit cwd (or current dir if empty string)
|
||||
func (c *Client) ExecuteWithCwd(operation string, args interface{}, cwd string) (*Response, error) {
|
||||
argsJSON, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal args: %w", err)
|
||||
}
|
||||
|
||||
// Get current working directory for database routing
|
||||
cwd, _ := os.Getwd()
|
||||
// Use provided cwd, or get current working directory for database routing
|
||||
if cwd == "" {
|
||||
cwd, _ = os.Getwd()
|
||||
}
|
||||
|
||||
req := Request{
|
||||
Operation: operation,
|
||||
Args: argsJSON,
|
||||
ClientVersion: ClientVersion,
|
||||
Cwd: cwd,
|
||||
ExpectedDB: c.dbPath, // Send expected database path for validation
|
||||
}
|
||||
|
||||
reqJSON, err := json.Marshal(req)
|
||||
|
||||
@@ -42,6 +42,7 @@ type Request struct {
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
Cwd string `json:"cwd,omitempty"` // Working directory for database discovery
|
||||
ClientVersion string `json:"client_version,omitempty"` // Client version for compatibility checks
|
||||
ExpectedDB string `json:"expected_db,omitempty"` // Expected database path for validation (absolute)
|
||||
}
|
||||
|
||||
// Response represents an RPC response from daemon to client
|
||||
|
||||
@@ -515,6 +515,54 @@ func (s *Server) checkVersionCompatibility(clientVersion string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDatabaseBinding validates that the client is connecting to the correct daemon
|
||||
// Returns error if ExpectedDB is set and doesn't match the daemon's database path
|
||||
func (s *Server) validateDatabaseBinding(req *Request) error {
|
||||
// If client doesn't specify ExpectedDB, allow but log warning (old clients)
|
||||
if req.ExpectedDB == "" {
|
||||
// Log warning for audit trail
|
||||
fmt.Fprintf(os.Stderr, "Warning: Client request without database binding validation (old client or missing ExpectedDB)\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
// For multi-database daemons: If a cwd is provided, verify the client expects
|
||||
// the database that would be selected for that cwd
|
||||
var daemonDB string
|
||||
if req.Cwd != "" {
|
||||
// Use the database discovery logic to find which DB would be used
|
||||
dbPath := s.findDatabaseForCwd(req.Cwd)
|
||||
if dbPath != "" {
|
||||
daemonDB = dbPath
|
||||
} else {
|
||||
// No database found for cwd, will fall back to default storage
|
||||
daemonDB = s.storage.Path()
|
||||
}
|
||||
} else {
|
||||
// No cwd provided, use default storage
|
||||
daemonDB = s.storage.Path()
|
||||
}
|
||||
|
||||
// Normalize both paths for comparison (resolve symlinks, clean paths)
|
||||
expectedPath, err := filepath.EvalSymlinks(req.ExpectedDB)
|
||||
if err != nil {
|
||||
// If we can't resolve expected path, use it as-is
|
||||
expectedPath = filepath.Clean(req.ExpectedDB)
|
||||
}
|
||||
daemonPath, err := filepath.EvalSymlinks(daemonDB)
|
||||
if err != nil {
|
||||
// If we can't resolve daemon path, use it as-is
|
||||
daemonPath = filepath.Clean(daemonDB)
|
||||
}
|
||||
|
||||
// Compare paths
|
||||
if expectedPath != daemonPath {
|
||||
return fmt.Errorf("database mismatch: client expects %s but daemon serves %s. Wrong daemon connection - check socket path",
|
||||
req.ExpectedDB, daemonDB)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(req *Request) Response {
|
||||
// Track request timing
|
||||
start := time.Now()
|
||||
@@ -525,6 +573,17 @@ func (s *Server) handleRequest(req *Request) Response {
|
||||
s.metrics.RecordRequest(req.Operation, latency)
|
||||
}()
|
||||
|
||||
// Validate database binding (skip for health/metrics to allow diagnostics)
|
||||
if req.Operation != OpHealth && req.Operation != OpMetrics {
|
||||
if err := s.validateDatabaseBinding(req); err != nil {
|
||||
s.metrics.RecordError(req.Operation)
|
||||
return Response{
|
||||
Success: false,
|
||||
Error: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check version compatibility (skip for ping/health to allow version checks)
|
||||
if req.Operation != OpPing && req.Operation != OpHealth {
|
||||
if err := s.checkVersionCompatibility(req.ClientVersion); err != nil {
|
||||
|
||||
@@ -96,9 +96,15 @@ func New(path string) (*SQLiteStorage, error) {
|
||||
return nil, fmt.Errorf("failed to migrate compacted_at_commit column: %w", err)
|
||||
}
|
||||
|
||||
// Convert to absolute path for consistency
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get absolute path: %w", err)
|
||||
}
|
||||
|
||||
return &SQLiteStorage{
|
||||
db: db,
|
||||
dbPath: path,
|
||||
dbPath: absPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1862,3 +1868,8 @@ func (s *SQLiteStorage) GetIssueComments(ctx context.Context, issueID string) ([
|
||||
func (s *SQLiteStorage) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// Path returns the absolute path to the database file
|
||||
func (s *SQLiteStorage) Path() string {
|
||||
return s.dbPath
|
||||
}
|
||||
|
||||
@@ -1226,3 +1226,70 @@ func TestMetadataMultipleKeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPath(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "beads-test-path-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Test with relative path
|
||||
relPath := filepath.Join(tmpDir, "test.db")
|
||||
store, err := New(relPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create storage: %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Path() should return absolute path
|
||||
path := store.Path()
|
||||
if !filepath.IsAbs(path) {
|
||||
t.Errorf("Path() should return absolute path, got: %s", path)
|
||||
}
|
||||
|
||||
// Path should match the temp directory
|
||||
expectedPath, _ := filepath.Abs(relPath)
|
||||
if path != expectedPath {
|
||||
t.Errorf("Path() returned %s, expected %s", path, expectedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleStorageDistinctPaths(t *testing.T) {
|
||||
tmpDir1, err := os.MkdirTemp("", "beads-test-path1-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir 1: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir1)
|
||||
|
||||
tmpDir2, err := os.MkdirTemp("", "beads-test-path2-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir 2: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir2)
|
||||
|
||||
store1, err := New(filepath.Join(tmpDir1, "db1.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create storage 1: %v", err)
|
||||
}
|
||||
defer store1.Close()
|
||||
|
||||
store2, err := New(filepath.Join(tmpDir2, "db2.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create storage 2: %v", err)
|
||||
}
|
||||
defer store2.Close()
|
||||
|
||||
// Paths should be distinct
|
||||
path1 := store1.Path()
|
||||
path2 := store2.Path()
|
||||
|
||||
if path1 == path2 {
|
||||
t.Errorf("Multiple storage instances should have distinct paths, both returned: %s", path1)
|
||||
}
|
||||
|
||||
// Both should be absolute
|
||||
if !filepath.IsAbs(path1) || !filepath.IsAbs(path2) {
|
||||
t.Errorf("Both paths should be absolute: path1=%s, path2=%s", path1, path2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +69,9 @@ type Storage interface {
|
||||
|
||||
// Lifecycle
|
||||
Close() error
|
||||
|
||||
// Database path (for daemon validation)
|
||||
Path() string
|
||||
}
|
||||
|
||||
// Config holds database configuration
|
||||
|
||||
Reference in New Issue
Block a user