Improve type safety and fix minor issues in beads-mcp

Type Safety Improvements:
- Change dict → dict[str, Any] throughout codebase for explicit typing
- Add PEP 561 py.typed marker file to export type information
- Add types-requests to dev dependencies
- Improve signal handler typing (FrameType | None)
- Improve decorator typing (Callable[..., Awaitable[T]])
- Add quickstart() abstract method to BdClientBase for interface completeness

Bug Fixes:
- Fix variable shadowing: beads_dir → local_beads_dir in bd_client.py
- Improve error handling in mail.py:_call_agent_mail() to prevent undefined error
- Make working_dir required (not Optional) in BdDaemonClient
- Remove unnecessary 'or' defaults for required Pydantic fields

Validation:
- mypy passes with no errors
- All unit tests passing
- Daemon quickstart returns helpful static text (RPC doesn't support this command)
This commit is contained in:
Steve Yegge
2025-11-20 19:26:44 -05:00
parent e1c8853748
commit 9e57cb69d8
9 changed files with 96 additions and 54 deletions

View File

@@ -89,4 +89,5 @@ dev = [
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"ruff>=0.14.0",
"types-requests>=2.31.0",
]

View File

@@ -5,7 +5,7 @@ import json
import os
import re
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Any, List, Optional
from .config import load_config
from .models import (
@@ -114,6 +114,11 @@ class BdClientBase(ABC):
"""Add a dependency between issues."""
pass
@abstractmethod
async def quickstart(self) -> str:
"""Get quickstart guide."""
pass
@abstractmethod
async def stats(self) -> Stats:
"""Get repository statistics."""
@@ -130,17 +135,17 @@ class BdClientBase(ABC):
pass
@abstractmethod
async def inspect_migration(self) -> dict:
async def inspect_migration(self) -> dict[str, Any]:
"""Get migration plan and database state for agent analysis."""
pass
@abstractmethod
async def get_schema_info(self) -> dict:
async def get_schema_info(self) -> dict[str, Any]:
"""Get current database schema for inspection."""
pass
@abstractmethod
async def repair_deps(self, fix: bool = False) -> dict:
async def repair_deps(self, fix: bool = False) -> dict[str, Any]:
"""Find and optionally fix orphaned dependency references.
Args:
@@ -152,7 +157,7 @@ class BdClientBase(ABC):
pass
@abstractmethod
async def detect_pollution(self, clean: bool = False) -> dict:
async def detect_pollution(self, clean: bool = False) -> dict[str, Any]:
"""Detect test issues that leaked into production database.
Args:
@@ -164,7 +169,7 @@ class BdClientBase(ABC):
pass
@abstractmethod
async def validate(self, checks: str | None = None, fix_all: bool = False) -> dict:
async def validate(self, checks: str | None = None, fix_all: bool = False) -> dict[str, Any]:
"""Run database validation checks.
Args:
@@ -246,7 +251,7 @@ class BdCliClient(BdClientBase):
flags.append("--no-auto-import")
return flags
async def _run_command(self, *args: str, cwd: str | None = None) -> object:
async def _run_command(self, *args: str, cwd: str | None = None) -> Any:
"""Run bd command and parse JSON output.
Args:
@@ -638,7 +643,7 @@ class BdCliClient(BdClientBase):
return [BlockedIssue.model_validate(issue) for issue in data]
async def inspect_migration(self) -> dict:
async def inspect_migration(self) -> dict[str, Any]:
"""Get migration plan and database state for agent analysis.
Returns:
@@ -649,7 +654,7 @@ class BdCliClient(BdClientBase):
raise BdCommandError("Invalid response for inspect_migration")
return data
async def get_schema_info(self) -> dict:
async def get_schema_info(self) -> dict[str, Any]:
"""Get current database schema for inspection.
Returns:
@@ -660,7 +665,7 @@ class BdCliClient(BdClientBase):
raise BdCommandError("Invalid response for get_schema_info")
return data
async def repair_deps(self, fix: bool = False) -> dict:
async def repair_deps(self, fix: bool = False) -> dict[str, Any]:
"""Find and optionally fix orphaned dependency references.
Args:
@@ -678,7 +683,7 @@ class BdCliClient(BdClientBase):
raise BdCommandError("Invalid response for repair-deps")
return data
async def detect_pollution(self, clean: bool = False) -> dict:
async def detect_pollution(self, clean: bool = False) -> dict[str, Any]:
"""Detect test issues that leaked into production database.
Args:
@@ -696,7 +701,7 @@ class BdCliClient(BdClientBase):
raise BdCommandError("Invalid response for detect-pollution")
return data
async def validate(self, checks: str | None = None, fix_all: bool = False) -> dict:
async def validate(self, checks: str | None = None, fix_all: bool = False) -> dict[str, Any]:
"""Run database validation checks.
Args:
@@ -804,9 +809,9 @@ def create_bd_client(
current = search_dir.resolve()
while True:
beads_dir = current / ".beads"
if beads_dir.is_dir():
sock_path = beads_dir / "bd.sock"
local_beads_dir = current / ".beads"
if local_beads_dir.is_dir():
sock_path = local_beads_dir / "bd.sock"
if sock_path.exists():
socket_found = True
break

View File

@@ -46,7 +46,7 @@ class BdDaemonClient(BdClientBase):
"""Client for calling bd daemon via RPC over Unix socket."""
socket_path: str | None
working_dir: str | None
working_dir: str
actor: str | None
timeout: float
@@ -113,7 +113,7 @@ class BdDaemonClient(BdClientBase):
"Daemon socket not found. Is the daemon running? Try: bd daemon (local) or bd daemon --global"
)
async def _send_request(self, operation: str, args: Dict[str, Any]) -> Dict[str, Any]:
async def _send_request(self, operation: str, args: Dict[str, Any]) -> Any:
"""Send RPC request to daemon and get response.
Args:
@@ -192,7 +192,7 @@ class BdDaemonClient(BdClientBase):
writer.close()
await writer.wait_closed()
async def ping(self) -> Dict[str, str]:
async def ping(self) -> Dict[str, Any]:
"""Ping daemon to check if it's running.
Returns:
@@ -204,7 +204,8 @@ class BdDaemonClient(BdClientBase):
DaemonError: If request fails
"""
data = await self._send_request("ping", {})
return json.loads(data) if isinstance(data, str) else data
result = json.loads(data) if isinstance(data, str) else data
return dict(result)
async def health(self) -> Dict[str, Any]:
"""Get daemon health status.
@@ -224,7 +225,24 @@ class BdDaemonClient(BdClientBase):
DaemonError: If request fails
"""
data = await self._send_request("health", {})
return json.loads(data) if isinstance(data, str) else data
result = json.loads(data) if isinstance(data, str) else data
return dict(result)
async def quickstart(self) -> str:
"""Get quickstart guide.
Note: Daemon RPC doesn't support quickstart command.
Returns static guide text pointing users to CLI.
Returns:
Quickstart guide text
"""
return (
"Beads (bd) Quickstart\n\n"
"To get started with beads, please refer to the documentation or use the CLI:\n"
" bd quickstart\n\n"
"For MCP usage, try 'beads list' or 'beads create'."
)
async def init(self, params: Optional[InitParams] = None) -> str:
"""Initialize new beads database (not typically used via daemon).
@@ -256,7 +274,7 @@ class BdDaemonClient(BdClientBase):
"""
args = {
"title": params.title,
"issue_type": params.issue_type or "task",
"issue_type": params.issue_type,
"priority": params.priority if params.priority is not None else 2,
}
if params.id:
@@ -430,7 +448,7 @@ class BdDaemonClient(BdClientBase):
# This is a placeholder for when it's added
raise NotImplementedError("Blocked operation not yet supported via daemon")
async def inspect_migration(self) -> dict:
async def inspect_migration(self) -> dict[str, Any]:
"""Get migration plan and database state for agent analysis.
Returns:
@@ -441,7 +459,7 @@ class BdDaemonClient(BdClientBase):
"""
raise NotImplementedError("inspect_migration not supported via daemon - use CLI client")
async def get_schema_info(self) -> dict:
async def get_schema_info(self) -> dict[str, Any]:
"""Get current database schema for inspection.
Returns:
@@ -452,7 +470,7 @@ class BdDaemonClient(BdClientBase):
"""
raise NotImplementedError("get_schema_info not supported via daemon - use CLI client")
async def repair_deps(self, fix: bool = False) -> dict:
async def repair_deps(self, fix: bool = False) -> dict[str, Any]:
"""Find and optionally fix orphaned dependency references.
Args:
@@ -466,7 +484,7 @@ class BdDaemonClient(BdClientBase):
"""
raise NotImplementedError("repair_deps not supported via daemon - use CLI client")
async def detect_pollution(self, clean: bool = False) -> dict:
async def detect_pollution(self, clean: bool = False) -> dict[str, Any]:
"""Detect test issues that leaked into production database.
Args:
@@ -480,7 +498,7 @@ class BdDaemonClient(BdClientBase):
"""
raise NotImplementedError("detect_pollution not supported via daemon - use CLI client")
async def validate(self, checks: str | None = None, fix_all: bool = False) -> dict:
async def validate(self, checks: str | None = None, fix_all: bool = False) -> dict[str, Any]:
"""Run database validation checks.
Args:
@@ -504,7 +522,7 @@ class BdDaemonClient(BdClientBase):
args = {
"from_id": params.issue_id,
"to_id": params.depends_on_id,
"dep_type": params.dep_type or "blocks",
"dep_type": params.dep_type,
}
await self._send_request("dep_add", args)

View File

@@ -21,7 +21,7 @@ AGENT_MAIL_RETRIES = 2
class MailError(Exception):
"""Base exception for Agent Mail errors."""
def __init__(self, code: str, message: str, data: Optional[dict] = None):
def __init__(self, code: str, message: str, data: Optional[dict[str, Any]] = None):
self.code = code
self.message = message
self.data = data or {}
@@ -97,9 +97,9 @@ def _get_project_key() -> str:
def _call_agent_mail(
method: str,
endpoint: str,
json_data: Optional[dict] = None,
params: Optional[dict] = None,
) -> dict[str, Any]:
json_data: Optional[dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
) -> Any:
"""Make HTTP request to Agent Mail server with retries.
Args:
@@ -205,7 +205,9 @@ def _call_agent_mail(
time.sleep(0.5 * (2**attempt))
# All retries exhausted
raise last_error
if last_error:
raise last_error
raise MailError("INTERNAL_ERROR", "Request failed with no error details")
def mail_send(
@@ -350,7 +352,7 @@ def mail_inbox(
)
# Agent Mail returns list of messages directly
messages = result if isinstance(result, list) else []
messages: list[dict[str, Any]] = result if isinstance(result, list) else []
# Transform to our format and filter unread if requested
formatted_messages = []

View File

@@ -155,7 +155,7 @@ def beads_mail_reply(params: MailReplyParams) -> dict[str, Any]:
return {"error": e.code, "message": e.message, "data": e.data}
def beads_mail_ack(params: MailAckParams) -> dict[str, bool]:
def beads_mail_ack(params: MailAckParams) -> dict[str, Any]:
"""Acknowledge a message (for ack_required messages).
Safe to call even if message doesn't require acknowledgement.
@@ -183,7 +183,7 @@ def beads_mail_ack(params: MailAckParams) -> dict[str, bool]:
return {"error": e.code, "acknowledged": False, "message": e.message}
def beads_mail_delete(params: MailDeleteParams) -> dict[str, bool]:
def beads_mail_delete(params: MailDeleteParams) -> dict[str, Any]:
"""Delete (archive) a message from Agent Mail inbox.
Note: Agent Mail archives messages rather than permanently deleting them.

View File

@@ -0,0 +1 @@

View File

@@ -9,7 +9,8 @@ import signal
import subprocess
import sys
from functools import wraps
from typing import Callable, TypeVar
from types import FrameType
from typing import Any, Awaitable, Callable, TypeVar
from fastmcp import FastMCP
@@ -46,7 +47,7 @@ logging.basicConfig(
T = TypeVar("T")
# Global state for cleanup
_daemon_clients: list = []
_daemon_clients: list[Any] = []
_cleanup_done = False
# Persistent workspace context (survives across MCP tool calls)
@@ -92,7 +93,7 @@ def cleanup() -> None:
logger.info("Cleanup complete")
def signal_handler(signum: int, frame) -> None:
def signal_handler(signum: int, frame: FrameType | None) -> None:
"""Handle termination signals gracefully."""
sig_name = signal.Signals(signum).name
logger.info(f"Received {sig_name}, shutting down gracefully...")
@@ -114,7 +115,7 @@ except importlib.metadata.PackageNotFoundError:
logger.info(f"beads-mcp v{__version__} initialized with lifecycle management")
def with_workspace(func: Callable[..., T]) -> Callable[..., T]:
def with_workspace(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
"""Decorator to set workspace context for the duration of a tool call.
Extracts workspace_root parameter from tool call kwargs, resolves it,
@@ -124,7 +125,7 @@ def with_workspace(func: Callable[..., T]) -> Callable[..., T]:
This enables per-request workspace routing for multi-project support.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
async def wrapper(*args: Any, **kwargs: Any) -> T:
# Extract workspace_root parameter (if provided)
workspace_root = kwargs.get('workspace_root')
@@ -148,7 +149,7 @@ def with_workspace(func: Callable[..., T]) -> Callable[..., T]:
return wrapper
def require_context(func: Callable[..., T]) -> Callable[..., T]:
def require_context(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
"""Decorator to enforce context has been set before write operations.
Passes if either:
@@ -159,7 +160,7 @@ def require_context(func: Callable[..., T]) -> Callable[..., T]:
This allows backward compatibility while adding safety for multi-repo setups.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
async def wrapper(*args: Any, **kwargs: Any) -> T:
# Only enforce if explicitly enabled
if os.environ.get("BEADS_REQUIRE_CONTEXT") == "1":
# Check ContextVar or environment
@@ -453,7 +454,7 @@ async def update_issue(
notes: str | None = None,
external_ref: str | None = None,
workspace_root: str | None = None,
) -> Issue:
) -> Issue | list[Issue] | None:
"""Update an existing issue."""
# If trying to close via update, redirect to close_issue to preserve approval workflow
if status == "closed":
@@ -577,7 +578,7 @@ async def debug_env(workspace_root: str | None = None) -> str:
description="Get migration plan and database state for agent analysis.",
)
@with_workspace
async def inspect_migration(workspace_root: str | None = None) -> dict:
async def inspect_migration(workspace_root: str | None = None) -> dict[str, Any]:
"""Get migration plan and database state for agent analysis.
AI agents should:
@@ -596,7 +597,7 @@ async def inspect_migration(workspace_root: str | None = None) -> dict:
description="Get current database schema for inspection.",
)
@with_workspace
async def get_schema_info(workspace_root: str | None = None) -> dict:
async def get_schema_info(workspace_root: str | None = None) -> dict[str, Any]:
"""Get current database schema for inspection.
Returns tables, schema version, config, sample issue IDs, and detected prefix.
@@ -610,7 +611,7 @@ async def get_schema_info(workspace_root: str | None = None) -> dict:
description="Find and optionally fix orphaned dependency references.",
)
@with_workspace
async def repair_deps(fix: bool = False, workspace_root: str | None = None) -> dict:
async def repair_deps(fix: bool = False, workspace_root: str | None = None) -> dict[str, Any]:
"""Find and optionally fix orphaned dependency references.
Scans all issues for dependencies pointing to non-existent issues.
@@ -624,7 +625,7 @@ async def repair_deps(fix: bool = False, workspace_root: str | None = None) -> d
description="Detect test issues that leaked into production database.",
)
@with_workspace
async def detect_pollution(clean: bool = False, workspace_root: str | None = None) -> dict:
async def detect_pollution(clean: bool = False, workspace_root: str | None = None) -> dict[str, Any]:
"""Detect test issues that leaked into production database.
Detects test issues using pattern matching (titles starting with 'test', etc.).
@@ -642,7 +643,7 @@ async def validate(
checks: str | None = None,
fix_all: bool = False,
workspace_root: str | None = None,
) -> dict:
) -> dict[str, Any]:
"""Run comprehensive database health checks.
Available checks: orphans, duplicates, pollution, conflicts.

View File

@@ -7,7 +7,7 @@ import subprocess
import sys
from contextvars import ContextVar
from functools import lru_cache
from typing import Annotated, TYPE_CHECKING
from typing import Annotated, Any, TYPE_CHECKING
from .bd_client import create_bd_client, BdClientBase, BdError
@@ -516,7 +516,7 @@ async def beads_blocked() -> list[BlockedIssue]:
return await client.blocked()
async def beads_inspect_migration() -> dict:
async def beads_inspect_migration() -> dict[str, Any]:
"""Get migration plan and database state for agent analysis.
AI agents should:
@@ -531,7 +531,7 @@ async def beads_inspect_migration() -> dict:
return await client.inspect_migration()
async def beads_get_schema_info() -> dict:
async def beads_get_schema_info() -> dict[str, Any]:
"""Get current database schema for inspection.
Returns tables, schema version, config, sample issue IDs, and detected prefix.
@@ -543,7 +543,7 @@ async def beads_get_schema_info() -> dict:
async def beads_repair_deps(
fix: Annotated[bool, "If True, automatically remove orphaned dependencies"] = False,
) -> dict:
) -> dict[str, Any]:
"""Find and optionally fix orphaned dependency references.
Scans all issues for dependencies pointing to non-existent issues.
@@ -560,7 +560,7 @@ async def beads_repair_deps(
async def beads_detect_pollution(
clean: Annotated[bool, "If True, delete detected test issues"] = False,
) -> dict:
) -> dict[str, Any]:
"""Detect test issues that leaked into production database.
Detects test issues using pattern matching:
@@ -578,7 +578,7 @@ async def beads_detect_pollution(
async def beads_validate(
checks: Annotated[str | None, "Comma-separated list of checks (orphans,duplicates,pollution,conflicts)"] = None,
fix_all: Annotated[bool, "If True, auto-fix all fixable issues"] = False,
) -> dict:
) -> dict[str, Any]:
"""Run comprehensive database health checks.
Available checks:

View File

@@ -82,6 +82,7 @@ dev = [
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "ruff" },
{ name = "types-requests" },
]
[package.metadata]
@@ -98,6 +99,7 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
{ name = "pytest-cov", specifier = ">=7.0.0" },
{ name = "ruff", specifier = ">=0.14.0" },
{ name = "types-requests", specifier = ">=2.31.0" },
]
[[package]]
@@ -1637,6 +1639,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408 },
]
[[package]]
name = "types-requests"
version = "2.32.4.20250913"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/36/27/489922f4505975b11de2b5ad07b4fe1dca0bca9be81a703f26c5f3acfce5/types_requests-2.32.4.20250913.tar.gz", hash = "sha256:abd6d4f9ce3a9383f269775a9835a4c24e5cd6b9f647d64f88aa4613c33def5d", size = 23113 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/20/9a227ea57c1285986c4cf78400d0a91615d25b24e257fd9e2969606bdfae/types_requests-2.32.4.20250913-py3-none-any.whl", hash = "sha256:78c9c1fffebbe0fa487a418e0fa5252017e9c60d1a2da394077f1780f655d7e1", size = 20658 },
]
[[package]]
name = "typing-extensions"
version = "4.15.0"