diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index b602d96..88d750d 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -41,14 +41,15 @@ func wireApp(configConfig *config.Config) (*app.App, func(), error) { schedulerService := service.NewSchedulerService(baseData, dingospeedDao, modelFileRecordDao, modelFileProcessDao, repositoryDao, cacheJobDao) repositoryService := service.NewRepositoryService(dingospeedDao, repositoryDao, baseData, organizationDao, tagDao, hfTokenDao) hfTokenService := service.NewHfTokenService(hfTokenDao) - managerHandler := handler.NewManagerHandler(schedulerService, repositoryService, hfTokenService) - sysService := service.NewSysService(repositoryDao) + lockDao := dao.NewLockDao(baseData) + cacheJobService := service.NewCacheJobService(dingospeedDao, modelFileProcessDao, cacheJobDao, hfTokenDao, lockDao) + managerService := service.NewManagerService(repositoryDao, repositoryService, cacheJobDao, cacheJobService) + managerHandler := handler.NewManagerHandler(schedulerService, repositoryService, hfTokenService, managerService) + sysService := service.NewSysService(repositoryDao, cacheJobDao) sysHandler := handler.NewSysHandler(sysService) repositoryHandler := handler.NewRepositoryHandler(repositoryService) tagService := service.NewTagService(tagDao) tagHandler := handler.NewTagHandler(tagService) - lockDao := dao.NewLockDao(baseData) - cacheJobService := service.NewCacheJobService(dingospeedDao, modelFileProcessDao, cacheJobDao, hfTokenDao, lockDao) cacheJobHandler := handler.NewCacheJobHandler(cacheJobService) httpRouter := router.NewHttpRouter(echo, managerHandler, sysHandler, repositoryHandler, tagHandler, cacheJobHandler) httpServer := server.NewHTTPServer(configConfig, httpRouter) diff --git a/internal/dao/cache_job_dao.go b/internal/dao/cache_job_dao.go index 23497a1..e731120 100644 --- a/internal/dao/cache_job_dao.go +++ b/internal/dao/cache_job_dao.go @@ -21,7 +21,6 @@ import ( "dingoscheduler/internal/data" "dingoscheduler/internal/model" "dingoscheduler/internal/model/query" - "dingoscheduler/pkg/common" "dingoscheduler/pkg/consts" "dingoscheduler/pkg/util" @@ -78,38 +77,6 @@ func (c *CacheJobDao) GetCacheJob(condition *query.CacheJobQuery) (*model.CacheJ return nil, nil } -func (c *CacheJobDao) RemoteRequestPathsInfo(domain, dataType, org, repo, revision, token string, fileNames []string) ([]common.PathsInfo, error) { - var reqUri = "/api/getPathInfo" - headers := map[string]string{} - if token != "" { - headers["authorization"] = fmt.Sprintf("Bearer %s", token) - } - query := query.PathInfoQuery{ - Datatype: dataType, - Org: org, - Repo: repo, - Revision: revision, - Token: token, - FileNames: fileNames, - } - b, err := sonic.Marshal(query) - if err != nil { - return nil, err - } - response, err := util.RetryRequest(func() (*common.Response, error) { - return util.PostForDomain(domain, reqUri, "application/json", b, headers) - }) - if err != nil { - return nil, err - } - ret := make([]common.PathsInfo, 0) - err = sonic.Unmarshal(response.Body, &ret) - if err != nil { - return nil, err - } - return ret, nil -} - func (c *CacheJobDao) UpdateCacheStatus(statusReq *query.UpdateJobStatusReq) error { var ( newMsgStr string @@ -142,7 +109,7 @@ func (c *CacheJobDao) UpdateStatusAndRepo(jobStatusReq *query.UpdateJobStatusReq if err != nil { return err } - if jobStatusReq.Status == consts.StatusCacheJobComplete { + if jobStatusReq.Status == consts.RunningStatusJobComplete { err = c.repositoryDao.PersistRepo(&query.PersistRepoReq{InstanceIds: []string{jobStatusReq.InstanceId}, Org: jobStatusReq.Org, Repo: jobStatusReq.Repo, OffVerify: true}) if err != nil { @@ -158,15 +125,6 @@ func (c *CacheJobDao) Delete(id int64) error { return nil } -func (c *CacheJobDao) UpdateMountCachePid(mountCachePidReq *query.UpdateMountCachePidReq) error { - sql := fmt.Sprintf("UPDATE mount_cache_job SET shell_pid = %d, updated_at = '%s' WHERE id = %d", - mountCachePidReq.Pid, util.GetCurrentTimeStr(), mountCachePidReq.Id) - if err := c.baseData.BizDB.Exec(sql).Error; err != nil { - return err - } - return nil -} - func (c *CacheJobDao) ListCacheJob(condition *query.CacheJobQuery) ([]*model.CacheJob, int64, error) { var cacheJobs []*model.CacheJob db := c.baseData.BizDB.Model(&model.CacheJob{}) @@ -200,3 +158,22 @@ func (c *CacheJobDao) ListCacheJob(condition *query.CacheJobQuery) ([]*model.Cac } return cacheJobs, count, nil } + +func (c *CacheJobDao) GetUnCacheJob(instanceId string, ids []int, runningStatus []int32, limit int) ([]*model.CacheJob, error) { + cacheJobs := make([]*model.CacheJob, 0) + db := c.baseData.BizDB.Table("cache_job t1") + if instanceId != "" { + db.Where("t1.instance_id = ?", instanceId) + } + if len(ids) > 0 { + db.Where("t1.id in (?)", ids) + } + if len(runningStatus) > 0 { + db.Where("t1.status in (?)", runningStatus) + } + if limit > 0 { + db.Limit(limit) + } + err := db.Find(&cacheJobs).Error // 中断或等待中的 + return cacheJobs, err +} diff --git a/internal/dao/repository_dao.go b/internal/dao/repository_dao.go index aedd85c..c65d915 100644 --- a/internal/dao/repository_dao.go +++ b/internal/dao/repository_dao.go @@ -354,3 +354,22 @@ func (r *RepositoryDao) UpdateRepositoryMountStatus(statusReq *query.UpdateMount } return nil } + +func (r *RepositoryDao) GetUnmountRepository(instanceId string, ids []int, runningStatus []int32, limit int) ([]*model.Repository, error) { + repositories := make([]*model.Repository, 0) + db := r.baseData.BizDB.Table("repository t1").Select("t1.id, t1.datatype, t1.org, t1.repo, t1.org_repo, t1.status") + if instanceId != "" { + db.Where("t1.instance_id = ?", instanceId) + } + if len(ids) > 0 { + db.Where("t1.id in (?)", ids) + } + if len(runningStatus) > 0 { + db.Where("t1.status in (?)", runningStatus) + } + if limit > 0 { + db.Limit(limit) + } + err := db.Find(&repositories).Error // 中断或等待中的 + return repositories, err +} diff --git a/internal/handler/manager_handler.go b/internal/handler/manager_handler.go index b0a2869..df70b5a 100644 --- a/internal/handler/manager_handler.go +++ b/internal/handler/manager_handler.go @@ -6,20 +6,23 @@ import ( "dingoscheduler/pkg/util" "github.com/labstack/echo/v4" + "go.uber.org/zap" ) type ManagerHandler struct { schedulerService *service.SchedulerService repositoryService *service.RepositoryService hfTokenService *service.HfTokenService + managerService *service.ManagerService } func NewManagerHandler(schedulerService *service.SchedulerService, repositoryService *service.RepositoryService, - hfTokenService *service.HfTokenService) *ManagerHandler { + hfTokenService *service.HfTokenService, managerService *service.ManagerService) *ManagerHandler { return &ManagerHandler{ schedulerService: schedulerService, repositoryService: repositoryService, hfTokenService: hfTokenService, + managerService: managerService, } } @@ -38,3 +41,19 @@ func (handler *ManagerHandler) PersistRepoHandler(c echo.Context) error { func (handler *ManagerHandler) RefreshToken(c echo.Context) error { return util.NormalResponseData(c, handler.hfTokenService.RefreshToken()) } + +func (handler *ManagerHandler) ExecWaitTaskHandler(c echo.Context) error { + waitTaskReq := new(query.WaitTaskReq) + if err := c.Bind(waitTaskReq); err != nil { + return util.ErrorRequestParamCN(c) + } + if waitTaskReq.Limit == 0 { + waitTaskReq.Limit = 30 + } + err := handler.managerService.ExecWaitTask(waitTaskReq) + if err != nil { + zap.S().Errorf("GetRepositoryById err.%v", err) + return util.ResponseError(c) + } + return util.NormalResponseData(c, nil) +} diff --git a/internal/model/query/query.go b/internal/model/query/query.go index 625f279..cf01271 100644 --- a/internal/model/query/query.go +++ b/internal/model/query/query.go @@ -118,6 +118,13 @@ type RepositoryReq struct { Token string `json:"token"` } +type WaitTaskReq struct { + InstanceId string `json:"instanceId"` + Ids []int `json:"ids"` + Type int `json:"type"` + Limit int `json:"limit"` +} + type TagQuery struct { Id string Types []string diff --git a/internal/router/http_router.go b/internal/router/http_router.go index 3874c94..adc90ad 100644 --- a/internal/router/http_router.go +++ b/internal/router/http_router.go @@ -55,10 +55,11 @@ func (r *HttpRouter) initRouter() { if config.SysConfig.EnableMetric() { r.echo.GET("/metrics", echo.WrapHandler(promhttp.Handler())) } - r.echo.POST("/api/persistRepo", r.managerHandler.PersistRepoHandler) // 持久化仓库 - r.echo.GET("/api/refreshToken", r.managerHandler.RefreshToken) // 持久化仓库 - r.repositoryRouter() // repository接口 - r.cacheJobRouter() // 模型缓存 + r.echo.POST("/api/persistRepo", r.managerHandler.PersistRepoHandler) // 持久化仓库 + r.echo.GET("/api/refreshToken", r.managerHandler.RefreshToken) // 刷新默认token + r.echo.POST("/api/execWaitTask", r.managerHandler.ExecWaitTaskHandler) // 执行等待中的缓存下载任务和挂载模型任务 + r.repositoryRouter() // repository接口 + r.cacheJobRouter() // 模型缓存 } func (r *HttpRouter) repositoryRouter() { diff --git a/internal/service/cache_job_service.go b/internal/service/cache_job_service.go index 61b5009..9d43963 100644 --- a/internal/service/cache_job_service.go +++ b/internal/service/cache_job_service.go @@ -62,7 +62,7 @@ func (c *CacheJobService) ListCacheJob(instanceId, datatype string, page, pageSi } jobIds := make([]int64, 0) for _, job := range cacheJobs { - if job.Status == consts.StatusCacheJobIng { + if job.Status == consts.RunningStatusJobIng { jobIds = append(jobIds, job.ID) } } @@ -156,7 +156,7 @@ func (c *CacheJobService) StopCacheJob(jobStatusReq *query.JobStatusReq) error { if cacheJob == nil { return myerr.New(fmt.Sprintf("任务不存在。")) } - if cacheJob.Status != consts.StatusCacheJobIng { + if cacheJob.Status != consts.RunningStatusJobIng { return myerr.New(fmt.Sprintf("job is not running, Can't be stopped.%d", cacheJob.Status)) } entity, err := c.dingospeedDao.GetEntity(jobStatusReq.InstanceId, true) @@ -166,7 +166,7 @@ func (c *CacheJobService) StopCacheJob(jobStatusReq *query.JobStatusReq) error { if entity == nil { return myerr.New("该区域dingspeed未注册。") } - err = c.cacheJobDao.UpdateCacheStatus(&query.UpdateJobStatusReq{Id: jobStatusReq.Id, Status: consts.StatusCacheJobStopping}) + err = c.cacheJobDao.UpdateCacheStatus(&query.UpdateJobStatusReq{Id: jobStatusReq.Id, Status: consts.RunningStatusJobStopping}) if err != nil { return err } @@ -193,7 +193,9 @@ func (c *CacheJobService) ResumeCacheJob(resumeCacheJobReq *query.ResumeCacheJob if cacheJob == nil { return myerr.New(fmt.Sprintf("job is not exist.jobId:%d", resumeCacheJobReq.Id)) } - if cacheJob.Status != consts.StatusCacheJobBreak { + if cacheJob.Status != consts.RunningStatusJobBreak && + cacheJob.Status != consts.RunningStatusJobStop && + cacheJob.Status != consts.RunningStatusJobWait { return myerr.New("当前状态不可执行该操作。") } entity, err := c.dingospeedDao.GetEntity(resumeCacheJobReq.InstanceId, true) @@ -235,7 +237,7 @@ func (c *CacheJobService) DeleteCacheJob(id int64) error { if cacheJob == nil { return myerr.New(fmt.Sprintf("记录不存在。")) } - if cacheJob.Status == consts.StatusCacheJobIng || cacheJob.Status == consts.StatusCacheJobComplete { + if cacheJob.Status == consts.RunningStatusJobIng || cacheJob.Status == consts.RunningStatusJobComplete { return myerr.New(fmt.Sprintf("当前缓存任务不能删除。")) } return c.cacheJobDao.Delete(id) diff --git a/internal/service/mamager_service.go b/internal/service/mamager_service.go new file mode 100644 index 0000000..2fece20 --- /dev/null +++ b/internal/service/mamager_service.go @@ -0,0 +1,61 @@ +package service + +import ( + "fmt" + + "dingoscheduler/internal/dao" + "dingoscheduler/internal/model/query" + "dingoscheduler/pkg/consts" +) + +type ManagerService struct { + repositoryDao *dao.RepositoryDao + repositoryService *RepositoryService + cacheJobDao *dao.CacheJobDao + cacheJobService *CacheJobService +} + +func NewManagerService(repositoryDao *dao.RepositoryDao, repositoryService *RepositoryService, cacheJobDao *dao.CacheJobDao, + cacheJobService *CacheJobService) *ManagerService { + return &ManagerService{ + repositoryService: repositoryService, + repositoryDao: repositoryDao, + cacheJobDao: cacheJobDao, + cacheJobService: cacheJobService, + } +} + +func (s *ManagerService) ExecWaitTask(waitTaskReq *query.WaitTaskReq) error { + execStatus := []int32{consts.RunningStatusJobBreak, consts.RunningStatusJobWait} + if waitTaskReq.Type == consts.CacheTypePreheat { + unCacheJobs, err := s.cacheJobDao.GetUnCacheJob(waitTaskReq.InstanceId, waitTaskReq.Ids, execStatus, waitTaskReq.Limit) + if err != nil { + return err + } + for _, i := range unCacheJobs { + err = s.cacheJobService.ResumeCacheJob(&query.ResumeCacheJobReq{ + Id: i.ID, + InstanceId: waitTaskReq.InstanceId, + }) + if err != nil { + return err + } + } + } else if waitTaskReq.Type == consts.CacheTypeMount { + repositories, err := s.repositoryDao.GetUnmountRepository(waitTaskReq.InstanceId, waitTaskReq.Ids, execStatus, waitTaskReq.Limit) + if err != nil { + return err + } + for _, i := range repositories { + err = s.repositoryService.MountRepository(&query.RepositoryReq{ + Id: i.ID, + }) + if err != nil { + return err + } + } + } else { + return fmt.Errorf("type is invalid") + } + return nil +} diff --git a/internal/service/repository_service.go b/internal/service/repository_service.go index 103c7c2..c3c4257 100644 --- a/internal/service/repository_service.go +++ b/internal/service/repository_service.go @@ -227,7 +227,7 @@ func (s *RepositoryService) MountRepository(repoReq *query.RepositoryReq) error if repository == nil { return myerr.New(fmt.Sprintf("记录不存在。编号:%d", repoReq.Id)) } - if repository.Status == consts.StatusCacheJobIng || repository.Status == consts.StatusCacheJobComplete { + if repository.Status == consts.RunningStatusJobIng || repository.Status == consts.RunningStatusJobComplete { return myerr.New("当前状态不可执行该操作。") } entity, err := s.dingospeedDao.GetEntity(repository.InstanceId, false) // 挂载到公共目录,通过离线模式处理 @@ -240,7 +240,7 @@ func (s *RepositoryService) MountRepository(repoReq *query.RepositoryReq) error speedDomain := fmt.Sprintf("http://%s:%d", entity.Host, entity.Port) if err = s.repositoryDao.UpdateRepositoryMountStatus(&query.UpdateMountStatusReq{ Id: repository.ID, - Status: consts.StatusCacheJobIng, + Status: consts.RunningStatusJobIng, }); err != nil { zap.S().Errorf("UpdateRepositoryMountStatus err.%v", err) return myerr.New("更新状态错误。") diff --git a/internal/service/service.go b/internal/service/service.go index 2c2d4b5..b231fd3 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -17,4 +17,4 @@ package service import "github.com/google/wire" var ServiceProvider = wire.NewSet(NewSchedulerService, NewSysService, NewCacheJobService, NewRepositoryService, - NewTagService, NewOrganizationService, NewHfTokenService) + NewTagService, NewOrganizationService, NewHfTokenService, NewManagerService) diff --git a/internal/service/sys_service.go b/internal/service/sys_service.go index 280c8cb..a10f750 100644 --- a/internal/service/sys_service.go +++ b/internal/service/sys_service.go @@ -7,6 +7,7 @@ import ( "dingoscheduler/internal/dao" "dingoscheduler/internal/model/query" "dingoscheduler/pkg/config" + "dingoscheduler/pkg/consts" "github.com/robfig/cron/v3" "go.uber.org/zap" @@ -16,17 +17,22 @@ var once sync.Once type SysService struct { repositoryDao *dao.RepositoryDao + cacheJobDao *dao.CacheJobDao } -func NewSysService(repositoryDao *dao.RepositoryDao) *SysService { +func NewSysService(repositoryDao *dao.RepositoryDao, cacheJobDao *dao.CacheJobDao) *SysService { sysSvc := &SysService{} sysSvc.repositoryDao = repositoryDao + sysSvc.cacheJobDao = cacheJobDao once.Do( func() { if config.SysConfig.GetEnablePersistRepo() { go sysSvc.startPersistRepo() } }) + if err := sysSvc.repairJobRunStatus(); err != nil { + panic(err) + } return sysSvc } @@ -58,3 +64,33 @@ func (s SysService) startPersistRepo() { defer c.Stop() select {} } + +func (s SysService) repairJobRunStatus() error { + if unCacheJobs, err := s.cacheJobDao.GetUnCacheJob("", []int{}, []int32{consts.RunningStatusJobStopping}, 0); err != nil { + zap.S().Errorf("GetUnmountRepository err.%v", err) + } else { + for _, i := range unCacheJobs { + err = s.cacheJobDao.UpdateStatusAndRepo(&query.UpdateJobStatusReq{ + Id: i.ID, + Status: consts.RunningStatusJobStop, + }) + if err != nil { + return err + } + } + } + if repositories, err := s.repositoryDao.GetUnmountRepository("", []int{}, []int32{consts.RunningStatusJobStopping}, 0); err != nil { + zap.S().Errorf("GetUnmountRepository err.%v", err) + } else { + for _, i := range repositories { + err = s.repositoryDao.UpdateRepositoryMountStatus(&query.UpdateMountStatusReq{ + Id: i.ID, + Status: consts.RunningStatusJobStop, + }) + if err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/consts/const.go b/pkg/consts/const.go index 3d16167..bb15371 100644 --- a/pkg/consts/const.go +++ b/pkg/consts/const.go @@ -58,11 +58,12 @@ const ( CacheTypePreheat = 1 CacheTypeMount = 2 - StatusCacheJobDefault = 0 - StatusCacheJobIng = 1 - StatusCacheJobBreak = 2 - StatusCacheJobComplete = 3 - StatusCacheJobStopping = 4 + RunningStatusJobIng = 1 + RunningStatusJobBreak = 2 + RunningStatusJobComplete = 3 + RunningStatusJobStopping = 4 + RunningStatusJobStop = 5 + RunningStatusJobWait = 6 ) const (