Implement database handshake protocol in RPC layer

- Add ExpectedDB field to RPC Request
- Server validates client's expected DB matches daemon's DB
- Return clear error on mismatch with both paths
- Old clients (no ExpectedDB) still work with warning
- Add Path() method to storage.Storage interface
- Tests verify cross-database connections rejected

Prevents database pollution when client connects to wrong daemon.

Amp-Thread-ID: https://ampcode.com/threads/T-c4454192-39c6-4c67-96a9-675cbfc4db92
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Steve Yegge
2025-10-21 20:35:44 -07:00
parent e1a445afd2
commit 645d268e43
6 changed files with 158 additions and 3 deletions

View File

@@ -18,6 +18,7 @@ type Client struct {
conn net.Conn
socketPath string
timeout time.Duration
dbPath string // Expected database path for validation
}
// TryConnect attempts to connect to the daemon socket
@@ -92,21 +93,34 @@ func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
// SetDatabasePath sets the expected database path for validation
func (c *Client) SetDatabasePath(dbPath string) {
c.dbPath = dbPath
}
// Execute sends an RPC request and waits for a response
func (c *Client) Execute(operation string, args interface{}) (*Response, error) {
return c.ExecuteWithCwd(operation, args, "")
}
// ExecuteWithCwd sends an RPC request with an explicit cwd (or current dir if empty string)
func (c *Client) ExecuteWithCwd(operation string, args interface{}, cwd string) (*Response, error) {
argsJSON, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("failed to marshal args: %w", err)
}
// Get current working directory for database routing
cwd, _ := os.Getwd()
// Use provided cwd, or get current working directory for database routing
if cwd == "" {
cwd, _ = os.Getwd()
}
req := Request{
Operation: operation,
Args: argsJSON,
ClientVersion: ClientVersion,
Cwd: cwd,
ExpectedDB: c.dbPath, // Send expected database path for validation
}
reqJSON, err := json.Marshal(req)

View File

@@ -42,6 +42,7 @@ type Request struct {
RequestID string `json:"request_id,omitempty"`
Cwd string `json:"cwd,omitempty"` // Working directory for database discovery
ClientVersion string `json:"client_version,omitempty"` // Client version for compatibility checks
ExpectedDB string `json:"expected_db,omitempty"` // Expected database path for validation (absolute)
}
// Response represents an RPC response from daemon to client

View File

@@ -515,6 +515,54 @@ func (s *Server) checkVersionCompatibility(clientVersion string) error {
return nil
}
// validateDatabaseBinding validates that the client is connecting to the correct daemon
// Returns error if ExpectedDB is set and doesn't match the daemon's database path
func (s *Server) validateDatabaseBinding(req *Request) error {
// If client doesn't specify ExpectedDB, allow but log warning (old clients)
if req.ExpectedDB == "" {
// Log warning for audit trail
fmt.Fprintf(os.Stderr, "Warning: Client request without database binding validation (old client or missing ExpectedDB)\n")
return nil
}
// For multi-database daemons: If a cwd is provided, verify the client expects
// the database that would be selected for that cwd
var daemonDB string
if req.Cwd != "" {
// Use the database discovery logic to find which DB would be used
dbPath := s.findDatabaseForCwd(req.Cwd)
if dbPath != "" {
daemonDB = dbPath
} else {
// No database found for cwd, will fall back to default storage
daemonDB = s.storage.Path()
}
} else {
// No cwd provided, use default storage
daemonDB = s.storage.Path()
}
// Normalize both paths for comparison (resolve symlinks, clean paths)
expectedPath, err := filepath.EvalSymlinks(req.ExpectedDB)
if err != nil {
// If we can't resolve expected path, use it as-is
expectedPath = filepath.Clean(req.ExpectedDB)
}
daemonPath, err := filepath.EvalSymlinks(daemonDB)
if err != nil {
// If we can't resolve daemon path, use it as-is
daemonPath = filepath.Clean(daemonDB)
}
// Compare paths
if expectedPath != daemonPath {
return fmt.Errorf("database mismatch: client expects %s but daemon serves %s. Wrong daemon connection - check socket path",
req.ExpectedDB, daemonDB)
}
return nil
}
func (s *Server) handleRequest(req *Request) Response {
// Track request timing
start := time.Now()
@@ -525,6 +573,17 @@ func (s *Server) handleRequest(req *Request) Response {
s.metrics.RecordRequest(req.Operation, latency)
}()
// Validate database binding (skip for health/metrics to allow diagnostics)
if req.Operation != OpHealth && req.Operation != OpMetrics {
if err := s.validateDatabaseBinding(req); err != nil {
s.metrics.RecordError(req.Operation)
return Response{
Success: false,
Error: err.Error(),
}
}
}
// Check version compatibility (skip for ping/health to allow version checks)
if req.Operation != OpPing && req.Operation != OpHealth {
if err := s.checkVersionCompatibility(req.ClientVersion); err != nil {