Implementing a streaming setup with Usenet
This commit is contained in:
@@ -47,7 +47,6 @@ type Debrid struct {
|
||||
type QBitTorrent struct {
|
||||
Username string `json:"username,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Port string `json:"port,omitempty"` // deprecated
|
||||
DownloadFolder string `json:"download_folder,omitempty"`
|
||||
Categories []string `json:"categories,omitempty"`
|
||||
RefreshInterval int `json:"refresh_interval,omitempty"`
|
||||
@@ -82,26 +81,55 @@ type Auth struct {
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
type SABnzbd struct {
|
||||
DownloadFolder string `json:"download_folder,omitempty"`
|
||||
RefreshInterval int `json:"refresh_interval,omitempty"`
|
||||
Categories []string `json:"categories,omitempty"`
|
||||
}
|
||||
|
||||
type Usenet struct {
|
||||
Providers []UsenetProvider `json:"providers,omitempty"` // List of usenet providers
|
||||
MountFolder string `json:"mount_folder,omitempty"` // Folder where usenet downloads are mounted
|
||||
SkipPreCache bool `json:"skip_pre_cache,omitempty"`
|
||||
Chunks int `json:"chunks,omitempty"` // Number of chunks to pre-cache
|
||||
RcUrl string `json:"rc_url,omitempty"` // Rclone RC URL for the webdav
|
||||
RcUser string `json:"rc_user,omitempty"` // Rclone RC username
|
||||
RcPass string `json:"rc_pass,omitempty"` // Rclone RC password
|
||||
}
|
||||
|
||||
type UsenetProvider struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Host string `json:"host,omitempty"` // Host of the usenet server
|
||||
Port int `json:"port,omitempty"` // Port of the usenet server
|
||||
Username string `json:"username,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Connections int `json:"connections,omitempty"` // Number of connections to use
|
||||
SSL bool `json:"ssl,omitempty"` // Use SSL for the connection
|
||||
UseTLS bool `json:"use_tls,omitempty"` // Use TLS for the connection
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// server
|
||||
BindAddress string `json:"bind_address,omitempty"`
|
||||
URLBase string `json:"url_base,omitempty"`
|
||||
Port string `json:"port,omitempty"`
|
||||
|
||||
LogLevel string `json:"log_level,omitempty"`
|
||||
Debrids []Debrid `json:"debrids,omitempty"`
|
||||
QBitTorrent QBitTorrent `json:"qbittorrent,omitempty"`
|
||||
Arrs []Arr `json:"arrs,omitempty"`
|
||||
Repair Repair `json:"repair,omitempty"`
|
||||
WebDav WebDav `json:"webdav,omitempty"`
|
||||
AllowedExt []string `json:"allowed_file_types,omitempty"`
|
||||
MinFileSize string `json:"min_file_size,omitempty"` // Minimum file size to download, 10MB, 1GB, etc
|
||||
MaxFileSize string `json:"max_file_size,omitempty"` // Maximum file size to download (0 means no limit)
|
||||
Path string `json:"-"` // Path to save the config file
|
||||
UseAuth bool `json:"use_auth,omitempty"`
|
||||
Auth *Auth `json:"-"`
|
||||
DiscordWebhook string `json:"discord_webhook_url,omitempty"`
|
||||
RemoveStalledAfter string `json:"remove_stalled_after,omitzero"`
|
||||
LogLevel string `json:"log_level,omitempty"`
|
||||
Debrids []Debrid `json:"debrids,omitempty"`
|
||||
QBitTorrent *QBitTorrent `json:"qbittorrent,omitempty"`
|
||||
SABnzbd *SABnzbd `json:"sabnzbd,omitempty"`
|
||||
Usenet *Usenet `json:"usenet,omitempty"` // Usenet configuration
|
||||
Arrs []Arr `json:"arrs,omitempty"`
|
||||
Repair Repair `json:"repair,omitempty"`
|
||||
WebDav WebDav `json:"webdav,omitempty"`
|
||||
AllowedExt []string `json:"allowed_file_types,omitempty"`
|
||||
MinFileSize string `json:"min_file_size,omitempty"` // Minimum file size to download, 10MB, 1GB, etc
|
||||
MaxFileSize string `json:"max_file_size,omitempty"` // Maximum file size to download (0 means no limit)
|
||||
Path string `json:"-"` // Path to save the config file
|
||||
UseAuth bool `json:"use_auth,omitempty"`
|
||||
Auth *Auth `json:"-"`
|
||||
DiscordWebhook string `json:"discord_webhook_url,omitempty"`
|
||||
RemoveStalledAfter string `json:"remove_stalled_after,omitzero"`
|
||||
}
|
||||
|
||||
func (c *Config) JsonFile() string {
|
||||
@@ -115,6 +143,10 @@ func (c *Config) TorrentsFile() string {
|
||||
return filepath.Join(c.Path, "torrents.json")
|
||||
}
|
||||
|
||||
func (c *Config) NZBsPath() string {
|
||||
return filepath.Join(c.Path, "cache/nzbs")
|
||||
}
|
||||
|
||||
func (c *Config) loadConfig() error {
|
||||
// Load the config file
|
||||
if configPath == "" {
|
||||
@@ -142,9 +174,6 @@ func (c *Config) loadConfig() error {
|
||||
}
|
||||
|
||||
func validateDebrids(debrids []Debrid) error {
|
||||
if len(debrids) == 0 {
|
||||
return errors.New("no debrids configured")
|
||||
}
|
||||
|
||||
for _, debrid := range debrids {
|
||||
// Basic field validation
|
||||
@@ -159,17 +188,51 @@ func validateDebrids(debrids []Debrid) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateQbitTorrent(config *QBitTorrent) error {
|
||||
if config.DownloadFolder == "" {
|
||||
return errors.New("qbittorent download folder is required")
|
||||
func validateUsenet(usenet *Usenet) error {
|
||||
if usenet == nil {
|
||||
return nil // No usenet configuration provided
|
||||
}
|
||||
if _, err := os.Stat(config.DownloadFolder); os.IsNotExist(err) {
|
||||
return fmt.Errorf("qbittorent download folder(%s) does not exist", config.DownloadFolder)
|
||||
for _, usenet := range usenet.Providers {
|
||||
// Basic field validation
|
||||
if usenet.Host == "" {
|
||||
return errors.New("usenet host is required")
|
||||
}
|
||||
if usenet.Username == "" {
|
||||
return errors.New("usenet username is required")
|
||||
}
|
||||
if usenet.Password == "" {
|
||||
return errors.New("usenet password is required")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSabznbd(config *SABnzbd) error {
|
||||
if config == nil {
|
||||
return nil // No SABnzbd configuration provided
|
||||
}
|
||||
if config.DownloadFolder != "" {
|
||||
if _, err := os.Stat(config.DownloadFolder); os.IsNotExist(err) {
|
||||
return fmt.Errorf("sabnzbd download folder(%s) does not exist", config.DownloadFolder)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRepair(config *Repair) error {
|
||||
func validateQbitTorrent(config *QBitTorrent) error {
|
||||
if config == nil {
|
||||
return nil // No qBittorrent configuration provided
|
||||
}
|
||||
if config.DownloadFolder != "" {
|
||||
if _, err := os.Stat(config.DownloadFolder); os.IsNotExist(err) {
|
||||
return fmt.Errorf("qbittorent download folder(%s) does not exist", config.DownloadFolder)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRepair(config Repair) error {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
@@ -181,19 +244,34 @@ func validateRepair(config *Repair) error {
|
||||
|
||||
func ValidateConfig(config *Config) error {
|
||||
// Run validations concurrently
|
||||
// Check if there's at least one debrid or usenet configured
|
||||
hasUsenet := false
|
||||
if config.Usenet != nil && len(config.Usenet.Providers) > 0 {
|
||||
hasUsenet = true
|
||||
}
|
||||
if len(config.Debrids) == 0 && !hasUsenet {
|
||||
return errors.New("at least one debrid or usenet provider must be configured")
|
||||
}
|
||||
|
||||
if err := validateDebrids(config.Debrids); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateQbitTorrent(&config.QBitTorrent); err != nil {
|
||||
if err := validateUsenet(config.Usenet); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateRepair(&config.Repair); err != nil {
|
||||
if err := validateSabznbd(config.SABnzbd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateQbitTorrent(config.QBitTorrent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateRepair(config.Repair); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -299,6 +377,10 @@ func (c *Config) updateDebrid(d Debrid) Debrid {
|
||||
}
|
||||
d.DownloadAPIKeys = downloadKeys
|
||||
|
||||
if d.Workers == 0 {
|
||||
d.Workers = perDebrid
|
||||
}
|
||||
|
||||
if !d.UseWebDav {
|
||||
return d
|
||||
}
|
||||
@@ -309,9 +391,6 @@ func (c *Config) updateDebrid(d Debrid) Debrid {
|
||||
if d.WebDav.DownloadLinksRefreshInterval == "" {
|
||||
d.DownloadLinksRefreshInterval = cmp.Or(c.WebDav.DownloadLinksRefreshInterval, "40m") // 40 minutes
|
||||
}
|
||||
if d.Workers == 0 {
|
||||
d.Workers = perDebrid
|
||||
}
|
||||
if d.FolderNaming == "" {
|
||||
d.FolderNaming = cmp.Or(c.WebDav.FolderNaming, "original_no_ext")
|
||||
}
|
||||
@@ -338,17 +417,47 @@ func (c *Config) updateDebrid(d Debrid) Debrid {
|
||||
return d
|
||||
}
|
||||
|
||||
func (c *Config) updateUsenet(u UsenetProvider) UsenetProvider {
|
||||
if u.Name == "" {
|
||||
parts := strings.Split(u.Host, ".")
|
||||
if len(parts) >= 2 {
|
||||
u.Name = parts[len(parts)-2] // Gets "example" from "news.example.com"
|
||||
} else {
|
||||
u.Name = u.Host // Fallback to host if it doesn't look like a domain
|
||||
}
|
||||
}
|
||||
if u.Port == 0 {
|
||||
u.Port = 119 // Default port for usenet
|
||||
}
|
||||
if u.Connections == 0 {
|
||||
u.Connections = 30 // Default connections
|
||||
}
|
||||
if u.SSL && !u.UseTLS {
|
||||
u.UseTLS = true // Use TLS if SSL is enabled
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func (c *Config) setDefaults() {
|
||||
for i, debrid := range c.Debrids {
|
||||
c.Debrids[i] = c.updateDebrid(debrid)
|
||||
}
|
||||
|
||||
if c.SABnzbd != nil {
|
||||
c.SABnzbd.RefreshInterval = cmp.Or(c.SABnzbd.RefreshInterval, 10) // Default to 10 seconds
|
||||
}
|
||||
|
||||
if c.Usenet != nil {
|
||||
c.Usenet.Chunks = cmp.Or(c.Usenet.Chunks, 5)
|
||||
for i, provider := range c.Usenet.Providers {
|
||||
c.Usenet.Providers[i] = c.updateUsenet(provider)
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.AllowedExt) == 0 {
|
||||
c.AllowedExt = getDefaultExtensions()
|
||||
}
|
||||
|
||||
c.Port = cmp.Or(c.Port, c.QBitTorrent.Port)
|
||||
|
||||
if c.URLBase == "" {
|
||||
c.URLBase = "/"
|
||||
}
|
||||
@@ -395,11 +504,6 @@ func (c *Config) createConfig(path string) error {
|
||||
c.Port = "8282"
|
||||
c.LogLevel = "info"
|
||||
c.UseAuth = true
|
||||
c.QBitTorrent = QBitTorrent{
|
||||
DownloadFolder: filepath.Join(path, "downloads"),
|
||||
Categories: []string{"sonarr", "radarr"},
|
||||
RefreshInterval: 15,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -408,7 +512,3 @@ func Reload() {
|
||||
instance = nil
|
||||
once = sync.Once{}
|
||||
}
|
||||
|
||||
func DefaultFreeSlot() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
178
internal/nntp/client.go
Normal file
178
internal/nntp/client.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sirrobot01/decypharr/internal/config"
|
||||
"github.com/sirrobot01/decypharr/internal/logger"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client represents a failover NNTP client that manages multiple providers
|
||||
type Client struct {
|
||||
providers []config.UsenetProvider
|
||||
pools *xsync.Map[string, *Pool]
|
||||
logger zerolog.Logger
|
||||
closed atomic.Bool
|
||||
minimumMaxConns int // Minimum number of max connections across all pools
|
||||
}
|
||||
|
||||
func NewClient(providers []config.UsenetProvider) (*Client, error) {
|
||||
|
||||
client := &Client{
|
||||
providers: providers,
|
||||
logger: logger.New("nntp"),
|
||||
pools: xsync.NewMap[string, *Pool](),
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return nil, fmt.Errorf("no NNTP providers configured")
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) InitPools() error {
|
||||
|
||||
var initErrors []error
|
||||
successfulPools := 0
|
||||
|
||||
for _, provider := range c.providers {
|
||||
serverPool, err := NewPool(provider, c.logger)
|
||||
if err != nil {
|
||||
c.logger.Error().
|
||||
Err(err).
|
||||
Str("server", provider.Host).
|
||||
Int("port", provider.Port).
|
||||
Msg("Failed to initialize server pool")
|
||||
initErrors = append(initErrors, err)
|
||||
continue
|
||||
}
|
||||
if c.minimumMaxConns == 0 {
|
||||
// Set minimumMaxConns to the max connections of the first successful pool
|
||||
c.minimumMaxConns = serverPool.ConnectionCount()
|
||||
} else {
|
||||
c.minimumMaxConns = min(c.minimumMaxConns, serverPool.ConnectionCount())
|
||||
}
|
||||
|
||||
c.pools.Store(provider.Name, serverPool)
|
||||
successfulPools++
|
||||
}
|
||||
|
||||
if successfulPools == 0 {
|
||||
return fmt.Errorf("failed to initialize any server pools: %v", initErrors)
|
||||
}
|
||||
|
||||
c.logger.Info().
|
||||
Int("providers", len(c.providers)).
|
||||
Msg("NNTP client created")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
if c.closed.Load() {
|
||||
c.logger.Warn().Msg("NNTP client already closed")
|
||||
return
|
||||
}
|
||||
|
||||
c.pools.Range(func(key string, value *Pool) bool {
|
||||
if value != nil {
|
||||
err := value.Close()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
c.closed.Store(true)
|
||||
c.logger.Info().Msg("NNTP client closed")
|
||||
}
|
||||
|
||||
func (c *Client) GetConnection(ctx context.Context) (*Connection, func(), error) {
|
||||
if c.closed.Load() {
|
||||
return nil, nil, fmt.Errorf("nntp client is closed")
|
||||
}
|
||||
|
||||
// Prevent workers from waiting too long for connections
|
||||
connCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
providerCount := len(c.providers)
|
||||
|
||||
for _, provider := range c.providers {
|
||||
pool, ok := c.pools.Load(provider.Name)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("no pool found for provider %s", provider.Name)
|
||||
}
|
||||
|
||||
if !pool.IsFree() && providerCount > 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := pool.Get(connCtx) // Use timeout context
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoAvailableConnection) || errors.Is(err, context.DeadlineExceeded) {
|
||||
continue
|
||||
}
|
||||
return nil, nil, fmt.Errorf("error getting connection from provider %s: %w", provider.Name, err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return conn, func() { pool.Put(conn) }, nil
|
||||
}
|
||||
|
||||
return nil, nil, ErrNoAvailableConnection
|
||||
}
|
||||
|
||||
func (c *Client) DownloadHeader(ctx context.Context, messageID string) (*YencMetadata, error) {
|
||||
conn, cleanup, err := c.GetConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
data, err := conn.GetBody(messageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// yEnc decode
|
||||
part, err := DecodeYencHeaders(bytes.NewReader(data))
|
||||
if err != nil || part == nil {
|
||||
return nil, fmt.Errorf("failed to decode segment")
|
||||
}
|
||||
|
||||
// Return both the filename and decoded data
|
||||
return part, nil
|
||||
}
|
||||
|
||||
func (c *Client) MinimumMaxConns() int {
|
||||
return c.minimumMaxConns
|
||||
}
|
||||
|
||||
func (c *Client) TotalActiveConnections() int {
|
||||
total := 0
|
||||
c.pools.Range(func(key string, value *Pool) bool {
|
||||
if value != nil {
|
||||
total += value.ActiveConnections()
|
||||
}
|
||||
return true
|
||||
})
|
||||
return total
|
||||
}
|
||||
|
||||
func (c *Client) Pools() *xsync.Map[string, *Pool] {
|
||||
return c.pools
|
||||
}
|
||||
|
||||
func (c *Client) GetProviders() []config.UsenetProvider {
|
||||
return c.providers
|
||||
}
|
||||
394
internal/nntp/conns.go
Normal file
394
internal/nntp/conns.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/chrisfarms/yenc"
|
||||
"github.com/rs/zerolog"
|
||||
"io"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Connection represents an NNTP connection
|
||||
type Connection struct {
|
||||
username, password, address string
|
||||
port int
|
||||
conn net.Conn
|
||||
text *textproto.Conn
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
func (c *Connection) authenticate() error {
|
||||
// Send AUTHINFO USER command
|
||||
if err := c.sendCommand(fmt.Sprintf("AUTHINFO USER %s", c.username)); err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to send username: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to read user response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 381 {
|
||||
return classifyNNTPError(resp.Code, fmt.Sprintf("unexpected response to AUTHINFO USER: %s", resp.Message))
|
||||
}
|
||||
|
||||
// Send AUTHINFO PASS command
|
||||
if err := c.sendCommand(fmt.Sprintf("AUTHINFO PASS %s", c.password)); err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to send password: %w", err))
|
||||
}
|
||||
|
||||
resp, err = c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to read password response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 281 {
|
||||
return classifyNNTPError(resp.Code, fmt.Sprintf("authentication failed: %s", resp.Message))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startTLS initiates TLS encryption with proper error handling
|
||||
func (c *Connection) startTLS() error {
|
||||
if err := c.sendCommand("STARTTLS"); err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to send STARTTLS: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to read STARTTLS response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 382 {
|
||||
return classifyNNTPError(resp.Code, fmt.Sprintf("STARTTLS not supported: %s", resp.Message))
|
||||
}
|
||||
|
||||
// Upgrade connection to TLS
|
||||
tlsConn := tls.Client(c.conn, &tls.Config{
|
||||
ServerName: c.address,
|
||||
InsecureSkipVerify: false,
|
||||
})
|
||||
|
||||
c.conn = tlsConn
|
||||
c.reader = bufio.NewReader(tlsConn)
|
||||
c.writer = bufio.NewWriter(tlsConn)
|
||||
c.text = textproto.NewConn(tlsConn)
|
||||
|
||||
c.logger.Debug().Msg("TLS encryption enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ping sends a simple command to test the connection
|
||||
func (c *Connection) ping() error {
|
||||
if err := c.sendCommand("DATE"); err != nil {
|
||||
return NewConnectionError(err)
|
||||
}
|
||||
_, err := c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendCommand sends a command to the NNTP server
|
||||
func (c *Connection) sendCommand(command string) error {
|
||||
_, err := fmt.Fprintf(c.writer, "%s\r\n", command)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.writer.Flush()
|
||||
}
|
||||
|
||||
// readResponse reads a response from the NNTP server
|
||||
func (c *Connection) readResponse() (*Response, error) {
|
||||
line, err := c.text.ReadLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, " ", 2)
|
||||
code, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid response code: %s", parts[0])
|
||||
}
|
||||
|
||||
message := ""
|
||||
if len(parts) > 1 {
|
||||
message = parts[1]
|
||||
}
|
||||
|
||||
return &Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// readMultilineResponse reads a multiline response
|
||||
func (c *Connection) readMultilineResponse() (*Response, error) {
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if this is a multiline response
|
||||
if resp.Code < 200 || resp.Code >= 300 {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
lines, err := c.text.ReadDotLines()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Lines = lines
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetArticle retrieves an article by message ID with proper error classification
|
||||
func (c *Connection) GetArticle(messageID string) (*Article, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
if err := c.sendCommand(fmt.Sprintf("ARTICLE %s", messageID)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send ARTICLE command: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readMultilineResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read article response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 220 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
return c.parseArticle(messageID, resp.Lines)
|
||||
}
|
||||
|
||||
// GetBody retrieves article body by message ID with proper error classification
|
||||
func (c *Connection) GetBody(messageID string) ([]byte, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
if err := c.sendCommand(fmt.Sprintf("BODY %s", messageID)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send BODY command: %w", err))
|
||||
}
|
||||
|
||||
// Read the initial response
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read body response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 222 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
// Read the raw body data directly using textproto to preserve exact formatting for yEnc
|
||||
lines, err := c.text.ReadDotLines()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read body data: %w", err))
|
||||
}
|
||||
|
||||
// Join with \r\n to preserve original line endings and add final \r\n
|
||||
body := strings.Join(lines, "\r\n")
|
||||
if len(lines) > 0 {
|
||||
body += "\r\n"
|
||||
}
|
||||
|
||||
return []byte(body), nil
|
||||
}
|
||||
|
||||
// GetHead retrieves article headers by message ID
|
||||
func (c *Connection) GetHead(messageID string) ([]byte, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
if err := c.sendCommand(fmt.Sprintf("HEAD %s", messageID)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send HEAD command: %w", err))
|
||||
}
|
||||
|
||||
// Read the initial response
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read head response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 221 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
// Read the header data using textproto
|
||||
lines, err := c.text.ReadDotLines()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read header data: %w", err))
|
||||
}
|
||||
|
||||
// Join with \r\n to preserve original line endings and add final \r\n
|
||||
headers := strings.Join(lines, "\r\n")
|
||||
if len(lines) > 0 {
|
||||
headers += "\r\n"
|
||||
}
|
||||
|
||||
return []byte(headers), nil
|
||||
}
|
||||
|
||||
// GetSegment retrieves a specific segment with proper error handling
|
||||
func (c *Connection) GetSegment(messageID string, segmentNumber int) (*Segment, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
body, err := c.GetBody(messageID)
|
||||
if err != nil {
|
||||
return nil, err // GetBody already returns classified errors
|
||||
}
|
||||
|
||||
return &Segment{
|
||||
MessageID: messageID,
|
||||
Number: segmentNumber,
|
||||
Bytes: int64(len(body)),
|
||||
Data: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stat retrieves article statistics by message ID with proper error classification
|
||||
func (c *Connection) Stat(messageID string) (articleNumber int, echoedID string, err error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
|
||||
if err = c.sendCommand(fmt.Sprintf("STAT %s", messageID)); err != nil {
|
||||
return 0, "", NewConnectionError(fmt.Errorf("failed to send STAT: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return 0, "", NewConnectionError(fmt.Errorf("failed to read STAT response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 223 {
|
||||
return 0, "", classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
fields := strings.Fields(resp.Message)
|
||||
if len(fields) < 2 {
|
||||
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("unexpected STAT response format: %q", resp.Message))
|
||||
}
|
||||
|
||||
if articleNumber, err = strconv.Atoi(fields[0]); err != nil {
|
||||
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("invalid article number %q: %v", fields[0], err))
|
||||
}
|
||||
echoedID = fields[1]
|
||||
|
||||
return articleNumber, echoedID, nil
|
||||
}
|
||||
|
||||
// SelectGroup selects a newsgroup and returns group information
|
||||
func (c *Connection) SelectGroup(groupName string) (*GroupInfo, error) {
|
||||
if err := c.sendCommand(fmt.Sprintf("GROUP %s", groupName)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send GROUP command: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read GROUP response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 211 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
// Parse GROUP response: "211 number low high group-name"
|
||||
fields := strings.Fields(resp.Message)
|
||||
if len(fields) < 4 {
|
||||
return nil, NewProtocolError(resp.Code, fmt.Sprintf("unexpected GROUP response format: %q", resp.Message))
|
||||
}
|
||||
|
||||
groupInfo := &GroupInfo{
|
||||
Name: groupName,
|
||||
}
|
||||
|
||||
if count, err := strconv.Atoi(fields[0]); err == nil {
|
||||
groupInfo.Count = count
|
||||
}
|
||||
if low, err := strconv.Atoi(fields[1]); err == nil {
|
||||
groupInfo.Low = low
|
||||
}
|
||||
if high, err := strconv.Atoi(fields[2]); err == nil {
|
||||
groupInfo.High = high
|
||||
}
|
||||
|
||||
return groupInfo, nil
|
||||
}
|
||||
|
||||
// parseArticle parses article data from response lines
|
||||
func (c *Connection) parseArticle(messageID string, lines []string) (*Article, error) {
|
||||
article := &Article{
|
||||
MessageID: messageID,
|
||||
Groups: []string{},
|
||||
}
|
||||
|
||||
headerEnd := -1
|
||||
for i, line := range lines {
|
||||
if line == "" {
|
||||
headerEnd = i
|
||||
break
|
||||
}
|
||||
|
||||
// Parse headers
|
||||
if strings.HasPrefix(line, "Subject: ") {
|
||||
article.Subject = strings.TrimPrefix(line, "Subject: ")
|
||||
} else if strings.HasPrefix(line, "From: ") {
|
||||
article.From = strings.TrimPrefix(line, "From: ")
|
||||
} else if strings.HasPrefix(line, "Date: ") {
|
||||
article.Date = strings.TrimPrefix(line, "Date: ")
|
||||
} else if strings.HasPrefix(line, "Newsgroups: ") {
|
||||
groups := strings.TrimPrefix(line, "Newsgroups: ")
|
||||
article.Groups = strings.Split(groups, ",")
|
||||
for i := range article.Groups {
|
||||
article.Groups[i] = strings.TrimSpace(article.Groups[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Join body lines
|
||||
if headerEnd != -1 && headerEnd+1 < len(lines) {
|
||||
body := strings.Join(lines[headerEnd+1:], "\n")
|
||||
article.Body = []byte(body)
|
||||
article.Size = int64(len(article.Body))
|
||||
}
|
||||
|
||||
return article, nil
|
||||
}
|
||||
|
||||
// close closes the NNTP connection
|
||||
func (c *Connection) close() error {
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DecodeYenc(reader io.Reader) (*yenc.Part, error) {
|
||||
part, err := yenc.Decode(reader)
|
||||
if err != nil {
|
||||
return nil, NewYencDecodeError(fmt.Errorf("failed to create yenc decoder: %w", err))
|
||||
}
|
||||
return part, nil
|
||||
}
|
||||
|
||||
func IsValidMessageID(messageID string) bool {
|
||||
if len(messageID) < 3 {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(messageID, "@")
|
||||
}
|
||||
|
||||
// FormatMessageID ensures message ID has proper format
|
||||
func FormatMessageID(messageID string) string {
|
||||
messageID = strings.TrimSpace(messageID)
|
||||
if !strings.HasPrefix(messageID, "<") {
|
||||
messageID = "<" + messageID
|
||||
}
|
||||
if !strings.HasSuffix(messageID, ">") {
|
||||
messageID = messageID + ">"
|
||||
}
|
||||
return messageID
|
||||
}
|
||||
116
internal/nntp/decoder.go
Normal file
116
internal/nntp/decoder.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// YencMetadata contains just the header information
|
||||
type YencMetadata struct {
|
||||
Name string // filename
|
||||
Size int64 // total file size
|
||||
Part int // part number
|
||||
Total int // total parts
|
||||
Begin int64 // part start byte
|
||||
End int64 // part end byte
|
||||
LineSize int // line length
|
||||
}
|
||||
|
||||
// DecodeYencHeaders extracts only yenc header metadata without decoding body
|
||||
func DecodeYencHeaders(reader io.Reader) (*YencMetadata, error) {
|
||||
buf := bufio.NewReader(reader)
|
||||
metadata := &YencMetadata{}
|
||||
|
||||
// Find and parse =ybegin header
|
||||
if err := parseYBeginHeader(buf, metadata); err != nil {
|
||||
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ybegin header: %w", err))
|
||||
}
|
||||
|
||||
// Parse =ypart header if this is a multipart file
|
||||
if metadata.Part > 0 {
|
||||
if err := parseYPartHeader(buf, metadata); err != nil {
|
||||
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ypart header: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func parseYBeginHeader(buf *bufio.Reader, metadata *YencMetadata) error {
|
||||
var s string
|
||||
var err error
|
||||
|
||||
// Find the =ybegin line
|
||||
for {
|
||||
s, err = buf.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(s) >= 7 && s[:7] == "=ybegin" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the header line
|
||||
parts := strings.SplitN(s[7:], "name=", 2)
|
||||
if len(parts) > 1 {
|
||||
metadata.Name = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
// Parse other parameters
|
||||
for _, header := range strings.Split(parts[0], " ") {
|
||||
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
|
||||
if len(kv) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch kv[0] {
|
||||
case "size":
|
||||
metadata.Size, _ = strconv.ParseInt(kv[1], 10, 64)
|
||||
case "line":
|
||||
metadata.LineSize, _ = strconv.Atoi(kv[1])
|
||||
case "part":
|
||||
metadata.Part, _ = strconv.Atoi(kv[1])
|
||||
case "total":
|
||||
metadata.Total, _ = strconv.Atoi(kv[1])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseYPartHeader(buf *bufio.Reader, metadata *YencMetadata) error {
|
||||
var s string
|
||||
var err error
|
||||
|
||||
// Find the =ypart line
|
||||
for {
|
||||
s, err = buf.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(s) >= 6 && s[:6] == "=ypart" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Parse part parameters
|
||||
for _, header := range strings.Split(s[6:], " ") {
|
||||
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
|
||||
if len(kv) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch kv[0] {
|
||||
case "begin":
|
||||
metadata.Begin, _ = strconv.ParseInt(kv[1], 10, 64)
|
||||
case "end":
|
||||
metadata.End, _ = strconv.ParseInt(kv[1], 10, 64)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
195
internal/nntp/errors.go
Normal file
195
internal/nntp/errors.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Error types for NNTP operations
|
||||
type ErrorType int
|
||||
|
||||
const (
|
||||
ErrorTypeUnknown ErrorType = iota
|
||||
ErrorTypeConnection
|
||||
ErrorTypeAuthentication
|
||||
ErrorTypeTimeout
|
||||
ErrorTypeArticleNotFound
|
||||
ErrorTypeGroupNotFound
|
||||
ErrorTypePermissionDenied
|
||||
ErrorTypeServerBusy
|
||||
ErrorTypeInvalidCommand
|
||||
ErrorTypeProtocol
|
||||
ErrorTypeYencDecode
|
||||
ErrorTypeNoAvailableConnection
|
||||
)
|
||||
|
||||
// Error represents an NNTP-specific error
|
||||
type Error struct {
|
||||
Type ErrorType
|
||||
Code int // NNTP response code
|
||||
Message string // Error message
|
||||
Err error // Underlying error
|
||||
}
|
||||
|
||||
// Predefined errors for common cases
|
||||
var (
|
||||
ErrArticleNotFound = &Error{Type: ErrorTypeArticleNotFound, Code: 430, Message: "article not found"}
|
||||
ErrGroupNotFound = &Error{Type: ErrorTypeGroupNotFound, Code: 411, Message: "group not found"}
|
||||
ErrPermissionDenied = &Error{Type: ErrorTypePermissionDenied, Code: 502, Message: "permission denied"}
|
||||
ErrAuthenticationFail = &Error{Type: ErrorTypeAuthentication, Code: 482, Message: "authentication failed"}
|
||||
ErrServerBusy = &Error{Type: ErrorTypeServerBusy, Code: 400, Message: "server busy"}
|
||||
ErrPoolNotFound = &Error{Type: ErrorTypeUnknown, Code: 0, Message: "NNTP pool not found", Err: nil}
|
||||
ErrNoAvailableConnection = &Error{Type: ErrorTypeNoAvailableConnection, Code: 0, Message: "no available connection in pool", Err: nil}
|
||||
)
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("NNTP %s (code %d): %s - %v", e.Type.String(), e.Code, e.Message, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("NNTP %s (code %d): %s", e.Type.String(), e.Code, e.Message)
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e *Error) Is(target error) bool {
|
||||
if t, ok := target.(*Error); ok {
|
||||
return e.Type == t.Type
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsRetryable returns true if the error might be resolved by retrying
|
||||
func (e *Error) IsRetryable() bool {
|
||||
switch e.Type {
|
||||
case ErrorTypeConnection, ErrorTypeTimeout, ErrorTypeServerBusy:
|
||||
return true
|
||||
case ErrorTypeArticleNotFound, ErrorTypeGroupNotFound, ErrorTypePermissionDenied, ErrorTypeAuthentication:
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldStopParsing returns true if this error should stop the entire parsing process
|
||||
func (e *Error) ShouldStopParsing() bool {
|
||||
switch e.Type {
|
||||
case ErrorTypeAuthentication, ErrorTypePermissionDenied:
|
||||
return true // Critical auth issues
|
||||
case ErrorTypeConnection:
|
||||
return false // Can continue with other connections
|
||||
case ErrorTypeArticleNotFound:
|
||||
return false // Can continue searching for other articles
|
||||
case ErrorTypeServerBusy:
|
||||
return false // Temporary issue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (et ErrorType) String() string {
|
||||
switch et {
|
||||
case ErrorTypeConnection:
|
||||
return "CONNECTION"
|
||||
case ErrorTypeAuthentication:
|
||||
return "AUTHENTICATION"
|
||||
case ErrorTypeTimeout:
|
||||
return "TIMEOUT"
|
||||
case ErrorTypeArticleNotFound:
|
||||
return "ARTICLE_NOT_FOUND"
|
||||
case ErrorTypeGroupNotFound:
|
||||
return "GROUP_NOT_FOUND"
|
||||
case ErrorTypePermissionDenied:
|
||||
return "PERMISSION_DENIED"
|
||||
case ErrorTypeServerBusy:
|
||||
return "SERVER_BUSY"
|
||||
case ErrorTypeInvalidCommand:
|
||||
return "INVALID_COMMAND"
|
||||
case ErrorTypeProtocol:
|
||||
return "PROTOCOL"
|
||||
case ErrorTypeYencDecode:
|
||||
return "YENC_DECODE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to create specific errors
|
||||
func NewConnectionError(err error) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeConnection,
|
||||
Message: "connection failed",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTimeoutError(err error) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeTimeout,
|
||||
Message: "operation timed out",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProtocolError(code int, message string) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeProtocol,
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
func NewYencDecodeError(err error) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeYencDecode,
|
||||
Message: "yEnc decode failed",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// classifyNNTPError classifies an NNTP response code into an error type
|
||||
func classifyNNTPError(code int, message string) *Error {
|
||||
switch {
|
||||
case code == 430 || code == 423:
|
||||
return &Error{Type: ErrorTypeArticleNotFound, Code: code, Message: message}
|
||||
case code == 411:
|
||||
return &Error{Type: ErrorTypeGroupNotFound, Code: code, Message: message}
|
||||
case code == 502 || code == 503:
|
||||
return &Error{Type: ErrorTypePermissionDenied, Code: code, Message: message}
|
||||
case code == 481 || code == 482:
|
||||
return &Error{Type: ErrorTypeAuthentication, Code: code, Message: message}
|
||||
case code == 400:
|
||||
return &Error{Type: ErrorTypeServerBusy, Code: code, Message: message}
|
||||
case code == 500 || code == 501:
|
||||
return &Error{Type: ErrorTypeInvalidCommand, Code: code, Message: message}
|
||||
case code >= 400:
|
||||
return &Error{Type: ErrorTypeProtocol, Code: code, Message: message}
|
||||
default:
|
||||
return &Error{Type: ErrorTypeUnknown, Code: code, Message: message}
|
||||
}
|
||||
}
|
||||
|
||||
func IsArticleNotFoundError(err error) bool {
|
||||
var nntpErr *Error
|
||||
if errors.As(err, &nntpErr) {
|
||||
return nntpErr.Type == ErrorTypeArticleNotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsAuthenticationError(err error) bool {
|
||||
var nntpErr *Error
|
||||
if errors.As(err, &nntpErr) {
|
||||
return nntpErr.Type == ErrorTypeAuthentication
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsRetryableError(err error) bool {
|
||||
var nntpErr *Error
|
||||
if errors.As(err, &nntpErr) {
|
||||
return nntpErr.IsRetryable()
|
||||
}
|
||||
return false
|
||||
}
|
||||
299
internal/nntp/pool.go
Normal file
299
internal/nntp/pool.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sirrobot01/decypharr/internal/config"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pool manages a pool of NNTP connections
|
||||
type Pool struct {
|
||||
address, username, password string
|
||||
maxConns, port int
|
||||
ssl bool
|
||||
useTLS bool
|
||||
connections chan *Connection
|
||||
logger zerolog.Logger
|
||||
closed atomic.Bool
|
||||
totalConnections atomic.Int32
|
||||
activeConnections atomic.Int32
|
||||
}
|
||||
|
||||
// Segment represents a usenet segment
|
||||
type Segment struct {
|
||||
MessageID string
|
||||
Number int
|
||||
Bytes int64
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Article represents a complete usenet article
|
||||
type Article struct {
|
||||
MessageID string
|
||||
Subject string
|
||||
From string
|
||||
Date string
|
||||
Groups []string
|
||||
Body []byte
|
||||
Size int64
|
||||
}
|
||||
|
||||
// Response represents an NNTP server response
|
||||
type Response struct {
|
||||
Code int
|
||||
Message string
|
||||
Lines []string
|
||||
}
|
||||
|
||||
// GroupInfo represents information about a newsgroup
|
||||
type GroupInfo struct {
|
||||
Name string
|
||||
Count int // Number of articles in the group
|
||||
Low int // Lowest article number
|
||||
High int // Highest article number
|
||||
}
|
||||
|
||||
// NewPool creates a new NNTP connection pool
|
||||
func NewPool(provider config.UsenetProvider, logger zerolog.Logger) (*Pool, error) {
|
||||
maxConns := provider.Connections
|
||||
if maxConns <= 0 {
|
||||
maxConns = 1
|
||||
}
|
||||
|
||||
pool := &Pool{
|
||||
address: provider.Host,
|
||||
username: provider.Username,
|
||||
password: provider.Password,
|
||||
port: provider.Port,
|
||||
maxConns: maxConns,
|
||||
ssl: provider.SSL,
|
||||
useTLS: provider.UseTLS,
|
||||
connections: make(chan *Connection, maxConns),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return pool.initializeConnections()
|
||||
}
|
||||
|
||||
func (p *Pool) initializeConnections() (*Pool, error) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
var successfulConnections []*Connection
|
||||
var errs []error
|
||||
|
||||
// Create connections concurrently
|
||||
for i := 0; i < p.maxConns; i++ {
|
||||
wg.Add(1)
|
||||
go func(connIndex int) {
|
||||
defer wg.Done()
|
||||
|
||||
conn, err := p.createConnection()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
} else {
|
||||
successfulConnections = append(successfulConnections, conn)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all connection attempts to complete
|
||||
wg.Wait()
|
||||
|
||||
// Add successful connections to the pool
|
||||
for _, conn := range successfulConnections {
|
||||
p.connections <- conn
|
||||
}
|
||||
p.totalConnections.Store(int32(len(successfulConnections)))
|
||||
|
||||
if len(successfulConnections) == 0 {
|
||||
return nil, fmt.Errorf("failed to create any connections: %v", errs)
|
||||
}
|
||||
|
||||
// Log results
|
||||
p.logger.Info().
|
||||
Str("server", p.address).
|
||||
Int("port", p.port).
|
||||
Int("requested_connections", p.maxConns).
|
||||
Int("successful_connections", len(successfulConnections)).
|
||||
Int("failed_connections", len(errs)).
|
||||
Msg("NNTP connection pool created")
|
||||
|
||||
// If some connections failed, log a warning but continue
|
||||
if len(errs) > 0 {
|
||||
p.logger.Warn().
|
||||
Int("failed_count", len(errs)).
|
||||
Msg("Some connections failed during pool initialization")
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Get retrieves a connection from the pool
|
||||
func (p *Pool) Get(ctx context.Context) (*Connection, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, NewConnectionError(fmt.Errorf("connection pool is closed"))
|
||||
}
|
||||
|
||||
select {
|
||||
case conn := <-p.connections:
|
||||
if conn == nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("received nil connection from pool"))
|
||||
}
|
||||
p.activeConnections.Add(1)
|
||||
|
||||
if err := conn.ping(); err != nil {
|
||||
p.activeConnections.Add(-1)
|
||||
err := conn.close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create a new connection
|
||||
newConn, err := p.createConnection()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to create replacement connection: %w", err))
|
||||
}
|
||||
p.activeConnections.Add(1)
|
||||
return newConn, nil
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, NewTimeoutError(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// Put returns a connection to the pool
|
||||
func (p *Pool) Put(conn *Connection) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer p.activeConnections.Add(-1)
|
||||
|
||||
if p.closed.Load() {
|
||||
conn.close()
|
||||
return
|
||||
}
|
||||
|
||||
// Try non-blocking first
|
||||
select {
|
||||
case p.connections <- conn:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// If pool is full, this usually means we have too many connections
|
||||
// Force return by making space (close oldest connection)
|
||||
select {
|
||||
case oldConn := <-p.connections:
|
||||
oldConn.close() // Close the old connection
|
||||
p.connections <- conn // Put the new one back
|
||||
case <-time.After(1 * time.Second):
|
||||
// Still can't return - close this connection
|
||||
conn.close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all connections in the pool
|
||||
func (p *Pool) Close() error {
|
||||
|
||||
if p.closed.Load() {
|
||||
return nil
|
||||
}
|
||||
p.closed.Store(true)
|
||||
|
||||
close(p.connections)
|
||||
for conn := range p.connections {
|
||||
err := conn.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.Info().Msg("NNTP connection pool closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createConnection creates a new NNTP connection with proper error handling
|
||||
func (p *Pool) createConnection() (*Connection, error) {
|
||||
addr := fmt.Sprintf("%s:%d", p.address, p.port)
|
||||
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
||||
if p.ssl {
|
||||
conn, err = tls.DialWithDialer(&net.Dialer{}, "tcp", addr, &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
})
|
||||
} else {
|
||||
conn, err = net.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to connect to %s: %w", addr, err))
|
||||
}
|
||||
|
||||
reader := bufio.NewReaderSize(conn, 256*1024) // 256KB buffer for better performance
|
||||
writer := bufio.NewWriterSize(conn, 256*1024) // 256KB buffer for better performance
|
||||
text := textproto.NewConn(conn)
|
||||
|
||||
nntpConn := &Connection{
|
||||
username: p.username,
|
||||
password: p.password,
|
||||
address: p.address,
|
||||
port: p.port,
|
||||
conn: conn,
|
||||
text: text,
|
||||
reader: reader,
|
||||
writer: writer,
|
||||
logger: p.logger,
|
||||
}
|
||||
|
||||
// Read welcome message
|
||||
_, err = nntpConn.readResponse()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read welcome message: %w", err))
|
||||
}
|
||||
|
||||
// Authenticate if credentials are provided
|
||||
if p.username != "" && p.password != "" {
|
||||
if err := nntpConn.authenticate(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err // authenticate() already returns NNTPError
|
||||
}
|
||||
}
|
||||
|
||||
// Enable TLS if requested (STARTTLS)
|
||||
if p.useTLS && !p.ssl {
|
||||
if err := nntpConn.startTLS(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err // startTLS() already returns NNTPError
|
||||
}
|
||||
}
|
||||
return nntpConn, nil
|
||||
}
|
||||
|
||||
func (p *Pool) ConnectionCount() int {
|
||||
return int(p.totalConnections.Load())
|
||||
}
|
||||
|
||||
func (p *Pool) ActiveConnections() int {
|
||||
return int(p.activeConnections.Load())
|
||||
}
|
||||
|
||||
func (p *Pool) IsFree() bool {
|
||||
return p.ActiveConnections() < p.maxConns
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sirrobot01/decypharr/internal/logger"
|
||||
@@ -180,8 +179,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
|
||||
resp, err = c.doRequest(req)
|
||||
if err != nil {
|
||||
// Check if this is a network error that might be worth retrying
|
||||
if isRetryableError(err) && attempt < c.maxRetries {
|
||||
if attempt < c.maxRetries {
|
||||
// Apply backoff with jitter
|
||||
jitter := time.Duration(rand.Int63n(int64(backoff / 4)))
|
||||
sleepTime := backoff + jitter
|
||||
@@ -390,30 +388,3 @@ func Default() *Client {
|
||||
})
|
||||
return instance
|
||||
}
|
||||
|
||||
func isRetryableError(err error) bool {
|
||||
errString := err.Error()
|
||||
|
||||
// Connection reset and other network errors
|
||||
if strings.Contains(errString, "connection reset by peer") ||
|
||||
strings.Contains(errString, "read: connection reset") ||
|
||||
strings.Contains(errString, "connection refused") ||
|
||||
strings.Contains(errString, "network is unreachable") ||
|
||||
strings.Contains(errString, "connection timed out") ||
|
||||
strings.Contains(errString, "no such host") ||
|
||||
strings.Contains(errString, "i/o timeout") ||
|
||||
strings.Contains(errString, "unexpected EOF") ||
|
||||
strings.Contains(errString, "TLS handshake timeout") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for net.Error type which can provide more information
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
// Retry on timeout errors and temporary errors
|
||||
return netErr.Timeout()
|
||||
}
|
||||
|
||||
// Not a retryable error
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RemoveItem[S ~[]E, E comparable](s S, values ...E) S {
|
||||
result := make(S, 0, len(s))
|
||||
outer:
|
||||
@@ -22,3 +33,131 @@ func Contains(slice []string, value string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func GenerateHash(data string) string {
|
||||
// Simple hash generation using a basic algorithm (for demonstration purposes)
|
||||
_hash := 0
|
||||
for _, char := range data {
|
||||
_hash = (_hash*31 + int(char)) % 1000003 // Simple hash function
|
||||
}
|
||||
return string(rune(_hash))
|
||||
}
|
||||
|
||||
func DownloadFile(url string) (string, []byte, error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to download file: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", nil, fmt.Errorf("failed to download file: status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
filename := getFilenameFromResponse(resp, url)
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
return filename, data, nil
|
||||
}
|
||||
|
||||
func getFilenameFromResponse(resp *http.Response, originalURL string) string {
|
||||
// 1. Try Content-Disposition header
|
||||
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
|
||||
if _, params, err := mime.ParseMediaType(cd); err == nil {
|
||||
if filename := params["filename"]; filename != "" {
|
||||
return filename
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Try to decode URL-encoded filename from Content-Disposition
|
||||
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
|
||||
if strings.Contains(cd, "filename*=") {
|
||||
// Handle RFC 5987 encoded filenames
|
||||
parts := strings.Split(cd, "filename*=")
|
||||
if len(parts) > 1 {
|
||||
encoded := strings.Trim(parts[1], `"`)
|
||||
if strings.HasPrefix(encoded, "UTF-8''") {
|
||||
if decoded, err := url.QueryUnescape(encoded[7:]); err == nil {
|
||||
return decoded
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fall back to URL path
|
||||
if parsedURL, err := url.Parse(originalURL); err == nil {
|
||||
if filename := filepath.Base(parsedURL.Path); filename != "." && filename != "/" {
|
||||
// URL decode the filename
|
||||
if decoded, err := url.QueryUnescape(filename); err == nil {
|
||||
return decoded
|
||||
}
|
||||
return filename
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Default filename
|
||||
return "downloaded_file"
|
||||
}
|
||||
|
||||
func ValidateServiceURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("URL cannot be empty")
|
||||
}
|
||||
|
||||
// Try parsing as full URL first
|
||||
u, err := url.Parse(urlStr)
|
||||
if err == nil && u.Scheme != "" && u.Host != "" {
|
||||
// It's a full URL, validate scheme
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return fmt.Errorf("URL scheme must be http or https")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if it's a host:port format (no scheme)
|
||||
if strings.Contains(urlStr, ":") && !strings.Contains(urlStr, "://") {
|
||||
// Try parsing with http:// prefix
|
||||
testURL := "http://" + urlStr
|
||||
u, err := url.Parse(testURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid host:port format: %w", err)
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("host is required in host:port format")
|
||||
}
|
||||
|
||||
// Validate port number
|
||||
if u.Port() == "" {
|
||||
return fmt.Errorf("port is required in host:port format")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid URL format: %s", urlStr)
|
||||
}
|
||||
|
||||
func ExtractFilenameFromURL(rawURL string) string {
|
||||
// Parse the URL
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Get the base filename from path
|
||||
filename := path.Base(parsedURL.Path)
|
||||
|
||||
// Handle edge cases
|
||||
if filename == "/" || filename == "." || filename == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return filename
|
||||
}
|
||||
|
||||
@@ -57,3 +57,15 @@ func IsSampleFile(path string) bool {
|
||||
}
|
||||
return RegexMatch(sampleRegex, path)
|
||||
}
|
||||
|
||||
func IsParFile(path string) bool {
|
||||
ext := filepath.Ext(path)
|
||||
return strings.EqualFold(ext, ".par") || strings.EqualFold(ext, ".par2")
|
||||
}
|
||||
|
||||
func IsRarFile(path string) bool {
|
||||
ext := filepath.Ext(path)
|
||||
return strings.EqualFold(ext, ".rar") || strings.EqualFold(ext, ".r00") ||
|
||||
strings.EqualFold(ext, ".r01") || strings.EqualFold(ext, ".r02") ||
|
||||
strings.EqualFold(ext, ".r03") || strings.EqualFold(ext, ".r04")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user