Code review fixes for cache eviction (bd-145)

Oracle-recommended improvements:

**Thread Safety:**
- Fix lastAccess race: updates now under Write lock
- Make Stop() idempotent with sync.Once
- Close stores synchronously (not in goroutine)

**Performance:**
- Replace O(n²) sort with sort.Slice (O(n log n))
- Enforce LRU immediately on insert (prevents FD spikes)

**Correctness:**
- Canonicalize cache key to repo root (not cwd)
  - Prevents duplicate connections for same repo
  - Multiple subdirs → single cache entry
- Validate env vars: TTL <= 0 falls back to default

**Tests (6 new edge cases):**
- Canonical key behavior across subdirectories
- Immediate LRU enforcement without periodic cleanup
- Invalid TTL handling
- Re-open after eviction
- Stop idempotency
- All tests pass with -race flag

This addresses potential data races, resource spikes, and duplicate
connections identified during code review.
This commit is contained in:
Steve Yegge
2025-10-18 13:24:59 -07:00
parent 259e994522
commit f987722f96
3 changed files with 294 additions and 58 deletions

View File

@@ -9,6 +9,7 @@ import (
"os"
"os/signal"
"path/filepath"
"sort"
"sync"
"syscall"
"time"
@@ -32,8 +33,9 @@ type Server struct {
mu sync.RWMutex
shutdown bool
shutdownChan chan struct{}
stopOnce sync.Once
// Per-request storage routing with eviction support
storageCache map[string]*StorageCacheEntry // path -> entry
storageCache map[string]*StorageCacheEntry // repoRoot -> entry
cacheMu sync.RWMutex
maxCacheSize int
cacheTTL time.Duration
@@ -54,7 +56,7 @@ func NewServer(socketPath string, store storage.Storage) *Server {
cacheTTL := 30 * time.Minute // default
if env := os.Getenv("BEADS_DAEMON_CACHE_TTL"); env != "" {
if ttl, err := time.ParseDuration(env); err == nil {
if ttl, err := time.ParseDuration(env); err == nil && ttl > 0 {
cacheTTL = ttl
}
}
@@ -112,38 +114,43 @@ func (s *Server) Start(ctx context.Context) error {
// Stop stops the RPC server and cleans up resources
func (s *Server) Stop() error {
s.mu.Lock()
s.shutdown = true
s.mu.Unlock()
var err error
s.stopOnce.Do(func() {
s.mu.Lock()
s.shutdown = true
s.mu.Unlock()
// Signal cleanup goroutine to stop
close(s.shutdownChan)
if s.cleanupTicker != nil {
s.cleanupTicker.Stop()
}
// Signal cleanup goroutine to stop
close(s.shutdownChan)
// Close all cached storage connections
s.cacheMu.Lock()
for _, entry := range s.storageCache {
if err := entry.store.Close(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to close storage: %v\n", err)
// Close all cached storage connections outside lock
s.cacheMu.Lock()
stores := make([]storage.Storage, 0, len(s.storageCache))
for _, entry := range s.storageCache {
stores = append(stores, entry.store)
}
}
s.storageCache = make(map[string]*StorageCacheEntry)
s.cacheMu.Unlock()
s.storageCache = make(map[string]*StorageCacheEntry)
s.cacheMu.Unlock()
if s.listener != nil {
if err := s.listener.Close(); err != nil {
return fmt.Errorf("failed to close listener: %w", err)
// Close stores without holding lock
for _, store := range stores {
if closeErr := store.Close(); closeErr != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to close storage: %v\n", closeErr)
}
}
}
if err := s.removeOldSocket(); err != nil {
return fmt.Errorf("failed to remove socket: %w", err)
}
if s.listener != nil {
if closeErr := s.listener.Close(); closeErr != nil {
err = fmt.Errorf("failed to close listener: %w", closeErr)
return
}
}
return nil
if removeErr := s.removeOldSocket(); removeErr != nil {
err = fmt.Errorf("failed to remove socket: %w", removeErr)
}
})
return err
}
func (s *Server) ensureSocketDir() error {
@@ -191,7 +198,6 @@ func (s *Server) evictStaleStorage() {
toClose := []storage.Storage{}
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
// First pass: evict TTL-expired entries
for path, entry := range s.storageCache {
@@ -205,22 +211,18 @@ func (s *Server) evictStaleStorage() {
if len(s.storageCache) > s.maxCacheSize {
// Build sorted list of entries by lastAccess
type cacheItem struct {
path string
entry *StorageCacheEntry
path string
entry *StorageCacheEntry
}
items := make([]cacheItem, 0, len(s.storageCache))
for path, entry := range s.storageCache {
items = append(items, cacheItem{path, entry})
}
// Sort by lastAccess (oldest first)
for i := 0; i < len(items)-1; i++ {
for j := i + 1; j < len(items); j++ {
if items[i].entry.lastAccess.After(items[j].entry.lastAccess) {
items[i], items[j] = items[j], items[i]
}
}
}
// Sort by lastAccess (oldest first) with sort.Slice
sort.Slice(items, func(i, j int) bool {
return items[i].entry.lastAccess.Before(items[j].entry.lastAccess)
})
// Evict oldest entries until we're under the limit
numToEvict := len(s.storageCache) - s.maxCacheSize
@@ -230,14 +232,15 @@ func (s *Server) evictStaleStorage() {
}
}
// Close connections outside of lock to avoid blocking
go func() {
for _, store := range toClose {
if err := store.Close(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to close evicted storage: %v\n", err)
}
// Unlock before closing to avoid holding lock during Close
s.cacheMu.Unlock()
// Close connections synchronously
for _, store := range toClose {
if err := store.Close(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to close evicted storage: %v\n", err)
}
}()
}
}
func (s *Server) handleConnection(conn net.Conn) {
@@ -829,22 +832,26 @@ func (s *Server) getStorageForRequest(req *Request) (storage.Storage, error) {
return s.storage, nil
}
// Check cache first
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
if entry, ok := s.storageCache[req.Cwd]; ok {
// Update last access time
entry.lastAccess = time.Now()
return entry.store, nil
}
// Find database for this cwd
// Find database for this cwd (to get the canonical repo root)
dbPath := s.findDatabaseForCwd(req.Cwd)
if dbPath == "" {
return nil, fmt.Errorf("no .beads database found for path: %s", req.Cwd)
}
// Canonicalize key to repo root (parent of .beads directory)
beadsDir := filepath.Dir(dbPath)
repoRoot := filepath.Dir(beadsDir)
// Check cache first with write lock (to avoid race on lastAccess update)
s.cacheMu.Lock()
defer s.cacheMu.Unlock()
if entry, ok := s.storageCache[repoRoot]; ok {
// Update last access time (safe under Lock)
entry.lastAccess = time.Now()
return entry.store, nil
}
// Open storage
store, err := sqlite.New(dbPath)
if err != nil {
@@ -852,11 +859,22 @@ func (s *Server) getStorageForRequest(req *Request) (storage.Storage, error) {
}
// Cache it with current timestamp
s.storageCache[req.Cwd] = &StorageCacheEntry{
s.storageCache[repoRoot] = &StorageCacheEntry{
store: store,
lastAccess: time.Now(),
}
// Enforce LRU immediately to prevent FD spikes between cleanup ticks
needEvict := len(s.storageCache) > s.maxCacheSize
s.cacheMu.Unlock()
if needEvict {
s.evictStaleStorage()
}
// Re-acquire lock for defer
s.cacheMu.Lock()
return store, nil
}

View File

@@ -1,6 +1,7 @@
package rpc
import (
"fmt"
"os"
"path/filepath"
"testing"
@@ -305,3 +306,220 @@ func TestStorageCacheEviction_CleanupOnStop(t *testing.T) {
t.Errorf("expected cache to be cleared on stop, got %d entries", cacheSize)
}
}
func TestStorageCacheEviction_CanonicalKey(t *testing.T) {
tmpDir := t.TempDir()
// Create main DB
mainDB := filepath.Join(tmpDir, "main.db")
mainStore, err := sqlite.New(mainDB)
if err != nil {
t.Fatal(err)
}
defer mainStore.Close()
// Create server
socketPath := filepath.Join(tmpDir, "test.sock")
server := NewServer(socketPath, mainStore)
defer server.Stop()
// Create test database
dbPath := filepath.Join(tmpDir, "repo1", ".beads", "issues.db")
os.MkdirAll(filepath.Dir(dbPath), 0755)
store, err := sqlite.New(dbPath)
if err != nil {
t.Fatal(err)
}
store.Close()
// Access from different subdirectories of the same repo
req1 := &Request{Cwd: filepath.Join(tmpDir, "repo1")}
_, err = server.getStorageForRequest(req1)
if err != nil {
t.Fatal(err)
}
req2 := &Request{Cwd: filepath.Join(tmpDir, "repo1", "subdir1")}
_, err = server.getStorageForRequest(req2)
if err != nil {
t.Fatal(err)
}
req3 := &Request{Cwd: filepath.Join(tmpDir, "repo1", "subdir1", "subdir2")}
_, err = server.getStorageForRequest(req3)
if err != nil {
t.Fatal(err)
}
// Should only have one cache entry (all pointing to same repo root)
server.cacheMu.RLock()
cacheSize := len(server.storageCache)
server.cacheMu.RUnlock()
if cacheSize != 1 {
t.Errorf("expected 1 cached entry (canonical key), got %d", cacheSize)
}
}
func TestStorageCacheEviction_ImmediateLRU(t *testing.T) {
tmpDir := t.TempDir()
// Create main DB
mainDB := filepath.Join(tmpDir, "main.db")
mainStore, err := sqlite.New(mainDB)
if err != nil {
t.Fatal(err)
}
defer mainStore.Close()
// Create server with max cache size of 2
socketPath := filepath.Join(tmpDir, "test.sock")
server := NewServer(socketPath, mainStore)
server.maxCacheSize = 2
server.cacheTTL = 1 * time.Hour // Long TTL
defer server.Stop()
// Create 3 test databases
for i := 1; i <= 3; i++ {
dbPath := filepath.Join(tmpDir, fmt.Sprintf("repo%d", i), ".beads", "issues.db")
os.MkdirAll(filepath.Dir(dbPath), 0755)
store, err := sqlite.New(dbPath)
if err != nil {
t.Fatal(err)
}
store.Close()
}
// Access all 3 repos
for i := 1; i <= 3; i++ {
req := &Request{Cwd: filepath.Join(tmpDir, fmt.Sprintf("repo%d", i))}
_, err = server.getStorageForRequest(req)
if err != nil {
t.Fatal(err)
}
time.Sleep(10 * time.Millisecond) // Ensure different timestamps
}
// Cache should never exceed maxCacheSize (immediate LRU enforcement)
server.cacheMu.RLock()
cacheSize := len(server.storageCache)
server.cacheMu.RUnlock()
if cacheSize > server.maxCacheSize {
t.Errorf("cache size %d exceeds max %d (immediate LRU not enforced)", cacheSize, server.maxCacheSize)
}
}
func TestStorageCacheEviction_InvalidTTL(t *testing.T) {
tmpDir := t.TempDir()
// Create main DB
mainDB := filepath.Join(tmpDir, "main.db")
mainStore, err := sqlite.New(mainDB)
if err != nil {
t.Fatal(err)
}
defer mainStore.Close()
// Set invalid TTL
os.Setenv("BEADS_DAEMON_CACHE_TTL", "-5m")
defer os.Unsetenv("BEADS_DAEMON_CACHE_TTL")
// Create server
socketPath := filepath.Join(tmpDir, "test.sock")
server := NewServer(socketPath, mainStore)
defer server.Stop()
// Should fall back to default (30 minutes)
expectedTTL := 30 * time.Minute
if server.cacheTTL != expectedTTL {
t.Errorf("expected TTL to fall back to %v for invalid value, got %v", expectedTTL, server.cacheTTL)
}
}
func TestStorageCacheEviction_ReopenAfterEviction(t *testing.T) {
tmpDir := t.TempDir()
// Create main DB
mainDB := filepath.Join(tmpDir, "main.db")
mainStore, err := sqlite.New(mainDB)
if err != nil {
t.Fatal(err)
}
defer mainStore.Close()
// Create server with short TTL
socketPath := filepath.Join(tmpDir, "test.sock")
server := NewServer(socketPath, mainStore)
server.cacheTTL = 50 * time.Millisecond
defer server.Stop()
// Create test database
dbPath := filepath.Join(tmpDir, "repo1", ".beads", "issues.db")
os.MkdirAll(filepath.Dir(dbPath), 0755)
store, err := sqlite.New(dbPath)
if err != nil {
t.Fatal(err)
}
store.Close()
// Access repo
req := &Request{Cwd: filepath.Join(tmpDir, "repo1")}
_, err = server.getStorageForRequest(req)
if err != nil {
t.Fatal(err)
}
// Wait for TTL to expire
time.Sleep(100 * time.Millisecond)
// Evict
server.evictStaleStorage()
// Verify evicted
server.cacheMu.RLock()
cacheSize := len(server.storageCache)
server.cacheMu.RUnlock()
if cacheSize != 0 {
t.Fatalf("expected cache to be empty after eviction, got %d", cacheSize)
}
// Access again - should cleanly re-open
_, err = server.getStorageForRequest(req)
if err != nil {
t.Fatalf("failed to re-open after eviction: %v", err)
}
// Verify re-cached
server.cacheMu.RLock()
cacheSize = len(server.storageCache)
server.cacheMu.RUnlock()
if cacheSize != 1 {
t.Errorf("expected 1 cached entry after re-open, got %d", cacheSize)
}
}
func TestStorageCacheEviction_StopIdempotent(t *testing.T) {
tmpDir := t.TempDir()
// Create main DB
mainDB := filepath.Join(tmpDir, "main.db")
mainStore, err := sqlite.New(mainDB)
if err != nil {
t.Fatal(err)
}
defer mainStore.Close()
// Create server
socketPath := filepath.Join(tmpDir, "test.sock")
server := NewServer(socketPath, mainStore)
// Stop multiple times - should not panic
if err := server.Stop(); err != nil {
t.Fatalf("first Stop failed: %v", err)
}
if err := server.Stop(); err != nil {
t.Fatalf("second Stop failed: %v", err)
}
if err := server.Stop(); err != nil {
t.Fatalf("third Stop failed: %v", err)
}
}