From 18fc292de5b963e71583a7a74f1456a37f7a1423 Mon Sep 17 00:00:00 2001 From: shijie <810833920@qq.com> Date: Tue, 6 Jan 2026 10:40:21 +0800 Subject: [PATCH] [feat][meta] Modelscope download and access --- cmd/wire_gen.go | 4 +- config/config.yaml | 9 + internal/handler/handler.go | 2 +- internal/handler/modelscope_handler.go | 88 +++++ internal/router/http_router.go | 34 +- internal/server/http.go | 1 + internal/service/modelscope_service.go | 484 +++++++++++++++++++++++++ internal/service/service.go | 2 +- pkg/config/config.go | 12 + pkg/middleware/queue_limit.go | 21 ++ pkg/util/modelscope_util.go | 212 +++++++++++ pkg/util/repo_util.go | 12 +- 12 files changed, 864 insertions(+), 17 deletions(-) create mode 100644 internal/handler/modelscope_handler.go create mode 100644 internal/service/modelscope_service.go create mode 100644 pkg/util/modelscope_util.go diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index c72e2d9..c80b519 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -40,7 +40,9 @@ func wireApp(configConfig *config.Config) (*app.App, func(), error) { sysHandler := handler.NewSysHandler(sysService) cacheJobService := service.NewCacheJobService(fileDao, metaDao, downloaderDao, schedulerDao) cacheJobHandler := handler.NewCacheJobHandler(cacheJobService) - httpRouter := router.NewHttpRouter(echo, fileHandler, metaHandler, sysHandler, cacheJobHandler) + modelscopeService := service.NewModelscopeService() + modelscopeHandler := handler.NewModelscopeHandler(modelscopeService) + httpRouter := router.NewHttpRouter(echo, fileHandler, metaHandler, sysHandler, cacheJobHandler, modelscopeHandler) httpServer := server.NewServer(configConfig, echo, httpRouter) schedulerService := service.NewSchedulerService(schedulerDao) schedulerServer := server.NewSchedulerServer(schedulerService, sysService, localOperationService) diff --git a/config/config.yaml b/config/config.yaml index 948bd7c..822a512 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -75,3 +75,12 @@ dynamicProxy: timePeriod: 60 #定期检测代理是否可用时间周期,单位秒(S) maxContinuousFails: 5 #连续失败次数超过该值,则认为代理不可用 webhook: https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=73662ac1-1055-48a7-8c89-37964b5f4fdc111 # 企业微信机器人Webhook地址 + +modelscope: + modelCacheRoot: ./repos/modelscope/models # 模型缓存根目录 + datasetCacheRoot: ./repos/modelscope/datasets # 数据集缓存根目录 + officialBaseURL: https://www.modelscope.cn # ModelScope官方基础地址 + chunkSize: 8388608 # 8MB分块,16*1024*1024的数值结果 + maxRetry: 5 # 超时重试次数 + retryDelay: 3 # 重试间隔,单位秒(S)(原配置为5*time.Second,YAML中简化为数值+注释) + minFileSize: 1 # 最小缓存文件大小,单位字节(B) \ No newline at end of file diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 4dafbe9..7726302 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -18,4 +18,4 @@ import ( "github.com/google/wire" ) -var HandlerProvider = wire.NewSet(NewFileHandler, NewMetaHandler, NewSysHandler, NewCacheJobHandler) +var HandlerProvider = wire.NewSet(NewFileHandler, NewMetaHandler, NewSysHandler, NewCacheJobHandler, NewModelscopeHandler) diff --git a/internal/handler/modelscope_handler.go b/internal/handler/modelscope_handler.go new file mode 100644 index 0000000..00cd15a --- /dev/null +++ b/internal/handler/modelscope_handler.go @@ -0,0 +1,88 @@ +package handler + +import ( + "strings" + + "dingospeed/internal/service" + "dingospeed/pkg/util" + + "github.com/labstack/echo/v4" +) + +// ModelscopeHandler 模型代理请求处理器 +type ModelscopeHandler struct { + ModelscopeService *service.ModelscopeService +} + +// NewModelscopeHandler 创建模型代理处理器实例 +func NewModelscopeHandler(ModelscopeService *service.ModelscopeService) *ModelscopeHandler { + return &ModelscopeHandler{ + ModelscopeService: ModelscopeService, + } +} + +// ModelInfoHandler 处理模型信息查询请求 +func (h *ModelscopeHandler) ModelInfoHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := h.ModelscopeService.ForwardModelInfo(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// RevisionsHandler 处理模型版本查询请求 +func (h *ModelscopeHandler) RevisionsHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := h.ModelscopeService.ForwardRevisions(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// FileListHandler 处理模型文件列表请求 +func (h *ModelscopeHandler) FileListHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := h.ModelscopeService.ForwardFileList(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// FileDownloadHandler 处理模型文件下载请求(支持续传) +func (h *ModelscopeHandler) FileDownloadHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := h.ModelscopeService.HandleFileDownload(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// FileTreeHandler 处理数据集文件列表请求 +func (h *ModelscopeHandler) FileTreeHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + org, repo, repoType := parts[3], parts[4], parts[2] + if err := h.ModelscopeService.ForwardRepoTree(c, org, repo, repoType); err != nil { + return util.ResponseError(c, err) + } + return nil +} + +// DatasetFileTreeHandler 处理数据集文件列表请求 +func (h *ModelscopeHandler) DatasetFileTreeHandler(c echo.Context) error { + parts := strings.Split(strings.Trim(c.Request().URL.Path, "/"), "/") + + datasetId := parts[3] + if err := h.ModelscopeService.ForwardRepoTreeByDatasetId(c, datasetId); err != nil { + return util.ResponseError(c, err) + } + return nil +} diff --git a/internal/router/http_router.go b/internal/router/http_router.go index 360fd45..eb2b178 100644 --- a/internal/router/http_router.go +++ b/internal/router/http_router.go @@ -23,21 +23,23 @@ import ( ) type HttpRouter struct { - echo *echo.Echo - fileHandler *handler.FileHandler - metaHandler *handler.MetaHandler - sysHandler *handler.SysHandler - cacheJobHandler *handler.CacheJobHandler + echo *echo.Echo + fileHandler *handler.FileHandler + metaHandler *handler.MetaHandler + sysHandler *handler.SysHandler + cacheJobHandler *handler.CacheJobHandler + modelscopeHandler *handler.ModelscopeHandler } func NewHttpRouter(echo *echo.Echo, fileHandler *handler.FileHandler, metaHandler *handler.MetaHandler, - sysHandler *handler.SysHandler, cacheJobHandler *handler.CacheJobHandler) *HttpRouter { + sysHandler *handler.SysHandler, cacheJobHandler *handler.CacheJobHandler, modelscopeHandler *handler.ModelscopeHandler) *HttpRouter { r := &HttpRouter{ - echo: echo, - fileHandler: fileHandler, - metaHandler: metaHandler, - sysHandler: sysHandler, - cacheJobHandler: cacheJobHandler, + echo: echo, + fileHandler: fileHandler, + metaHandler: metaHandler, + sysHandler: sysHandler, + cacheJobHandler: cacheJobHandler, + modelscopeHandler: modelscopeHandler, } r.initRouter() return r @@ -54,6 +56,7 @@ func (r *HttpRouter) initRouter() { r.routerForCacheJob() r.routerForSpeed() + r.routerForModelscope() } func (r *HttpRouter) routerForSpeed() { // alayanew @@ -91,3 +94,12 @@ func (r *HttpRouter) routerForCacheJob() { // alayanew r.echo.POST("/api/cacheJob/resume", r.cacheJobHandler.ResumeCacheJobHandler) r.echo.POST("/api/cacheJob/realtime", r.cacheJobHandler.RealtimeCacheJobHandler) } + +func (r *HttpRouter) routerForModelscope() { // modelscope + r.echo.GET("/api/v1/:repoType/:org/:repo", r.modelscopeHandler.ModelInfoHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/revisions", r.modelscopeHandler.RevisionsHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/repo/files", r.modelscopeHandler.FileListHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/repo", r.modelscopeHandler.FileDownloadHandler) + r.echo.GET("/api/v1/:repoType/:org/:repo/repo/tree", r.modelscopeHandler.FileTreeHandler) + r.echo.GET("/api/v1/datasets/:datasetId/repo/tree", r.modelscopeHandler.DatasetFileTreeHandler) +} diff --git a/internal/server/http.go b/internal/server/http.go index 4b8f290..9f72025 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -80,6 +80,7 @@ func NewEngine() *echo.Echo { r := echo.New() middleware.InitMiddlewareConfig() r.Use(middleware.QueueLimitMiddleware) + r.Use(middleware.CORSMiddleware()) t := &Template{ templates: template.Must(template.ParseFS(templatesFS, "templates/*.html")), diff --git a/internal/service/modelscope_service.go b/internal/service/modelscope_service.go new file mode 100644 index 0000000..1499685 --- /dev/null +++ b/internal/service/modelscope_service.go @@ -0,0 +1,484 @@ +package service + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + + "dingospeed/pkg/config" + "dingospeed/pkg/util" + + "github.com/labstack/echo/v4" + "go.uber.org/zap" +) + +type ModelscopeService struct{} + +func NewModelscopeService() *ModelscopeService { + return &ModelscopeService{} +} + +func (s *ModelscopeService) ForwardModelInfo(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + + zap.S().Infof("转发%s信息请求到官方: %s", apiPrefix, officialURL) + return s.forwardRequest(c, officialURL) +} + +func (s *ModelscopeService) ForwardRevisions(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/revisions?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + + zap.S().Infof("转发%s版本请求到官方: %s", apiPrefix, officialURL) + return s.forwardRequest(c, officialURL) +} + +func (s *ModelscopeService) ForwardFileList(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/repo/files?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + + zap.S().Infof("转发%s文件列表请求到官方: %s", apiPrefix, officialURL) + return s.forwardRequest(c, officialURL) +} + +func (s *ModelscopeService) ForwardRepoTree(c echo.Context, owner, repo string, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/repo/tree?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + c.Request().URL.RawQuery) + + zap.S().Infof("转发%s文件树请求到官方: %s", apiPrefix, officialURL) + return s.forwardRequest(c, officialURL) +} + +func (s *ModelscopeService) ForwardRepoTreeByDatasetId(c echo.Context, datasetId string) error { + officialURL := fmt.Sprintf("%s/api/v1/datasets/%s/repo/tree?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + url.PathEscape(datasetId), + c.Request().URL.RawQuery) + + zap.S().Infof("转发文件树请求到官方: %s", officialURL) + return s.forwardRequest(c, officialURL) +} + +// HandleFileDownload 处理ModelScope文件下载请求 +func (s *ModelscopeService) HandleFileDownload(c echo.Context, owner, repo, repoType string) error { + repoId := fmt.Sprintf("%s/%s", owner, repo) + revision := c.Request().URL.Query().Get("Revision") + filePath := c.Request().URL.Query().Get("FilePath") + + if revision == "" { + revision = "master" + } + if filePath == "" { + zap.S().Error("请求参数缺失: FilePath为空") + return c.JSON(http.StatusBadRequest, map[string]string{ + "code": "400", + "error": "missing FilePath parameter", + }) + } + + msCfg := config.SysConfig.Modelscope + if err := util.EnsureDir(filepath.Join(msCfg.ModelCacheRoot, "dummy")); err != nil { + zap.S().Errorf("初始化模型缓存根目录失败: %v", err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "init model cache root dir failed", + "msg": err.Error(), + }) + } + + cachePath, cacheExists := util.GetCachePath(repoType, repoId, revision, filePath) + if cachePath == "" { + zap.S().Errorf("生成缓存路径失败: 无效的repoId格式 %s", repoId) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "get cache path failed", + "msg": "invalid repoId format, require org/repo", + }) + } + zap.S().Infof("生成缓存路径: %s (缓存文件是否存在: %t)", cachePath, cacheExists) + + // 创建缓存文件上级目录(util.EnsureDir需要传入文件路径,以创建其上级目录) + if err := util.EnsureDir(cachePath); err != nil { + zap.S().Errorf("初始化缓存文件上级目录失败: %s, err: %v", filepath.Dir(cachePath), err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "init cache file parent dir failed", + "msg": err.Error(), + }) + } + + cachedSize := util.GetFileSize(cachePath) + zap.S().Infof("缓存文件状态: %s (已下载: %d字节)", cachePath, cachedSize) + if cacheExists && cachedSize == 0 { + zap.S().Warnf("缓存文件存在但大小为0,视为无效缓存: %s", cachePath) + if err := os.Remove(cachePath); err != nil { + zap.S().Errorf("删除空缓存文件失败: %s, err: %v", cachePath, err) + } + cachedSize = 0 + cacheExists = false + } + + zap.S().Infof("缓存文件状态: %s (已下载: %d字节)", cachePath, cachedSize) + + clientStart, clientEnd, err := util.ParseRangeHeader(c.Request()) + if err != nil { + zap.S().Errorf("解析Range失败: %v", err) + return c.JSON(http.StatusBadRequest, map[string]string{ + "code": "400", + "error": "parse Range header failed", + "msg": err.Error(), + }) + } + + actualStart := clientStart + if cachedSize > 0 && actualStart < cachedSize { + actualStart = cachedSize + } + zap.S().Infof("续传起始位置: 客户端请求=%d, 缓存末尾=%d, 实际起始=%d", clientStart, cachedSize, actualStart) + + headerWritten := false + c.Response().Header().Set("Transfer-Encoding", "chunked") + c.Response().Header().Set("Content-Type", "application/octet-stream") + c.Response().Header().Set("Access-Control-Expose-Headers", "Content-Range, Content-Type") + + var cacheWritten int64 = 0 + if cachedSize > 0 && clientStart < cachedSize { + cacheWritten, headerWritten, err = s.writeCacheData(c, cachePath, clientStart, clientEnd, cachedSize, headerWritten) + if err != nil { + zap.S().Errorf("写入缓存数据失败: %v", err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "write cache data failed", + "msg": err.Error(), + }) + } + + if clientEnd != -1 && (clientStart+cacheWritten-1) >= clientEnd { + zap.S().Infof("缓存数据已满足客户端Range请求,无需续传") + return nil + } + } + + if err := s.downloadAndWriteRemaining(c, owner, repo, actualStart, clientEnd, cachePath, headerWritten, repoType); err != nil { + return err + } + + return nil +} + +// forwardRequest 通用请求转发逻辑 +func (s *ModelscopeService) forwardRequest(c echo.Context, officialURL string) error { + req, err := http.NewRequest(http.MethodGet, officialURL, nil) + if err != nil { + zap.S().Errorf("构建请求失败: %v", err) + return err + } + + util.AddCLIHeaders(req.Header, c.Request().Header.Get("User-Agent")) + + for k, v := range c.Request().Header { + req.Header[k] = v + } + + resp, err := util.DoRequestWithRetry(req) + if err != nil { + zap.S().Errorf("转发请求失败: %v", err) + return err + } + defer resp.Body.Close() + + for k, v := range resp.Header { + c.Response().Header()[k] = v + } + c.Response().WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Response(), resp.Body) + if err != nil { + zap.S().Errorf("复制响应体失败: %v", err) + return err + } + return nil +} + +// writeCacheData 写入缓存中的数据到响应 +func (s *ModelscopeService) writeCacheData(c echo.Context, cachePath string, clientStart, clientEnd, cachedSize int64, headerWritten bool) (int64, bool, error) { + cacheFile, err := os.Open(cachePath) + if err != nil { + zap.S().Errorf("打开缓存文件失败: %s, err: %v", cachePath, err) + return 0, headerWritten, fmt.Errorf("open cache file failed: %w", err) + } + defer cacheFile.Close() + + if _, err := cacheFile.Seek(clientStart, io.SeekStart); err != nil { + zap.S().Errorf("定位缓存文件失败: %s, err: %v", cachePath, err) + return 0, headerWritten, fmt.Errorf("seek cache file failed: %w", err) + } + + cacheEnd := cachedSize - 1 + if clientEnd != -1 && clientEnd < cacheEnd { + cacheEnd = clientEnd + } + cacheResponseSize := cacheEnd - clientStart + 1 + if cacheResponseSize <= 0 { + zap.S().Warnf("缓存响应大小无效: %d (clientStart: %d, cacheEnd: %d)", cacheResponseSize, clientStart, cacheEnd) + return 0, headerWritten, nil + } + + if !headerWritten { + contentRange := fmt.Sprintf("bytes %d-%d/%d", clientStart, cacheEnd, cachedSize) + c.Response().Header().Set("Content-Range", contentRange) + c.Response().WriteHeader(http.StatusPartialContent) + headerWritten = true + zap.S().Infof("设置缓存响应头: Content-Range=%s", contentRange) + } + + buf := make([]byte, config.SysConfig.Modelscope.ChunkSize) + written := int64(0) + for written < cacheResponseSize { + if c.Request().Context().Err() != nil { + zap.S().Warnf("客户端断开连接,停止返回缓存数据: %s", cachePath) + return written, headerWritten, nil + } + + readSize := cacheResponseSize - written + if readSize > int64(len(buf)) { + readSize = int64(len(buf)) + } + + n, err := cacheFile.Read(buf[:readSize]) + if n > 0 { + if _, writeErr := c.Response().Write(buf[:n]); writeErr != nil { + zap.S().Errorf("返回缓存数据失败: %s, err: %v", cachePath, writeErr) + return written, headerWritten, fmt.Errorf("write cache data to response failed: %w", writeErr) + } + written += int64(n) + + if f, ok := c.Response().Writer.(http.Flusher); ok { + f.Flush() + } + } + + if err == io.EOF { + zap.S().Infof("缓存文件读取到EOF: %s, 已读取%d字节", cachePath, written) + break + } + if err != nil { + zap.S().Errorf("读取缓存数据失败: %s, err: %v", cachePath, err) + return written, headerWritten, fmt.Errorf("read cache file failed: %w", err) + } + } + + zap.S().Infof("✅ 返回缓存数据完成: %s, 范围%d-%d (共%d字节)", cachePath, clientStart, cacheEnd, written) + return written, headerWritten, nil +} + +// downloadAndWriteRemaining 下载剩余部分并写入响应+缓存 +func (s *ModelscopeService) downloadAndWriteRemaining(c echo.Context, owner, repo string, actualStart, clientEnd int64, cachePath string, headerWritten bool, repoType string) error { + apiPrefix := util.GetAPIPathPrefix(repoType) + query := c.Request().URL.RawQuery + officialURL := fmt.Sprintf("%s/api/v1/%s/%s/%s/repo?%s", + config.SysConfig.Modelscope.OfficialBaseURL, + apiPrefix, + url.PathEscape(owner), + url.PathEscape(repo), + query, + ) + zap.S().Infof("请求ModelScope官方地址: %s", officialURL) + + req, err := http.NewRequest(http.MethodGet, officialURL, nil) + if err != nil { + zap.S().Errorf("构建请求失败: %v", err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "build request failed", + "msg": err.Error(), + }) + } + + skipHeaders := map[string]bool{ + "Range": true, + "User-Agent": true, + "Host": true, + } + for k, v := range c.Request().Header { + key := strings.ToLower(k) + if !skipHeaders[key] { + req.Header[k] = v + } + } + + // 设置ModelScope CLI头信息 + util.AddCLIHeaders(req.Header, c.Request().Header.Get("User-Agent")) + + // 设置Range头,仅下载剩余部分 + rangeHeader := fmt.Sprintf("bytes=%d-", actualStart) + if clientEnd != -1 { + rangeHeader = fmt.Sprintf("bytes=%d-%d", actualStart, clientEnd) + } + req.Header.Set("Range", rangeHeader) + zap.S().Infof("向官方请求剩余部分: %s", rangeHeader) + + resp, err := util.DoRequestWithRetry(req) + if err != nil { + zap.S().Errorf("下载剩余部分失败: %v", err) + return c.JSON(http.StatusBadGateway, map[string]string{ + "code": "502", + "error": "download remaining failed", + "msg": err.Error(), + }) + } + defer resp.Body.Close() + + // 校验ModelScope官方响应状态码 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + zap.S().Errorf("ModelScope返回错误状态码: %d, URL: %s", resp.StatusCode, officialURL) + errorMsg := fmt.Sprintf("modelscope server return status code: %d", resp.StatusCode) + switch resp.StatusCode { + case http.StatusNotFound: + return c.JSON(http.StatusNotFound, map[string]string{ + "code": "404", + "error": "resource not found", + "msg": "model or file does not exist on ModelScope", + }) + case http.StatusForbidden: + return c.JSON(http.StatusForbidden, map[string]string{ + "code": "403", + "error": "forbidden", + "msg": "no permission to access the resource", + }) + default: + return c.JSON(http.StatusBadGateway, map[string]string{ + "code": "502", + "error": "modelscope server error", + "msg": errorMsg, + }) + } + } + + totalFileSize := int64(-1) + contentRange := resp.Header.Get("Content-Range") + if contentRange != "" { + parts := strings.Split(contentRange, "/") + if len(parts) == 2 { + parsedSize, err := strconv.ParseInt(parts[1], 10, 64) + if err == nil { + totalFileSize = parsedSize + } else { + zap.S().Warnf("解析Content-Range失败: %s, err: %v", contentRange, err) + } + } + } + + cacheFile, err := os.OpenFile(cachePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0664) + if err != nil { + zap.S().Errorf("打开缓存文件失败: %s, err: %v", cachePath, err) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "open cache file failed", + "msg": err.Error(), + }) + } + defer cacheFile.Close() + + if !headerWritten { + if resp.StatusCode == http.StatusPartialContent { + c.Response().Header().Set("Content-Range", resp.Header.Get("Content-Range")) + c.Response().WriteHeader(http.StatusPartialContent) + } else { + c.Response().WriteHeader(http.StatusOK) + } + headerWritten = true + zap.S().Infof("设置续传响应头,状态码: %d", resp.StatusCode) + } + + buf := make([]byte, config.SysConfig.Modelscope.ChunkSize) + written := int64(0) + for { + if c.Request().Context().Err() != nil { + zap.S().Warnf("客户端断开连接,停止续传: %s", cachePath) + return nil + } + + n, err := resp.Body.Read(buf) + if n > 0 { + if _, writeErr := cacheFile.Write(buf[:n]); writeErr != nil { + zap.S().Errorf("写入缓存失败: %s, err: %v", cachePath, writeErr) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "write cache failed", + "msg": writeErr.Error(), + }) + } + + if _, writeErr := c.Response().Write(buf[:n]); writeErr != nil { + if strings.Contains(writeErr.Error(), "http2: stream closed") || + strings.Contains(writeErr.Error(), "broken pipe") || + strings.Contains(writeErr.Error(), "connection reset by peer") { + zap.S().Warnf("客户端断开连接,停止返回续传数据: %s, err: %v", cachePath, writeErr) + return nil + } + zap.S().Errorf("返回续传数据失败: %s, err: %v", cachePath, writeErr) + return c.JSON(http.StatusInternalServerError, map[string]string{ + "code": "500", + "error": "write response failed", + "msg": writeErr.Error(), + }) + } + + written += int64(n) + if f, ok := c.Response().Writer.(http.Flusher); ok { + f.Flush() + } + + if written%(100*1024*1024) == 0 { + zap.S().Infof("续传进度: %dMB, 文件: %s", written/(1024*1024), cachePath) + } + } + + // 处理读取错误 + if err == io.EOF { + // 刷盘确保数据写入磁盘 + if syncErr := cacheFile.Sync(); syncErr != nil { + zap.S().Warnf("缓存文件刷盘失败: %s, err: %v", cachePath, syncErr) + } + zap.S().Infof("续传完成: %s, 共下载%d字节,完整文件%d字节", cachePath, written, totalFileSize) + break + } + if err != nil { + zap.S().Errorf("续传中断: %s, err: %v", cachePath, err) + return c.JSON(http.StatusBadGateway, map[string]string{ + "code": "502", + "error": "download interrupted", + "msg": err.Error(), + }) + } + } + + return nil +} diff --git a/internal/service/service.go b/internal/service/service.go index d26f4bd..bbb41a7 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -16,4 +16,4 @@ package service import "github.com/google/wire" -var ServiceProvider = wire.NewSet(NewFileService, NewMetaService, NewSysService, NewSchedulerService, NewCacheJobService, NewLocalOperationService) +var ServiceProvider = wire.NewSet(NewFileService, NewMetaService, NewSysService, NewSchedulerService, NewCacheJobService, NewLocalOperationService, NewModelscopeService) diff --git a/pkg/config/config.go b/pkg/config/config.go index 8fdb763..c518a6a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -46,6 +46,7 @@ type Config struct { DynamicProxy DynamicProxy `json:"dynamicProxy" yaml:"dynamicProxy"` Scheduler Scheduler `json:"scheduler" yaml:"scheduler"` mu sync.RWMutex + Modelscope Modelscope `yaml:"modelscope"` } type ServerConfig struct { @@ -154,6 +155,17 @@ type DynamicProxy struct { Webhook string `json:"webhook " yaml:"webhook"` } +type Modelscope struct { + ProxyPort string `yaml:"proxyPort"` + ModelCacheRoot string `yaml:"modelCacheRoot"` + DatasetCacheRoot string `yaml:"datasetCacheRoot"` + OfficialBaseURL string `yaml:"officialBaseURL"` + ChunkSize int64 `yaml:"chunkSize"` + MaxRetry int `yaml:"maxRetry"` + RetryDelay int `yaml:"retryDelay"` + MinFileSize int64 `yaml:"minFileSize"` +} + func (c *Config) GetHFURLBase() string { return fmt.Sprintf("%s://%s", c.GetHfScheme(), c.GetHfNetLoc()) } diff --git a/pkg/middleware/queue_limit.go b/pkg/middleware/queue_limit.go index d6a001a..2aa7894 100644 --- a/pkg/middleware/queue_limit.go +++ b/pkg/middleware/queue_limit.go @@ -2,6 +2,7 @@ package middleware import ( "net" + "net/http" "strings" "dingospeed/pkg/config" @@ -75,3 +76,23 @@ func nextRequest(c echo.Context, next echo.HandlerFunc) error { return util.ErrorTooManyRequest(c) } } + +// CORSMiddleware 跨域中间件(适配Echo框架) +func CORSMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // 设置跨域头 + c.Response().Header().Set("Access-Control-Allow-Origin", "*") + c.Response().Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS, HEAD") + c.Response().Header().Set("Access-Control-Allow-Headers", "*") + c.Response().Header().Set("Access-Control-Expose-Headers", "*") + + // 处理OPTIONS预检请求 + if c.Request().Method == http.MethodOptions { + return c.NoContent(http.StatusOK) + } + + return next(c) + } + } +} diff --git a/pkg/util/modelscope_util.go b/pkg/util/modelscope_util.go new file mode 100644 index 0000000..f50a988 --- /dev/null +++ b/pkg/util/modelscope_util.go @@ -0,0 +1,212 @@ +package util + +import ( + "crypto/tls" + "fmt" + "math/rand" + "net/http" + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "time" + + "dingospeed/pkg/config" + + "go.uber.org/zap" +) + +// 提取ModelScope版本和Python版本的正则表达式 +var ( + msVersionRegex = regexp.MustCompile(`modelscope/(\d+\.\d+\.\d+)`) + pyVersionRegex = regexp.MustCompile(`python/(\d+\.\d+\.\d+)`) +) + +// 生成请求ID的辅助函数 +func generateReqID() string { + return "req-" + string(rand.Int63()) +} + +// ParseClientEnv 增强:从客户端UA提取版本+返回真实系统架构 +func ParseClientEnv(clientUA string) (msVersion, system, arch, pythonVer string) { + // 1. 提取ModelScope版本 + if msMatch := msVersionRegex.FindStringSubmatch(clientUA); len(msMatch) > 1 { + msVersion = msMatch[1] + } else { + msVersion = "1.33.0" // 默认使用客户端主流版本 + } + + // 2. 提取Python版本 + if pyMatch := pyVersionRegex.FindStringSubmatch(clientUA); len(pyMatch) > 1 { + pythonVer = pyMatch[1] + } else { + pythonVer = "3.13.2" // 默认使用客户端主流版本 + } + + // 3. 获取服务端真实系统/架构(替代Unknown) + system = runtime.GOOS + switch system { + case "darwin": + system = "macOS" + case "windows": + system = "Windows" + case "linux": + system = "Linux" + } + + arch = runtime.GOARCH + switch arch { + case "amd64": + arch = "x86_64" + case "arm64": + arch = "aarch64" + } + + return +} + +// AddCLIHeaders 修复:从客户端UA提取版本,构建兼容的User-Agent +func AddCLIHeaders(header http.Header, clientUA string) { + msVersion, system, arch, pythonVer := ParseClientEnv(clientUA) + + userAgent := fmt.Sprintf("modelscope/%s (%s; %s) Python/%s", msVersion, system, arch, pythonVer) + header.Set("User-Agent", userAgent) + zap.S().Infof("构建兼容的 User-Agent: %s (客户端原始 UA: %s)", userAgent, clientUA) + + header.Set("Accept-Encoding", "identity") + if header.Get("X-Request-ID") == "" { + reqID := generateReqID() + header.Set("X-Request-ID", reqID) + } +} + +// EnsureDir 确保目录存在 +func EnsureDir(path string) error { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + zap.S().Errorf("创建目录失败: %s, 错误: %v", filepath.Dir(path), err) + return err + } + return nil +} + +func GetCachePath(repoType, repoId, revision, filePath string) (string, bool) { + parts := strings.Split(repoId, "/") + if len(parts) != 2 { + zap.S().Errorf("无效的repoId格式: %s,需为 org/repo 格式", repoId) + return "", false + } + + var cacheRoot string + switch repoType { + case "datasets": + cacheRoot = config.SysConfig.Modelscope.DatasetCacheRoot + case "models", "": // 兼容原逻辑,空值默认走 models + cacheRoot = config.SysConfig.Modelscope.ModelCacheRoot + default: + zap.S().Warnf("未知的repoType: %s,默认使用models缓存目录", repoType) + cacheRoot = config.SysConfig.Modelscope.ModelCacheRoot + } + + targetCachePath := filepath.Join(cacheRoot, parts[0], parts[1], revision, filepath.Clean(filePath)) + fileInfo, err := os.Stat(targetCachePath) + if err == nil { + zap.S().Debugf("缓存文件存在: %s, 大小: %d字节", targetCachePath, fileInfo.Size()) + return targetCachePath, true + } + + _ = EnsureDir(targetCachePath) + return targetCachePath, false +} + +// createHTTPClient 创建宽松超时的HTTP客户端 +func CreateHTTPClient() *http.Client { + return &http.Client{ + Timeout: 30 * time.Minute, + Transport: &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 5 * time.Minute, + DisableCompression: true, + MaxConnsPerHost: 2, + DisableKeepAlives: false, + TLSHandshakeTimeout: 2 * time.Minute, + ResponseHeaderTimeout: 5 * time.Minute, + ExpectContinueTimeout: 1 * time.Minute, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS13, + InsecureSkipVerify: true, + Renegotiation: tls.RenegotiateFreelyAsClient, + }, + }, + } +} + +// DoRequestWithRetry 带重试的HTTP请求 +func DoRequestWithRetry(req *http.Request) (*http.Response, error) { + client := CreateHTTPClient() + var resp *http.Response + var err error + + for i := 0; i < config.SysConfig.Modelscope.MaxRetry; i++ { + resp, err = client.Do(req) + if err == nil { + return resp, nil + } + + if strings.Contains(err.Error(), "timeout") || strings.Contains(err.Error(), "deadline exceeded") { + zap.S().Warnf("⚠️ Retry %d/%d: request timeout - %v", i+1, config.SysConfig.Modelscope.MaxRetry, err) + time.Sleep(time.Duration(config.SysConfig.Modelscope.RetryDelay) * time.Duration(i+1)) + continue + } + + return nil, err + } + + return nil, fmt.Errorf("failed after %d retries: %v", config.SysConfig.Modelscope.MaxRetry, err) +} + +// ParseRangeHeader 解析Range请求头,返回起始字节和结束字节(-1表示到末尾) +func ParseRangeHeader(r *http.Request) (start int64, end int64, err error) { + rangeHeader := r.Header.Get("Range") + if rangeHeader == "" { + return 0, -1, nil + } + + // 解析Range头格式:bytes=start-end + parts := strings.SplitN(rangeHeader, "=", 2) + if len(parts) != 2 || parts[0] != "bytes" { + return 0, -1, fmt.Errorf("invalid Range header: %s", rangeHeader) + } + + rangeParts := strings.SplitN(parts[1], "-", 2) + start, err = strconv.ParseInt(rangeParts[0], 10, 64) + if err != nil { + return 0, -1, fmt.Errorf("invalid start byte: %s, err: %v", rangeParts[0], err) + } + + if len(rangeParts) == 2 && rangeParts[1] != "" { + end, err = strconv.ParseInt(rangeParts[1], 10, 64) + if err != nil { + return 0, -1, fmt.Errorf("invalid end byte: %s, err: %v", rangeParts[1], err) + } + } else { + end = -1 + } + + return start, end, nil +} + +func GetAPIPathPrefix(repoType string) string { + repoType = strings.TrimSpace(strings.ToLower(repoType)) + switch repoType { + case "dataset", "datasets": + return "datasets" + case "model", "models": + return "models" + default: + zap.S().Warnf("无效的repoType: %s,默认使用models", repoType) + return "models" + } +} diff --git a/pkg/util/repo_util.go b/pkg/util/repo_util.go index 3c65585..cc18f02 100644 --- a/pkg/util/repo_util.go +++ b/pkg/util/repo_util.go @@ -30,6 +30,7 @@ import ( "dingospeed/pkg/common" "github.com/bytedance/sonic" + "go.uber.org/zap" "golang.org/x/sys/unix" ) @@ -141,11 +142,16 @@ func IsFile(path string) bool { // GetFileSize 获取文件大小 func GetFileSize(path string) int64 { - fh, err := os.Stat(path) + fileInfo, err := os.Stat(path) if err != nil { - fmt.Printf("读取文件%s失败, err: %s\n", path, err) + if os.IsNotExist(err) { + zap.S().Infof("文件不存在: %s", path) + return 0 + } + zap.S().Errorf("读取文件大小失败: %s, err: %v", path, err) + return 0 } - return fh.Size() + return fileInfo.Size() } func ReadDir(dir string) ([]string, error) {