Implementing a streaming setup with Usenet
This commit is contained in:
@@ -0,0 +1,178 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sirrobot01/decypharr/internal/config"
|
||||
"github.com/sirrobot01/decypharr/internal/logger"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client represents a failover NNTP client that manages multiple providers
|
||||
type Client struct {
|
||||
providers []config.UsenetProvider
|
||||
pools *xsync.Map[string, *Pool]
|
||||
logger zerolog.Logger
|
||||
closed atomic.Bool
|
||||
minimumMaxConns int // Minimum number of max connections across all pools
|
||||
}
|
||||
|
||||
func NewClient(providers []config.UsenetProvider) (*Client, error) {
|
||||
|
||||
client := &Client{
|
||||
providers: providers,
|
||||
logger: logger.New("nntp"),
|
||||
pools: xsync.NewMap[string, *Pool](),
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return nil, fmt.Errorf("no NNTP providers configured")
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) InitPools() error {
|
||||
|
||||
var initErrors []error
|
||||
successfulPools := 0
|
||||
|
||||
for _, provider := range c.providers {
|
||||
serverPool, err := NewPool(provider, c.logger)
|
||||
if err != nil {
|
||||
c.logger.Error().
|
||||
Err(err).
|
||||
Str("server", provider.Host).
|
||||
Int("port", provider.Port).
|
||||
Msg("Failed to initialize server pool")
|
||||
initErrors = append(initErrors, err)
|
||||
continue
|
||||
}
|
||||
if c.minimumMaxConns == 0 {
|
||||
// Set minimumMaxConns to the max connections of the first successful pool
|
||||
c.minimumMaxConns = serverPool.ConnectionCount()
|
||||
} else {
|
||||
c.minimumMaxConns = min(c.minimumMaxConns, serverPool.ConnectionCount())
|
||||
}
|
||||
|
||||
c.pools.Store(provider.Name, serverPool)
|
||||
successfulPools++
|
||||
}
|
||||
|
||||
if successfulPools == 0 {
|
||||
return fmt.Errorf("failed to initialize any server pools: %v", initErrors)
|
||||
}
|
||||
|
||||
c.logger.Info().
|
||||
Int("providers", len(c.providers)).
|
||||
Msg("NNTP client created")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
if c.closed.Load() {
|
||||
c.logger.Warn().Msg("NNTP client already closed")
|
||||
return
|
||||
}
|
||||
|
||||
c.pools.Range(func(key string, value *Pool) bool {
|
||||
if value != nil {
|
||||
err := value.Close()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
c.closed.Store(true)
|
||||
c.logger.Info().Msg("NNTP client closed")
|
||||
}
|
||||
|
||||
func (c *Client) GetConnection(ctx context.Context) (*Connection, func(), error) {
|
||||
if c.closed.Load() {
|
||||
return nil, nil, fmt.Errorf("nntp client is closed")
|
||||
}
|
||||
|
||||
// Prevent workers from waiting too long for connections
|
||||
connCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
providerCount := len(c.providers)
|
||||
|
||||
for _, provider := range c.providers {
|
||||
pool, ok := c.pools.Load(provider.Name)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("no pool found for provider %s", provider.Name)
|
||||
}
|
||||
|
||||
if !pool.IsFree() && providerCount > 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := pool.Get(connCtx) // Use timeout context
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoAvailableConnection) || errors.Is(err, context.DeadlineExceeded) {
|
||||
continue
|
||||
}
|
||||
return nil, nil, fmt.Errorf("error getting connection from provider %s: %w", provider.Name, err)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return conn, func() { pool.Put(conn) }, nil
|
||||
}
|
||||
|
||||
return nil, nil, ErrNoAvailableConnection
|
||||
}
|
||||
|
||||
func (c *Client) DownloadHeader(ctx context.Context, messageID string) (*YencMetadata, error) {
|
||||
conn, cleanup, err := c.GetConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
data, err := conn.GetBody(messageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// yEnc decode
|
||||
part, err := DecodeYencHeaders(bytes.NewReader(data))
|
||||
if err != nil || part == nil {
|
||||
return nil, fmt.Errorf("failed to decode segment")
|
||||
}
|
||||
|
||||
// Return both the filename and decoded data
|
||||
return part, nil
|
||||
}
|
||||
|
||||
func (c *Client) MinimumMaxConns() int {
|
||||
return c.minimumMaxConns
|
||||
}
|
||||
|
||||
func (c *Client) TotalActiveConnections() int {
|
||||
total := 0
|
||||
c.pools.Range(func(key string, value *Pool) bool {
|
||||
if value != nil {
|
||||
total += value.ActiveConnections()
|
||||
}
|
||||
return true
|
||||
})
|
||||
return total
|
||||
}
|
||||
|
||||
func (c *Client) Pools() *xsync.Map[string, *Pool] {
|
||||
return c.pools
|
||||
}
|
||||
|
||||
func (c *Client) GetProviders() []config.UsenetProvider {
|
||||
return c.providers
|
||||
}
|
||||
@@ -0,0 +1,394 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/chrisfarms/yenc"
|
||||
"github.com/rs/zerolog"
|
||||
"io"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Connection represents an NNTP connection
|
||||
type Connection struct {
|
||||
username, password, address string
|
||||
port int
|
||||
conn net.Conn
|
||||
text *textproto.Conn
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
func (c *Connection) authenticate() error {
|
||||
// Send AUTHINFO USER command
|
||||
if err := c.sendCommand(fmt.Sprintf("AUTHINFO USER %s", c.username)); err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to send username: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to read user response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 381 {
|
||||
return classifyNNTPError(resp.Code, fmt.Sprintf("unexpected response to AUTHINFO USER: %s", resp.Message))
|
||||
}
|
||||
|
||||
// Send AUTHINFO PASS command
|
||||
if err := c.sendCommand(fmt.Sprintf("AUTHINFO PASS %s", c.password)); err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to send password: %w", err))
|
||||
}
|
||||
|
||||
resp, err = c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to read password response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 281 {
|
||||
return classifyNNTPError(resp.Code, fmt.Sprintf("authentication failed: %s", resp.Message))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startTLS initiates TLS encryption with proper error handling
|
||||
func (c *Connection) startTLS() error {
|
||||
if err := c.sendCommand("STARTTLS"); err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to send STARTTLS: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(fmt.Errorf("failed to read STARTTLS response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 382 {
|
||||
return classifyNNTPError(resp.Code, fmt.Sprintf("STARTTLS not supported: %s", resp.Message))
|
||||
}
|
||||
|
||||
// Upgrade connection to TLS
|
||||
tlsConn := tls.Client(c.conn, &tls.Config{
|
||||
ServerName: c.address,
|
||||
InsecureSkipVerify: false,
|
||||
})
|
||||
|
||||
c.conn = tlsConn
|
||||
c.reader = bufio.NewReader(tlsConn)
|
||||
c.writer = bufio.NewWriter(tlsConn)
|
||||
c.text = textproto.NewConn(tlsConn)
|
||||
|
||||
c.logger.Debug().Msg("TLS encryption enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ping sends a simple command to test the connection
|
||||
func (c *Connection) ping() error {
|
||||
if err := c.sendCommand("DATE"); err != nil {
|
||||
return NewConnectionError(err)
|
||||
}
|
||||
_, err := c.readResponse()
|
||||
if err != nil {
|
||||
return NewConnectionError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendCommand sends a command to the NNTP server
|
||||
func (c *Connection) sendCommand(command string) error {
|
||||
_, err := fmt.Fprintf(c.writer, "%s\r\n", command)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.writer.Flush()
|
||||
}
|
||||
|
||||
// readResponse reads a response from the NNTP server
|
||||
func (c *Connection) readResponse() (*Response, error) {
|
||||
line, err := c.text.ReadLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, " ", 2)
|
||||
code, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid response code: %s", parts[0])
|
||||
}
|
||||
|
||||
message := ""
|
||||
if len(parts) > 1 {
|
||||
message = parts[1]
|
||||
}
|
||||
|
||||
return &Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// readMultilineResponse reads a multiline response
|
||||
func (c *Connection) readMultilineResponse() (*Response, error) {
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if this is a multiline response
|
||||
if resp.Code < 200 || resp.Code >= 300 {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
lines, err := c.text.ReadDotLines()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Lines = lines
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetArticle retrieves an article by message ID with proper error classification
|
||||
func (c *Connection) GetArticle(messageID string) (*Article, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
if err := c.sendCommand(fmt.Sprintf("ARTICLE %s", messageID)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send ARTICLE command: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readMultilineResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read article response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 220 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
return c.parseArticle(messageID, resp.Lines)
|
||||
}
|
||||
|
||||
// GetBody retrieves article body by message ID with proper error classification
|
||||
func (c *Connection) GetBody(messageID string) ([]byte, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
if err := c.sendCommand(fmt.Sprintf("BODY %s", messageID)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send BODY command: %w", err))
|
||||
}
|
||||
|
||||
// Read the initial response
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read body response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 222 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
// Read the raw body data directly using textproto to preserve exact formatting for yEnc
|
||||
lines, err := c.text.ReadDotLines()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read body data: %w", err))
|
||||
}
|
||||
|
||||
// Join with \r\n to preserve original line endings and add final \r\n
|
||||
body := strings.Join(lines, "\r\n")
|
||||
if len(lines) > 0 {
|
||||
body += "\r\n"
|
||||
}
|
||||
|
||||
return []byte(body), nil
|
||||
}
|
||||
|
||||
// GetHead retrieves article headers by message ID
|
||||
func (c *Connection) GetHead(messageID string) ([]byte, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
if err := c.sendCommand(fmt.Sprintf("HEAD %s", messageID)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send HEAD command: %w", err))
|
||||
}
|
||||
|
||||
// Read the initial response
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read head response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 221 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
// Read the header data using textproto
|
||||
lines, err := c.text.ReadDotLines()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read header data: %w", err))
|
||||
}
|
||||
|
||||
// Join with \r\n to preserve original line endings and add final \r\n
|
||||
headers := strings.Join(lines, "\r\n")
|
||||
if len(lines) > 0 {
|
||||
headers += "\r\n"
|
||||
}
|
||||
|
||||
return []byte(headers), nil
|
||||
}
|
||||
|
||||
// GetSegment retrieves a specific segment with proper error handling
|
||||
func (c *Connection) GetSegment(messageID string, segmentNumber int) (*Segment, error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
body, err := c.GetBody(messageID)
|
||||
if err != nil {
|
||||
return nil, err // GetBody already returns classified errors
|
||||
}
|
||||
|
||||
return &Segment{
|
||||
MessageID: messageID,
|
||||
Number: segmentNumber,
|
||||
Bytes: int64(len(body)),
|
||||
Data: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stat retrieves article statistics by message ID with proper error classification
|
||||
func (c *Connection) Stat(messageID string) (articleNumber int, echoedID string, err error) {
|
||||
messageID = FormatMessageID(messageID)
|
||||
|
||||
if err = c.sendCommand(fmt.Sprintf("STAT %s", messageID)); err != nil {
|
||||
return 0, "", NewConnectionError(fmt.Errorf("failed to send STAT: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return 0, "", NewConnectionError(fmt.Errorf("failed to read STAT response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 223 {
|
||||
return 0, "", classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
fields := strings.Fields(resp.Message)
|
||||
if len(fields) < 2 {
|
||||
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("unexpected STAT response format: %q", resp.Message))
|
||||
}
|
||||
|
||||
if articleNumber, err = strconv.Atoi(fields[0]); err != nil {
|
||||
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("invalid article number %q: %v", fields[0], err))
|
||||
}
|
||||
echoedID = fields[1]
|
||||
|
||||
return articleNumber, echoedID, nil
|
||||
}
|
||||
|
||||
// SelectGroup selects a newsgroup and returns group information
|
||||
func (c *Connection) SelectGroup(groupName string) (*GroupInfo, error) {
|
||||
if err := c.sendCommand(fmt.Sprintf("GROUP %s", groupName)); err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to send GROUP command: %w", err))
|
||||
}
|
||||
|
||||
resp, err := c.readResponse()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read GROUP response: %w", err))
|
||||
}
|
||||
|
||||
if resp.Code != 211 {
|
||||
return nil, classifyNNTPError(resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
// Parse GROUP response: "211 number low high group-name"
|
||||
fields := strings.Fields(resp.Message)
|
||||
if len(fields) < 4 {
|
||||
return nil, NewProtocolError(resp.Code, fmt.Sprintf("unexpected GROUP response format: %q", resp.Message))
|
||||
}
|
||||
|
||||
groupInfo := &GroupInfo{
|
||||
Name: groupName,
|
||||
}
|
||||
|
||||
if count, err := strconv.Atoi(fields[0]); err == nil {
|
||||
groupInfo.Count = count
|
||||
}
|
||||
if low, err := strconv.Atoi(fields[1]); err == nil {
|
||||
groupInfo.Low = low
|
||||
}
|
||||
if high, err := strconv.Atoi(fields[2]); err == nil {
|
||||
groupInfo.High = high
|
||||
}
|
||||
|
||||
return groupInfo, nil
|
||||
}
|
||||
|
||||
// parseArticle parses article data from response lines
|
||||
func (c *Connection) parseArticle(messageID string, lines []string) (*Article, error) {
|
||||
article := &Article{
|
||||
MessageID: messageID,
|
||||
Groups: []string{},
|
||||
}
|
||||
|
||||
headerEnd := -1
|
||||
for i, line := range lines {
|
||||
if line == "" {
|
||||
headerEnd = i
|
||||
break
|
||||
}
|
||||
|
||||
// Parse headers
|
||||
if strings.HasPrefix(line, "Subject: ") {
|
||||
article.Subject = strings.TrimPrefix(line, "Subject: ")
|
||||
} else if strings.HasPrefix(line, "From: ") {
|
||||
article.From = strings.TrimPrefix(line, "From: ")
|
||||
} else if strings.HasPrefix(line, "Date: ") {
|
||||
article.Date = strings.TrimPrefix(line, "Date: ")
|
||||
} else if strings.HasPrefix(line, "Newsgroups: ") {
|
||||
groups := strings.TrimPrefix(line, "Newsgroups: ")
|
||||
article.Groups = strings.Split(groups, ",")
|
||||
for i := range article.Groups {
|
||||
article.Groups[i] = strings.TrimSpace(article.Groups[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Join body lines
|
||||
if headerEnd != -1 && headerEnd+1 < len(lines) {
|
||||
body := strings.Join(lines[headerEnd+1:], "\n")
|
||||
article.Body = []byte(body)
|
||||
article.Size = int64(len(article.Body))
|
||||
}
|
||||
|
||||
return article, nil
|
||||
}
|
||||
|
||||
// close closes the NNTP connection
|
||||
func (c *Connection) close() error {
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DecodeYenc(reader io.Reader) (*yenc.Part, error) {
|
||||
part, err := yenc.Decode(reader)
|
||||
if err != nil {
|
||||
return nil, NewYencDecodeError(fmt.Errorf("failed to create yenc decoder: %w", err))
|
||||
}
|
||||
return part, nil
|
||||
}
|
||||
|
||||
func IsValidMessageID(messageID string) bool {
|
||||
if len(messageID) < 3 {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(messageID, "@")
|
||||
}
|
||||
|
||||
// FormatMessageID ensures message ID has proper format
|
||||
func FormatMessageID(messageID string) string {
|
||||
messageID = strings.TrimSpace(messageID)
|
||||
if !strings.HasPrefix(messageID, "<") {
|
||||
messageID = "<" + messageID
|
||||
}
|
||||
if !strings.HasSuffix(messageID, ">") {
|
||||
messageID = messageID + ">"
|
||||
}
|
||||
return messageID
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// YencMetadata contains just the header information
|
||||
type YencMetadata struct {
|
||||
Name string // filename
|
||||
Size int64 // total file size
|
||||
Part int // part number
|
||||
Total int // total parts
|
||||
Begin int64 // part start byte
|
||||
End int64 // part end byte
|
||||
LineSize int // line length
|
||||
}
|
||||
|
||||
// DecodeYencHeaders extracts only yenc header metadata without decoding body
|
||||
func DecodeYencHeaders(reader io.Reader) (*YencMetadata, error) {
|
||||
buf := bufio.NewReader(reader)
|
||||
metadata := &YencMetadata{}
|
||||
|
||||
// Find and parse =ybegin header
|
||||
if err := parseYBeginHeader(buf, metadata); err != nil {
|
||||
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ybegin header: %w", err))
|
||||
}
|
||||
|
||||
// Parse =ypart header if this is a multipart file
|
||||
if metadata.Part > 0 {
|
||||
if err := parseYPartHeader(buf, metadata); err != nil {
|
||||
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ypart header: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func parseYBeginHeader(buf *bufio.Reader, metadata *YencMetadata) error {
|
||||
var s string
|
||||
var err error
|
||||
|
||||
// Find the =ybegin line
|
||||
for {
|
||||
s, err = buf.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(s) >= 7 && s[:7] == "=ybegin" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the header line
|
||||
parts := strings.SplitN(s[7:], "name=", 2)
|
||||
if len(parts) > 1 {
|
||||
metadata.Name = strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
// Parse other parameters
|
||||
for _, header := range strings.Split(parts[0], " ") {
|
||||
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
|
||||
if len(kv) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch kv[0] {
|
||||
case "size":
|
||||
metadata.Size, _ = strconv.ParseInt(kv[1], 10, 64)
|
||||
case "line":
|
||||
metadata.LineSize, _ = strconv.Atoi(kv[1])
|
||||
case "part":
|
||||
metadata.Part, _ = strconv.Atoi(kv[1])
|
||||
case "total":
|
||||
metadata.Total, _ = strconv.Atoi(kv[1])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseYPartHeader(buf *bufio.Reader, metadata *YencMetadata) error {
|
||||
var s string
|
||||
var err error
|
||||
|
||||
// Find the =ypart line
|
||||
for {
|
||||
s, err = buf.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(s) >= 6 && s[:6] == "=ypart" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Parse part parameters
|
||||
for _, header := range strings.Split(s[6:], " ") {
|
||||
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
|
||||
if len(kv) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch kv[0] {
|
||||
case "begin":
|
||||
metadata.Begin, _ = strconv.ParseInt(kv[1], 10, 64)
|
||||
case "end":
|
||||
metadata.End, _ = strconv.ParseInt(kv[1], 10, 64)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Error types for NNTP operations
|
||||
type ErrorType int
|
||||
|
||||
const (
|
||||
ErrorTypeUnknown ErrorType = iota
|
||||
ErrorTypeConnection
|
||||
ErrorTypeAuthentication
|
||||
ErrorTypeTimeout
|
||||
ErrorTypeArticleNotFound
|
||||
ErrorTypeGroupNotFound
|
||||
ErrorTypePermissionDenied
|
||||
ErrorTypeServerBusy
|
||||
ErrorTypeInvalidCommand
|
||||
ErrorTypeProtocol
|
||||
ErrorTypeYencDecode
|
||||
ErrorTypeNoAvailableConnection
|
||||
)
|
||||
|
||||
// Error represents an NNTP-specific error
|
||||
type Error struct {
|
||||
Type ErrorType
|
||||
Code int // NNTP response code
|
||||
Message string // Error message
|
||||
Err error // Underlying error
|
||||
}
|
||||
|
||||
// Predefined errors for common cases
|
||||
var (
|
||||
ErrArticleNotFound = &Error{Type: ErrorTypeArticleNotFound, Code: 430, Message: "article not found"}
|
||||
ErrGroupNotFound = &Error{Type: ErrorTypeGroupNotFound, Code: 411, Message: "group not found"}
|
||||
ErrPermissionDenied = &Error{Type: ErrorTypePermissionDenied, Code: 502, Message: "permission denied"}
|
||||
ErrAuthenticationFail = &Error{Type: ErrorTypeAuthentication, Code: 482, Message: "authentication failed"}
|
||||
ErrServerBusy = &Error{Type: ErrorTypeServerBusy, Code: 400, Message: "server busy"}
|
||||
ErrPoolNotFound = &Error{Type: ErrorTypeUnknown, Code: 0, Message: "NNTP pool not found", Err: nil}
|
||||
ErrNoAvailableConnection = &Error{Type: ErrorTypeNoAvailableConnection, Code: 0, Message: "no available connection in pool", Err: nil}
|
||||
)
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("NNTP %s (code %d): %s - %v", e.Type.String(), e.Code, e.Message, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("NNTP %s (code %d): %s", e.Type.String(), e.Code, e.Message)
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e *Error) Is(target error) bool {
|
||||
if t, ok := target.(*Error); ok {
|
||||
return e.Type == t.Type
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsRetryable returns true if the error might be resolved by retrying
|
||||
func (e *Error) IsRetryable() bool {
|
||||
switch e.Type {
|
||||
case ErrorTypeConnection, ErrorTypeTimeout, ErrorTypeServerBusy:
|
||||
return true
|
||||
case ErrorTypeArticleNotFound, ErrorTypeGroupNotFound, ErrorTypePermissionDenied, ErrorTypeAuthentication:
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldStopParsing returns true if this error should stop the entire parsing process
|
||||
func (e *Error) ShouldStopParsing() bool {
|
||||
switch e.Type {
|
||||
case ErrorTypeAuthentication, ErrorTypePermissionDenied:
|
||||
return true // Critical auth issues
|
||||
case ErrorTypeConnection:
|
||||
return false // Can continue with other connections
|
||||
case ErrorTypeArticleNotFound:
|
||||
return false // Can continue searching for other articles
|
||||
case ErrorTypeServerBusy:
|
||||
return false // Temporary issue
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (et ErrorType) String() string {
|
||||
switch et {
|
||||
case ErrorTypeConnection:
|
||||
return "CONNECTION"
|
||||
case ErrorTypeAuthentication:
|
||||
return "AUTHENTICATION"
|
||||
case ErrorTypeTimeout:
|
||||
return "TIMEOUT"
|
||||
case ErrorTypeArticleNotFound:
|
||||
return "ARTICLE_NOT_FOUND"
|
||||
case ErrorTypeGroupNotFound:
|
||||
return "GROUP_NOT_FOUND"
|
||||
case ErrorTypePermissionDenied:
|
||||
return "PERMISSION_DENIED"
|
||||
case ErrorTypeServerBusy:
|
||||
return "SERVER_BUSY"
|
||||
case ErrorTypeInvalidCommand:
|
||||
return "INVALID_COMMAND"
|
||||
case ErrorTypeProtocol:
|
||||
return "PROTOCOL"
|
||||
case ErrorTypeYencDecode:
|
||||
return "YENC_DECODE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to create specific errors
|
||||
func NewConnectionError(err error) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeConnection,
|
||||
Message: "connection failed",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTimeoutError(err error) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeTimeout,
|
||||
Message: "operation timed out",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProtocolError(code int, message string) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeProtocol,
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
func NewYencDecodeError(err error) *Error {
|
||||
return &Error{
|
||||
Type: ErrorTypeYencDecode,
|
||||
Message: "yEnc decode failed",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// classifyNNTPError classifies an NNTP response code into an error type
|
||||
func classifyNNTPError(code int, message string) *Error {
|
||||
switch {
|
||||
case code == 430 || code == 423:
|
||||
return &Error{Type: ErrorTypeArticleNotFound, Code: code, Message: message}
|
||||
case code == 411:
|
||||
return &Error{Type: ErrorTypeGroupNotFound, Code: code, Message: message}
|
||||
case code == 502 || code == 503:
|
||||
return &Error{Type: ErrorTypePermissionDenied, Code: code, Message: message}
|
||||
case code == 481 || code == 482:
|
||||
return &Error{Type: ErrorTypeAuthentication, Code: code, Message: message}
|
||||
case code == 400:
|
||||
return &Error{Type: ErrorTypeServerBusy, Code: code, Message: message}
|
||||
case code == 500 || code == 501:
|
||||
return &Error{Type: ErrorTypeInvalidCommand, Code: code, Message: message}
|
||||
case code >= 400:
|
||||
return &Error{Type: ErrorTypeProtocol, Code: code, Message: message}
|
||||
default:
|
||||
return &Error{Type: ErrorTypeUnknown, Code: code, Message: message}
|
||||
}
|
||||
}
|
||||
|
||||
func IsArticleNotFoundError(err error) bool {
|
||||
var nntpErr *Error
|
||||
if errors.As(err, &nntpErr) {
|
||||
return nntpErr.Type == ErrorTypeArticleNotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsAuthenticationError(err error) bool {
|
||||
var nntpErr *Error
|
||||
if errors.As(err, &nntpErr) {
|
||||
return nntpErr.Type == ErrorTypeAuthentication
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsRetryableError(err error) bool {
|
||||
var nntpErr *Error
|
||||
if errors.As(err, &nntpErr) {
|
||||
return nntpErr.IsRetryable()
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
package nntp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/sirrobot01/decypharr/internal/config"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pool manages a pool of NNTP connections
|
||||
type Pool struct {
|
||||
address, username, password string
|
||||
maxConns, port int
|
||||
ssl bool
|
||||
useTLS bool
|
||||
connections chan *Connection
|
||||
logger zerolog.Logger
|
||||
closed atomic.Bool
|
||||
totalConnections atomic.Int32
|
||||
activeConnections atomic.Int32
|
||||
}
|
||||
|
||||
// Segment represents a usenet segment
|
||||
type Segment struct {
|
||||
MessageID string
|
||||
Number int
|
||||
Bytes int64
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Article represents a complete usenet article
|
||||
type Article struct {
|
||||
MessageID string
|
||||
Subject string
|
||||
From string
|
||||
Date string
|
||||
Groups []string
|
||||
Body []byte
|
||||
Size int64
|
||||
}
|
||||
|
||||
// Response represents an NNTP server response
|
||||
type Response struct {
|
||||
Code int
|
||||
Message string
|
||||
Lines []string
|
||||
}
|
||||
|
||||
// GroupInfo represents information about a newsgroup
|
||||
type GroupInfo struct {
|
||||
Name string
|
||||
Count int // Number of articles in the group
|
||||
Low int // Lowest article number
|
||||
High int // Highest article number
|
||||
}
|
||||
|
||||
// NewPool creates a new NNTP connection pool
|
||||
func NewPool(provider config.UsenetProvider, logger zerolog.Logger) (*Pool, error) {
|
||||
maxConns := provider.Connections
|
||||
if maxConns <= 0 {
|
||||
maxConns = 1
|
||||
}
|
||||
|
||||
pool := &Pool{
|
||||
address: provider.Host,
|
||||
username: provider.Username,
|
||||
password: provider.Password,
|
||||
port: provider.Port,
|
||||
maxConns: maxConns,
|
||||
ssl: provider.SSL,
|
||||
useTLS: provider.UseTLS,
|
||||
connections: make(chan *Connection, maxConns),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return pool.initializeConnections()
|
||||
}
|
||||
|
||||
func (p *Pool) initializeConnections() (*Pool, error) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
var successfulConnections []*Connection
|
||||
var errs []error
|
||||
|
||||
// Create connections concurrently
|
||||
for i := 0; i < p.maxConns; i++ {
|
||||
wg.Add(1)
|
||||
go func(connIndex int) {
|
||||
defer wg.Done()
|
||||
|
||||
conn, err := p.createConnection()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
} else {
|
||||
successfulConnections = append(successfulConnections, conn)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all connection attempts to complete
|
||||
wg.Wait()
|
||||
|
||||
// Add successful connections to the pool
|
||||
for _, conn := range successfulConnections {
|
||||
p.connections <- conn
|
||||
}
|
||||
p.totalConnections.Store(int32(len(successfulConnections)))
|
||||
|
||||
if len(successfulConnections) == 0 {
|
||||
return nil, fmt.Errorf("failed to create any connections: %v", errs)
|
||||
}
|
||||
|
||||
// Log results
|
||||
p.logger.Info().
|
||||
Str("server", p.address).
|
||||
Int("port", p.port).
|
||||
Int("requested_connections", p.maxConns).
|
||||
Int("successful_connections", len(successfulConnections)).
|
||||
Int("failed_connections", len(errs)).
|
||||
Msg("NNTP connection pool created")
|
||||
|
||||
// If some connections failed, log a warning but continue
|
||||
if len(errs) > 0 {
|
||||
p.logger.Warn().
|
||||
Int("failed_count", len(errs)).
|
||||
Msg("Some connections failed during pool initialization")
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Get retrieves a connection from the pool
|
||||
func (p *Pool) Get(ctx context.Context) (*Connection, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, NewConnectionError(fmt.Errorf("connection pool is closed"))
|
||||
}
|
||||
|
||||
select {
|
||||
case conn := <-p.connections:
|
||||
if conn == nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("received nil connection from pool"))
|
||||
}
|
||||
p.activeConnections.Add(1)
|
||||
|
||||
if err := conn.ping(); err != nil {
|
||||
p.activeConnections.Add(-1)
|
||||
err := conn.close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create a new connection
|
||||
newConn, err := p.createConnection()
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to create replacement connection: %w", err))
|
||||
}
|
||||
p.activeConnections.Add(1)
|
||||
return newConn, nil
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
return nil, NewTimeoutError(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// Put returns a connection to the pool
|
||||
func (p *Pool) Put(conn *Connection) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer p.activeConnections.Add(-1)
|
||||
|
||||
if p.closed.Load() {
|
||||
conn.close()
|
||||
return
|
||||
}
|
||||
|
||||
// Try non-blocking first
|
||||
select {
|
||||
case p.connections <- conn:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// If pool is full, this usually means we have too many connections
|
||||
// Force return by making space (close oldest connection)
|
||||
select {
|
||||
case oldConn := <-p.connections:
|
||||
oldConn.close() // Close the old connection
|
||||
p.connections <- conn // Put the new one back
|
||||
case <-time.After(1 * time.Second):
|
||||
// Still can't return - close this connection
|
||||
conn.close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all connections in the pool
|
||||
func (p *Pool) Close() error {
|
||||
|
||||
if p.closed.Load() {
|
||||
return nil
|
||||
}
|
||||
p.closed.Store(true)
|
||||
|
||||
close(p.connections)
|
||||
for conn := range p.connections {
|
||||
err := conn.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.Info().Msg("NNTP connection pool closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createConnection creates a new NNTP connection with proper error handling
|
||||
func (p *Pool) createConnection() (*Connection, error) {
|
||||
addr := fmt.Sprintf("%s:%d", p.address, p.port)
|
||||
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
||||
if p.ssl {
|
||||
conn, err = tls.DialWithDialer(&net.Dialer{}, "tcp", addr, &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
})
|
||||
} else {
|
||||
conn, err = net.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to connect to %s: %w", addr, err))
|
||||
}
|
||||
|
||||
reader := bufio.NewReaderSize(conn, 256*1024) // 256KB buffer for better performance
|
||||
writer := bufio.NewWriterSize(conn, 256*1024) // 256KB buffer for better performance
|
||||
text := textproto.NewConn(conn)
|
||||
|
||||
nntpConn := &Connection{
|
||||
username: p.username,
|
||||
password: p.password,
|
||||
address: p.address,
|
||||
port: p.port,
|
||||
conn: conn,
|
||||
text: text,
|
||||
reader: reader,
|
||||
writer: writer,
|
||||
logger: p.logger,
|
||||
}
|
||||
|
||||
// Read welcome message
|
||||
_, err = nntpConn.readResponse()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, NewConnectionError(fmt.Errorf("failed to read welcome message: %w", err))
|
||||
}
|
||||
|
||||
// Authenticate if credentials are provided
|
||||
if p.username != "" && p.password != "" {
|
||||
if err := nntpConn.authenticate(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err // authenticate() already returns NNTPError
|
||||
}
|
||||
}
|
||||
|
||||
// Enable TLS if requested (STARTTLS)
|
||||
if p.useTLS && !p.ssl {
|
||||
if err := nntpConn.startTLS(); err != nil {
|
||||
conn.Close()
|
||||
return nil, err // startTLS() already returns NNTPError
|
||||
}
|
||||
}
|
||||
return nntpConn, nil
|
||||
}
|
||||
|
||||
func (p *Pool) ConnectionCount() int {
|
||||
return int(p.totalConnections.Load())
|
||||
}
|
||||
|
||||
func (p *Pool) ActiveConnections() int {
|
||||
return int(p.activeConnections.Load())
|
||||
}
|
||||
|
||||
func (p *Pool) IsFree() bool {
|
||||
return p.ActiveConnections() < p.maxConns
|
||||
}
|
||||
Reference in New Issue
Block a user