fix(doctor): address code review issues in --server health checks
- Use parameterized query for INFORMATION_SCHEMA lookup (SQL injection) - Add isValidIdentifier() to validate database names before USE statement - Add password support via BEADS_DOLT_PASSWORD env var - Remove unused variable declaration - Add unit tests for identifier validation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
committed by
Steve Yegge
parent
66d994264b
commit
3bcbca41fe
@@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -111,15 +112,18 @@ func RunServerHealthChecks(path string) ServerHealthResult {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Get database name from config (default: "beads")
|
||||||
|
database := "beads" // Default database name for Dolt server mode
|
||||||
|
|
||||||
// Check 3: Database exists and is queryable
|
// Check 3: Database exists and is queryable
|
||||||
dbExistsCheck := checkDatabaseExists(db, "beads")
|
dbExistsCheck := checkDatabaseExists(db, database)
|
||||||
result.Checks = append(result.Checks, dbExistsCheck)
|
result.Checks = append(result.Checks, dbExistsCheck)
|
||||||
if dbExistsCheck.Status == StatusError {
|
if dbExistsCheck.Status == StatusError {
|
||||||
result.OverallOK = false
|
result.OverallOK = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check 4: Schema compatible (can query beads tables)
|
// Check 4: Schema compatible (can query beads tables)
|
||||||
schemaCheck := checkSchemaCompatible(db)
|
schemaCheck := checkSchemaCompatible(db, database)
|
||||||
result.Checks = append(result.Checks, schemaCheck)
|
result.Checks = append(result.Checks, schemaCheck)
|
||||||
if schemaCheck.Status == StatusError {
|
if schemaCheck.Status == StatusError {
|
||||||
result.OverallOK = false
|
result.OverallOK = false
|
||||||
@@ -166,10 +170,18 @@ func checkDoltVersion(cfg *configfile.Config) (DoctorCheck, *sql.DB) {
|
|||||||
port := cfg.GetDoltServerPort()
|
port := cfg.GetDoltServerPort()
|
||||||
user := cfg.GetDoltServerUser()
|
user := cfg.GetDoltServerUser()
|
||||||
|
|
||||||
|
// Get password from environment (more secure than config file)
|
||||||
|
password := os.Getenv("BEADS_DOLT_PASSWORD")
|
||||||
|
|
||||||
// Build DSN without database (just to test server connectivity)
|
// Build DSN without database (just to test server connectivity)
|
||||||
var connStr string
|
var connStr string
|
||||||
connStr = fmt.Sprintf("%s@tcp(%s:%d)/?parseTime=true&timeout=5s",
|
if password != "" {
|
||||||
user, host, port)
|
connStr = fmt.Sprintf("%s:%s@tcp(%s:%d)/?parseTime=true&timeout=5s",
|
||||||
|
user, password, host, port)
|
||||||
|
} else {
|
||||||
|
connStr = fmt.Sprintf("%s@tcp(%s:%d)/?parseTime=true&timeout=5s",
|
||||||
|
user, host, port)
|
||||||
|
}
|
||||||
|
|
||||||
db, err := sql.Open("mysql", connStr)
|
db, err := sql.Open("mysql", connStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -243,10 +255,22 @@ func checkDatabaseExists(db *sql.DB, database string) DoctorCheck {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Check if database exists
|
// Validate database name (alphanumeric and underscore only)
|
||||||
|
if !isValidIdentifier(database) {
|
||||||
|
return DoctorCheck{
|
||||||
|
Name: "Database Exists",
|
||||||
|
Status: StatusError,
|
||||||
|
Message: fmt.Sprintf("Invalid database name '%s'", database),
|
||||||
|
Detail: "Database name must be alphanumeric with underscores only",
|
||||||
|
Category: CategoryFederation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if database exists using parameterized query
|
||||||
var exists int
|
var exists int
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s'", database)
|
err := db.QueryRowContext(ctx,
|
||||||
err := db.QueryRowContext(ctx, query).Scan(&exists)
|
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?",
|
||||||
|
database).Scan(&exists)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorCheck{
|
return DoctorCheck{
|
||||||
Name: "Database Exists",
|
Name: "Database Exists",
|
||||||
@@ -268,7 +292,8 @@ func checkDatabaseExists(db *sql.DB, database string) DoctorCheck {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Switch to the database
|
// Switch to the database
|
||||||
_, err = db.ExecContext(ctx, fmt.Sprintf("USE %s", database))
|
// Note: USE cannot use parameterized queries, but we validated the identifier above
|
||||||
|
_, err = db.ExecContext(ctx, "USE "+database) // #nosec G201 - database validated by isValidIdentifier
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorCheck{
|
return DoctorCheck{
|
||||||
Name: "Database Exists",
|
Name: "Database Exists",
|
||||||
@@ -287,8 +312,27 @@ func checkDatabaseExists(db *sql.DB, database string) DoctorCheck {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidIdentifier checks if a string is a valid SQL identifier
|
||||||
|
// (alphanumeric and underscore only, doesn't start with a number)
|
||||||
|
func isValidIdentifier(s string) bool {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, c := range s {
|
||||||
|
if i == 0 && c >= '0' && c <= '9' {
|
||||||
|
return false // Can't start with a number
|
||||||
|
}
|
||||||
|
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// checkSchemaCompatible checks if the beads tables are queryable
|
// checkSchemaCompatible checks if the beads tables are queryable
|
||||||
func checkSchemaCompatible(db *sql.DB) DoctorCheck {
|
func checkSchemaCompatible(db *sql.DB, database string) DoctorCheck {
|
||||||
|
// Note: database parameter reserved for future use (e.g., multi-database support)
|
||||||
|
_ = database
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
32
cmd/bd/doctor/server_test.go
Normal file
32
cmd/bd/doctor/server_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package doctor
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsValidIdentifier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"beads", true},
|
||||||
|
{"beads_db", true},
|
||||||
|
{"Beads123", true},
|
||||||
|
{"_private", true},
|
||||||
|
{"123start", false}, // Can't start with number
|
||||||
|
{"", false}, // Empty string
|
||||||
|
{"db-name", false}, // Hyphen not allowed
|
||||||
|
{"db.name", false}, // Dot not allowed
|
||||||
|
{"db name", false}, // Space not allowed
|
||||||
|
{"db;drop", false}, // Semicolon not allowed
|
||||||
|
{"db'inject", false}, // Quote not allowed
|
||||||
|
{"beads_test_db", true}, // Multiple underscores ok
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := isValidIdentifier(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isValidIdentifier(%q) = %v, want %v", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user