package request import ( "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "math/rand" "net" "net/http" "net/url" "strconv" "strings" "sync" "time" "github.com/rs/zerolog" "github.com/sirrobot01/decypharr/internal/logger" "go.uber.org/ratelimit" "golang.org/x/net/proxy" ) func JoinURL(base string, paths ...string) (string, error) { // Split the last path component to separate query parameters lastPath := paths[len(paths)-1] parts := strings.Split(lastPath, "?") paths[len(paths)-1] = parts[0] joined, err := url.JoinPath(base, paths...) if err != nil { return "", err } // Add back query parameters if they exist if len(parts) > 1 { return joined + "?" + parts[1], nil } return joined, nil } var ( once sync.Once instance *Client ) type ClientOption func(*Client) // Client represents an HTTP client with additional capabilities type Client struct { client *http.Client rateLimiter ratelimit.Limiter headers map[string]string headersMu sync.RWMutex maxRetries int timeout time.Duration skipTLSVerify bool retryableStatus map[int]struct{} logger zerolog.Logger proxy string } // WithMaxRetries sets the maximum number of retry attempts func WithMaxRetries(maxRetries int) ClientOption { return func(c *Client) { c.maxRetries = maxRetries } } // WithTimeout sets the request timeout func WithTimeout(timeout time.Duration) ClientOption { return func(c *Client) { c.timeout = timeout } } func WithRedirectPolicy(policy func(req *http.Request, via []*http.Request) error) ClientOption { return func(c *Client) { c.client.CheckRedirect = policy } } // WithRateLimiter sets a rate limiter func WithRateLimiter(rl ratelimit.Limiter) ClientOption { return func(c *Client) { c.rateLimiter = rl } } // WithHeaders sets default headers func WithHeaders(headers map[string]string) ClientOption { return func(c *Client) { c.headersMu.Lock() c.headers = headers c.headersMu.Unlock() } } func (c *Client) SetHeader(key, value string) { c.headersMu.Lock() c.headers[key] = value c.headersMu.Unlock() } func WithLogger(logger zerolog.Logger) ClientOption { return func(c *Client) { c.logger = logger } } func WithTransport(transport *http.Transport) ClientOption { return func(c *Client) { c.client.Transport = transport } } // WithRetryableStatus adds status codes that should trigger a retry func WithRetryableStatus(statusCodes ...int) ClientOption { return func(c *Client) { c.retryableStatus = make(map[int]struct{}) // reset the map for _, code := range statusCodes { c.retryableStatus[code] = struct{}{} } } } func WithProxy(proxyURL string) ClientOption { return func(c *Client) { c.proxy = proxyURL } } // doRequest performs a single HTTP request with rate limiting func (c *Client) doRequest(req *http.Request) (*http.Response, error) { if c.rateLimiter != nil { select { case <-req.Context().Done(): return nil, req.Context().Err() default: c.rateLimiter.Take() } } return c.client.Do(req) } // Do performs an HTTP request with retries for certain status codes func (c *Client) Do(req *http.Request) (*http.Response, error) { // Save the request body for reuse in retries var bodyBytes []byte var err error if req.Body != nil { bodyBytes, err = io.ReadAll(req.Body) if err != nil { return nil, fmt.Errorf("reading request body: %w", err) } req.Body.Close() } backoff := time.Millisecond * 500 var resp *http.Response for attempt := 0; attempt <= c.maxRetries; attempt++ { // Reset the request body if it exists if bodyBytes != nil { req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) } // Apply headers c.headersMu.RLock() if c.headers != nil { for key, value := range c.headers { req.Header.Set(key, value) } } c.headersMu.RUnlock() resp, err = c.doRequest(req) if err != nil { // Check if this is a network error that might be worth retrying if isRetryableError(err) && attempt < c.maxRetries { // Apply backoff with jitter jitter := time.Duration(rand.Int63n(int64(backoff / 4))) sleepTime := backoff + jitter select { case <-req.Context().Done(): return nil, req.Context().Err() case <-time.After(sleepTime): // Continue to next retry attempt } // Exponential backoff backoff *= 2 continue } return nil, err } // Check if the status code is retryable if _, ok := c.retryableStatus[resp.StatusCode]; !ok || attempt == c.maxRetries { return resp, nil } // Close the response body before retrying resp.Body.Close() // Apply backoff with jitter jitter := time.Duration(rand.Int63n(int64(backoff / 4))) sleepTime := backoff + jitter select { case <-req.Context().Done(): return nil, req.Context().Err() case <-time.After(sleepTime): // Continue to next retry attempt } // Exponential backoff backoff *= 2 } return nil, fmt.Errorf("max retries exceeded") } // MakeRequest performs an HTTP request and returns the response body as bytes func (c *Client) MakeRequest(req *http.Request) ([]byte, error) { res, err := c.Do(req) if err != nil { return nil, err } defer func() { if err := res.Body.Close(); err != nil { c.logger.Printf("Failed to close response body: %v", err) } }() bodyBytes, err := io.ReadAll(res.Body) if err != nil { return nil, fmt.Errorf("reading response body: %w", err) } if res.StatusCode < 200 || res.StatusCode >= 300 { return nil, fmt.Errorf("HTTP error %d: %s", res.StatusCode, string(bodyBytes)) } return bodyBytes, nil } func (c *Client) Get(url string) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("creating GET request: %w", err) } return c.Do(req) } // New creates a new HTTP client with the specified options func New(options ...ClientOption) *Client { client := &Client{ maxRetries: 3, skipTLSVerify: true, retryableStatus: map[int]struct{}{ http.StatusTooManyRequests: struct{}{}, http.StatusInternalServerError: struct{}{}, http.StatusBadGateway: struct{}{}, http.StatusServiceUnavailable: struct{}{}, http.StatusGatewayTimeout: struct{}{}, }, logger: logger.New("request"), timeout: 60 * time.Second, proxy: "", headers: make(map[string]string), } // default http client client.client = &http.Client{ Timeout: client.timeout, } // Apply options before configuring transport for _, option := range options { option(client) } // Check if transport was set by WithTransport option if client.client.Transport == nil { transport := &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: client.skipTLSVerify, }, DisableKeepAlives: false, } // Configure proxy if needed SetProxy(transport, client.proxy) // Set the transport to the client client.client.Transport = transport } return client } func ParseRateLimit(rateStr string) ratelimit.Limiter { if rateStr == "" { return nil } parts := strings.SplitN(rateStr, "/", 2) if len(parts) != 2 { return nil } // parse count count, err := strconv.Atoi(strings.TrimSpace(parts[0])) if err != nil || count <= 0 { return nil } // Set slack size to 10% slackSize := count / 10 // normalize unit unit := strings.ToLower(strings.TrimSpace(parts[1])) unit = strings.TrimSuffix(unit, "s") switch unit { case "minute", "min": return ratelimit.New(count, ratelimit.Per(time.Minute), ratelimit.WithSlack(slackSize)) case "second", "sec": return ratelimit.New(count, ratelimit.Per(time.Second), ratelimit.WithSlack(slackSize)) case "hour", "hr": return ratelimit.New(count, ratelimit.Per(time.Hour), ratelimit.WithSlack(slackSize)) case "day", "d": return ratelimit.New(count, ratelimit.Per(24*time.Hour), ratelimit.WithSlack(slackSize)) default: return nil } } func JSONResponse(w http.ResponseWriter, data interface{}, code int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) err := json.NewEncoder(w).Encode(data) if err != nil { return } } func Default() *Client { once.Do(func() { instance = New() }) return instance } func isRetryableError(err error) bool { errString := err.Error() // Connection reset and other network errors if strings.Contains(errString, "connection reset by peer") || strings.Contains(errString, "read: connection reset") || strings.Contains(errString, "connection refused") || strings.Contains(errString, "network is unreachable") || strings.Contains(errString, "connection timed out") || strings.Contains(errString, "no such host") || strings.Contains(errString, "i/o timeout") || strings.Contains(errString, "unexpected EOF") || strings.Contains(errString, "TLS handshake timeout") { return true } // Check for net.Error type which can provide more information var netErr net.Error if errors.As(err, &netErr) { // Retry on timeout errors and temporary errors return netErr.Timeout() } // Not a retryable error return false } func SetProxy(transport *http.Transport, proxyURL string) { if proxyURL != "" { if strings.HasPrefix(proxyURL, "socks5://") { // Handle SOCKS5 proxy socksURL, err := url.Parse(proxyURL) if err != nil { return } else { auth := &proxy.Auth{} if socksURL.User != nil { auth.User = socksURL.User.Username() password, _ := socksURL.User.Password() auth.Password = password } dialer, err := proxy.SOCKS5("tcp", socksURL.Host, auth, proxy.Direct) if err != nil { return } else { transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) } } } } else { _proxy, err := url.Parse(proxyURL) if err != nil { return } else { transport.Proxy = http.ProxyURL(_proxy) } } } else { transport.Proxy = http.ProxyFromEnvironment } return } func ValidateURL(urlStr string) error { if urlStr == "" { return fmt.Errorf("URL cannot be empty") } // Try parsing as full URL first u, err := url.Parse(urlStr) if err == nil && u.Scheme != "" && u.Host != "" { // It's a full URL, validate scheme if u.Scheme != "http" && u.Scheme != "https" { return fmt.Errorf("URL scheme must be http or https") } return nil } // Check if it's a host:port format (no scheme) if strings.Contains(urlStr, ":") && !strings.Contains(urlStr, "://") { // Try parsing with http:// prefix testURL := "http://" + urlStr u, err := url.Parse(testURL) if err != nil { return fmt.Errorf("invalid host:port format: %w", err) } if u.Host == "" { return fmt.Errorf("host is required in host:port format") } // Validate port number if u.Port() == "" { return fmt.Errorf("port is required in host:port format") } return nil } return fmt.Errorf("invalid URL format: %s", urlStr) }