diff --git a/internal/conf/const.go b/internal/conf/const.go index b99d8849c..46cf8d677 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -191,4 +191,6 @@ const ( PathKey SharingIDKey SkipHookKey + VirtualHostKey + VhostPrefixKey ) diff --git a/internal/db/db.go b/internal/db/db.go index 96529c15d..34e5fe25e 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -12,7 +12,7 @@ var db *gorm.DB func Init(d *gorm.DB) { db = d - err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.SharingDB)) + err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.SharingDB), new(model.VirtualHost)) if err != nil { log.Fatalf("failed migrate database: %s", err.Error()) } diff --git a/internal/db/virtual_host.go b/internal/db/virtual_host.go new file mode 100644 index 000000000..0ae3645c9 --- /dev/null +++ b/internal/db/virtual_host.go @@ -0,0 +1,45 @@ +package db + +import ( + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/pkg/errors" +) + +func GetVirtualHostByDomain(domain string) (*model.VirtualHost, error) { + var v model.VirtualHost + if err := db.Where("domain = ?", domain).First(&v).Error; err != nil { + return nil, errors.Wrapf(err, "failed to select virtual host") + } + return &v, nil +} + +func GetVirtualHostById(id uint) (*model.VirtualHost, error) { + var v model.VirtualHost + if err := db.First(&v, id).Error; err != nil { + return nil, errors.Wrapf(err, "failed get virtual host") + } + return &v, nil +} + +func CreateVirtualHost(v *model.VirtualHost) error { + return errors.WithStack(db.Create(v).Error) +} + +func UpdateVirtualHost(v *model.VirtualHost) error { + return errors.WithStack(db.Save(v).Error) +} + +func GetVirtualHosts(pageIndex, pageSize int) (vhosts []model.VirtualHost, count int64, err error) { + vhostDB := db.Model(&model.VirtualHost{}) + if err = vhostDB.Count(&count).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get virtual hosts count") + } + if err = vhostDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&vhosts).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed find virtual hosts") + } + return vhosts, count, nil +} + +func DeleteVirtualHostById(id uint) error { + return errors.WithStack(db.Delete(&model.VirtualHost{}, id).Error) +} diff --git a/internal/model/virtual_host.go b/internal/model/virtual_host.go new file mode 100644 index 000000000..23bf5dab9 --- /dev/null +++ b/internal/model/virtual_host.go @@ -0,0 +1,9 @@ +package model + +type VirtualHost struct { + ID uint `json:"id" gorm:"primaryKey"` + Enabled bool `json:"enabled"` + Domain string `json:"domain" gorm:"unique" binding:"required"` + Path string `json:"path" binding:"required"` + WebHosting bool `json:"web_hosting"` +} diff --git a/internal/op/virtual_host.go b/internal/op/virtual_host.go new file mode 100644 index 000000000..345ab4554 --- /dev/null +++ b/internal/op/virtual_host.go @@ -0,0 +1,75 @@ +package op + +import ( + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/db" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/OpenListTeam/go-cache" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +var vhostCache = cache.NewMemCache(cache.WithShards[*model.VirtualHost](2)) + +// GetVirtualHostByDomain 根据域名获取虚拟主机配置(带缓存) +func GetVirtualHostByDomain(domain string) (*model.VirtualHost, error) { + if v, ok := vhostCache.Get(domain); ok { + if v == nil { + utils.Log.Debugf("[VirtualHost] cache hit (nil) for domain=%q", domain) + return nil, errors.New("virtual host not found") + } + utils.Log.Debugf("[VirtualHost] cache hit for domain=%q id=%d", domain, v.ID) + return v, nil + } + utils.Log.Debugf("[VirtualHost] cache miss for domain=%q, querying db...", domain) + v, err := db.GetVirtualHostByDomain(domain) + if err != nil { + if errors.Is(errors.Cause(err), gorm.ErrRecordNotFound) { + utils.Log.Debugf("[VirtualHost] domain=%q not found in db, caching nil", domain) + vhostCache.Set(domain, nil, cache.WithEx[*model.VirtualHost](time.Minute*5)) + return nil, errors.New("virtual host not found") + } + utils.Log.Errorf("[VirtualHost] db error for domain=%q: %v", domain, err) + return nil, err + } + utils.Log.Debugf("[VirtualHost] db found domain=%q id=%d enabled=%v web_hosting=%v", domain, v.ID, v.Enabled, v.WebHosting) + vhostCache.Set(domain, v, cache.WithEx[*model.VirtualHost](time.Hour)) + return v, nil +} + +func GetVirtualHostById(id uint) (*model.VirtualHost, error) { + return db.GetVirtualHostById(id) +} + +func CreateVirtualHost(v *model.VirtualHost) error { + v.Path = utils.FixAndCleanPath(v.Path) + vhostCache.Del(v.Domain) + return db.CreateVirtualHost(v) +} + +func UpdateVirtualHost(v *model.VirtualHost) error { + v.Path = utils.FixAndCleanPath(v.Path) + old, err := db.GetVirtualHostById(v.ID) + if err != nil { + return err + } + // 如果域名变更,清除旧域名缓存 + vhostCache.Del(old.Domain) + vhostCache.Del(v.Domain) + return db.UpdateVirtualHost(v) +} + +func DeleteVirtualHostById(id uint) error { + old, err := db.GetVirtualHostById(id) + if err != nil { + return err + } + vhostCache.Del(old.Domain) + return db.DeleteVirtualHostById(id) +} + +func GetVirtualHosts(pageIndex, pageSize int) ([]model.VirtualHost, int64, error) { + return db.GetVirtualHosts(pageIndex, pageSize) +} diff --git a/server/handles/fsread.go b/server/handles/fsread.go index a90fc1082..78afc2853 100644 --- a/server/handles/fsread.go +++ b/server/handles/fsread.go @@ -2,6 +2,7 @@ package handles import ( "fmt" + "net" stdpath "path" "strings" "time" @@ -69,6 +70,8 @@ func FsListSplit(c *gin.Context) { SharingList(c, &req) return } + // 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径 + req.Path = applyVhostPathMapping(c, req.Path) user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() && user.Disabled { common.ErrorStrResp(c, "Guest user is disabled, login please", 401) @@ -272,6 +275,11 @@ func FsGetSplit(c *gin.Context) { SharingGet(c, &req) return } + // 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径 + // 同时将 vhost.Path 前缀存入 context,供 FsGet 生成 /p/ 链接时去掉前缀 + var vhostPrefix string + req.Path, vhostPrefix = applyVhostPathMappingWithPrefix(c, req.Path) + common.GinWithValue(c, conf.VhostPrefixKey, vhostPrefix) user := c.Request.Context().Value(conf.UserKey).(*model.User) if user.IsGuest() && user.Disabled { common.ErrorStrResp(c, "Guest user is disabled, login please", 401) @@ -319,12 +327,14 @@ func FsGet(c *gin.Context, req *FsGetReq, user *model.User) { rawURL = common.GenerateDownProxyURL(storage.GetStorage(), reqPath) if rawURL == "" { query := "" + // 生成 /p/ 链接时,去掉 vhost 路径前缀,保持前端看到的路径一致 + downPath := stripVhostPrefix(c, reqPath) if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { query = "?sign=" + sign.Sign(reqPath) } rawURL = fmt.Sprintf("%s/p%s%s", common.GetApiUrl(c), - utils.EncodePath(reqPath, true), + utils.EncodePath(downPath, true), query) } } else { @@ -427,3 +437,66 @@ func FsOther(c *gin.Context) { } common.SuccessResp(c, res) } + +// applyVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则,将请求路径映射到实际路径。 +func applyVhostPathMapping(c *gin.Context, reqPath string) string { + mapped, _ := applyVhostPathMappingWithPrefix(c, reqPath) + return mapped +} + +// applyVhostPathMappingWithPrefix 根据请求的 Host 头匹配虚拟主机规则, +// 将请求路径映射到虚拟主机配置的实际路径,同时返回 vhost.Path 前缀(用于生成下载链接时去掉前缀)。 +// 例如:vhost.Path="/123pan/Downloads",reqPath="/",则返回 ("/123pan/Downloads", "/123pan/Downloads") +// 例如:vhost.Path="/123pan/Downloads",reqPath="/subdir",则返回 ("/123pan/Downloads/subdir", "/123pan/Downloads") +// 如果没有匹配的虚拟主机规则,则返回 (原始路径, "") +func applyVhostPathMappingWithPrefix(c *gin.Context, reqPath string) (string, string) { + rawHost := c.Request.Host + domain := stripHostPortForVhost(rawHost) + if domain == "" { + return reqPath, "" + } + vhost, err := op.GetVirtualHostByDomain(domain) + if err != nil || vhost == nil { + return reqPath, "" + } + if !vhost.Enabled || vhost.WebHosting { + // 未启用,或者是 Web 托管模式(Web 托管不做路径重映射) + return reqPath, "" + } + // Map request path into the vhost root and verify it does not escape via traversal. + // stdpath.Join calls Clean internally, which collapses ".." segments, so we only need + // to confirm the result still lives under vhost.Path. + mapped := stdpath.Join(vhost.Path, reqPath) + if !strings.HasPrefix(mapped, strings.TrimRight(vhost.Path, "/")+"/") && mapped != vhost.Path { + utils.Log.Warnf("[VirtualHost] path traversal rejected for API remapping: domain=%q reqPath=%q", domain, reqPath) + return reqPath, "" + } + utils.Log.Debugf("[VirtualHost] API path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped) + return mapped, vhost.Path +} + +// stripVhostPrefix 从 gin context 中取出 vhost 路径前缀,并从 path 中去掉该前缀。 +// 用于生成 /p/ 下载链接时,将真实路径还原为前端看到的路径。 +func stripVhostPrefix(c *gin.Context, path string) string { + prefix, ok := c.Request.Context().Value(conf.VhostPrefixKey).(string) + if !ok || prefix == "" { + return path + } + if strings.HasPrefix(path, prefix+"/") { + return path[len(prefix):] + } + if path == prefix { + return "/" + } + return path +} + +// stripHostPortForVhost removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6). +func stripHostPortForVhost(host string) string { + h, _, err := net.SplitHostPort(host) + if err != nil { + // No port present; return host as-is + return host + } + return h +} diff --git a/server/handles/virtual_host.go b/server/handles/virtual_host.go new file mode 100644 index 000000000..583c2b6cd --- /dev/null +++ b/server/handles/virtual_host.go @@ -0,0 +1,83 @@ +package handles + +import ( + "strconv" + + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" +) + +func ListVirtualHosts(c *gin.Context) { + var req model.PageReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + req.Validate() + vhosts, total, err := op.GetVirtualHosts(req.Page, req.PerPage) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, common.PageResp{ + Content: vhosts, + Total: total, + }) +} + +func GetVirtualHost(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + vhost, err := op.GetVirtualHostById(uint(id)) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, vhost) +} + +func CreateVirtualHost(c *gin.Context) { + var req model.VirtualHost + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.CreateVirtualHost(&req); err != nil { + common.ErrorResp(c, err, 500, true) + } else { + common.SuccessResp(c) + } +} + +func UpdateVirtualHost(c *gin.Context) { + var req model.VirtualHost + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.UpdateVirtualHost(&req); err != nil { + common.ErrorResp(c, err, 500, true) + } else { + common.SuccessResp(c) + } +} + +func DeleteVirtualHost(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.DeleteVirtualHostById(uint(id)); err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} diff --git a/server/middlewares/down.go b/server/middlewares/down.go index c1f81b54b..6c63b96c1 100644 --- a/server/middlewares/down.go +++ b/server/middlewares/down.go @@ -1,6 +1,8 @@ package middlewares import ( + "net" + stdpath "path" "strings" "github.com/OpenListTeam/OpenList/v4/internal/conf" @@ -17,10 +19,46 @@ import ( func PathParse(c *gin.Context) { rawPath := parsePath(c.Param("path")) + // 虚拟主机路径重映射:根据 Host 头匹配虚拟主机规则,将请求路径映射到实际路径 + // 例如:vhost.Path="/123pan/Downloads",rawPath="/tests.html" -> "/123pan/Downloads/tests.html" + rawPath = applyDownVhostPathMapping(c, rawPath) common.GinWithValue(c, conf.PathKey, rawPath) c.Next() } +// applyDownVhostPathMapping 根据请求的 Host 头匹配虚拟主机规则, +// 将下载/预览路由的路径映射到虚拟主机配置的实际路径。 +// 仅在虚拟主机启用且非 Web 托管模式时生效。 +func applyDownVhostPathMapping(c *gin.Context, reqPath string) string { + rawHost := c.Request.Host + domain := stripDownHostPort(rawHost) + if domain == "" { + return reqPath + } + vhost, err := op.GetVirtualHostByDomain(domain) + if err != nil || vhost == nil { + return reqPath + } + if !vhost.Enabled || vhost.WebHosting { + // 未启用,或者是 Web 托管模式(Web 托管不做路径重映射) + return reqPath + } + // 路径重映射:将 reqPath 拼接到 vhost.Path 后面 + mapped := stdpath.Join(vhost.Path, reqPath) + utils.Log.Debugf("[VirtualHost] down path remapping: domain=%q reqPath=%q -> mappedPath=%q", domain, reqPath, mapped) + return mapped +} + +// stripDownHostPort removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6). +func stripDownHostPort(host string) string { + h, _, err := net.SplitHostPort(host) + if err != nil { + // No port present; return host as-is + return host + } + return h +} + func Down(verifyFunc func(string, string) error) func(c *gin.Context) { return func(c *gin.Context) { rawPath := c.Request.Context().Value(conf.PathKey).(string) diff --git a/server/middlewares/virtual_host.go b/server/middlewares/virtual_host.go new file mode 100644 index 000000000..3372e400f --- /dev/null +++ b/server/middlewares/virtual_host.go @@ -0,0 +1,4 @@ +package middlewares + +// Note: Virtual host resolution is handled by existing handlers/middlewares. +// This file intentionally contains no additional code to avoid unused/dead middleware. diff --git a/server/router.go b/server/router.go index 57d1166ae..8cc559714 100644 --- a/server/router.go +++ b/server/router.go @@ -122,6 +122,13 @@ func admin(g *gin.RouterGroup) { meta.POST("/update", handles.UpdateMeta) meta.POST("/delete", handles.DeleteMeta) + vhost := g.Group("/vhost") + vhost.GET("/list", handles.ListVirtualHosts) + vhost.GET("/get", handles.GetVirtualHost) + vhost.POST("/create", handles.CreateVirtualHost) + vhost.POST("/update", handles.UpdateVirtualHost) + vhost.POST("/delete", handles.DeleteVirtualHost) + user := g.Group("/user") user.GET("/list", handles.ListUsers) user.GET("/get", handles.GetUser) diff --git a/server/static/static.go b/server/static/static.go index 29f97ff74..81208365a 100644 --- a/server/static/static.go +++ b/server/static/static.go @@ -5,16 +5,23 @@ import ( "errors" "fmt" "io" - "io/fs" + iofs "io/fs" + stdnet "net" "net/http" "os" + stdpath "path" "strings" "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/conf" + internalfs "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/setting" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/public" + "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/gin-gonic/gin" ) @@ -32,12 +39,12 @@ type Manifest struct { Icons []ManifestIcon `json:"icons"` } -var static fs.FS +var static iofs.FS func initStatic() { utils.Log.Debug("Initializing static file system...") if conf.Conf.DistDir == "" { - dist, err := fs.Sub(public.Public, "dist") + dist, err := iofs.Sub(public.Public, "dist") if err != nil { utils.Log.Fatalf("failed to read dist dir: %v", err) } @@ -76,7 +83,7 @@ func initIndex(siteConfig SiteConfig) { utils.Log.Debug("Reading index.html from static files system...") indexFile, err := static.Open("index.html") if err != nil { - if errors.Is(err, fs.ErrNotExist) { + if errors.Is(err, iofs.ErrNotExist) { utils.Log.Fatalf("index.html not exist, you may forget to put dist of frontend to public/dist") } utils.Log.Fatalf("failed to read index.html: %v", err) @@ -98,9 +105,9 @@ func initIndex(siteConfig SiteConfig) { manifestPath = siteConfig.BasePath + "/manifest.json" } replaceMap := map[string]string{ - "cdn: undefined": fmt.Sprintf("cdn: '%s'", siteConfig.Cdn), - "base_path: undefined": fmt.Sprintf("base_path: '%s'", siteConfig.BasePath), - `href="/manifest.json"`: fmt.Sprintf(`href="%s"`, manifestPath), + "cdn: undefined": fmt.Sprintf("cdn: '%s'", siteConfig.Cdn), + "base_path: undefined": fmt.Sprintf("base_path: '%s'", siteConfig.BasePath), + `href="/manifest.json"`: fmt.Sprintf(`href="%s"`, manifestPath), } conf.RawIndexHtml = replaceStrings(conf.RawIndexHtml, replaceMap) UpdateIndex() @@ -134,10 +141,10 @@ func UpdateIndex() { func ManifestJSON(c *gin.Context) { // Get site configuration to ensure consistent base path handling siteConfig := getSiteConfig() - + // Get site title from settings siteTitle := setting.GetStr(conf.SiteTitle) - + // Get logo from settings, use the first line (light theme logo) logoSetting := setting.GetStr(conf.Logo) logoUrl := strings.Split(logoSetting, "\n")[0] @@ -167,7 +174,7 @@ func ManifestJSON(c *gin.Context) { c.Header("Content-Type", "application/json") c.Header("Cache-Control", "public, max-age=3600") // cache for 1 hour - + if err := json.NewEncoder(c.Writer).Encode(manifest); err != nil { utils.Log.Errorf("Failed to encode manifest.json: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate manifest"}) @@ -181,7 +188,7 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { initStatic() initIndex(siteConfig) folders := []string{"assets", "images", "streamer", "static"} - + if conf.Conf.Cdn == "" { utils.Log.Debug("Setting up static file serving...") r.Use(func(c *gin.Context) { @@ -192,7 +199,7 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { } }) for _, folder := range folders { - sub, err := fs.Sub(static, folder) + sub, err := iofs.Sub(static, folder) if err != nil { utils.Log.Fatalf("can't find folder: %s", folder) } @@ -210,7 +217,49 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { } utils.Log.Debug("Setting up catch-all route...") - noRoute(func(c *gin.Context) { + + // virtualHostHandler 处理虚拟主机 Web 托管,以及默认的前端 SPA 路由 + virtualHostHandler := func(c *gin.Context) { + // 直接从 Host 头解析域名,检查是否匹配虚拟主机的 Web 托管请求 + rawHost := c.Request.Host + domain := stripHostPort(rawHost) + utils.Log.Debugf("[VirtualHost] handler triggered: method=%s path=%s host=%q domain=%q", + c.Request.Method, c.Request.URL.Path, rawHost, domain) + if domain != "" { + vhost, err := op.GetVirtualHostByDomain(domain) + if err != nil { + utils.Log.Debugf("[VirtualHost] domain=%q not found in db: %v", domain, err) + } else { + utils.Log.Debugf("[VirtualHost] domain=%q matched vhost: id=%d enabled=%v web_hosting=%v path=%q", + domain, vhost.ID, vhost.Enabled, vhost.WebHosting, vhost.Path) + if vhost.Enabled && vhost.WebHosting { + // Web 托管模式:直接返回文件内容 + // 注入 guest 用户到 context,供 internalfs.Get/Link 权限检查使用 + guest, guestErr := op.GetGuest() + if guestErr != nil { + utils.Log.Errorf("[VirtualHost] failed to get guest user: %v", guestErr) + c.Status(http.StatusInternalServerError) + return + } + common.GinWithValue(c, conf.UserKey, guest) + if handleWebHosting(c, vhost) { + return + } + } else if vhost.Enabled && !vhost.WebHosting { + // 路径重映射模式(伪静态):直接返回正常的 SPA 页面 + // 地址栏保持不变,面包屑显示用户访问的路径 + // 实际的路径映射由后端 API(fs/list、fs/get)在处理请求时完成 + utils.Log.Debugf("[VirtualHost] path remapping mode: serving SPA for domain=%q path=%q", domain, c.Request.URL.Path) + c.Header("Content-Type", "text/html") + c.Status(200) + _, _ = c.Writer.WriteString(conf.IndexHtml) + c.Writer.Flush() + c.Writer.WriteHeaderNow() + return + } + } + } + if c.Request.Method != "GET" && c.Request.Method != "POST" { c.Status(405) return @@ -224,5 +273,174 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { } c.Writer.Flush() c.Writer.WriteHeaderNow() + } + + // 显式注册根路径路由,确保 GET / 能被正确处理 + // gin 的 NoRoute 不会触发已注册路由前缀下的 GET / + r.GET("/", virtualHostHandler) + r.POST("/", virtualHostHandler) + // NoRoute 处理其他所有未匹配路径(如 /@manage、/d/... 等 SPA 路由) + noRoute(virtualHostHandler) +} + +// handleWebHosting 处理虚拟主机的 Web 托管请求 +// 直接将 HTML 文件内容返回给客户端,而不是走前端 SPA 路由 +// 返回 true 表示已处理,false 表示未处理(继续走默认逻辑) +func handleWebHosting(c *gin.Context, vhost *model.VirtualHost) bool { + if c.Request.Method != "GET" && c.Request.Method != "HEAD" { + utils.Log.Debugf("[VirtualHost] skip: method=%s not allowed for web hosting", c.Request.Method) + return false + } + + reqPath := c.Request.URL.Path + // Map request path into the vhost root and verify it does not escape via traversal. + // stdpath.Join calls Clean internally, which collapses ".." segments, so we only need + // to confirm the result still lives under vhost.Path. + filePath := stdpath.Join(vhost.Path, reqPath) + if !strings.HasPrefix(filePath, strings.TrimRight(vhost.Path, "/")+"/") && filePath != vhost.Path { + utils.Log.Warnf("[VirtualHost] path traversal rejected: vhost=%q reqPath=%q", vhost.Path, reqPath) + c.Status(http.StatusBadRequest) + return false + } + utils.Log.Debugf("[VirtualHost] handleWebHosting: reqPath=%q -> filePath=%q", reqPath, filePath) + + // 尝试获取文件 + obj, err := internalfs.Get(c.Request.Context(), filePath, &internalfs.GetArgs{NoLog: true}) + if err == nil && !obj.IsDir() { + // 找到文件,直接代理返回 + utils.Log.Debugf("[VirtualHost] serving file: %q", filePath) + serveWebHostingFile(c, filePath, obj.GetName()) + return true + } + utils.Log.Debugf("[VirtualHost] file not found or is dir at %q: %v", filePath, err) + + // 如果是目录或未找到,尝试 index.html + indexPath := stdpath.Join(filePath, "index.html") + obj, err = internalfs.Get(c.Request.Context(), indexPath, &internalfs.GetArgs{NoLog: true}) + if err == nil && !obj.IsDir() { + utils.Log.Debugf("[VirtualHost] serving index.html: %q", indexPath) + serveWebHostingFile(c, indexPath, "index.html") + return true + } + utils.Log.Debugf("[VirtualHost] index.html not found at %q: %v", indexPath, err) + + // 尝试 .html(SPA 友好路由) + if stdpath.Ext(reqPath) == "" && reqPath != "/" { + htmlPath := stdpath.Join(vhost.Path, reqPath+".html") + obj, err = internalfs.Get(c.Request.Context(), htmlPath, &internalfs.GetArgs{NoLog: true}) + if err == nil && !obj.IsDir() { + utils.Log.Debugf("[VirtualHost] serving .html fallback: %q", htmlPath) + serveWebHostingFile(c, htmlPath, stdpath.Base(htmlPath)) + return true + } + utils.Log.Debugf("[VirtualHost] .html fallback not found at %q: %v", htmlPath, err) + } + + utils.Log.Debugf("[VirtualHost] no file matched for reqPath=%q, falling through", reqPath) + return false +} + +// serveWebHostingFile 通过代理方式直接返回文件内容 +func serveWebHostingFile(c *gin.Context, filePath, filename string) { + link, file, err := internalfs.Link(c.Request.Context(), filePath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, }) + if err != nil { + utils.Log.Errorf("web hosting: failed to get link for %s: %v", filePath, err) + c.Status(http.StatusInternalServerError) + return + } + defer link.Close() + + // 根据文件扩展名确定正确的 Content-Type + ext := strings.ToLower(stdpath.Ext(filename)) + contentType := mimeTypeByExt(ext) + + // 使用包装的 ResponseWriter,在 WriteHeader 时强制覆盖 Content-Type 和 Content-Disposition + // 这样即使 Proxy 内部的 maps.Copy 将上游响应头复制进来,我们也能在最终发送前覆盖 + wrapped := &forceContentTypeWriter{ + ResponseWriter: c.Writer, + contentType: contentType, + contentDisp: "inline", + } + + // 同时注入到 link.Header,供 attachHeader 路径(RangeReader/Concurrency 模式)使用 + if link.Header == nil { + link.Header = make(http.Header) + } + link.Header.Set("Content-Type", contentType) + link.Header.Set("Content-Disposition", "inline") + + // 使用通用代理函数处理文件传输 + if err := common.Proxy(wrapped, c.Request, link, file); err != nil { + utils.Log.Errorf("web hosting: proxy error for %s: %v", filePath, err) + } +} + +// forceContentTypeWriter 包装 http.ResponseWriter, +// 在 WriteHeader 时强制覆盖 Content-Type 和 Content-Disposition, +// 确保 HTML 等文件以正确类型返回而不是被浏览器下载 +type forceContentTypeWriter struct { + http.ResponseWriter + contentType string + contentDisp string +} + +func (w *forceContentTypeWriter) WriteHeader(statusCode int) { + w.ResponseWriter.Header().Set("Content-Type", w.contentType) + w.ResponseWriter.Header().Set("Content-Disposition", w.contentDisp) + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *forceContentTypeWriter) Write(b []byte) (int, error) { + return w.ResponseWriter.Write(b) +} + +// mimeTypeByExt 根据文件扩展名返回 MIME 类型 +func mimeTypeByExt(ext string) string { + switch ext { + case ".html", ".htm": + return "text/html; charset=utf-8" + case ".css": + return "text/css; charset=utf-8" + case ".js", ".mjs": + return "application/javascript; charset=utf-8" + case ".json": + return "application/json; charset=utf-8" + case ".xml": + return "application/xml; charset=utf-8" + case ".svg": + return "image/svg+xml" + case ".png": + return "image/png" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".gif": + return "image/gif" + case ".webp": + return "image/webp" + case ".ico": + return "image/x-icon" + case ".woff": + return "font/woff" + case ".woff2": + return "font/woff2" + case ".ttf": + return "font/ttf" + case ".txt": + return "text/plain; charset=utf-8" + default: + return "application/octet-stream" + } +} + +// stripHostPort removes the port from a host string (supports IPv4, IPv6, and bracketed IPv6). +func stripHostPort(host string) string { + h, _, err := stdnet.SplitHostPort(host) + if err != nil { + // No port present; return host as-is + return host + } + return h }