Add health checks and reconnection logic for stale daemon sockets (bd-137)
- Add ping() and health() methods to BdDaemonClient for connection verification - Implement _health_check_client() to verify cached client connections - Add _reconnect_client() with exponential backoff (0.1s, 0.2s, 0.4s, max 3 retries) - Update _get_client() to health-check before returning cached clients - Automatically detect and remove stale connections from pool - Add comprehensive test suite with 14 tests covering all scenarios - Handle daemon restarts, upgrades, and long-idle connections gracefully Amp-Thread-ID: https://ampcode.com/threads/T-2366ef1b-389c-4293-8145-7613037c9dfa Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -200,10 +200,33 @@ class BdDaemonClient(BdClientBase):
|
||||
|
||||
Raises:
|
||||
DaemonNotRunningError: If daemon is not running
|
||||
DaemonConnectionError: If connection fails
|
||||
DaemonError: If request fails
|
||||
"""
|
||||
data = await self._send_request("ping", {})
|
||||
return json.loads(data) if isinstance(data, str) else data
|
||||
|
||||
async def health(self) -> Dict[str, Any]:
|
||||
"""Get daemon health status.
|
||||
|
||||
Returns:
|
||||
Dict with health info including:
|
||||
- status: "healthy" | "degraded" | "unhealthy"
|
||||
- version: daemon version string
|
||||
- uptime: uptime in seconds
|
||||
- cache_size: number of cached databases
|
||||
- db_response_time_ms: database ping time
|
||||
- active_connections: number of active connections
|
||||
- memory_bytes: memory usage
|
||||
|
||||
Raises:
|
||||
DaemonNotRunningError: If daemon is not running
|
||||
DaemonConnectionError: If connection fails
|
||||
DaemonError: If request fails
|
||||
"""
|
||||
data = await self._send_request("health", {})
|
||||
return json.loads(data) if isinstance(data, str) else data
|
||||
|
||||
async def init(self, params: Optional[InitParams] = None) -> str:
|
||||
"""Initialize new beads database (not typically used via daemon).
|
||||
|
||||
|
||||
@@ -110,12 +110,76 @@ def _canonicalize_path(path: str) -> str:
|
||||
return _resolve_workspace_root(real)
|
||||
|
||||
|
||||
async def _health_check_client(client: BdClientBase) -> bool:
|
||||
"""Check if a client is healthy and responsive.
|
||||
|
||||
Args:
|
||||
client: Client to health check
|
||||
|
||||
Returns:
|
||||
True if client is healthy, False otherwise
|
||||
"""
|
||||
# Only health check daemon clients
|
||||
if not hasattr(client, 'ping'):
|
||||
return True
|
||||
|
||||
try:
|
||||
await client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
# Any exception means the client is stale/unhealthy
|
||||
return False
|
||||
|
||||
|
||||
async def _reconnect_client(canonical: str, max_retries: int = 3) -> BdClientBase:
|
||||
"""Attempt to reconnect to daemon with exponential backoff.
|
||||
|
||||
Args:
|
||||
canonical: Canonical workspace path
|
||||
max_retries: Maximum number of retry attempts (default: 3)
|
||||
|
||||
Returns:
|
||||
New client instance
|
||||
|
||||
Raises:
|
||||
BdError: If all reconnection attempts fail
|
||||
"""
|
||||
use_daemon = os.environ.get("BEADS_USE_DAEMON", "1") == "1"
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
client = create_bd_client(
|
||||
prefer_daemon=use_daemon,
|
||||
working_dir=canonical
|
||||
)
|
||||
|
||||
# Verify new client works
|
||||
if await _health_check_client(client):
|
||||
_register_client_for_cleanup(client)
|
||||
return client
|
||||
|
||||
except Exception:
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff: 0.1s, 0.2s, 0.4s
|
||||
backoff = 0.1 * (2 ** attempt)
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
|
||||
raise BdError(
|
||||
f"Failed to connect to daemon after {max_retries} attempts. "
|
||||
"The daemon may be stopped or unresponsive."
|
||||
)
|
||||
|
||||
|
||||
async def _get_client() -> BdClientBase:
|
||||
"""Get a BdClient instance for the current workspace.
|
||||
|
||||
Uses connection pool to manage per-project daemon sockets.
|
||||
Workspace is determined by current_workspace ContextVar or BEADS_WORKING_DIR env.
|
||||
|
||||
Performs health check before returning cached client.
|
||||
On failure, drops from pool and attempts reconnection with exponential backoff.
|
||||
|
||||
Performs version check on first connection to each workspace.
|
||||
Uses daemon client if available, falls back to CLI client.
|
||||
|
||||
@@ -137,7 +201,19 @@ async def _get_client() -> BdClientBase:
|
||||
|
||||
# Thread-safe connection pool access
|
||||
async with _pool_lock:
|
||||
if canonical not in _connection_pool:
|
||||
if canonical in _connection_pool:
|
||||
# Health check cached client before returning
|
||||
client = _connection_pool[canonical]
|
||||
if not await _health_check_client(client):
|
||||
# Stale connection - remove from pool and reconnect
|
||||
del _connection_pool[canonical]
|
||||
if canonical in _version_checked:
|
||||
_version_checked.remove(canonical)
|
||||
|
||||
# Attempt reconnection with backoff
|
||||
client = await _reconnect_client(canonical)
|
||||
_connection_pool[canonical] = client
|
||||
else:
|
||||
# Create new client for this workspace
|
||||
use_daemon = os.environ.get("BEADS_USE_DAEMON", "1") == "1"
|
||||
|
||||
@@ -151,8 +227,6 @@ async def _get_client() -> BdClientBase:
|
||||
|
||||
# Add to pool
|
||||
_connection_pool[canonical] = client
|
||||
|
||||
client = _connection_pool[canonical]
|
||||
|
||||
# Check version once per workspace (only for CLI client)
|
||||
if canonical not in _version_checked:
|
||||
|
||||
304
integrations/beads-mcp/tests/test_daemon_health_check.py
Normal file
304
integrations/beads-mcp/tests/test_daemon_health_check.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Tests for daemon health check and reconnection logic."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from beads_mcp.bd_client import BdError
|
||||
from beads_mcp.bd_daemon_client import (
|
||||
BdDaemonClient,
|
||||
DaemonConnectionError,
|
||||
DaemonError,
|
||||
DaemonNotRunningError,
|
||||
)
|
||||
from beads_mcp.tools import _get_client, _health_check_client, _reconnect_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_daemon_client_ping_success():
|
||||
"""Test successful ping to daemon."""
|
||||
client = BdDaemonClient(socket_path="/tmp/bd.sock", working_dir="/tmp/test")
|
||||
|
||||
with patch.object(client, '_send_request', new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {"message": "pong", "version": "0.9.10"}
|
||||
|
||||
result = await client.ping()
|
||||
|
||||
assert result["message"] == "pong"
|
||||
assert result["version"] == "0.9.10"
|
||||
mock_send.assert_called_once_with("ping", {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_daemon_client_ping_connection_error():
|
||||
"""Test ping when daemon connection fails."""
|
||||
client = BdDaemonClient(socket_path="/tmp/bd.sock", working_dir="/tmp/test")
|
||||
|
||||
with patch.object(client, '_send_request', new_callable=AsyncMock) as mock_send:
|
||||
mock_send.side_effect = DaemonConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(DaemonConnectionError):
|
||||
await client.ping()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_daemon_client_health_success():
|
||||
"""Test successful health check to daemon."""
|
||||
client = BdDaemonClient(socket_path="/tmp/bd.sock", working_dir="/tmp/test")
|
||||
|
||||
with patch.object(client, '_send_request', new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {
|
||||
"status": "healthy",
|
||||
"version": "0.9.10",
|
||||
"uptime": 123.45,
|
||||
"cache_size": 5,
|
||||
"db_response_time_ms": 2.5,
|
||||
"active_connections": 3,
|
||||
"memory_bytes": 104857600,
|
||||
}
|
||||
|
||||
result = await client.health()
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["version"] == "0.9.10"
|
||||
assert result["uptime"] == 123.45
|
||||
mock_send.assert_called_once_with("health", {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_daemon_client_health_unhealthy():
|
||||
"""Test health check when daemon is unhealthy."""
|
||||
client = BdDaemonClient(socket_path="/tmp/bd.sock", working_dir="/tmp/test")
|
||||
|
||||
with patch.object(client, '_send_request', new_callable=AsyncMock) as mock_send:
|
||||
mock_send.return_value = {
|
||||
"status": "unhealthy",
|
||||
"error": "Database connection failed",
|
||||
}
|
||||
|
||||
result = await client.health()
|
||||
|
||||
assert result["status"] == "unhealthy"
|
||||
assert "error" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_client_daemon_client_healthy():
|
||||
"""Test health check for healthy daemon client."""
|
||||
client = BdDaemonClient(socket_path="/tmp/bd.sock", working_dir="/tmp/test")
|
||||
|
||||
with patch.object(client, 'ping', new_callable=AsyncMock) as mock_ping:
|
||||
mock_ping.return_value = {"message": "pong", "version": "0.9.10"}
|
||||
|
||||
result = await _health_check_client(client)
|
||||
|
||||
assert result is True
|
||||
mock_ping.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_client_daemon_client_unhealthy():
|
||||
"""Test health check for unhealthy daemon client."""
|
||||
client = BdDaemonClient(socket_path="/tmp/bd.sock", working_dir="/tmp/test")
|
||||
|
||||
with patch.object(client, 'ping', new_callable=AsyncMock) as mock_ping:
|
||||
mock_ping.side_effect = DaemonConnectionError("Connection failed")
|
||||
|
||||
result = await _health_check_client(client)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_client_cli_client():
|
||||
"""Test health check for CLI client (always returns True)."""
|
||||
from beads_mcp.bd_client import BdClient
|
||||
|
||||
client = BdClient(bd_path="/usr/bin/bd", beads_db="/tmp/test.db")
|
||||
|
||||
result = await _health_check_client(client)
|
||||
|
||||
# CLI clients don't have ping, so they're always considered healthy
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_client_success():
|
||||
"""Test successful reconnection after failure."""
|
||||
from beads_mcp.bd_client import create_bd_client
|
||||
|
||||
with (
|
||||
patch('beads_mcp.tools.create_bd_client') as mock_create,
|
||||
patch('beads_mcp.tools._health_check_client', new_callable=AsyncMock) as mock_health,
|
||||
patch('beads_mcp.tools._register_client_for_cleanup') as mock_register,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_create.return_value = mock_client
|
||||
mock_health.return_value = True
|
||||
|
||||
result = await _reconnect_client("/tmp/test")
|
||||
|
||||
assert result == mock_client
|
||||
mock_create.assert_called_once_with(prefer_daemon=True, working_dir="/tmp/test")
|
||||
mock_register.assert_called_once_with(mock_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_client_retry_with_backoff():
|
||||
"""Test reconnection with exponential backoff on failure."""
|
||||
# Need to patch asyncio.sleep in the actual module where it's called
|
||||
import beads_mcp.tools as tools_module
|
||||
|
||||
with (
|
||||
patch.object(tools_module, 'create_bd_client') as mock_create,
|
||||
patch.object(tools_module, '_health_check_client', new_callable=AsyncMock) as mock_health,
|
||||
patch.object(tools_module, '_register_client_for_cleanup') as mock_register,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
|
||||
# Raise exception first two times, succeed third time
|
||||
mock_create.side_effect = [
|
||||
Exception("Connection failed"),
|
||||
Exception("Connection failed"),
|
||||
mock_client,
|
||||
]
|
||||
mock_health.return_value = True
|
||||
|
||||
# Mock asyncio.sleep to track calls
|
||||
sleep_calls = []
|
||||
async def mock_sleep(duration):
|
||||
sleep_calls.append(duration)
|
||||
# Don't actually sleep in tests
|
||||
return
|
||||
|
||||
with patch.object(asyncio, 'sleep', side_effect=mock_sleep):
|
||||
result = await _reconnect_client("/tmp/test", max_retries=3)
|
||||
|
||||
assert result == mock_client
|
||||
assert mock_create.call_count == 3
|
||||
assert len(sleep_calls) == 2
|
||||
|
||||
# Verify exponential backoff: 0.1s, 0.2s
|
||||
assert sleep_calls[0] == 0.1
|
||||
assert sleep_calls[1] == 0.2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_client_max_retries_exceeded():
|
||||
"""Test reconnection failure after max retries."""
|
||||
with (
|
||||
patch('beads_mcp.tools.create_bd_client') as mock_create,
|
||||
patch('beads_mcp.tools._health_check_client', new_callable=AsyncMock) as mock_health,
|
||||
patch('asyncio.sleep', new_callable=AsyncMock),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_create.return_value = mock_client
|
||||
mock_health.return_value = False # Always fail health check
|
||||
|
||||
with pytest.raises(BdError, match="Failed to connect to daemon after 3 attempts"):
|
||||
await _reconnect_client("/tmp/test", max_retries=3)
|
||||
|
||||
assert mock_create.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_uses_cached_healthy_client(monkeypatch):
|
||||
"""Test that _get_client returns cached client if healthy."""
|
||||
from beads_mcp import tools
|
||||
|
||||
# Set up environment
|
||||
monkeypatch.setenv("BEADS_WORKING_DIR", "/tmp/test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client._check_version = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('beads_mcp.tools._canonicalize_path', return_value="/tmp/test"),
|
||||
patch('beads_mcp.tools._health_check_client', new_callable=AsyncMock) as mock_health,
|
||||
):
|
||||
mock_health.return_value = True
|
||||
|
||||
# Add mock client to pool and mark as version checked
|
||||
tools._connection_pool["/tmp/test"] = mock_client
|
||||
tools._version_checked.add("/tmp/test")
|
||||
|
||||
result = await _get_client()
|
||||
|
||||
assert result == mock_client
|
||||
mock_health.assert_called_once_with(mock_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_reconnects_on_stale_connection(monkeypatch):
|
||||
"""Test that _get_client reconnects when cached client is stale."""
|
||||
from beads_mcp import tools
|
||||
|
||||
# Set up environment
|
||||
monkeypatch.setenv("BEADS_WORKING_DIR", "/tmp/test")
|
||||
|
||||
old_client = MagicMock()
|
||||
new_client = MagicMock()
|
||||
new_client._check_version = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('beads_mcp.tools._canonicalize_path', return_value="/tmp/test"),
|
||||
patch('beads_mcp.tools._health_check_client', new_callable=AsyncMock) as mock_health,
|
||||
patch('beads_mcp.tools._reconnect_client', new_callable=AsyncMock) as mock_reconnect,
|
||||
):
|
||||
# First health check fails (stale), reconnect returns new client
|
||||
mock_health.return_value = False
|
||||
mock_reconnect.return_value = new_client
|
||||
|
||||
# Add old client to pool
|
||||
tools._connection_pool["/tmp/test"] = old_client
|
||||
tools._version_checked.add("/tmp/test")
|
||||
|
||||
result = await _get_client()
|
||||
|
||||
assert result == new_client
|
||||
assert tools._connection_pool["/tmp/test"] == new_client
|
||||
# Version check is performed after reconnect, so it's back in the set
|
||||
assert "/tmp/test" in tools._version_checked
|
||||
mock_reconnect.assert_called_once_with("/tmp/test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_creates_new_client_if_not_cached(monkeypatch):
|
||||
"""Test that _get_client creates new client if not in pool."""
|
||||
from beads_mcp import tools
|
||||
|
||||
# Clear pool
|
||||
tools._connection_pool.clear()
|
||||
tools._version_checked.clear()
|
||||
|
||||
# Set up environment
|
||||
monkeypatch.setenv("BEADS_WORKING_DIR", "/tmp/test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client._check_version = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('beads_mcp.tools._canonicalize_path', return_value="/tmp/test"),
|
||||
patch('beads_mcp.tools.create_bd_client', return_value=mock_client) as mock_create,
|
||||
patch('beads_mcp.tools._register_client_for_cleanup') as mock_register,
|
||||
):
|
||||
result = await _get_client()
|
||||
|
||||
assert result == mock_client
|
||||
assert tools._connection_pool["/tmp/test"] == mock_client
|
||||
mock_create.assert_called_once_with(prefer_daemon=True, working_dir="/tmp/test")
|
||||
mock_register.assert_called_once_with(mock_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_no_workspace_error():
|
||||
"""Test that _get_client raises error if no workspace is set."""
|
||||
from beads_mcp import tools
|
||||
|
||||
# Clear context
|
||||
tools.current_workspace.set(None)
|
||||
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with pytest.raises(BdError, match="No workspace set"):
|
||||
await _get_client()
|
||||
Reference in New Issue
Block a user