Add resource limits to daemon (bd-152)
Implemented connection limiting, request timeouts, and memory pressure detection: - Connection limiting with semaphore pattern (default 100 max connections) - Request timeout enforcement on read/write (default 30s) - Memory pressure detection with aggressive cache eviction (default 500MB threshold) - Configurable via environment variables: - BEADS_DAEMON_MAX_CONNS - BEADS_DAEMON_REQUEST_TIMEOUT - BEADS_DAEMON_MEMORY_THRESHOLD_MB - Health endpoint now exposes active/max connections and memory usage - Comprehensive test coverage for all limits This prevents resource exhaustion under heavy load or attack scenarios. Amp-Thread-ID: https://ampcode.com/threads/T-44d1817a-3709-4f1d-a27a-78bb2fa4d3dc Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
330
internal/rpc/limits_test.go
Normal file
330
internal/rpc/limits_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/steveyegge/beads/internal/storage/sqlite"
|
||||
)
|
||||
|
||||
func TestConnectionLimits(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, ".beads", "test.db")
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := sqlite.New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Set low connection limit for testing
|
||||
os.Setenv("BEADS_DAEMON_MAX_CONNS", "5")
|
||||
defer os.Unsetenv("BEADS_DAEMON_MAX_CONNS")
|
||||
|
||||
srv := NewServer(socketPath, store)
|
||||
if srv.maxConns != 5 {
|
||||
t.Fatalf("expected maxConns=5, got %d", srv.maxConns)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := srv.Start(ctx); err != nil && ctx.Err() == nil {
|
||||
t.Logf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for server to be ready
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer srv.Stop()
|
||||
|
||||
// Open maxConns connections and hold them
|
||||
var wg sync.WaitGroup
|
||||
connections := make([]net.Conn, srv.maxConns)
|
||||
|
||||
for i := 0; i < srv.maxConns; i++ {
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial connection %d: %v", i, err)
|
||||
}
|
||||
connections[i] = conn
|
||||
|
||||
// Send a long-running ping to keep connection busy
|
||||
wg.Add(1)
|
||||
go func(c net.Conn, idx int) {
|
||||
defer wg.Done()
|
||||
req := Request{
|
||||
Operation: OpPing,
|
||||
}
|
||||
data, _ := json.Marshal(req)
|
||||
c.Write(append(data, '\n'))
|
||||
|
||||
// Read response
|
||||
reader := bufio.NewReader(c)
|
||||
_, _ = reader.ReadBytes('\n')
|
||||
}(conn, i)
|
||||
}
|
||||
|
||||
// Wait for all connections to be active
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify active connection count
|
||||
activeConns := atomic.LoadInt32(&srv.activeConns)
|
||||
if activeConns != int32(srv.maxConns) {
|
||||
t.Errorf("expected %d active connections, got %d", srv.maxConns, activeConns)
|
||||
}
|
||||
|
||||
// Try to open one more connection - should be rejected
|
||||
extraConn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial extra connection: %v", err)
|
||||
}
|
||||
defer extraConn.Close()
|
||||
|
||||
// Send request on extra connection
|
||||
req := Request{Operation: OpPing}
|
||||
data, _ := json.Marshal(req)
|
||||
extraConn.Write(append(data, '\n'))
|
||||
|
||||
// Set short read timeout to detect rejection
|
||||
extraConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
reader := bufio.NewReader(extraConn)
|
||||
_, err = reader.ReadBytes('\n')
|
||||
|
||||
// Connection should be closed (EOF or timeout)
|
||||
if err == nil {
|
||||
t.Error("expected extra connection to be rejected, but got response")
|
||||
}
|
||||
|
||||
// Close existing connections
|
||||
for _, conn := range connections {
|
||||
conn.Close()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Wait for connection cleanup
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Now should be able to connect again
|
||||
newConn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to reconnect after cleanup: %v", err)
|
||||
}
|
||||
defer newConn.Close()
|
||||
|
||||
req = Request{Operation: OpPing}
|
||||
data, _ = json.Marshal(req)
|
||||
newConn.Write(append(data, '\n'))
|
||||
|
||||
reader = bufio.NewReader(newConn)
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var resp Response
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
t.Error("expected successful ping after connection cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestTimeout(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, ".beads", "test.db")
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := sqlite.New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Set very short timeout for testing
|
||||
os.Setenv("BEADS_DAEMON_REQUEST_TIMEOUT", "100ms")
|
||||
defer os.Unsetenv("BEADS_DAEMON_REQUEST_TIMEOUT")
|
||||
|
||||
srv := NewServer(socketPath, store)
|
||||
if srv.requestTimeout != 100*time.Millisecond {
|
||||
t.Fatalf("expected timeout=100ms, got %v", srv.requestTimeout)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := srv.Start(ctx); err != nil && ctx.Err() == nil {
|
||||
t.Logf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer srv.Stop()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send partial request and wait for timeout
|
||||
conn.Write([]byte(`{"operation":"ping"`)) // Incomplete JSON
|
||||
|
||||
// Wait longer than timeout
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Try to write - connection should be closed due to read timeout
|
||||
_, err = conn.Write([]byte("}\n"))
|
||||
if err == nil {
|
||||
t.Error("expected connection to be closed due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryPressureDetection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, ".beads", "test.db")
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := sqlite.New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Set very low memory threshold to trigger eviction
|
||||
os.Setenv("BEADS_DAEMON_MEMORY_THRESHOLD_MB", "1")
|
||||
defer os.Unsetenv("BEADS_DAEMON_MEMORY_THRESHOLD_MB")
|
||||
|
||||
srv := NewServer(socketPath, store)
|
||||
|
||||
// Add some entries to cache
|
||||
srv.cacheMu.Lock()
|
||||
for i := 0; i < 10; i++ {
|
||||
path := fmt.Sprintf("/test/path/%d", i)
|
||||
srv.storageCache[path] = &StorageCacheEntry{
|
||||
store: store,
|
||||
lastAccess: time.Now().Add(-time.Duration(i) * time.Minute),
|
||||
}
|
||||
}
|
||||
initialSize := len(srv.storageCache)
|
||||
srv.cacheMu.Unlock()
|
||||
|
||||
// Trigger memory pressure check (should evict entries)
|
||||
srv.checkMemoryPressure()
|
||||
|
||||
// Check that some entries were evicted
|
||||
srv.cacheMu.RLock()
|
||||
finalSize := len(srv.storageCache)
|
||||
srv.cacheMu.RUnlock()
|
||||
|
||||
if finalSize >= initialSize {
|
||||
t.Errorf("expected cache eviction, but size went from %d to %d", initialSize, finalSize)
|
||||
}
|
||||
|
||||
t.Logf("Cache evicted: %d -> %d entries", initialSize, finalSize)
|
||||
}
|
||||
|
||||
func TestHealthResponseIncludesLimits(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "bd-limits-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
store, err := sqlite.New(dbPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
os.Setenv("BEADS_DAEMON_MAX_CONNS", "50")
|
||||
defer os.Unsetenv("BEADS_DAEMON_MAX_CONNS")
|
||||
|
||||
srv := NewServer(socketPath, store)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := srv.Start(ctx); err != nil && ctx.Err() == nil {
|
||||
t.Logf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer srv.Stop()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req := Request{Operation: OpHealth}
|
||||
data, _ := json.Marshal(req)
|
||||
conn.Write(append(data, '\n'))
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var resp Response
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if !resp.Success {
|
||||
t.Fatalf("health check failed: %s", resp.Error)
|
||||
}
|
||||
|
||||
var health HealthResponse
|
||||
if err := json.Unmarshal(resp.Data, &health); err != nil {
|
||||
t.Fatalf("failed to unmarshal health response: %v", err)
|
||||
}
|
||||
|
||||
// Verify limit fields are present
|
||||
if health.MaxConns != 50 {
|
||||
t.Errorf("expected MaxConns=50, got %d", health.MaxConns)
|
||||
}
|
||||
|
||||
if health.ActiveConns < 0 {
|
||||
t.Errorf("expected ActiveConns>=0, got %d", health.ActiveConns)
|
||||
}
|
||||
|
||||
if health.MemoryAllocMB == 0 {
|
||||
t.Error("expected MemoryAllocMB>0")
|
||||
}
|
||||
|
||||
t.Logf("Health: %d/%d connections, %d MB memory", health.ActiveConns, health.MaxConns, health.MemoryAllocMB)
|
||||
}
|
||||
@@ -150,6 +150,9 @@ type HealthResponse struct {
|
||||
CacheHits int64 `json:"cache_hits"`
|
||||
CacheMisses int64 `json:"cache_misses"`
|
||||
DBResponseTime float64 `json:"db_response_ms"`
|
||||
ActiveConns int32 `json:"active_connections"`
|
||||
MaxConns int `json:"max_connections"`
|
||||
MemoryAllocMB uint64 `json:"memory_alloc_mb"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -52,6 +53,12 @@ type Server struct {
|
||||
startTime time.Time
|
||||
cacheHits int64
|
||||
cacheMisses int64
|
||||
// Connection limiting
|
||||
maxConns int
|
||||
activeConns int32 // atomic counter
|
||||
connSemaphore chan struct{}
|
||||
// Request timeout
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewServer creates a new RPC server
|
||||
@@ -73,14 +80,32 @@ func NewServer(socketPath string, store storage.Storage) *Server {
|
||||
}
|
||||
}
|
||||
|
||||
maxConns := 100 // default
|
||||
if env := os.Getenv("BEADS_DAEMON_MAX_CONNS"); env != "" {
|
||||
var conns int
|
||||
if _, err := fmt.Sscanf(env, "%d", &conns); err == nil && conns > 0 {
|
||||
maxConns = conns
|
||||
}
|
||||
}
|
||||
|
||||
requestTimeout := 30 * time.Second // default
|
||||
if env := os.Getenv("BEADS_DAEMON_REQUEST_TIMEOUT"); env != "" {
|
||||
if timeout, err := time.ParseDuration(env); err == nil && timeout > 0 {
|
||||
requestTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
socketPath: socketPath,
|
||||
storage: store,
|
||||
storageCache: make(map[string]*StorageCacheEntry),
|
||||
maxCacheSize: maxCacheSize,
|
||||
cacheTTL: cacheTTL,
|
||||
shutdownChan: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
socketPath: socketPath,
|
||||
storage: store,
|
||||
storageCache: make(map[string]*StorageCacheEntry),
|
||||
maxCacheSize: maxCacheSize,
|
||||
cacheTTL: cacheTTL,
|
||||
shutdownChan: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
maxConns: maxConns,
|
||||
connSemaphore: make(chan struct{}, maxConns),
|
||||
requestTimeout: requestTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +156,20 @@ func (s *Server) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to accept connection: %w", err)
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
// Try to acquire connection slot (non-blocking)
|
||||
select {
|
||||
case s.connSemaphore <- struct{}{}:
|
||||
// Acquired slot, handle connection
|
||||
go func(c net.Conn) {
|
||||
defer func() { <-s.connSemaphore }() // Release slot
|
||||
atomic.AddInt32(&s.activeConns, 1)
|
||||
defer atomic.AddInt32(&s.activeConns, -1)
|
||||
s.handleConnection(c)
|
||||
}(conn)
|
||||
default:
|
||||
// Max connections reached, reject immediately
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,7 +256,7 @@ func (s *Server) handleSignals() {
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
// runCleanupLoop periodically evicts stale storage connections
|
||||
// runCleanupLoop periodically evicts stale storage connections and checks memory pressure
|
||||
func (s *Server) runCleanupLoop() {
|
||||
s.cleanupTicker = time.NewTicker(5 * time.Minute)
|
||||
defer s.cleanupTicker.Stop()
|
||||
@@ -226,6 +264,7 @@ func (s *Server) runCleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanupTicker.C:
|
||||
s.checkMemoryPressure()
|
||||
s.evictStaleStorage()
|
||||
case <-s.shutdownChan:
|
||||
return
|
||||
@@ -233,6 +272,71 @@ func (s *Server) runCleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// checkMemoryPressure monitors memory usage and triggers aggressive eviction if needed
|
||||
func (s *Server) checkMemoryPressure() {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
// Memory thresholds (configurable via env var)
|
||||
const defaultThresholdMB = 500
|
||||
thresholdMB := defaultThresholdMB
|
||||
if env := os.Getenv("BEADS_DAEMON_MEMORY_THRESHOLD_MB"); env != "" {
|
||||
var threshold int
|
||||
if _, err := fmt.Sscanf(env, "%d", &threshold); err == nil && threshold > 0 {
|
||||
thresholdMB = threshold
|
||||
}
|
||||
}
|
||||
|
||||
allocMB := m.Alloc / 1024 / 1024
|
||||
if allocMB > uint64(thresholdMB) {
|
||||
fmt.Fprintf(os.Stderr, "Warning: High memory usage detected (%d MB), triggering aggressive cache eviction\n", allocMB)
|
||||
s.aggressiveEviction()
|
||||
runtime.GC() // Suggest garbage collection
|
||||
}
|
||||
}
|
||||
|
||||
// aggressiveEviction evicts 50% of cached storage to reduce memory pressure
|
||||
func (s *Server) aggressiveEviction() {
|
||||
toClose := []storage.Storage{}
|
||||
|
||||
s.cacheMu.Lock()
|
||||
|
||||
if len(s.storageCache) == 0 {
|
||||
s.cacheMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Build sorted list by last access
|
||||
type cacheItem struct {
|
||||
path string
|
||||
entry *StorageCacheEntry
|
||||
}
|
||||
items := make([]cacheItem, 0, len(s.storageCache))
|
||||
for path, entry := range s.storageCache {
|
||||
items = append(items, cacheItem{path, entry})
|
||||
}
|
||||
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return items[i].entry.lastAccess.Before(items[j].entry.lastAccess)
|
||||
})
|
||||
|
||||
// Evict oldest 50%
|
||||
numToEvict := len(items) / 2
|
||||
for i := 0; i < numToEvict; i++ {
|
||||
toClose = append(toClose, items[i].entry.store)
|
||||
delete(s.storageCache, items[i].path)
|
||||
}
|
||||
|
||||
s.cacheMu.Unlock()
|
||||
|
||||
// Close without holding lock
|
||||
for _, store := range toClose {
|
||||
if err := store.Close(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: failed to close evicted storage: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evictStaleStorage removes idle connections and enforces cache size limits
|
||||
func (s *Server) evictStaleStorage() {
|
||||
now := time.Now()
|
||||
@@ -291,6 +395,11 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
writer := bufio.NewWriter(conn)
|
||||
|
||||
for {
|
||||
// Set read deadline for the next request
|
||||
if err := conn.SetReadDeadline(time.Now().Add(s.requestTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
@@ -306,6 +415,11 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Set write deadline for the response
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(s.requestTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp := s.handleRequest(&req)
|
||||
s.writeResponse(writer, resp)
|
||||
}
|
||||
@@ -488,6 +602,10 @@ func (s *Server) handlePing(_ *Request) Response {
|
||||
func (s *Server) handleHealth(req *Request) Response {
|
||||
start := time.Now()
|
||||
|
||||
// Get memory stats for health response
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
store, err := s.getStorageForRequest(req)
|
||||
if err != nil {
|
||||
data, _ := json.Marshal(HealthResponse{
|
||||
@@ -541,6 +659,9 @@ func (s *Server) handleHealth(req *Request) Response {
|
||||
CacheHits: atomic.LoadInt64(&s.cacheHits),
|
||||
CacheMisses: atomic.LoadInt64(&s.cacheMisses),
|
||||
DBResponseTime: dbResponseMs,
|
||||
ActiveConns: atomic.LoadInt32(&s.activeConns),
|
||||
MaxConns: s.maxConns,
|
||||
MemoryAllocMB: m.Alloc / 1024 / 1024,
|
||||
}
|
||||
|
||||
if dbError != "" {
|
||||
|
||||
Reference in New Issue
Block a user