- Improve propfind handler

- remove path escapes in fileinfo
- other minor fixes
This commit is contained in:
Mukhtar Akere
2025-05-12 03:35:40 +01:00
parent ffb1745bf6
commit 9de7cfd73b
11 changed files with 160 additions and 88 deletions

View File

@@ -1,7 +1,6 @@
package webdav
import (
"github.com/sirrobot01/decypharr/internal/utils"
"os"
"time"
)
@@ -15,7 +14,7 @@ type FileInfo struct {
isDir bool
}
func (fi *FileInfo) Name() string { return utils.EscapePath(fi.name) } // uses minimal escaping
func (fi *FileInfo) Name() string { return fi.name } // uses minimal escaping
func (fi *FileInfo) Size() int64 { return fi.size }
func (fi *FileInfo) Mode() os.FileMode { return fi.mode }
func (fi *FileInfo) ModTime() time.Time { return fi.modTime }

View File

@@ -119,7 +119,7 @@ func (h *Handler) getChildren(name string) []os.FileInfo {
if name[0] != '/' {
name = "/" + name
}
name = utils.UnescapePath(path.Clean(name))
name = utils.PathUnescape(path.Clean(name))
root := path.Clean(h.getRootPath())
// toplevel “parents” (e.g. __all__, torrents)
@@ -133,9 +133,8 @@ func (h *Handler) getChildren(name string) []os.FileInfo {
// torrent-folder level (e.g. /root/parentFolder/torrentName)
rel := strings.TrimPrefix(name, root+string(os.PathSeparator))
parts := strings.Split(rel, string(os.PathSeparator))
parent, _ := url.PathUnescape(parts[0])
if len(parts) == 2 && utils.Contains(h.getParentItems(), parent) {
torrentName := utils.UnescapePath(parts[1])
if len(parts) == 2 && utils.Contains(h.getParentItems(), parts[0]) {
torrentName := parts[1]
if t := h.cache.GetTorrentByName(torrentName); t != nil {
return h.getFileInfos(t.Torrent)
}
@@ -147,7 +146,7 @@ func (h *Handler) OpenFile(ctx context.Context, name string, flag int, perm os.F
if !strings.HasPrefix(name, "/") {
name = "/" + name
}
name = utils.UnescapePath(path.Clean(name))
name = utils.PathUnescape(path.Clean(name))
rootDir := path.Clean(h.getRootPath())
metadataOnly := ctx.Value("metadataOnly") != nil
now := time.Now()
@@ -188,9 +187,8 @@ func (h *Handler) OpenFile(ctx context.Context, name string, flag int, perm os.F
rel := strings.TrimPrefix(name, rootDir+string(os.PathSeparator))
parts := strings.Split(rel, string(os.PathSeparator))
if len(parts) >= 2 {
parent, _ := url.PathUnescape(parts[0])
if utils.Contains(h.getParentItems(), parent) {
torrentName := utils.UnescapePath(parts[1])
if utils.Contains(h.getParentItems(), parts[0]) {
torrentName := parts[1]
cached := h.cache.GetTorrentByName(torrentName)
if cached != nil && len(parts) >= 3 {
filename := filepath.Join(parts[2:]...)

View File

@@ -38,3 +38,31 @@ func (h *Handler) getCacheTTL(urlPath string) time.Duration {
}
return 2 * time.Minute // Longer TTL for other paths
}
var pctHex = "0123456789ABCDEF"
// fastEscapePath returns a percent-encoded path, preserving '/'
// and only encoding bytes outside the unreserved set:
//
// ALPHA / DIGIT / '-' / '_' / '.' / '~' / '/'
func fastEscapePath(p string) string {
var b strings.Builder
for i := 0; i < len(p); i++ {
c := p[i]
// unreserved (plus '/')
if (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '_' ||
c == '.' || c == '~' ||
c == '/' {
b.WriteByte(c)
} else {
b.WriteByte('%')
b.WriteByte(pctHex[c>>4])
b.WriteByte(pctHex[c&0xF])
}
}
return b.String()
}

View File

@@ -2,14 +2,20 @@ package webdav
import (
"context"
"fmt"
"github.com/stanNthe5/stringbuf"
"net/http"
"net/url"
"os"
"path"
"strconv"
"strings"
"sync"
"time"
)
var builderPool = sync.Pool{
New: func() interface{} { return stringbuf.New("") },
}
func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request) {
// Setup context for metadata only
ctx := context.WithValue(r.Context(), "metadataOnly", true)
@@ -25,8 +31,12 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request) {
// Build the list of entries
type entry struct {
href string
fi os.FileInfo
href string
escHref string // already XML-safe + percent-escaped
escName string
size int64
isDir bool
modTime string
}
// Always include the resource itself
@@ -45,65 +55,81 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request) {
return
}
// Collect children if a directory and depth allows
children := make([]os.FileInfo, 0)
if fi.IsDir() && depth != "0" {
children = h.getChildren(cleanPath)
var rawEntries []os.FileInfo
if fi.IsDir() {
rawEntries = append(rawEntries, h.getChildren(cleanPath)...)
}
entries := make([]entry, 0, 1+len(children))
entries = append(entries, entry{href: cleanPath, fi: fi})
now := time.Now().UTC().Format("2006-01-02T15:04:05.000-07:00")
entries := make([]entry, 0, len(rawEntries)+1)
// Add the current file itself
entries = append(entries, entry{
escHref: xmlEscape(fastEscapePath(cleanPath)),
escName: xmlEscape(fi.Name()),
isDir: fi.IsDir(),
size: fi.Size(),
modTime: fi.ModTime().Format("2006-01-02T15:04:05.000-07:00"),
})
for _, info := range rawEntries {
for _, child := range children {
childHref := path.Join("/", cleanPath, child.Name())
if child.IsDir() {
childHref += "/"
nm := info.Name()
// build raw href
href := path.Join("/", cleanPath, nm)
if info.IsDir() {
href += "/"
}
entries = append(entries, entry{href: childHref, fi: child})
entries = append(entries, entry{
escHref: xmlEscape(fastEscapePath(href)),
escName: xmlEscape(nm),
isDir: info.IsDir(),
size: info.Size(),
modTime: info.ModTime().Format("2006-01-02T15:04:05.000-07:00"),
})
}
// Use a string builder for creating XML
var sb strings.Builder
sb := builderPool.Get().(stringbuf.StringBuf)
sb.Reset()
defer builderPool.Put(sb)
// XML header and main element
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
sb.WriteString(`<d:multistatus xmlns:d="DAV:">`)
// Format time once
timeFormat := "2006-01-02T15:04:05.000-07:00"
_, _ = sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
_, _ = sb.WriteString(`<d:multistatus xmlns:d="DAV:">`)
// Add responses for each entry
for _, e := range entries {
// Format href path properly
u := &url.URL{Path: e.href}
escaped := u.EscapedPath()
_, _ = sb.WriteString(`<d:response>`)
_, _ = sb.WriteString(`<d:href>`)
_, _ = sb.WriteString(e.escHref)
_, _ = sb.WriteString(`</d:href>`)
_, _ = sb.WriteString(`<d:propstat>`)
_, _ = sb.WriteString(`<d:prop>`)
sb.WriteString(`<d:response>`)
sb.WriteString(fmt.Sprintf(`<d:href>%s</d:href>`, xmlEscape(escaped)))
sb.WriteString(`<d:propstat>`)
sb.WriteString(`<d:prop>`)
// Resource type differs based on directory vs file
if e.fi.IsDir() {
sb.WriteString(`<d:resourcetype><d:collection/></d:resourcetype>`)
if e.isDir {
_, _ = sb.WriteString(`<d:resourcetype><d:collection/></d:resourcetype>`)
} else {
sb.WriteString(`<d:resourcetype/>`)
sb.WriteString(fmt.Sprintf(`<d:getcontentlength>%d</d:getcontentlength>`, e.fi.Size()))
_, _ = sb.WriteString(`<d:resourcetype/>`)
_, _ = sb.WriteString(`<d:getcontentlength>`)
_, _ = sb.WriteString(strconv.FormatInt(e.size, 10))
_, _ = sb.WriteString(`</d:getcontentlength>`)
}
// Always add lastmodified and displayname
lastModified := e.fi.ModTime().Format(timeFormat)
sb.WriteString(fmt.Sprintf(`<d:getlastmodified>%s</d:getlastmodified>`, xmlEscape(lastModified)))
sb.WriteString(fmt.Sprintf(`<d:displayname>%s</d:displayname>`, xmlEscape(e.fi.Name())))
_, _ = sb.WriteString(`<d:getlastmodified>`)
_, _ = sb.WriteString(now)
_, _ = sb.WriteString(`</d:getlastmodified>`)
sb.WriteString(`</d:prop>`)
sb.WriteString(`<d:status>HTTP/1.1 200 OK</d:status>`)
sb.WriteString(`</d:propstat>`)
sb.WriteString(`</d:response>`)
_, _ = sb.WriteString(`<d:displayname>`)
_, _ = sb.WriteString(e.escName)
_, _ = sb.WriteString(`</d:displayname>`)
_, _ = sb.WriteString(`</d:prop>`)
_, _ = sb.WriteString(`<d:status>HTTP/1.1 200 OK</d:status>`)
_, _ = sb.WriteString(`</d:propstat>`)
_, _ = sb.WriteString(`</d:response>`)
}
// Close root element
sb.WriteString(`</d:multistatus>`)
_, _ = sb.WriteString(`</d:multistatus>`)
// Set headers
w.Header().Set("Content-Type", "application/xml; charset=utf-8")
@@ -111,18 +137,28 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request) {
// Set status code and write response
w.WriteHeader(http.StatusMultiStatus) // 207 MultiStatus
_, _ = w.Write([]byte(sb.String()))
_, _ = w.Write(sb.Bytes())
}
// Basic XML escaping function
func xmlEscape(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
s = strings.ReplaceAll(s, "'", "&apos;")
s = strings.ReplaceAll(s, "\"", "&quot;")
s = strings.ReplaceAll(s, "\n", "&#10;")
s = strings.ReplaceAll(s, "\r", "&#13;")
s = strings.ReplaceAll(s, "\t", "&#9;")
return s
var b strings.Builder
b.Grow(len(s))
for _, r := range s {
switch r {
case '&':
b.WriteString("&amp;")
case '<':
b.WriteString("&lt;")
case '>':
b.WriteString("&gt;")
case '"':
b.WriteString("&quot;")
case '\'':
b.WriteString("&apos;")
default:
b.WriteRune(r)
}
}
return b.String()
}