425 lines
9.8 KiB
Go
425 lines
9.8 KiB
Go
package request
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/rs/zerolog"
|
|
"github.com/sirrobot01/decypharr/internal/logger"
|
|
"go.uber.org/ratelimit"
|
|
"golang.org/x/net/proxy"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
func JoinURL(base string, paths ...string) (string, error) {
|
|
// Split the last path component to separate query parameters
|
|
lastPath := paths[len(paths)-1]
|
|
parts := strings.Split(lastPath, "?")
|
|
paths[len(paths)-1] = parts[0]
|
|
|
|
joined, err := url.JoinPath(base, paths...)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Add back query parameters if they exist
|
|
if len(parts) > 1 {
|
|
return joined + "?" + parts[1], nil
|
|
}
|
|
|
|
return joined, nil
|
|
}
|
|
|
|
var (
|
|
once sync.Once
|
|
instance *Client
|
|
)
|
|
|
|
type ClientOption func(*Client)
|
|
|
|
// Client represents an HTTP client with additional capabilities
|
|
type Client struct {
|
|
client *http.Client
|
|
rateLimiter ratelimit.Limiter
|
|
headers map[string]string
|
|
headersMu sync.RWMutex
|
|
maxRetries int
|
|
timeout time.Duration
|
|
skipTLSVerify bool
|
|
retryableStatus map[int]struct{}
|
|
logger zerolog.Logger
|
|
proxy string
|
|
}
|
|
|
|
// WithMaxRetries sets the maximum number of retry attempts
|
|
func WithMaxRetries(maxRetries int) ClientOption {
|
|
return func(c *Client) {
|
|
c.maxRetries = maxRetries
|
|
}
|
|
}
|
|
|
|
// WithTimeout sets the request timeout
|
|
func WithTimeout(timeout time.Duration) ClientOption {
|
|
return func(c *Client) {
|
|
c.timeout = timeout
|
|
}
|
|
}
|
|
|
|
func WithRedirectPolicy(policy func(req *http.Request, via []*http.Request) error) ClientOption {
|
|
return func(c *Client) {
|
|
c.client.CheckRedirect = policy
|
|
}
|
|
}
|
|
|
|
// WithRateLimiter sets a rate limiter
|
|
func WithRateLimiter(rl ratelimit.Limiter) ClientOption {
|
|
return func(c *Client) {
|
|
c.rateLimiter = rl
|
|
}
|
|
}
|
|
|
|
// WithHeaders sets default headers
|
|
func WithHeaders(headers map[string]string) ClientOption {
|
|
return func(c *Client) {
|
|
c.headersMu.Lock()
|
|
c.headers = headers
|
|
c.headersMu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (c *Client) SetHeader(key, value string) {
|
|
c.headersMu.Lock()
|
|
c.headers[key] = value
|
|
c.headersMu.Unlock()
|
|
}
|
|
|
|
func WithLogger(logger zerolog.Logger) ClientOption {
|
|
return func(c *Client) {
|
|
c.logger = logger
|
|
}
|
|
}
|
|
|
|
func WithTransport(transport *http.Transport) ClientOption {
|
|
return func(c *Client) {
|
|
c.client.Transport = transport
|
|
}
|
|
}
|
|
|
|
// WithRetryableStatus adds status codes that should trigger a retry
|
|
func WithRetryableStatus(statusCodes ...int) ClientOption {
|
|
return func(c *Client) {
|
|
c.retryableStatus = make(map[int]struct{}) // reset the map
|
|
for _, code := range statusCodes {
|
|
c.retryableStatus[code] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithProxy(proxyURL string) ClientOption {
|
|
return func(c *Client) {
|
|
c.proxy = proxyURL
|
|
}
|
|
}
|
|
|
|
// doRequest performs a single HTTP request with rate limiting
|
|
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
|
|
if c.rateLimiter != nil {
|
|
select {
|
|
case <-req.Context().Done():
|
|
return nil, req.Context().Err()
|
|
default:
|
|
c.rateLimiter.Take()
|
|
}
|
|
}
|
|
|
|
return c.client.Do(req)
|
|
}
|
|
|
|
// Do performs an HTTP request with retries for certain status codes
|
|
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
|
// Save the request body for reuse in retries
|
|
var bodyBytes []byte
|
|
var err error
|
|
|
|
if req.Body != nil {
|
|
bodyBytes, err = io.ReadAll(req.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading request body: %w", err)
|
|
}
|
|
req.Body.Close()
|
|
}
|
|
|
|
backoff := time.Millisecond * 500
|
|
var resp *http.Response
|
|
|
|
for attempt := 0; attempt <= c.maxRetries; attempt++ {
|
|
// Reset the request body if it exists
|
|
if bodyBytes != nil {
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
}
|
|
|
|
// Apply headers
|
|
c.headersMu.RLock()
|
|
if c.headers != nil {
|
|
for key, value := range c.headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
}
|
|
c.headersMu.RUnlock()
|
|
|
|
resp, err = c.doRequest(req)
|
|
if err != nil {
|
|
// Check if this is a network error that might be worth retrying
|
|
if isRetryableError(err) && attempt < c.maxRetries {
|
|
// Apply backoff with jitter
|
|
jitter := time.Duration(rand.Int63n(int64(backoff / 4)))
|
|
sleepTime := backoff + jitter
|
|
|
|
select {
|
|
case <-req.Context().Done():
|
|
return nil, req.Context().Err()
|
|
case <-time.After(sleepTime):
|
|
// Continue to next retry attempt
|
|
}
|
|
|
|
// Exponential backoff
|
|
backoff *= 2
|
|
continue
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Check if the status code is retryable
|
|
if _, ok := c.retryableStatus[resp.StatusCode]; !ok || attempt == c.maxRetries {
|
|
return resp, nil
|
|
}
|
|
|
|
// Close the response body before retrying
|
|
resp.Body.Close()
|
|
|
|
// Apply backoff with jitter
|
|
jitter := time.Duration(rand.Int63n(int64(backoff / 4)))
|
|
sleepTime := backoff + jitter
|
|
|
|
select {
|
|
case <-req.Context().Done():
|
|
return nil, req.Context().Err()
|
|
case <-time.After(sleepTime):
|
|
// Continue to next retry attempt
|
|
}
|
|
|
|
// Exponential backoff
|
|
backoff *= 2
|
|
}
|
|
|
|
return nil, fmt.Errorf("max retries exceeded")
|
|
}
|
|
|
|
// MakeRequest performs an HTTP request and returns the response body as bytes
|
|
func (c *Client) MakeRequest(req *http.Request) ([]byte, error) {
|
|
res, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer func() {
|
|
if err := res.Body.Close(); err != nil {
|
|
c.logger.Printf("Failed to close response body: %v", err)
|
|
}
|
|
}()
|
|
|
|
bodyBytes, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading response body: %w", err)
|
|
}
|
|
|
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("HTTP error %d: %s", res.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
return bodyBytes, nil
|
|
}
|
|
|
|
func (c *Client) Get(url string) (*http.Response, error) {
|
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating GET request: %w", err)
|
|
}
|
|
|
|
return c.Do(req)
|
|
}
|
|
|
|
// New creates a new HTTP client with the specified options
|
|
func New(options ...ClientOption) *Client {
|
|
client := &Client{
|
|
maxRetries: 3,
|
|
skipTLSVerify: true,
|
|
retryableStatus: map[int]struct{}{
|
|
http.StatusTooManyRequests: struct{}{},
|
|
http.StatusInternalServerError: struct{}{},
|
|
http.StatusBadGateway: struct{}{},
|
|
http.StatusServiceUnavailable: struct{}{},
|
|
http.StatusGatewayTimeout: struct{}{},
|
|
},
|
|
logger: logger.New("request"),
|
|
timeout: 60 * time.Second,
|
|
proxy: "",
|
|
headers: make(map[string]string),
|
|
}
|
|
|
|
// default http client
|
|
client.client = &http.Client{
|
|
Timeout: client.timeout,
|
|
}
|
|
|
|
// Apply options before configuring transport
|
|
for _, option := range options {
|
|
option(client)
|
|
}
|
|
|
|
// Check if transport was set by WithTransport option
|
|
if client.client.Transport == nil {
|
|
transport := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: client.skipTLSVerify,
|
|
},
|
|
DisableKeepAlives: false,
|
|
}
|
|
|
|
// Configure proxy if needed
|
|
SetProxy(transport, client.proxy)
|
|
|
|
// Set the transport to the client
|
|
client.client.Transport = transport
|
|
}
|
|
|
|
return client
|
|
}
|
|
|
|
func ParseRateLimit(rateStr string) ratelimit.Limiter {
|
|
if rateStr == "" {
|
|
return nil
|
|
}
|
|
parts := strings.SplitN(rateStr, "/", 2)
|
|
if len(parts) != 2 {
|
|
return nil
|
|
}
|
|
|
|
// parse count
|
|
count, err := strconv.Atoi(strings.TrimSpace(parts[0]))
|
|
if err != nil || count <= 0 {
|
|
return nil
|
|
}
|
|
|
|
// Set slack size to 10%
|
|
slackSize := count / 10
|
|
|
|
// normalize unit
|
|
unit := strings.ToLower(strings.TrimSpace(parts[1]))
|
|
unit = strings.TrimSuffix(unit, "s")
|
|
switch unit {
|
|
case "minute", "min":
|
|
return ratelimit.New(count, ratelimit.Per(time.Minute), ratelimit.WithSlack(slackSize))
|
|
case "second", "sec":
|
|
return ratelimit.New(count, ratelimit.Per(time.Second), ratelimit.WithSlack(slackSize))
|
|
case "hour", "hr":
|
|
return ratelimit.New(count, ratelimit.Per(time.Hour), ratelimit.WithSlack(slackSize))
|
|
case "day", "d":
|
|
return ratelimit.New(count, ratelimit.Per(24*time.Hour), ratelimit.WithSlack(slackSize))
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func JSONResponse(w http.ResponseWriter, data interface{}, code int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
err := json.NewEncoder(w).Encode(data)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
func Default() *Client {
|
|
once.Do(func() {
|
|
instance = New()
|
|
})
|
|
return instance
|
|
}
|
|
|
|
func isRetryableError(err error) bool {
|
|
errString := err.Error()
|
|
|
|
// Connection reset and other network errors
|
|
if strings.Contains(errString, "connection reset by peer") ||
|
|
strings.Contains(errString, "read: connection reset") ||
|
|
strings.Contains(errString, "connection refused") ||
|
|
strings.Contains(errString, "network is unreachable") ||
|
|
strings.Contains(errString, "connection timed out") ||
|
|
strings.Contains(errString, "no such host") ||
|
|
strings.Contains(errString, "i/o timeout") ||
|
|
strings.Contains(errString, "unexpected EOF") ||
|
|
strings.Contains(errString, "TLS handshake timeout") {
|
|
return true
|
|
}
|
|
|
|
// Check for net.Error type which can provide more information
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) {
|
|
// Retry on timeout errors and temporary errors
|
|
return netErr.Timeout()
|
|
}
|
|
|
|
// Not a retryable error
|
|
return false
|
|
}
|
|
|
|
func SetProxy(transport *http.Transport, proxyURL string) {
|
|
if proxyURL != "" {
|
|
if strings.HasPrefix(proxyURL, "socks5://") {
|
|
// Handle SOCKS5 proxy
|
|
socksURL, err := url.Parse(proxyURL)
|
|
if err != nil {
|
|
return
|
|
} else {
|
|
auth := &proxy.Auth{}
|
|
if socksURL.User != nil {
|
|
auth.User = socksURL.User.Username()
|
|
password, _ := socksURL.User.Password()
|
|
auth.Password = password
|
|
}
|
|
|
|
dialer, err := proxy.SOCKS5("tcp", socksURL.Host, auth, proxy.Direct)
|
|
if err != nil {
|
|
return
|
|
} else {
|
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return dialer.Dial(network, addr)
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
_proxy, err := url.Parse(proxyURL)
|
|
if err != nil {
|
|
return
|
|
} else {
|
|
transport.Proxy = http.ProxyURL(_proxy)
|
|
}
|
|
}
|
|
} else {
|
|
transport.Proxy = http.ProxyFromEnvironment
|
|
}
|
|
return
|
|
}
|