Implementing a streaming setup with Usenet

This commit is contained in:
Mukhtar Akere
2025-08-01 15:27:24 +01:00
parent afe577bf2f
commit f9861e3b54
65 changed files with 9437 additions and 924 deletions
+178
View File
@@ -0,0 +1,178 @@
package nntp
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/sirrobot01/decypharr/internal/config"
"github.com/sirrobot01/decypharr/internal/logger"
"sync/atomic"
"time"
)
// Client represents a failover NNTP client that manages multiple providers
type Client struct {
providers []config.UsenetProvider
pools *xsync.Map[string, *Pool]
logger zerolog.Logger
closed atomic.Bool
minimumMaxConns int // Minimum number of max connections across all pools
}
func NewClient(providers []config.UsenetProvider) (*Client, error) {
client := &Client{
providers: providers,
logger: logger.New("nntp"),
pools: xsync.NewMap[string, *Pool](),
}
if len(providers) == 0 {
return nil, fmt.Errorf("no NNTP providers configured")
}
return client, nil
}
func (c *Client) InitPools() error {
var initErrors []error
successfulPools := 0
for _, provider := range c.providers {
serverPool, err := NewPool(provider, c.logger)
if err != nil {
c.logger.Error().
Err(err).
Str("server", provider.Host).
Int("port", provider.Port).
Msg("Failed to initialize server pool")
initErrors = append(initErrors, err)
continue
}
if c.minimumMaxConns == 0 {
// Set minimumMaxConns to the max connections of the first successful pool
c.minimumMaxConns = serverPool.ConnectionCount()
} else {
c.minimumMaxConns = min(c.minimumMaxConns, serverPool.ConnectionCount())
}
c.pools.Store(provider.Name, serverPool)
successfulPools++
}
if successfulPools == 0 {
return fmt.Errorf("failed to initialize any server pools: %v", initErrors)
}
c.logger.Info().
Int("providers", len(c.providers)).
Msg("NNTP client created")
return nil
}
func (c *Client) Close() {
if c.closed.Load() {
c.logger.Warn().Msg("NNTP client already closed")
return
}
c.pools.Range(func(key string, value *Pool) bool {
if value != nil {
err := value.Close()
if err != nil {
return false
}
}
return true
})
c.closed.Store(true)
c.logger.Info().Msg("NNTP client closed")
}
func (c *Client) GetConnection(ctx context.Context) (*Connection, func(), error) {
if c.closed.Load() {
return nil, nil, fmt.Errorf("nntp client is closed")
}
// Prevent workers from waiting too long for connections
connCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
providerCount := len(c.providers)
for _, provider := range c.providers {
pool, ok := c.pools.Load(provider.Name)
if !ok {
return nil, nil, fmt.Errorf("no pool found for provider %s", provider.Name)
}
if !pool.IsFree() && providerCount > 1 {
continue
}
conn, err := pool.Get(connCtx) // Use timeout context
if err != nil {
if errors.Is(err, ErrNoAvailableConnection) || errors.Is(err, context.DeadlineExceeded) {
continue
}
return nil, nil, fmt.Errorf("error getting connection from provider %s: %w", provider.Name, err)
}
if conn == nil {
continue
}
return conn, func() { pool.Put(conn) }, nil
}
return nil, nil, ErrNoAvailableConnection
}
func (c *Client) DownloadHeader(ctx context.Context, messageID string) (*YencMetadata, error) {
conn, cleanup, err := c.GetConnection(ctx)
if err != nil {
return nil, err
}
defer cleanup()
data, err := conn.GetBody(messageID)
if err != nil {
return nil, err
}
// yEnc decode
part, err := DecodeYencHeaders(bytes.NewReader(data))
if err != nil || part == nil {
return nil, fmt.Errorf("failed to decode segment")
}
// Return both the filename and decoded data
return part, nil
}
func (c *Client) MinimumMaxConns() int {
return c.minimumMaxConns
}
func (c *Client) TotalActiveConnections() int {
total := 0
c.pools.Range(func(key string, value *Pool) bool {
if value != nil {
total += value.ActiveConnections()
}
return true
})
return total
}
func (c *Client) Pools() *xsync.Map[string, *Pool] {
return c.pools
}
func (c *Client) GetProviders() []config.UsenetProvider {
return c.providers
}
+394
View File
@@ -0,0 +1,394 @@
package nntp
import (
"bufio"
"crypto/tls"
"fmt"
"github.com/chrisfarms/yenc"
"github.com/rs/zerolog"
"io"
"net"
"net/textproto"
"strconv"
"strings"
)
// Connection represents an NNTP connection
type Connection struct {
username, password, address string
port int
conn net.Conn
text *textproto.Conn
reader *bufio.Reader
writer *bufio.Writer
logger zerolog.Logger
}
func (c *Connection) authenticate() error {
// Send AUTHINFO USER command
if err := c.sendCommand(fmt.Sprintf("AUTHINFO USER %s", c.username)); err != nil {
return NewConnectionError(fmt.Errorf("failed to send username: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return NewConnectionError(fmt.Errorf("failed to read user response: %w", err))
}
if resp.Code != 381 {
return classifyNNTPError(resp.Code, fmt.Sprintf("unexpected response to AUTHINFO USER: %s", resp.Message))
}
// Send AUTHINFO PASS command
if err := c.sendCommand(fmt.Sprintf("AUTHINFO PASS %s", c.password)); err != nil {
return NewConnectionError(fmt.Errorf("failed to send password: %w", err))
}
resp, err = c.readResponse()
if err != nil {
return NewConnectionError(fmt.Errorf("failed to read password response: %w", err))
}
if resp.Code != 281 {
return classifyNNTPError(resp.Code, fmt.Sprintf("authentication failed: %s", resp.Message))
}
return nil
}
// startTLS initiates TLS encryption with proper error handling
func (c *Connection) startTLS() error {
if err := c.sendCommand("STARTTLS"); err != nil {
return NewConnectionError(fmt.Errorf("failed to send STARTTLS: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return NewConnectionError(fmt.Errorf("failed to read STARTTLS response: %w", err))
}
if resp.Code != 382 {
return classifyNNTPError(resp.Code, fmt.Sprintf("STARTTLS not supported: %s", resp.Message))
}
// Upgrade connection to TLS
tlsConn := tls.Client(c.conn, &tls.Config{
ServerName: c.address,
InsecureSkipVerify: false,
})
c.conn = tlsConn
c.reader = bufio.NewReader(tlsConn)
c.writer = bufio.NewWriter(tlsConn)
c.text = textproto.NewConn(tlsConn)
c.logger.Debug().Msg("TLS encryption enabled")
return nil
}
// ping sends a simple command to test the connection
func (c *Connection) ping() error {
if err := c.sendCommand("DATE"); err != nil {
return NewConnectionError(err)
}
_, err := c.readResponse()
if err != nil {
return NewConnectionError(err)
}
return nil
}
// sendCommand sends a command to the NNTP server
func (c *Connection) sendCommand(command string) error {
_, err := fmt.Fprintf(c.writer, "%s\r\n", command)
if err != nil {
return err
}
return c.writer.Flush()
}
// readResponse reads a response from the NNTP server
func (c *Connection) readResponse() (*Response, error) {
line, err := c.text.ReadLine()
if err != nil {
return nil, err
}
parts := strings.SplitN(line, " ", 2)
code, err := strconv.Atoi(parts[0])
if err != nil {
return nil, fmt.Errorf("invalid response code: %s", parts[0])
}
message := ""
if len(parts) > 1 {
message = parts[1]
}
return &Response{
Code: code,
Message: message,
}, nil
}
// readMultilineResponse reads a multiline response
func (c *Connection) readMultilineResponse() (*Response, error) {
resp, err := c.readResponse()
if err != nil {
return nil, err
}
// Check if this is a multiline response
if resp.Code < 200 || resp.Code >= 300 {
return resp, nil
}
lines, err := c.text.ReadDotLines()
if err != nil {
return nil, err
}
resp.Lines = lines
return resp, nil
}
// GetArticle retrieves an article by message ID with proper error classification
func (c *Connection) GetArticle(messageID string) (*Article, error) {
messageID = FormatMessageID(messageID)
if err := c.sendCommand(fmt.Sprintf("ARTICLE %s", messageID)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send ARTICLE command: %w", err))
}
resp, err := c.readMultilineResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read article response: %w", err))
}
if resp.Code != 220 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
return c.parseArticle(messageID, resp.Lines)
}
// GetBody retrieves article body by message ID with proper error classification
func (c *Connection) GetBody(messageID string) ([]byte, error) {
messageID = FormatMessageID(messageID)
if err := c.sendCommand(fmt.Sprintf("BODY %s", messageID)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send BODY command: %w", err))
}
// Read the initial response
resp, err := c.readResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read body response: %w", err))
}
if resp.Code != 222 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
// Read the raw body data directly using textproto to preserve exact formatting for yEnc
lines, err := c.text.ReadDotLines()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read body data: %w", err))
}
// Join with \r\n to preserve original line endings and add final \r\n
body := strings.Join(lines, "\r\n")
if len(lines) > 0 {
body += "\r\n"
}
return []byte(body), nil
}
// GetHead retrieves article headers by message ID
func (c *Connection) GetHead(messageID string) ([]byte, error) {
messageID = FormatMessageID(messageID)
if err := c.sendCommand(fmt.Sprintf("HEAD %s", messageID)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send HEAD command: %w", err))
}
// Read the initial response
resp, err := c.readResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read head response: %w", err))
}
if resp.Code != 221 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
// Read the header data using textproto
lines, err := c.text.ReadDotLines()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read header data: %w", err))
}
// Join with \r\n to preserve original line endings and add final \r\n
headers := strings.Join(lines, "\r\n")
if len(lines) > 0 {
headers += "\r\n"
}
return []byte(headers), nil
}
// GetSegment retrieves a specific segment with proper error handling
func (c *Connection) GetSegment(messageID string, segmentNumber int) (*Segment, error) {
messageID = FormatMessageID(messageID)
body, err := c.GetBody(messageID)
if err != nil {
return nil, err // GetBody already returns classified errors
}
return &Segment{
MessageID: messageID,
Number: segmentNumber,
Bytes: int64(len(body)),
Data: body,
}, nil
}
// Stat retrieves article statistics by message ID with proper error classification
func (c *Connection) Stat(messageID string) (articleNumber int, echoedID string, err error) {
messageID = FormatMessageID(messageID)
if err = c.sendCommand(fmt.Sprintf("STAT %s", messageID)); err != nil {
return 0, "", NewConnectionError(fmt.Errorf("failed to send STAT: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return 0, "", NewConnectionError(fmt.Errorf("failed to read STAT response: %w", err))
}
if resp.Code != 223 {
return 0, "", classifyNNTPError(resp.Code, resp.Message)
}
fields := strings.Fields(resp.Message)
if len(fields) < 2 {
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("unexpected STAT response format: %q", resp.Message))
}
if articleNumber, err = strconv.Atoi(fields[0]); err != nil {
return 0, "", NewProtocolError(resp.Code, fmt.Sprintf("invalid article number %q: %v", fields[0], err))
}
echoedID = fields[1]
return articleNumber, echoedID, nil
}
// SelectGroup selects a newsgroup and returns group information
func (c *Connection) SelectGroup(groupName string) (*GroupInfo, error) {
if err := c.sendCommand(fmt.Sprintf("GROUP %s", groupName)); err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to send GROUP command: %w", err))
}
resp, err := c.readResponse()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to read GROUP response: %w", err))
}
if resp.Code != 211 {
return nil, classifyNNTPError(resp.Code, resp.Message)
}
// Parse GROUP response: "211 number low high group-name"
fields := strings.Fields(resp.Message)
if len(fields) < 4 {
return nil, NewProtocolError(resp.Code, fmt.Sprintf("unexpected GROUP response format: %q", resp.Message))
}
groupInfo := &GroupInfo{
Name: groupName,
}
if count, err := strconv.Atoi(fields[0]); err == nil {
groupInfo.Count = count
}
if low, err := strconv.Atoi(fields[1]); err == nil {
groupInfo.Low = low
}
if high, err := strconv.Atoi(fields[2]); err == nil {
groupInfo.High = high
}
return groupInfo, nil
}
// parseArticle parses article data from response lines
func (c *Connection) parseArticle(messageID string, lines []string) (*Article, error) {
article := &Article{
MessageID: messageID,
Groups: []string{},
}
headerEnd := -1
for i, line := range lines {
if line == "" {
headerEnd = i
break
}
// Parse headers
if strings.HasPrefix(line, "Subject: ") {
article.Subject = strings.TrimPrefix(line, "Subject: ")
} else if strings.HasPrefix(line, "From: ") {
article.From = strings.TrimPrefix(line, "From: ")
} else if strings.HasPrefix(line, "Date: ") {
article.Date = strings.TrimPrefix(line, "Date: ")
} else if strings.HasPrefix(line, "Newsgroups: ") {
groups := strings.TrimPrefix(line, "Newsgroups: ")
article.Groups = strings.Split(groups, ",")
for i := range article.Groups {
article.Groups[i] = strings.TrimSpace(article.Groups[i])
}
}
}
// Join body lines
if headerEnd != -1 && headerEnd+1 < len(lines) {
body := strings.Join(lines[headerEnd+1:], "\n")
article.Body = []byte(body)
article.Size = int64(len(article.Body))
}
return article, nil
}
// close closes the NNTP connection
func (c *Connection) close() error {
if c.conn != nil {
return c.conn.Close()
}
return nil
}
func DecodeYenc(reader io.Reader) (*yenc.Part, error) {
part, err := yenc.Decode(reader)
if err != nil {
return nil, NewYencDecodeError(fmt.Errorf("failed to create yenc decoder: %w", err))
}
return part, nil
}
func IsValidMessageID(messageID string) bool {
if len(messageID) < 3 {
return false
}
return strings.Contains(messageID, "@")
}
// FormatMessageID ensures message ID has proper format
func FormatMessageID(messageID string) string {
messageID = strings.TrimSpace(messageID)
if !strings.HasPrefix(messageID, "<") {
messageID = "<" + messageID
}
if !strings.HasSuffix(messageID, ">") {
messageID = messageID + ">"
}
return messageID
}
+116
View File
@@ -0,0 +1,116 @@
package nntp
import (
"bufio"
"fmt"
"io"
"strconv"
"strings"
)
// YencMetadata contains just the header information
type YencMetadata struct {
Name string // filename
Size int64 // total file size
Part int // part number
Total int // total parts
Begin int64 // part start byte
End int64 // part end byte
LineSize int // line length
}
// DecodeYencHeaders extracts only yenc header metadata without decoding body
func DecodeYencHeaders(reader io.Reader) (*YencMetadata, error) {
buf := bufio.NewReader(reader)
metadata := &YencMetadata{}
// Find and parse =ybegin header
if err := parseYBeginHeader(buf, metadata); err != nil {
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ybegin header: %w", err))
}
// Parse =ypart header if this is a multipart file
if metadata.Part > 0 {
if err := parseYPartHeader(buf, metadata); err != nil {
return nil, NewYencDecodeError(fmt.Errorf("failed to parse ypart header: %w", err))
}
}
return metadata, nil
}
func parseYBeginHeader(buf *bufio.Reader, metadata *YencMetadata) error {
var s string
var err error
// Find the =ybegin line
for {
s, err = buf.ReadString('\n')
if err != nil {
return err
}
if len(s) >= 7 && s[:7] == "=ybegin" {
break
}
}
// Parse the header line
parts := strings.SplitN(s[7:], "name=", 2)
if len(parts) > 1 {
metadata.Name = strings.TrimSpace(parts[1])
}
// Parse other parameters
for _, header := range strings.Split(parts[0], " ") {
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
if len(kv) < 2 {
continue
}
switch kv[0] {
case "size":
metadata.Size, _ = strconv.ParseInt(kv[1], 10, 64)
case "line":
metadata.LineSize, _ = strconv.Atoi(kv[1])
case "part":
metadata.Part, _ = strconv.Atoi(kv[1])
case "total":
metadata.Total, _ = strconv.Atoi(kv[1])
}
}
return nil
}
func parseYPartHeader(buf *bufio.Reader, metadata *YencMetadata) error {
var s string
var err error
// Find the =ypart line
for {
s, err = buf.ReadString('\n')
if err != nil {
return err
}
if len(s) >= 6 && s[:6] == "=ypart" {
break
}
}
// Parse part parameters
for _, header := range strings.Split(s[6:], " ") {
kv := strings.SplitN(strings.TrimSpace(header), "=", 2)
if len(kv) < 2 {
continue
}
switch kv[0] {
case "begin":
metadata.Begin, _ = strconv.ParseInt(kv[1], 10, 64)
case "end":
metadata.End, _ = strconv.ParseInt(kv[1], 10, 64)
}
}
return nil
}
+195
View File
@@ -0,0 +1,195 @@
package nntp
import (
"errors"
"fmt"
)
// Error types for NNTP operations
type ErrorType int
const (
ErrorTypeUnknown ErrorType = iota
ErrorTypeConnection
ErrorTypeAuthentication
ErrorTypeTimeout
ErrorTypeArticleNotFound
ErrorTypeGroupNotFound
ErrorTypePermissionDenied
ErrorTypeServerBusy
ErrorTypeInvalidCommand
ErrorTypeProtocol
ErrorTypeYencDecode
ErrorTypeNoAvailableConnection
)
// Error represents an NNTP-specific error
type Error struct {
Type ErrorType
Code int // NNTP response code
Message string // Error message
Err error // Underlying error
}
// Predefined errors for common cases
var (
ErrArticleNotFound = &Error{Type: ErrorTypeArticleNotFound, Code: 430, Message: "article not found"}
ErrGroupNotFound = &Error{Type: ErrorTypeGroupNotFound, Code: 411, Message: "group not found"}
ErrPermissionDenied = &Error{Type: ErrorTypePermissionDenied, Code: 502, Message: "permission denied"}
ErrAuthenticationFail = &Error{Type: ErrorTypeAuthentication, Code: 482, Message: "authentication failed"}
ErrServerBusy = &Error{Type: ErrorTypeServerBusy, Code: 400, Message: "server busy"}
ErrPoolNotFound = &Error{Type: ErrorTypeUnknown, Code: 0, Message: "NNTP pool not found", Err: nil}
ErrNoAvailableConnection = &Error{Type: ErrorTypeNoAvailableConnection, Code: 0, Message: "no available connection in pool", Err: nil}
)
func (e *Error) Error() string {
if e.Err != nil {
return fmt.Sprintf("NNTP %s (code %d): %s - %v", e.Type.String(), e.Code, e.Message, e.Err)
}
return fmt.Sprintf("NNTP %s (code %d): %s", e.Type.String(), e.Code, e.Message)
}
func (e *Error) Unwrap() error {
return e.Err
}
func (e *Error) Is(target error) bool {
if t, ok := target.(*Error); ok {
return e.Type == t.Type
}
return false
}
// IsRetryable returns true if the error might be resolved by retrying
func (e *Error) IsRetryable() bool {
switch e.Type {
case ErrorTypeConnection, ErrorTypeTimeout, ErrorTypeServerBusy:
return true
case ErrorTypeArticleNotFound, ErrorTypeGroupNotFound, ErrorTypePermissionDenied, ErrorTypeAuthentication:
return false
default:
return false
}
}
// ShouldStopParsing returns true if this error should stop the entire parsing process
func (e *Error) ShouldStopParsing() bool {
switch e.Type {
case ErrorTypeAuthentication, ErrorTypePermissionDenied:
return true // Critical auth issues
case ErrorTypeConnection:
return false // Can continue with other connections
case ErrorTypeArticleNotFound:
return false // Can continue searching for other articles
case ErrorTypeServerBusy:
return false // Temporary issue
default:
return false
}
}
func (et ErrorType) String() string {
switch et {
case ErrorTypeConnection:
return "CONNECTION"
case ErrorTypeAuthentication:
return "AUTHENTICATION"
case ErrorTypeTimeout:
return "TIMEOUT"
case ErrorTypeArticleNotFound:
return "ARTICLE_NOT_FOUND"
case ErrorTypeGroupNotFound:
return "GROUP_NOT_FOUND"
case ErrorTypePermissionDenied:
return "PERMISSION_DENIED"
case ErrorTypeServerBusy:
return "SERVER_BUSY"
case ErrorTypeInvalidCommand:
return "INVALID_COMMAND"
case ErrorTypeProtocol:
return "PROTOCOL"
case ErrorTypeYencDecode:
return "YENC_DECODE"
default:
return "UNKNOWN"
}
}
// Helper functions to create specific errors
func NewConnectionError(err error) *Error {
return &Error{
Type: ErrorTypeConnection,
Message: "connection failed",
Err: err,
}
}
func NewTimeoutError(err error) *Error {
return &Error{
Type: ErrorTypeTimeout,
Message: "operation timed out",
Err: err,
}
}
func NewProtocolError(code int, message string) *Error {
return &Error{
Type: ErrorTypeProtocol,
Code: code,
Message: message,
}
}
func NewYencDecodeError(err error) *Error {
return &Error{
Type: ErrorTypeYencDecode,
Message: "yEnc decode failed",
Err: err,
}
}
// classifyNNTPError classifies an NNTP response code into an error type
func classifyNNTPError(code int, message string) *Error {
switch {
case code == 430 || code == 423:
return &Error{Type: ErrorTypeArticleNotFound, Code: code, Message: message}
case code == 411:
return &Error{Type: ErrorTypeGroupNotFound, Code: code, Message: message}
case code == 502 || code == 503:
return &Error{Type: ErrorTypePermissionDenied, Code: code, Message: message}
case code == 481 || code == 482:
return &Error{Type: ErrorTypeAuthentication, Code: code, Message: message}
case code == 400:
return &Error{Type: ErrorTypeServerBusy, Code: code, Message: message}
case code == 500 || code == 501:
return &Error{Type: ErrorTypeInvalidCommand, Code: code, Message: message}
case code >= 400:
return &Error{Type: ErrorTypeProtocol, Code: code, Message: message}
default:
return &Error{Type: ErrorTypeUnknown, Code: code, Message: message}
}
}
func IsArticleNotFoundError(err error) bool {
var nntpErr *Error
if errors.As(err, &nntpErr) {
return nntpErr.Type == ErrorTypeArticleNotFound
}
return false
}
func IsAuthenticationError(err error) bool {
var nntpErr *Error
if errors.As(err, &nntpErr) {
return nntpErr.Type == ErrorTypeAuthentication
}
return false
}
func IsRetryableError(err error) bool {
var nntpErr *Error
if errors.As(err, &nntpErr) {
return nntpErr.IsRetryable()
}
return false
}
+299
View File
@@ -0,0 +1,299 @@
package nntp
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"github.com/rs/zerolog"
"github.com/sirrobot01/decypharr/internal/config"
"net"
"net/textproto"
"sync"
"sync/atomic"
"time"
)
// Pool manages a pool of NNTP connections
type Pool struct {
address, username, password string
maxConns, port int
ssl bool
useTLS bool
connections chan *Connection
logger zerolog.Logger
closed atomic.Bool
totalConnections atomic.Int32
activeConnections atomic.Int32
}
// Segment represents a usenet segment
type Segment struct {
MessageID string
Number int
Bytes int64
Data []byte
}
// Article represents a complete usenet article
type Article struct {
MessageID string
Subject string
From string
Date string
Groups []string
Body []byte
Size int64
}
// Response represents an NNTP server response
type Response struct {
Code int
Message string
Lines []string
}
// GroupInfo represents information about a newsgroup
type GroupInfo struct {
Name string
Count int // Number of articles in the group
Low int // Lowest article number
High int // Highest article number
}
// NewPool creates a new NNTP connection pool
func NewPool(provider config.UsenetProvider, logger zerolog.Logger) (*Pool, error) {
maxConns := provider.Connections
if maxConns <= 0 {
maxConns = 1
}
pool := &Pool{
address: provider.Host,
username: provider.Username,
password: provider.Password,
port: provider.Port,
maxConns: maxConns,
ssl: provider.SSL,
useTLS: provider.UseTLS,
connections: make(chan *Connection, maxConns),
logger: logger,
}
return pool.initializeConnections()
}
func (p *Pool) initializeConnections() (*Pool, error) {
var wg sync.WaitGroup
var mu sync.Mutex
var successfulConnections []*Connection
var errs []error
// Create connections concurrently
for i := 0; i < p.maxConns; i++ {
wg.Add(1)
go func(connIndex int) {
defer wg.Done()
conn, err := p.createConnection()
mu.Lock()
defer mu.Unlock()
if err != nil {
errs = append(errs, err)
} else {
successfulConnections = append(successfulConnections, conn)
}
}(i)
}
// Wait for all connection attempts to complete
wg.Wait()
// Add successful connections to the pool
for _, conn := range successfulConnections {
p.connections <- conn
}
p.totalConnections.Store(int32(len(successfulConnections)))
if len(successfulConnections) == 0 {
return nil, fmt.Errorf("failed to create any connections: %v", errs)
}
// Log results
p.logger.Info().
Str("server", p.address).
Int("port", p.port).
Int("requested_connections", p.maxConns).
Int("successful_connections", len(successfulConnections)).
Int("failed_connections", len(errs)).
Msg("NNTP connection pool created")
// If some connections failed, log a warning but continue
if len(errs) > 0 {
p.logger.Warn().
Int("failed_count", len(errs)).
Msg("Some connections failed during pool initialization")
}
return p, nil
}
// Get retrieves a connection from the pool
func (p *Pool) Get(ctx context.Context) (*Connection, error) {
if p.closed.Load() {
return nil, NewConnectionError(fmt.Errorf("connection pool is closed"))
}
select {
case conn := <-p.connections:
if conn == nil {
return nil, NewConnectionError(fmt.Errorf("received nil connection from pool"))
}
p.activeConnections.Add(1)
if err := conn.ping(); err != nil {
p.activeConnections.Add(-1)
err := conn.close()
if err != nil {
return nil, err
}
// Create a new connection
newConn, err := p.createConnection()
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to create replacement connection: %w", err))
}
p.activeConnections.Add(1)
return newConn, nil
}
return conn, nil
case <-ctx.Done():
return nil, NewTimeoutError(ctx.Err())
}
}
// Put returns a connection to the pool
func (p *Pool) Put(conn *Connection) {
if conn == nil {
return
}
defer p.activeConnections.Add(-1)
if p.closed.Load() {
conn.close()
return
}
// Try non-blocking first
select {
case p.connections <- conn:
return
default:
}
// If pool is full, this usually means we have too many connections
// Force return by making space (close oldest connection)
select {
case oldConn := <-p.connections:
oldConn.close() // Close the old connection
p.connections <- conn // Put the new one back
case <-time.After(1 * time.Second):
// Still can't return - close this connection
conn.close()
}
}
// Close closes all connections in the pool
func (p *Pool) Close() error {
if p.closed.Load() {
return nil
}
p.closed.Store(true)
close(p.connections)
for conn := range p.connections {
err := conn.close()
if err != nil {
return err
}
}
p.logger.Info().Msg("NNTP connection pool closed")
return nil
}
// createConnection creates a new NNTP connection with proper error handling
func (p *Pool) createConnection() (*Connection, error) {
addr := fmt.Sprintf("%s:%d", p.address, p.port)
var conn net.Conn
var err error
if p.ssl {
conn, err = tls.DialWithDialer(&net.Dialer{}, "tcp", addr, &tls.Config{
InsecureSkipVerify: false,
})
} else {
conn, err = net.Dial("tcp", addr)
}
if err != nil {
return nil, NewConnectionError(fmt.Errorf("failed to connect to %s: %w", addr, err))
}
reader := bufio.NewReaderSize(conn, 256*1024) // 256KB buffer for better performance
writer := bufio.NewWriterSize(conn, 256*1024) // 256KB buffer for better performance
text := textproto.NewConn(conn)
nntpConn := &Connection{
username: p.username,
password: p.password,
address: p.address,
port: p.port,
conn: conn,
text: text,
reader: reader,
writer: writer,
logger: p.logger,
}
// Read welcome message
_, err = nntpConn.readResponse()
if err != nil {
conn.Close()
return nil, NewConnectionError(fmt.Errorf("failed to read welcome message: %w", err))
}
// Authenticate if credentials are provided
if p.username != "" && p.password != "" {
if err := nntpConn.authenticate(); err != nil {
conn.Close()
return nil, err // authenticate() already returns NNTPError
}
}
// Enable TLS if requested (STARTTLS)
if p.useTLS && !p.ssl {
if err := nntpConn.startTLS(); err != nil {
conn.Close()
return nil, err // startTLS() already returns NNTPError
}
}
return nntpConn, nil
}
func (p *Pool) ConnectionCount() int {
return int(p.totalConnections.Load())
}
func (p *Pool) ActiveConnections() int {
return int(p.activeConnections.Load())
}
func (p *Pool) IsFree() bool {
return p.ActiveConnections() < p.maxConns
}