Implementing a streaming setup with Usenet

This commit is contained in:
Mukhtar Akere
2025-08-01 15:27:24 +01:00
parent afe577bf2f
commit f9861e3b54
65 changed files with 9437 additions and 924 deletions

View File

@@ -47,7 +47,6 @@ type Debrid struct {
type QBitTorrent struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Port string `json:"port,omitempty"` // deprecated
DownloadFolder string `json:"download_folder,omitempty"`
Categories []string `json:"categories,omitempty"`
RefreshInterval int `json:"refresh_interval,omitempty"`
@@ -82,26 +81,55 @@ type Auth struct {
Password string `json:"password,omitempty"`
}
type SABnzbd struct {
DownloadFolder string `json:"download_folder,omitempty"`
RefreshInterval int `json:"refresh_interval,omitempty"`
Categories []string `json:"categories,omitempty"`
}
type Usenet struct {
Providers []UsenetProvider `json:"providers,omitempty"` // List of usenet providers
MountFolder string `json:"mount_folder,omitempty"` // Folder where usenet downloads are mounted
SkipPreCache bool `json:"skip_pre_cache,omitempty"`
Chunks int `json:"chunks,omitempty"` // Number of chunks to pre-cache
RcUrl string `json:"rc_url,omitempty"` // Rclone RC URL for the webdav
RcUser string `json:"rc_user,omitempty"` // Rclone RC username
RcPass string `json:"rc_pass,omitempty"` // Rclone RC password
}
type UsenetProvider struct {
Name string `json:"name,omitempty"`
Host string `json:"host,omitempty"` // Host of the usenet server
Port int `json:"port,omitempty"` // Port of the usenet server
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Connections int `json:"connections,omitempty"` // Number of connections to use
SSL bool `json:"ssl,omitempty"` // Use SSL for the connection
UseTLS bool `json:"use_tls,omitempty"` // Use TLS for the connection
}
type Config struct {
// server
BindAddress string `json:"bind_address,omitempty"`
URLBase string `json:"url_base,omitempty"`
Port string `json:"port,omitempty"`
LogLevel string `json:"log_level,omitempty"`
Debrids []Debrid `json:"debrids,omitempty"`
QBitTorrent QBitTorrent `json:"qbittorrent,omitempty"`
Arrs []Arr `json:"arrs,omitempty"`
Repair Repair `json:"repair,omitempty"`
WebDav WebDav `json:"webdav,omitempty"`
AllowedExt []string `json:"allowed_file_types,omitempty"`
MinFileSize string `json:"min_file_size,omitempty"` // Minimum file size to download, 10MB, 1GB, etc
MaxFileSize string `json:"max_file_size,omitempty"` // Maximum file size to download (0 means no limit)
Path string `json:"-"` // Path to save the config file
UseAuth bool `json:"use_auth,omitempty"`
Auth *Auth `json:"-"`
DiscordWebhook string `json:"discord_webhook_url,omitempty"`
RemoveStalledAfter string `json:"remove_stalled_after,omitzero"`
LogLevel string `json:"log_level,omitempty"`
Debrids []Debrid `json:"debrids,omitempty"`
QBitTorrent *QBitTorrent `json:"qbittorrent,omitempty"`
SABnzbd *SABnzbd `json:"sabnzbd,omitempty"`
Usenet *Usenet `json:"usenet,omitempty"` // Usenet configuration
Arrs []Arr `json:"arrs,omitempty"`
Repair Repair `json:"repair,omitempty"`
WebDav WebDav `json:"webdav,omitempty"`
AllowedExt []string `json:"allowed_file_types,omitempty"`
MinFileSize string `json:"min_file_size,omitempty"` // Minimum file size to download, 10MB, 1GB, etc
MaxFileSize string `json:"max_file_size,omitempty"` // Maximum file size to download (0 means no limit)
Path string `json:"-"` // Path to save the config file
UseAuth bool `json:"use_auth,omitempty"`
Auth *Auth `json:"-"`
DiscordWebhook string `json:"discord_webhook_url,omitempty"`
RemoveStalledAfter string `json:"remove_stalled_after,omitzero"`
}
func (c *Config) JsonFile() string {
@@ -115,6 +143,10 @@ func (c *Config) TorrentsFile() string {
return filepath.Join(c.Path, "torrents.json")
}
func (c *Config) NZBsPath() string {
return filepath.Join(c.Path, "cache/nzbs")
}
func (c *Config) loadConfig() error {
// Load the config file
if configPath == "" {
@@ -142,9 +174,6 @@ func (c *Config) loadConfig() error {
}
func validateDebrids(debrids []Debrid) error {
if len(debrids) == 0 {
return errors.New("no debrids configured")
}
for _, debrid := range debrids {
// Basic field validation
@@ -159,17 +188,51 @@ func validateDebrids(debrids []Debrid) error {
return nil
}
func validateQbitTorrent(config *QBitTorrent) error {
if config.DownloadFolder == "" {
return errors.New("qbittorent download folder is required")
func validateUsenet(usenet *Usenet) error {
if usenet == nil {
return nil // No usenet configuration provided
}
if _, err := os.Stat(config.DownloadFolder); os.IsNotExist(err) {
return fmt.Errorf("qbittorent download folder(%s) does not exist", config.DownloadFolder)
for _, usenet := range usenet.Providers {
// Basic field validation
if usenet.Host == "" {
return errors.New("usenet host is required")
}
if usenet.Username == "" {
return errors.New("usenet username is required")
}
if usenet.Password == "" {
return errors.New("usenet password is required")
}
}
return nil
}
func validateSabznbd(config *SABnzbd) error {
if config == nil {
return nil // No SABnzbd configuration provided
}
if config.DownloadFolder != "" {
if _, err := os.Stat(config.DownloadFolder); os.IsNotExist(err) {
return fmt.Errorf("sabnzbd download folder(%s) does not exist", config.DownloadFolder)
}
}
return nil
}
func validateRepair(config *Repair) error {
func validateQbitTorrent(config *QBitTorrent) error {
if config == nil {
return nil // No qBittorrent configuration provided
}
if config.DownloadFolder != "" {
if _, err := os.Stat(config.DownloadFolder); os.IsNotExist(err) {
return fmt.Errorf("qbittorent download folder(%s) does not exist", config.DownloadFolder)
}
}
return nil
}
func validateRepair(config Repair) error {
if !config.Enabled {
return nil
}
@@ -181,19 +244,34 @@ func validateRepair(config *Repair) error {
func ValidateConfig(config *Config) error {
// Run validations concurrently
// Check if there's at least one debrid or usenet configured
hasUsenet := false
if config.Usenet != nil && len(config.Usenet.Providers) > 0 {
hasUsenet = true
}
if len(config.Debrids) == 0 && !hasUsenet {
return errors.New("at least one debrid or usenet provider must be configured")
}
if err := validateDebrids(config.Debrids); err != nil {
return err
}
if err := validateQbitTorrent(&config.QBitTorrent); err != nil {
if err := validateUsenet(config.Usenet); err != nil {
return err
}
if err := validateRepair(&config.Repair); err != nil {
if err := validateSabznbd(config.SABnzbd); err != nil {
return err
}
if err := validateQbitTorrent(config.QBitTorrent); err != nil {
return err
}
if err := validateRepair(config.Repair); err != nil {
return err
}
return nil
}
@@ -299,6 +377,10 @@ func (c *Config) updateDebrid(d Debrid) Debrid {
}
d.DownloadAPIKeys = downloadKeys
if d.Workers == 0 {
d.Workers = perDebrid
}
if !d.UseWebDav {
return d
}
@@ -309,9 +391,6 @@ func (c *Config) updateDebrid(d Debrid) Debrid {
if d.WebDav.DownloadLinksRefreshInterval == "" {
d.DownloadLinksRefreshInterval = cmp.Or(c.WebDav.DownloadLinksRefreshInterval, "40m") // 40 minutes
}
if d.Workers == 0 {
d.Workers = perDebrid
}
if d.FolderNaming == "" {
d.FolderNaming = cmp.Or(c.WebDav.FolderNaming, "original_no_ext")
}
@@ -338,17 +417,47 @@ func (c *Config) updateDebrid(d Debrid) Debrid {
return d
}
func (c *Config) updateUsenet(u UsenetProvider) UsenetProvider {
if u.Name == "" {
parts := strings.Split(u.Host, ".")
if len(parts) >= 2 {
u.Name = parts[len(parts)-2] // Gets "example" from "news.example.com"
} else {
u.Name = u.Host // Fallback to host if it doesn't look like a domain
}
}
if u.Port == 0 {
u.Port = 119 // Default port for usenet
}
if u.Connections == 0 {
u.Connections = 30 // Default connections
}
if u.SSL && !u.UseTLS {
u.UseTLS = true // Use TLS if SSL is enabled
}
return u
}
func (c *Config) setDefaults() {
for i, debrid := range c.Debrids {
c.Debrids[i] = c.updateDebrid(debrid)
}
if c.SABnzbd != nil {
c.SABnzbd.RefreshInterval = cmp.Or(c.SABnzbd.RefreshInterval, 10) // Default to 10 seconds
}
if c.Usenet != nil {
c.Usenet.Chunks = cmp.Or(c.Usenet.Chunks, 5)
for i, provider := range c.Usenet.Providers {
c.Usenet.Providers[i] = c.updateUsenet(provider)
}
}
if len(c.AllowedExt) == 0 {
c.AllowedExt = getDefaultExtensions()
}
c.Port = cmp.Or(c.Port, c.QBitTorrent.Port)
if c.URLBase == "" {
c.URLBase = "/"
}
@@ -395,11 +504,6 @@ func (c *Config) createConfig(path string) error {
c.Port = "8282"
c.LogLevel = "info"
c.UseAuth = true
c.QBitTorrent = QBitTorrent{
DownloadFolder: filepath.Join(path, "downloads"),
Categories: []string{"sonarr", "radarr"},
RefreshInterval: 15,
}
return nil
}
@@ -408,7 +512,3 @@ func Reload() {
instance = nil
once = sync.Once{}
}
func DefaultFreeSlot() int {
return 10
}

178
internal/nntp/client.go Normal file
View File

@@ -0,0 +1,178 @@
package nntp
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/sirrobot01/decypharr/internal/config"
"github.com/sirrobot01/decypharr/internal/logger"
"sync/atomic"
"time"
)
// Client represents a failover NNTP client that manages multiple providers
type Client struct {
providers []config.UsenetProvider
pools *xsync.Map[string, *Pool]
logger zerolog.Logger
closed atomic.Bool
minimumMaxConns int // Minimum number of max connections across all pools
}
func NewClient(providers []config.UsenetProvider) (*Client, error) {
client := &Client{
providers: providers,
logger: logger.New("nntp"),
pools: xsync.NewMap[string, *Pool](),
}
if len(providers) == 0 {
return nil, fmt.Errorf("no NNTP providers configured")
}
return client, nil
}
func (c *Client) InitPools() error {
var initErrors []error
successfulPools := 0
for _, provider := range c.providers {
serverPool, err := NewPool(provider, c.logger)
if err != nil {
c.logger.Error().
Err(err).
Str("server", provider.Host).
Int("port", provider.Port).
Msg("Failed to initialize server pool")
initErrors = append(initErrors, err)
continue
}
if c.minimumMaxConns == 0 {
// Set minimumMaxConns to the max connections of the first successful pool
c.minimumMaxConns = serverPool.ConnectionCount()
} else {
c.minimumMaxConns = min(c.minimumMaxConns, serverPool.ConnectionCount())
}
c.pools.Store(provider.Name, serverPool)
successfulPools++
}
if successfulPools == 0 {
return fmt.Errorf("failed to initialize any server pools: %v", initErrors)
}
c.logger.Info().
Int("providers", len(c.providers)).
Msg("NNTP client created")
return nil
}
func (c *Client) Close() {
if c.closed.Load() {
c.logger.Warn().Msg("NNTP client already closed")
return
}
c.pools.Range(func(key string, value *Pool) bool {
if value != nil {
err := value.Close()
if err != nil {
return false
}
}
return true
})
c.closed.Store(true)
c.logger.Info().Msg("NNTP client closed")
}
func (c *Client) GetConnection(ctx context.Context) (*Connection, func(), error) {
if c.closed.Load() {
return nil, nil, fmt.Errorf("nntp client is closed")
}
// Prevent workers from waiting too long for connections
connCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
providerCount := len(c.providers)
for _, provider := range c.providers {
pool, ok := c.pools.Load(provider.Name)
if !ok {
return nil, nil, fmt.Errorf("no pool found for provider %s", provider.Name)
}
if !pool.IsFree() && providerCount > 1 {
continue
}
conn, err := pool.Get(connCtx) // Use timeout context
if err != nil {
if errors.Is(err, ErrNoAvailableConnection) || errors.Is(err, context.DeadlineExceeded) {
continue
}
return nil, nil, fmt.Errorf("error getting connection from provider %s: %w", provider.Name, err)
}
if conn == nil {
continue
}
return conn, func() { pool.Put(conn) }, nil
}
return nil, nil, ErrNoAvailableConnection
}
func (c *Client) DownloadHeader(ctx context.Context, messageID string) (*YencMetadata, error) {
conn, cleanup, err := c.GetConnection(ctx)
if err != nil {
return nil, err
}
defer cleanup()
data, err := conn.GetBody(messageID)
if err != nil {
return nil, err
}
// yEnc decode
part, err := DecodeYencHeaders(bytes.NewReader(data))
if err != nil || part == nil {
return nil, fmt.Errorf("failed to decode segment")
}
// Return both the filename and decoded data
return part, nil
}
func (c *Client) MinimumMaxConns() int {
return c.minimumMaxConns
}
func (c *Client) TotalActiveConnections() int {
total := 0
c.pools.Range(func(key string, value *Pool) bool {
if value != nil {
total += value.ActiveConnections()
}
return true
})
return total
}
func (c *Client) Pools() *xsync.Map[string, *Pool] {
return c.pools
}
func (c *Client) GetProviders() []config.UsenetProvider {
return c.providers
}

394
internal/nntp/conns.go Normal file
View File

@@ -0,0 +1,394 @@
package nntp
import (
"bufio"
"crypto/tls"
"fmt"
"github.com/chrisfarms/yenc"
"github.com/rs/zerolog"
"io"
"net"
"net/textproto"
"strconv"
"strings"
)
// Connection represents an NNTP connection
type Connection struct {
username, password, address string
port int
conn net.Conn
text *textproto.Conn
reader *bufio.Reader
writer *bufio.Writer
logger zerolog.Logger
}
func (c *Connection) authenticate() error {
// Send AUTHINFO USER command
if err := c.sendCommand(fmt.Sprintf("AUTHINFO USER %s", c.username)); err != nil {
return NewConnectionError(fmt.Errorf("failed to send username: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return NewConnectionError(fmt.Errorf("failed to read user response: %w", err))
}
if resp.Code != 381 {
return classifyNNTPError(resp.Code, fmt.Sprintf("unexpected response to AUTHINFO USER: %s", resp.Message))
}
// Send AUTHINFO PASS command
if err := c.sendCommand(fmt.Sprintf("AUTHINFO PASS %s", c.password)); err != nil {
return NewConnectionError(fmt.Errorf("failed to send password: %w", err))
}
resp, err = c.readResponse()
if err != nil {
return NewConnectionError(fmt.Errorf("failed to read password response: %w", err))
}
if resp.Code != 281 {
return classifyNNTPError(resp.Code, fmt.Sprintf("authentication failed: %s", resp.Message))
}
return nil
}
// startTLS initiates TLS encryption with proper error handling
func (c *Connection) startTLS() error {
if err := c.sendCommand("STARTTLS"); err != nil {
return NewConnectionError(fmt.Errorf("failed to send STARTTLS: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return NewConnectionError(fmt.Errorf("failed to read STARTTLS response: %w", err))
}
if resp.Code != 382 {
return classifyNNTPError(resp.Code, fmt.Sprintf("STARTTLS not supported: %s", resp.Message))
}
// Upgrade connection to TLS
tlsConn := tls.Client(c.conn, &tls.Config{
ServerName: c.address,
InsecureSkipVerify: false,
})
c.conn = tlsConn
c.reader = bufio.NewReader(tlsConn)
c.writer = bufio.NewWriter(tlsConn)
c.text = textproto.NewConn(tlsConn)
c.logger.Debug().Msg("TLS encryption enabled")
return nil
}
// ping sends a simple command to test the connection
func (c *Connection) ping() error {
if err := c.sendCommand("DATE"); err != nil {
return NewConnectionError(err)
}
_, err := c.readResponse()
if err != nil {
return NewConnectionError(err)
}
return nil
}
// sendCommand sends a command to the NNTP server
func (c *Connection) sendCommand(command string) error {
_, err := fmt.Fprintf(c.writer, "%s\r\n", command)
if err != nil {
return err
}
return c.writer.Flush()
}
// readResponse reads a response from the NNTP server
func (c *Connection) readResponse() (*Response, error) {
line, err := c.text.ReadLine()
if err != nil {
return nil, err
}
parts := strings.SplitN(line, " ", 2)
code, err := strconv.Atoi(parts[0])
if err != nil {
return nil, fmt.Errorf("invalid response code: %s", parts[0])
}
message := ""
if len(parts) > 1 {
message = parts[1]
}
return &Response{
Code: code,
Message: message,
}, nil
}
// readMultilineResponse reads a multiline response
func (c *Connection) readMultilineResponse() (*Response, error) {
resp, err := c.readResponse()
if err != nil {
return nil, err
}
// Check if this is a multiline response
if resp.Code < 200 || resp.Code >= 300 {
return resp, nil
}
lines, err := c.text.ReadDotLines()
if err != nil {
return nil, err
}
resp.Lines = lines
return resp, nil
}
// GetArticle retrieves an article by message ID with proper error classification
func (c *Connection) GetArticle(messageID string) (*Article, error) {
messageID = FormatMessageID(messageID)
if err := c.sendCommand(fmt.Sprintf("ARTICLE %s", messageID)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send ARTICLE command: %w", err))
}
resp, err := c.readMultilineResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read article response: %w", err))
}
if resp.Code != 220 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
return c.parseArticle(messageID, resp.Lines)
}
// GetBody retrieves article body by message ID with proper error classification
func (c *Connection) GetBody(messageID string) ([]byte, error) {
messageID = FormatMessageID(messageID)
if err := c.sendCommand(fmt.Sprintf("BODY %s", messageID)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send BODY command: %w", err))
}
// Read the initial response
resp, err := c.readResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read body response: %w", err))
}
if resp.Code != 222 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
// Read the raw body data directly using textproto to preserve exact formatting for yEnc
lines, err := c.text.ReadDotLines()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read body data: %w", err))
}
// Join with \r\n to preserve original line endings and add final \r\n
body := strings.Join(lines, "\r\n")
if len(lines) > 0 {
body += "\r\n"
}
return []byte(body), nil
}
// GetHead retrieves article headers by message ID
func (c *Connection) GetHead(messageID string) ([]byte, error) {
messageID = FormatMessageID(messageID)
if err := c.sendCommand(fmt.Sprintf("HEAD %s", messageID)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send HEAD command: %w", err))
}
// Read the initial response
resp, err := c.readResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read head response: %w", err))
}
if resp.Code != 221 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
// Read the header data using textproto
lines, err := c.text.ReadDotLines()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read header data: %w", err))
}
// Join with \r\n to preserve original line endings and add final \r\n
headers := strings.Join(lines, "\r\n")
if len(lines) > 0 {
headers += "\r\n"
}
return []byte(headers), nil
}
// GetSegment retrieves a specific segment with proper error handling
func (c *Connection) GetSegment(messageID string, segmentNumber int) (*Segment, error) {
messageID = FormatMessageID(messageID)
body, err := c.GetBody(messageID)
if err != nil {
return nil, err // GetBody already returns classified errors
}
return &Segment{
MessageID: messageID,
Number: segmentNumber,
Bytes: int64(len(body)),
Data: body,
}, nil
}
// Stat retrieves article statistics by message ID with proper error classification
func (c *Connection) Stat(messageID string) (articleNumber int, echoedID string, err error) {
messageID = FormatMessageID(messageID)
if err = c.sendCommand(fmt.Sprintf("STAT %s", messageID)); err != nil {
return 0, "", NewConnectionError(fmt.Errorf("failed to send STAT: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return 0, "", NewConnectionError(fmt.Errorf("failed to read STAT response: %w", err))
}
if resp.Code != 223 {
return 0, "", classifyNNTPError(resp.Code, resp.Message)
}
fields := strings.Fields(resp.Message)
if len(fields) < 2 {
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("unexpected STAT response format: %q", resp.Message))
}
if articleNumber, err = strconv.Atoi(fields[0]); err != nil {
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("invalid article number %q: %v", fields[0], err))
}
echoedID = fields[1]
return articleNumber, echoedID, nil
}
// SelectGroup selects a newsgroup and returns group information
func (c *Connection) SelectGroup(groupName string) (*GroupInfo, error) {
if err := c.sendCommand(fmt.Sprintf("GROUP %s", groupName)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send GROUP command: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read GROUP response: %w", err))
}
if resp.Code != 211 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
// Parse GROUP response: "211 number low high group-name"
fields := strings.Fields(resp.Message)
if len(fields) < 4 {
return nil, NewProtocolError(resp.Code, fmt.Sprintf("unexpected GROUP response format: %q", resp.Message))
}
groupInfo := &GroupInfo{
Name: groupName,
}
if count, err := strconv.Atoi(fields[0]); err == nil {
groupInfo.Count = count
}
if low, err := strconv.Atoi(fields[1]); err == nil {
groupInfo.Low = low
}
if high, err := strconv.Atoi(fields[2]); err == nil {
groupInfo.High = high
}
return groupInfo, nil
}
// parseArticle parses article data from response lines
func (c *Connection) parseArticle(messageID string, lines []string) (*Article, error) {
article := &Article{
MessageID: messageID,
Groups: []string{},
}
headerEnd := -1
for i, line := range lines {
if line == "" {
headerEnd = i
break
}
// Parse headers
if strings.HasPrefix(line, "Subject: ") {
article.Subject = strings.TrimPrefix(line, "Subject: ")
} else if strings.HasPrefix(line, "From: ") {
article.From = strings.TrimPrefix(line, "From: ")
} else if strings.HasPrefix(line, "Date: ") {
article.Date = strings.TrimPrefix(line, "Date: ")
} else if strings.HasPrefix(line, "Newsgroups: ") {
groups := strings.TrimPrefix(line, "Newsgroups: ")
article.Groups = strings.Split(groups, ",")
for i := range article.Groups {
article.Groups[i] = strings.TrimSpace(article.Groups[i])
}
}
}
// Join body lines
if headerEnd != -1 && headerEnd+1 < len(lines) {
body := strings.Join(lines[headerEnd+1:], "\n")
article.Body = []byte(body)
article.Size = int64(len(article.Body))
}
return article, nil
}
// close closes the NNTP connection
func (c *Connection) close() error {
if c.conn != nil {
return c.conn.Close()
}
return nil
}
func DecodeYenc(reader io.Reader) (*yenc.Part, error) {
part, err := yenc.Decode(reader)
if err != nil {
return nil, NewYencDecodeError(fmt.Errorf("failed to create yenc decoder: %w", err))
}
return part, nil
}
func IsValidMessageID(messageID string) bool {
if len(messageID) < 3 {
return false
}
return strings.Contains(messageID, "@")
}
// FormatMessageID ensures message ID has proper format
func FormatMessageID(messageID string) string {
messageID = strings.TrimSpace(messageID)
if !strings.HasPrefix(messageID, "<") {
messageID = "<" + messageID
}
if !strings.HasSuffix(messageID, ">") {
messageID = messageID + ">"
}
return messageID
}

116
internal/nntp/decoder.go Normal file
View File

@@ -0,0 +1,116 @@
package nntp
import (
"bufio"
"fmt"
"io"
"strconv"
"strings"
)
// YencMetadata contains just the header information
type YencMetadata struct {
Name string // filename
Size int64 // total file size
Part int // part number
Total int // total parts
Begin int64 // part start byte
End int64 // part end byte
LineSize int // line length
}
// DecodeYencHeaders extracts only yenc header metadata without decoding body
func DecodeYencHeaders(reader io.Reader) (*YencMetadata, error) {
buf := bufio.NewReader(reader)
metadata := &YencMetadata{}
// Find and parse =ybegin header
if err := parseYBeginHeader(buf, metadata); err != nil {
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ybegin header: %w", err))
}
// Parse =ypart header if this is a multipart file
if metadata.Part > 0 {
if err := parseYPartHeader(buf, metadata); err != nil {
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ypart header: %w", err))
}
}
return metadata, nil
}
func parseYBeginHeader(buf *bufio.Reader, metadata *YencMetadata) error {
var s string
var err error
// Find the =ybegin line
for {
s, err = buf.ReadString('\n')
if err != nil {
return err
}
if len(s) >= 7 && s[:7] == "=ybegin" {
break
}
}
// Parse the header line
parts := strings.SplitN(s[7:], "name=", 2)
if len(parts) > 1 {
metadata.Name = strings.TrimSpace(parts[1])
}
// Parse other parameters
for _, header := range strings.Split(parts[0], " ") {
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
if len(kv) < 2 {
continue
}
switch kv[0] {
case "size":
metadata.Size, _ = strconv.ParseInt(kv[1], 10, 64)
case "line":
metadata.LineSize, _ = strconv.Atoi(kv[1])
case "part":
metadata.Part, _ = strconv.Atoi(kv[1])
case "total":
metadata.Total, _ = strconv.Atoi(kv[1])
}
}
return nil
}
func parseYPartHeader(buf *bufio.Reader, metadata *YencMetadata) error {
var s string
var err error
// Find the =ypart line
for {
s, err = buf.ReadString('\n')
if err != nil {
return err
}
if len(s) >= 6 && s[:6] == "=ypart" {
break
}
}
// Parse part parameters
for _, header := range strings.Split(s[6:], " ") {
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
if len(kv) < 2 {
continue
}
switch kv[0] {
case "begin":
metadata.Begin, _ = strconv.ParseInt(kv[1], 10, 64)
case "end":
metadata.End, _ = strconv.ParseInt(kv[1], 10, 64)
}
}
return nil
}

195
internal/nntp/errors.go Normal file
View File

@@ -0,0 +1,195 @@
package nntp
import (
"errors"
"fmt"
)
// Error types for NNTP operations
type ErrorType int
const (
ErrorTypeUnknown ErrorType = iota
ErrorTypeConnection
ErrorTypeAuthentication
ErrorTypeTimeout
ErrorTypeArticleNotFound
ErrorTypeGroupNotFound
ErrorTypePermissionDenied
ErrorTypeServerBusy
ErrorTypeInvalidCommand
ErrorTypeProtocol
ErrorTypeYencDecode
ErrorTypeNoAvailableConnection
)
// Error represents an NNTP-specific error
type Error struct {
Type ErrorType
Code int // NNTP response code
Message string // Error message
Err error // Underlying error
}
// Predefined errors for common cases
var (
ErrArticleNotFound = &Error{Type: ErrorTypeArticleNotFound, Code: 430, Message: "article not found"}
ErrGroupNotFound = &Error{Type: ErrorTypeGroupNotFound, Code: 411, Message: "group not found"}
ErrPermissionDenied = &Error{Type: ErrorTypePermissionDenied, Code: 502, Message: "permission denied"}
ErrAuthenticationFail = &Error{Type: ErrorTypeAuthentication, Code: 482, Message: "authentication failed"}
ErrServerBusy = &Error{Type: ErrorTypeServerBusy, Code: 400, Message: "server busy"}
ErrPoolNotFound = &Error{Type: ErrorTypeUnknown, Code: 0, Message: "NNTP pool not found", Err: nil}
ErrNoAvailableConnection = &Error{Type: ErrorTypeNoAvailableConnection, Code: 0, Message: "no available connection in pool", Err: nil}
)
func (e *Error) Error() string {
if e.Err != nil {
return fmt.Sprintf("NNTP %s (code %d): %s - %v", e.Type.String(), e.Code, e.Message, e.Err)
}
return fmt.Sprintf("NNTP %s (code %d): %s", e.Type.String(), e.Code, e.Message)
}
func (e *Error) Unwrap() error {
return e.Err
}
func (e *Error) Is(target error) bool {
if t, ok := target.(*Error); ok {
return e.Type == t.Type
}
return false
}
// IsRetryable returns true if the error might be resolved by retrying
func (e *Error) IsRetryable() bool {
switch e.Type {
case ErrorTypeConnection, ErrorTypeTimeout, ErrorTypeServerBusy:
return true
case ErrorTypeArticleNotFound, ErrorTypeGroupNotFound, ErrorTypePermissionDenied, ErrorTypeAuthentication:
return false
default:
return false
}
}
// ShouldStopParsing returns true if this error should stop the entire parsing process
func (e *Error) ShouldStopParsing() bool {
switch e.Type {
case ErrorTypeAuthentication, ErrorTypePermissionDenied:
return true // Critical auth issues
case ErrorTypeConnection:
return false // Can continue with other connections
case ErrorTypeArticleNotFound:
return false // Can continue searching for other articles
case ErrorTypeServerBusy:
return false // Temporary issue
default:
return false
}
}
func (et ErrorType) String() string {
switch et {
case ErrorTypeConnection:
return "CONNECTION"
case ErrorTypeAuthentication:
return "AUTHENTICATION"
case ErrorTypeTimeout:
return "TIMEOUT"
case ErrorTypeArticleNotFound:
return "ARTICLE_NOT_FOUND"
case ErrorTypeGroupNotFound:
return "GROUP_NOT_FOUND"
case ErrorTypePermissionDenied:
return "PERMISSION_DENIED"
case ErrorTypeServerBusy:
return "SERVER_BUSY"
case ErrorTypeInvalidCommand:
return "INVALID_COMMAND"
case ErrorTypeProtocol:
return "PROTOCOL"
case ErrorTypeYencDecode:
return "YENC_DECODE"
default:
return "UNKNOWN"
}
}
// Helper functions to create specific errors
func NewConnectionError(err error) *Error {
return &Error{
Type: ErrorTypeConnection,
Message: "connection failed",
Err: err,
}
}
func NewTimeoutError(err error) *Error {
return &Error{
Type: ErrorTypeTimeout,
Message: "operation timed out",
Err: err,
}
}
func NewProtocolError(code int, message string) *Error {
return &Error{
Type: ErrorTypeProtocol,
Code: code,
Message: message,
}
}
func NewYencDecodeError(err error) *Error {
return &Error{
Type: ErrorTypeYencDecode,
Message: "yEnc decode failed",
Err: err,
}
}
// classifyNNTPError classifies an NNTP response code into an error type
func classifyNNTPError(code int, message string) *Error {
switch {
case code == 430 || code == 423:
return &Error{Type: ErrorTypeArticleNotFound, Code: code, Message: message}
case code == 411:
return &Error{Type: ErrorTypeGroupNotFound, Code: code, Message: message}
case code == 502 || code == 503:
return &Error{Type: ErrorTypePermissionDenied, Code: code, Message: message}
case code == 481 || code == 482:
return &Error{Type: ErrorTypeAuthentication, Code: code, Message: message}
case code == 400:
return &Error{Type: ErrorTypeServerBusy, Code: code, Message: message}
case code == 500 || code == 501:
return &Error{Type: ErrorTypeInvalidCommand, Code: code, Message: message}
case code >= 400:
return &Error{Type: ErrorTypeProtocol, Code: code, Message: message}
default:
return &Error{Type: ErrorTypeUnknown, Code: code, Message: message}
}
}
func IsArticleNotFoundError(err error) bool {
var nntpErr *Error
if errors.As(err, &nntpErr) {
return nntpErr.Type == ErrorTypeArticleNotFound
}
return false
}
func IsAuthenticationError(err error) bool {
var nntpErr *Error
if errors.As(err, &nntpErr) {
return nntpErr.Type == ErrorTypeAuthentication
}
return false
}
func IsRetryableError(err error) bool {
var nntpErr *Error
if errors.As(err, &nntpErr) {
return nntpErr.IsRetryable()
}
return false
}

299
internal/nntp/pool.go Normal file
View File

@@ -0,0 +1,299 @@
package nntp
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"github.com/rs/zerolog"
"github.com/sirrobot01/decypharr/internal/config"
"net"
"net/textproto"
"sync"
"sync/atomic"
"time"
)
// Pool manages a pool of NNTP connections
type Pool struct {
address, username, password string
maxConns, port int
ssl bool
useTLS bool
connections chan *Connection
logger zerolog.Logger
closed atomic.Bool
totalConnections atomic.Int32
activeConnections atomic.Int32
}
// Segment represents a usenet segment
type Segment struct {
MessageID string
Number int
Bytes int64
Data []byte
}
// Article represents a complete usenet article
type Article struct {
MessageID string
Subject string
From string
Date string
Groups []string
Body []byte
Size int64
}
// Response represents an NNTP server response
type Response struct {
Code int
Message string
Lines []string
}
// GroupInfo represents information about a newsgroup
type GroupInfo struct {
Name string
Count int // Number of articles in the group
Low int // Lowest article number
High int // Highest article number
}
// NewPool creates a new NNTP connection pool
func NewPool(provider config.UsenetProvider, logger zerolog.Logger) (*Pool, error) {
maxConns := provider.Connections
if maxConns <= 0 {
maxConns = 1
}
pool := &Pool{
address: provider.Host,
username: provider.Username,
password: provider.Password,
port: provider.Port,
maxConns: maxConns,
ssl: provider.SSL,
useTLS: provider.UseTLS,
connections: make(chan *Connection, maxConns),
logger: logger,
}
return pool.initializeConnections()
}
func (p *Pool) initializeConnections() (*Pool, error) {
var wg sync.WaitGroup
var mu sync.Mutex
var successfulConnections []*Connection
var errs []error
// Create connections concurrently
for i := 0; i < p.maxConns; i++ {
wg.Add(1)
go func(connIndex int) {
defer wg.Done()
conn, err := p.createConnection()
mu.Lock()
defer mu.Unlock()
if err != nil {
errs = append(errs, err)
} else {
successfulConnections = append(successfulConnections, conn)
}
}(i)
}
// Wait for all connection attempts to complete
wg.Wait()
// Add successful connections to the pool
for _, conn := range successfulConnections {
p.connections <- conn
}
p.totalConnections.Store(int32(len(successfulConnections)))
if len(successfulConnections) == 0 {
return nil, fmt.Errorf("failed to create any connections: %v", errs)
}
// Log results
p.logger.Info().
Str("server", p.address).
Int("port", p.port).
Int("requested_connections", p.maxConns).
Int("successful_connections", len(successfulConnections)).
Int("failed_connections", len(errs)).
Msg("NNTP connection pool created")
// If some connections failed, log a warning but continue
if len(errs) > 0 {
p.logger.Warn().
Int("failed_count", len(errs)).
Msg("Some connections failed during pool initialization")
}
return p, nil
}
// Get retrieves a connection from the pool
func (p *Pool) Get(ctx context.Context) (*Connection, error) {
if p.closed.Load() {
return nil, NewConnectionError(fmt.Errorf("connection pool is closed"))
}
select {
case conn := <-p.connections:
if conn == nil {
return nil, NewConnectionError(fmt.Errorf("received nil connection from pool"))
}
p.activeConnections.Add(1)
if err := conn.ping(); err != nil {
p.activeConnections.Add(-1)
err := conn.close()
if err != nil {
return nil, err
}
// Create a new connection
newConn, err := p.createConnection()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to create replacement connection: %w", err))
}
p.activeConnections.Add(1)
return newConn, nil
}
return conn, nil
case <-ctx.Done():
return nil, NewTimeoutError(ctx.Err())
}
}
// Put returns a connection to the pool
func (p *Pool) Put(conn *Connection) {
if conn == nil {
return
}
defer p.activeConnections.Add(-1)
if p.closed.Load() {
conn.close()
return
}
// Try non-blocking first
select {
case p.connections <- conn:
return
default:
}
// If pool is full, this usually means we have too many connections
// Force return by making space (close oldest connection)
select {
case oldConn := <-p.connections:
oldConn.close() // Close the old connection
p.connections <- conn // Put the new one back
case <-time.After(1 * time.Second):
// Still can't return - close this connection
conn.close()
}
}
// Close closes all connections in the pool
func (p *Pool) Close() error {
if p.closed.Load() {
return nil
}
p.closed.Store(true)
close(p.connections)
for conn := range p.connections {
err := conn.close()
if err != nil {
return err
}
}
p.logger.Info().Msg("NNTP connection pool closed")
return nil
}
// createConnection creates a new NNTP connection with proper error handling
func (p *Pool) createConnection() (*Connection, error) {
addr := fmt.Sprintf("%s:%d", p.address, p.port)
var conn net.Conn
var err error
if p.ssl {
conn, err = tls.DialWithDialer(&net.Dialer{}, "tcp", addr, &tls.Config{
InsecureSkipVerify: false,
})
} else {
conn, err = net.Dial("tcp", addr)
}
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to connect to %s: %w", addr, err))
}
reader := bufio.NewReaderSize(conn, 256*1024) // 256KB buffer for better performance
writer := bufio.NewWriterSize(conn, 256*1024) // 256KB buffer for better performance
text := textproto.NewConn(conn)
nntpConn := &Connection{
username: p.username,
password: p.password,
address: p.address,
port: p.port,
conn: conn,
text: text,
reader: reader,
writer: writer,
logger: p.logger,
}
// Read welcome message
_, err = nntpConn.readResponse()
if err != nil {
conn.Close()
return nil, NewConnectionError(fmt.Errorf("failed to read welcome message: %w", err))
}
// Authenticate if credentials are provided
if p.username != "" && p.password != "" {
if err := nntpConn.authenticate(); err != nil {
conn.Close()
return nil, err // authenticate() already returns NNTPError
}
}
// Enable TLS if requested (STARTTLS)
if p.useTLS && !p.ssl {
if err := nntpConn.startTLS(); err != nil {
conn.Close()
return nil, err // startTLS() already returns NNTPError
}
}
return nntpConn, nil
}
func (p *Pool) ConnectionCount() int {
return int(p.totalConnections.Load())
}
func (p *Pool) ActiveConnections() int {
return int(p.activeConnections.Load())
}
func (p *Pool) IsFree() bool {
return p.ActiveConnections() < p.maxConns
}

View File

@@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"github.com/sirrobot01/decypharr/internal/logger"
@@ -180,8 +179,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
resp, err = c.doRequest(req)
if err != nil {
// Check if this is a network error that might be worth retrying
if isRetryableError(err) && attempt < c.maxRetries {
if attempt < c.maxRetries {
// Apply backoff with jitter
jitter := time.Duration(rand.Int63n(int64(backoff / 4)))
sleepTime := backoff + jitter
@@ -390,30 +388,3 @@ func Default() *Client {
})
return instance
}
func isRetryableError(err error) bool {
errString := err.Error()
// Connection reset and other network errors
if strings.Contains(errString, "connection reset by peer") ||
strings.Contains(errString, "read: connection reset") ||
strings.Contains(errString, "connection refused") ||
strings.Contains(errString, "network is unreachable") ||
strings.Contains(errString, "connection timed out") ||
strings.Contains(errString, "no such host") ||
strings.Contains(errString, "i/o timeout") ||
strings.Contains(errString, "unexpected EOF") ||
strings.Contains(errString, "TLS handshake timeout") {
return true
}
// Check for net.Error type which can provide more information
var netErr net.Error
if errors.As(err, &netErr) {
// Retry on timeout errors and temporary errors
return netErr.Timeout()
}
// Not a retryable error
return false
}

View File

@@ -1,5 +1,16 @@
package utils
import (
"fmt"
"io"
"mime"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
)
func RemoveItem[S ~[]E, E comparable](s S, values ...E) S {
result := make(S, 0, len(s))
outer:
@@ -22,3 +33,131 @@ func Contains(slice []string, value string) bool {
}
return false
}
func GenerateHash(data string) string {
// Simple hash generation using a basic algorithm (for demonstration purposes)
_hash := 0
for _, char := range data {
_hash = (_hash*31 + int(char)) % 1000003 // Simple hash function
}
return string(rune(_hash))
}
func DownloadFile(url string) (string, []byte, error) {
resp, err := http.Get(url)
if err != nil {
return "", nil, fmt.Errorf("failed to download file: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", nil, fmt.Errorf("failed to download file: status code %d", resp.StatusCode)
}
filename := getFilenameFromResponse(resp, url)
data, err := io.ReadAll(resp.Body)
if err != nil {
return "", nil, fmt.Errorf("failed to read response body: %w", err)
}
return filename, data, nil
}
func getFilenameFromResponse(resp *http.Response, originalURL string) string {
// 1. Try Content-Disposition header
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
if _, params, err := mime.ParseMediaType(cd); err == nil {
if filename := params["filename"]; filename != "" {
return filename
}
}
}
// 2. Try to decode URL-encoded filename from Content-Disposition
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
if strings.Contains(cd, "filename*=") {
// Handle RFC 5987 encoded filenames
parts := strings.Split(cd, "filename*=")
if len(parts) > 1 {
encoded := strings.Trim(parts[1], `"`)
if strings.HasPrefix(encoded, "UTF-8''") {
if decoded, err := url.QueryUnescape(encoded[7:]); err == nil {
return decoded
}
}
}
}
}
// 3. Fall back to URL path
if parsedURL, err := url.Parse(originalURL); err == nil {
if filename := filepath.Base(parsedURL.Path); filename != "." && filename != "/" {
// URL decode the filename
if decoded, err := url.QueryUnescape(filename); err == nil {
return decoded
}
return filename
}
}
// 4. Default filename
return "downloaded_file"
}
func ValidateServiceURL(urlStr string) error {
if urlStr == "" {
return fmt.Errorf("URL cannot be empty")
}
// Try parsing as full URL first
u, err := url.Parse(urlStr)
if err == nil && u.Scheme != "" && u.Host != "" {
// It's a full URL, validate scheme
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("URL scheme must be http or https")
}
return nil
}
// Check if it's a host:port format (no scheme)
if strings.Contains(urlStr, ":") && !strings.Contains(urlStr, "://") {
// Try parsing with http:// prefix
testURL := "http://" + urlStr
u, err := url.Parse(testURL)
if err != nil {
return fmt.Errorf("invalid host:port format: %w", err)
}
if u.Host == "" {
return fmt.Errorf("host is required in host:port format")
}
// Validate port number
if u.Port() == "" {
return fmt.Errorf("port is required in host:port format")
}
return nil
}
return fmt.Errorf("invalid URL format: %s", urlStr)
}
func ExtractFilenameFromURL(rawURL string) string {
// Parse the URL
parsedURL, err := url.Parse(rawURL)
if err != nil {
return ""
}
// Get the base filename from path
filename := path.Base(parsedURL.Path)
// Handle edge cases
if filename == "/" || filename == "." || filename == "" {
return ""
}
return filename
}

View File

@@ -57,3 +57,15 @@ func IsSampleFile(path string) bool {
}
return RegexMatch(sampleRegex, path)
}
func IsParFile(path string) bool {
ext := filepath.Ext(path)
return strings.EqualFold(ext, ".par") || strings.EqualFold(ext, ".par2")
}
func IsRarFile(path string) bool {
ext := filepath.Ext(path)
return strings.EqualFold(ext, ".rar") || strings.EqualFold(ext, ".r00") ||
strings.EqualFold(ext, ".r01") || strings.EqualFold(ext, ".r02") ||
strings.EqualFold(ext, ".r03") || strings.EqualFold(ext, ".r04")
}