Files
decypharr/internal/request/request.go
Mukhtar Akere 700d00b802 - Fix issues with new setup
- Fix arr setup getting thr wrong crendentials
- Add file link invalidator
- Other minor bug fixes
2025-10-08 08:13:13 +01:00

465 lines
11 KiB
Go

package request
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/sirrobot01/decypharr/internal/logger"
"go.uber.org/ratelimit"
"golang.org/x/net/proxy"
)
func JoinURL(base string, paths ...string) (string, error) {
// Split the last path component to separate query parameters
lastPath := paths[len(paths)-1]
parts := strings.Split(lastPath, "?")
paths[len(paths)-1] = parts[0]
joined, err := url.JoinPath(base, paths...)
if err != nil {
return "", err
}
// Add back query parameters if they exist
if len(parts) > 1 {
return joined + "?" + parts[1], nil
}
return joined, nil
}
var (
once sync.Once
instance *Client
)
type ClientOption func(*Client)
// Client represents an HTTP client with additional capabilities
type Client struct {
client *http.Client
rateLimiter ratelimit.Limiter
headers map[string]string
headersMu sync.RWMutex
maxRetries int
timeout time.Duration
skipTLSVerify bool
retryableStatus map[int]struct{}
logger zerolog.Logger
proxy string
}
// WithMaxRetries sets the maximum number of retry attempts
func WithMaxRetries(maxRetries int) ClientOption {
return func(c *Client) {
c.maxRetries = maxRetries
}
}
// WithTimeout sets the request timeout
func WithTimeout(timeout time.Duration) ClientOption {
return func(c *Client) {
c.timeout = timeout
}
}
func WithRedirectPolicy(policy func(req *http.Request, via []*http.Request) error) ClientOption {
return func(c *Client) {
c.client.CheckRedirect = policy
}
}
// WithRateLimiter sets a rate limiter
func WithRateLimiter(rl ratelimit.Limiter) ClientOption {
return func(c *Client) {
c.rateLimiter = rl
}
}
// WithHeaders sets default headers
func WithHeaders(headers map[string]string) ClientOption {
return func(c *Client) {
c.headersMu.Lock()
c.headers = headers
c.headersMu.Unlock()
}
}
func (c *Client) SetHeader(key, value string) {
c.headersMu.Lock()
c.headers[key] = value
c.headersMu.Unlock()
}
func WithLogger(logger zerolog.Logger) ClientOption {
return func(c *Client) {
c.logger = logger
}
}
func WithTransport(transport *http.Transport) ClientOption {
return func(c *Client) {
c.client.Transport = transport
}
}
// WithRetryableStatus adds status codes that should trigger a retry
func WithRetryableStatus(statusCodes ...int) ClientOption {
return func(c *Client) {
c.retryableStatus = make(map[int]struct{}) // reset the map
for _, code := range statusCodes {
c.retryableStatus[code] = struct{}{}
}
}
}
func WithProxy(proxyURL string) ClientOption {
return func(c *Client) {
c.proxy = proxyURL
}
}
// doRequest performs a single HTTP request with rate limiting
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
if c.rateLimiter != nil {
select {
case <-req.Context().Done():
return nil, req.Context().Err()
default:
c.rateLimiter.Take()
}
}
return c.client.Do(req)
}
// Do performs an HTTP request with retries for certain status codes
func (c *Client) Do(req *http.Request) (*http.Response, error) {
// Save the request body for reuse in retries
var bodyBytes []byte
var err error
if req.Body != nil {
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("reading request body: %w", err)
}
req.Body.Close()
}
backoff := time.Millisecond * 500
var resp *http.Response
for attempt := 0; attempt <= c.maxRetries; attempt++ {
// Reset the request body if it exists
if bodyBytes != nil {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Apply headers
c.headersMu.RLock()
if c.headers != nil {
for key, value := range c.headers {
req.Header.Set(key, value)
}
}
c.headersMu.RUnlock()
resp, err = c.doRequest(req)
if err != nil {
// Check if this is a network error that might be worth retrying
if isRetryableError(err) && attempt < c.maxRetries {
// Apply backoff with jitter
jitter := time.Duration(rand.Int63n(int64(backoff / 4)))
sleepTime := backoff + jitter
select {
case <-req.Context().Done():
return nil, req.Context().Err()
case <-time.After(sleepTime):
// Continue to next retry attempt
}
// Exponential backoff
backoff *= 2
continue
}
return nil, err
}
// Check if the status code is retryable
if _, ok := c.retryableStatus[resp.StatusCode]; !ok || attempt == c.maxRetries {
return resp, nil
}
// Close the response body before retrying
resp.Body.Close()
// Apply backoff with jitter
jitter := time.Duration(rand.Int63n(int64(backoff / 4)))
sleepTime := backoff + jitter
select {
case <-req.Context().Done():
return nil, req.Context().Err()
case <-time.After(sleepTime):
// Continue to next retry attempt
}
// Exponential backoff
backoff *= 2
}
return nil, fmt.Errorf("max retries exceeded")
}
// MakeRequest performs an HTTP request and returns the response body as bytes
func (c *Client) MakeRequest(req *http.Request) ([]byte, error) {
res, err := c.Do(req)
if err != nil {
return nil, err
}
defer func() {
if err := res.Body.Close(); err != nil {
c.logger.Printf("Failed to close response body: %v", err)
}
}()
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("reading response body: %w", err)
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP error %d: %s", res.StatusCode, string(bodyBytes))
}
return bodyBytes, nil
}
func (c *Client) Get(url string) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("creating GET request: %w", err)
}
return c.Do(req)
}
// New creates a new HTTP client with the specified options
func New(options ...ClientOption) *Client {
client := &Client{
maxRetries: 3,
skipTLSVerify: true,
retryableStatus: map[int]struct{}{
http.StatusTooManyRequests: struct{}{},
http.StatusInternalServerError: struct{}{},
http.StatusBadGateway: struct{}{},
http.StatusServiceUnavailable: struct{}{},
http.StatusGatewayTimeout: struct{}{},
},
logger: logger.New("request"),
timeout: 60 * time.Second,
proxy: "",
headers: make(map[string]string),
}
// default http client
client.client = &http.Client{
Timeout: client.timeout,
}
// Apply options before configuring transport
for _, option := range options {
option(client)
}
// Check if transport was set by WithTransport option
if client.client.Transport == nil {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: client.skipTLSVerify,
},
DisableKeepAlives: false,
}
// Configure proxy if needed
SetProxy(transport, client.proxy)
// Set the transport to the client
client.client.Transport = transport
}
return client
}
func ParseRateLimit(rateStr string) ratelimit.Limiter {
if rateStr == "" {
return nil
}
parts := strings.SplitN(rateStr, "/", 2)
if len(parts) != 2 {
return nil
}
// parse count
count, err := strconv.Atoi(strings.TrimSpace(parts[0]))
if err != nil || count <= 0 {
return nil
}
// Set slack size to 10%
slackSize := count / 10
// normalize unit
unit := strings.ToLower(strings.TrimSpace(parts[1]))
unit = strings.TrimSuffix(unit, "s")
switch unit {
case "minute", "min":
return ratelimit.New(count, ratelimit.Per(time.Minute), ratelimit.WithSlack(slackSize))
case "second", "sec":
return ratelimit.New(count, ratelimit.Per(time.Second), ratelimit.WithSlack(slackSize))
case "hour", "hr":
return ratelimit.New(count, ratelimit.Per(time.Hour), ratelimit.WithSlack(slackSize))
case "day", "d":
return ratelimit.New(count, ratelimit.Per(24*time.Hour), ratelimit.WithSlack(slackSize))
default:
return nil
}
}
func JSONResponse(w http.ResponseWriter, data interface{}, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
err := json.NewEncoder(w).Encode(data)
if err != nil {
return
}
}
func Default() *Client {
once.Do(func() {
instance = New()
})
return instance
}
func isRetryableError(err error) bool {
errString := err.Error()
// Connection reset and other network errors
if strings.Contains(errString, "connection reset by peer") ||
strings.Contains(errString, "read: connection reset") ||
strings.Contains(errString, "connection refused") ||
strings.Contains(errString, "network is unreachable") ||
strings.Contains(errString, "connection timed out") ||
strings.Contains(errString, "no such host") ||
strings.Contains(errString, "i/o timeout") ||
strings.Contains(errString, "unexpected EOF") ||
strings.Contains(errString, "TLS handshake timeout") {
return true
}
// Check for net.Error type which can provide more information
var netErr net.Error
if errors.As(err, &netErr) {
// Retry on timeout errors and temporary errors
return netErr.Timeout()
}
// Not a retryable error
return false
}
func SetProxy(transport *http.Transport, proxyURL string) {
if proxyURL != "" {
if strings.HasPrefix(proxyURL, "socks5://") {
// Handle SOCKS5 proxy
socksURL, err := url.Parse(proxyURL)
if err != nil {
return
} else {
auth := &proxy.Auth{}
if socksURL.User != nil {
auth.User = socksURL.User.Username()
password, _ := socksURL.User.Password()
auth.Password = password
}
dialer, err := proxy.SOCKS5("tcp", socksURL.Host, auth, proxy.Direct)
if err != nil {
return
} else {
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
}
}
} else {
_proxy, err := url.Parse(proxyURL)
if err != nil {
return
} else {
transport.Proxy = http.ProxyURL(_proxy)
}
}
} else {
transport.Proxy = http.ProxyFromEnvironment
}
return
}
func ValidateURL(urlStr string) error {
if urlStr == "" {
return fmt.Errorf("URL cannot be empty")
}
// Try parsing as full URL first
u, err := url.Parse(urlStr)
if err == nil && u.Scheme != "" && u.Host != "" {
// It's a full URL, validate scheme
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("URL scheme must be http or https")
}
return nil
}
// Check if it's a host:port format (no scheme)
if strings.Contains(urlStr, ":") && !strings.Contains(urlStr, "://") {
// Try parsing with http:// prefix
testURL := "http://" + urlStr
u, err := url.Parse(testURL)
if err != nil {
return fmt.Errorf("invalid host:port format: %w", err)
}
if u.Host == "" {
return fmt.Errorf("host is required in host:port format")
}
// Validate port number
if u.Port() == "" {
return fmt.Errorf("port is required in host:port format")
}
return nil
}
return fmt.Errorf("invalid URL format: %s", urlStr)
}