208 lines
5.2 KiB
Go
208 lines
5.2 KiB
Go
package qbit
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/sirrobot01/decypharr/internal/config"
|
|
"github.com/sirrobot01/decypharr/pkg/arr"
|
|
"github.com/sirrobot01/decypharr/pkg/wire"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
categoryKey contextKey = "category"
|
|
hashesKey contextKey = "hashes"
|
|
arrKey contextKey = "arr"
|
|
)
|
|
|
|
func validateServiceURL(urlStr string) error {
|
|
if urlStr == "" {
|
|
return fmt.Errorf("URL cannot be empty")
|
|
}
|
|
|
|
// Try parsing as full URL first
|
|
u, err := url.Parse(urlStr)
|
|
if err == nil && u.Scheme != "" && u.Host != "" {
|
|
// It's a full URL, validate scheme
|
|
if u.Scheme != "http" && u.Scheme != "https" {
|
|
return fmt.Errorf("URL scheme must be http or https")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Check if it's a host:port format (no scheme)
|
|
if strings.Contains(urlStr, ":") && !strings.Contains(urlStr, "://") {
|
|
// Try parsing with http:// prefix
|
|
testURL := "http://" + urlStr
|
|
u, err := url.Parse(testURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid host:port format: %w", err)
|
|
}
|
|
|
|
if u.Host == "" {
|
|
return fmt.Errorf("host is required in host:port format")
|
|
}
|
|
|
|
// Validate port number
|
|
if u.Port() == "" {
|
|
return fmt.Errorf("port is required in host:port format")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("invalid URL format: %s", urlStr)
|
|
}
|
|
|
|
func getCategory(ctx context.Context) string {
|
|
if category, ok := ctx.Value(categoryKey).(string); ok {
|
|
return category
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func getHashes(ctx context.Context) []string {
|
|
if hashes, ok := ctx.Value(hashesKey).([]string); ok {
|
|
return hashes
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getArrFromContext(ctx context.Context) *arr.Arr {
|
|
if a, ok := ctx.Value(arrKey).(*arr.Arr); ok {
|
|
return a
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func decodeAuthHeader(header string) (string, string, error) {
|
|
encodedTokens := strings.Split(header, " ")
|
|
if len(encodedTokens) != 2 {
|
|
return "", "", nil
|
|
}
|
|
encodedToken := encodedTokens[1]
|
|
|
|
bytes, err := base64.StdEncoding.DecodeString(encodedToken)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
bearer := string(bytes)
|
|
|
|
colonIndex := strings.LastIndex(bearer, ":")
|
|
host := bearer[:colonIndex]
|
|
token := bearer[colonIndex+1:]
|
|
|
|
return host, token, nil
|
|
}
|
|
|
|
func (q *QBit) categoryContext(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
category := strings.Trim(r.URL.Query().Get("category"), "")
|
|
if category == "" {
|
|
// Get from form
|
|
_ = r.ParseForm()
|
|
category = r.Form.Get("category")
|
|
if category == "" {
|
|
// Get from multipart form
|
|
_ = r.ParseMultipartForm(32 << 20)
|
|
category = r.FormValue("category")
|
|
}
|
|
}
|
|
ctx := context.WithValue(r.Context(), categoryKey, strings.TrimSpace(category))
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// authContext creates a middleware that extracts the Arr host and token from the Authorization header
|
|
// and adds it to the request context.
|
|
// This is used to identify the Arr instance for the request.
|
|
// Only a valid host and token will be added to the context/config. The rest are manual
|
|
func (q *QBit) authContext(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cfg := config.Get()
|
|
host, token, err := decodeAuthHeader(r.Header.Get("Authorization"))
|
|
category := getCategory(r.Context())
|
|
arrs := wire.Get().Arr()
|
|
// Check if arr exists
|
|
a := arrs.Get(category)
|
|
if a == nil {
|
|
// Arr is not configured, create a new one
|
|
downloadUncached := false
|
|
a = arr.New(category, "", "", false, false, &downloadUncached, "", "auto")
|
|
}
|
|
if err == nil {
|
|
host = strings.TrimSpace(host)
|
|
if host != "" {
|
|
a.Host = host
|
|
}
|
|
token = strings.TrimSpace(token)
|
|
if token != "" {
|
|
a.Token = token
|
|
}
|
|
}
|
|
if cfg.NeedsAuth() {
|
|
if a.Host == "" || a.Token == "" {
|
|
http.Error(w, "Unauthorized: Host and token are required for authentication", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
// try to use either Arr validate, or user auth validation
|
|
if err := a.Validate(); err != nil {
|
|
// If this failed, try to use user auth validation
|
|
if !verifyAuth(host, token) {
|
|
http.Error(w, "Unauthorized: Invalid host or token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
a.Source = "auto"
|
|
arrs.AddOrUpdate(a)
|
|
ctx := context.WithValue(r.Context(), arrKey, a)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func hashesContext(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_hashes := chi.URLParam(r, "hashes")
|
|
var hashes []string
|
|
if _hashes != "" {
|
|
hashes = strings.Split(_hashes, "|")
|
|
}
|
|
if hashes == nil {
|
|
// Get hashes from form
|
|
_ = r.ParseForm()
|
|
hashes = r.Form["hashes"]
|
|
}
|
|
for i, hash := range hashes {
|
|
hashes[i] = strings.TrimSpace(hash)
|
|
}
|
|
ctx := context.WithValue(r.Context(), hashesKey, hashes)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func verifyAuth(username, password string) bool {
|
|
// If you're storing hashed password, use bcrypt to compare
|
|
if username == "" {
|
|
return false
|
|
}
|
|
auth := config.Get().GetAuth()
|
|
if auth == nil {
|
|
return false
|
|
}
|
|
if username != auth.Username {
|
|
return false
|
|
}
|
|
err := bcrypt.CompareHashAndPassword([]byte(auth.Password), []byte(password))
|
|
return err == nil
|
|
}
|