diff --git a/cmd/bd/doctor/server.go b/cmd/bd/doctor/server.go index 53692353..646e0202 100644 --- a/cmd/bd/doctor/server.go +++ b/cmd/bd/doctor/server.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "net" + "os" "strings" "time" @@ -111,15 +112,18 @@ func RunServerHealthChecks(path string) ServerHealthResult { } }() + // Get database name from config (default: "beads") + database := "beads" // Default database name for Dolt server mode + // Check 3: Database exists and is queryable - dbExistsCheck := checkDatabaseExists(db, "beads") + dbExistsCheck := checkDatabaseExists(db, database) result.Checks = append(result.Checks, dbExistsCheck) if dbExistsCheck.Status == StatusError { result.OverallOK = false } // Check 4: Schema compatible (can query beads tables) - schemaCheck := checkSchemaCompatible(db) + schemaCheck := checkSchemaCompatible(db, database) result.Checks = append(result.Checks, schemaCheck) if schemaCheck.Status == StatusError { result.OverallOK = false @@ -166,10 +170,18 @@ func checkDoltVersion(cfg *configfile.Config) (DoctorCheck, *sql.DB) { port := cfg.GetDoltServerPort() user := cfg.GetDoltServerUser() + // Get password from environment (more secure than config file) + password := os.Getenv("BEADS_DOLT_PASSWORD") + // Build DSN without database (just to test server connectivity) var connStr string - connStr = fmt.Sprintf("%s@tcp(%s:%d)/?parseTime=true&timeout=5s", - user, host, port) + if password != "" { + connStr = fmt.Sprintf("%s:%s@tcp(%s:%d)/?parseTime=true&timeout=5s", + user, password, host, port) + } else { + connStr = fmt.Sprintf("%s@tcp(%s:%d)/?parseTime=true&timeout=5s", + user, host, port) + } db, err := sql.Open("mysql", connStr) if err != nil { @@ -243,10 +255,22 @@ func checkDatabaseExists(db *sql.DB, database string) DoctorCheck { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Check if database exists + // Validate database name (alphanumeric and underscore only) + if !isValidIdentifier(database) { + return DoctorCheck{ + Name: "Database Exists", + Status: StatusError, + Message: fmt.Sprintf("Invalid database name '%s'", database), + Detail: "Database name must be alphanumeric with underscores only", + Category: CategoryFederation, + } + } + + // Check if database exists using parameterized query var exists int - query := fmt.Sprintf("SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'", database) - err := db.QueryRowContext(ctx, query).Scan(&exists) + err := db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?", + database).Scan(&exists) if err != nil { return DoctorCheck{ Name: "Database Exists", @@ -268,7 +292,8 @@ func checkDatabaseExists(db *sql.DB, database string) DoctorCheck { } // Switch to the database - _, err = db.ExecContext(ctx, fmt.Sprintf("USE %s", database)) + // Note: USE cannot use parameterized queries, but we validated the identifier above + _, err = db.ExecContext(ctx, "USE "+database) // #nosec G201 - database validated by isValidIdentifier if err != nil { return DoctorCheck{ Name: "Database Exists", @@ -287,8 +312,27 @@ func checkDatabaseExists(db *sql.DB, database string) DoctorCheck { } } +// isValidIdentifier checks if a string is a valid SQL identifier +// (alphanumeric and underscore only, doesn't start with a number) +func isValidIdentifier(s string) bool { + if len(s) == 0 { + return false + } + for i, c := range s { + if i == 0 && c >= '0' && c <= '9' { + return false // Can't start with a number + } + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { + return false + } + } + return true +} + // checkSchemaCompatible checks if the beads tables are queryable -func checkSchemaCompatible(db *sql.DB) DoctorCheck { +func checkSchemaCompatible(db *sql.DB, database string) DoctorCheck { + // Note: database parameter reserved for future use (e.g., multi-database support) + _ = database ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/cmd/bd/doctor/server_test.go b/cmd/bd/doctor/server_test.go new file mode 100644 index 00000000..561b70de --- /dev/null +++ b/cmd/bd/doctor/server_test.go @@ -0,0 +1,32 @@ +package doctor + +import "testing" + +func TestIsValidIdentifier(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"beads", true}, + {"beads_db", true}, + {"Beads123", true}, + {"_private", true}, + {"123start", false}, // Can't start with number + {"", false}, // Empty string + {"db-name", false}, // Hyphen not allowed + {"db.name", false}, // Dot not allowed + {"db name", false}, // Space not allowed + {"db;drop", false}, // Semicolon not allowed + {"db'inject", false}, // Quote not allowed + {"beads_test_db", true}, // Multiple underscores ok + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := isValidIdentifier(tt.input) + if got != tt.want { + t.Errorf("isValidIdentifier(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +}