Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cmd/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 20 additions & 43 deletions internal/dao/cache_job_dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"dingoscheduler/internal/data"
"dingoscheduler/internal/model"
"dingoscheduler/internal/model/query"
"dingoscheduler/pkg/common"
"dingoscheduler/pkg/consts"
"dingoscheduler/pkg/util"

Expand Down Expand Up @@ -78,38 +77,6 @@ func (c *CacheJobDao) GetCacheJob(condition *query.CacheJobQuery) (*model.CacheJ
return nil, nil
}

func (c *CacheJobDao) RemoteRequestPathsInfo(domain, dataType, org, repo, revision, token string, fileNames []string) ([]common.PathsInfo, error) {
var reqUri = "/api/getPathInfo"
headers := map[string]string{}
if token != "" {
headers["authorization"] = fmt.Sprintf("Bearer %s", token)
}
query := query.PathInfoQuery{
Datatype: dataType,
Org: org,
Repo: repo,
Revision: revision,
Token: token,
FileNames: fileNames,
}
b, err := sonic.Marshal(query)
if err != nil {
return nil, err
}
response, err := util.RetryRequest(func() (*common.Response, error) {
return util.PostForDomain(domain, reqUri, "application/json", b, headers)
})
if err != nil {
return nil, err
}
ret := make([]common.PathsInfo, 0)
err = sonic.Unmarshal(response.Body, &ret)
if err != nil {
return nil, err
}
return ret, nil
}

func (c *CacheJobDao) UpdateCacheStatus(statusReq *query.UpdateJobStatusReq) error {
var (
newMsgStr string
Expand Down Expand Up @@ -142,7 +109,7 @@ func (c *CacheJobDao) UpdateStatusAndRepo(jobStatusReq *query.UpdateJobStatusReq
if err != nil {
return err
}
if jobStatusReq.Status == consts.StatusCacheJobComplete {
if jobStatusReq.Status == consts.RunningStatusJobComplete {
err = c.repositoryDao.PersistRepo(&query.PersistRepoReq{InstanceIds: []string{jobStatusReq.InstanceId},
Org: jobStatusReq.Org, Repo: jobStatusReq.Repo, OffVerify: true})
if err != nil {
Expand All @@ -158,15 +125,6 @@ func (c *CacheJobDao) Delete(id int64) error {
return nil
}

func (c *CacheJobDao) UpdateMountCachePid(mountCachePidReq *query.UpdateMountCachePidReq) error {
sql := fmt.Sprintf("UPDATE mount_cache_job SET shell_pid = %d, updated_at = '%s' WHERE id = %d",
mountCachePidReq.Pid, util.GetCurrentTimeStr(), mountCachePidReq.Id)
if err := c.baseData.BizDB.Exec(sql).Error; err != nil {
return err
}
return nil
}

func (c *CacheJobDao) ListCacheJob(condition *query.CacheJobQuery) ([]*model.CacheJob, int64, error) {
var cacheJobs []*model.CacheJob
db := c.baseData.BizDB.Model(&model.CacheJob{})
Expand Down Expand Up @@ -200,3 +158,22 @@ func (c *CacheJobDao) ListCacheJob(condition *query.CacheJobQuery) ([]*model.Cac
}
return cacheJobs, count, nil
}

func (c *CacheJobDao) GetUnCacheJob(instanceId string, ids []int, runningStatus []int32, limit int) ([]*model.CacheJob, error) {
cacheJobs := make([]*model.CacheJob, 0)
db := c.baseData.BizDB.Table("cache_job t1")
if instanceId != "" {
db.Where("t1.instance_id = ?", instanceId)
}
if len(ids) > 0 {
db.Where("t1.id in (?)", ids)
}
if len(runningStatus) > 0 {
db.Where("t1.status in (?)", runningStatus)
}
if limit > 0 {
db.Limit(limit)
}
err := db.Find(&cacheJobs).Error // 中断或等待中的
return cacheJobs, err
}
19 changes: 19 additions & 0 deletions internal/dao/repository_dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,22 @@ func (r *RepositoryDao) UpdateRepositoryMountStatus(statusReq *query.UpdateMount
}
return nil
}

func (r *RepositoryDao) GetUnmountRepository(instanceId string, ids []int, runningStatus []int32, limit int) ([]*model.Repository, error) {
repositories := make([]*model.Repository, 0)
db := r.baseData.BizDB.Table("repository t1").Select("t1.id, t1.datatype, t1.org, t1.repo, t1.org_repo, t1.status")
if instanceId != "" {
db.Where("t1.instance_id = ?", instanceId)
}
if len(ids) > 0 {
db.Where("t1.id in (?)", ids)
}
if len(runningStatus) > 0 {
db.Where("t1.status in (?)", runningStatus)
}
if limit > 0 {
db.Limit(limit)
}
err := db.Find(&repositories).Error // 中断或等待中的
return repositories, err
}
21 changes: 20 additions & 1 deletion internal/handler/manager_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@ import (
"dingoscheduler/pkg/util"

"github.com/labstack/echo/v4"
"go.uber.org/zap"
)

type ManagerHandler struct {
schedulerService *service.SchedulerService
repositoryService *service.RepositoryService
hfTokenService *service.HfTokenService
managerService *service.ManagerService
}

func NewManagerHandler(schedulerService *service.SchedulerService, repositoryService *service.RepositoryService,
hfTokenService *service.HfTokenService) *ManagerHandler {
hfTokenService *service.HfTokenService, managerService *service.ManagerService) *ManagerHandler {
return &ManagerHandler{
schedulerService: schedulerService,
repositoryService: repositoryService,
hfTokenService: hfTokenService,
managerService: managerService,
}
}

Expand All @@ -38,3 +41,19 @@ func (handler *ManagerHandler) PersistRepoHandler(c echo.Context) error {
func (handler *ManagerHandler) RefreshToken(c echo.Context) error {
return util.NormalResponseData(c, handler.hfTokenService.RefreshToken())
}

func (handler *ManagerHandler) ExecWaitTaskHandler(c echo.Context) error {
waitTaskReq := new(query.WaitTaskReq)
if err := c.Bind(waitTaskReq); err != nil {
return util.ErrorRequestParamCN(c)
}
if waitTaskReq.Limit == 0 {
waitTaskReq.Limit = 30
}
err := handler.managerService.ExecWaitTask(waitTaskReq)
if err != nil {
zap.S().Errorf("GetRepositoryById err.%v", err)
return util.ResponseError(c)
}
return util.NormalResponseData(c, nil)
}
7 changes: 7 additions & 0 deletions internal/model/query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ type RepositoryReq struct {
Token string `json:"token"`
}

type WaitTaskReq struct {
InstanceId string `json:"instanceId"`
Ids []int `json:"ids"`
Type int `json:"type"`
Limit int `json:"limit"`
}

type TagQuery struct {
Id string
Types []string
Expand Down
9 changes: 5 additions & 4 deletions internal/router/http_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ func (r *HttpRouter) initRouter() {
if config.SysConfig.EnableMetric() {
r.echo.GET("/metrics", echo.WrapHandler(promhttp.Handler()))
}
r.echo.POST("/api/persistRepo", r.managerHandler.PersistRepoHandler) // 持久化仓库
r.echo.GET("/api/refreshToken", r.managerHandler.RefreshToken) // 持久化仓库
r.repositoryRouter() // repository接口
r.cacheJobRouter() // 模型缓存
r.echo.POST("/api/persistRepo", r.managerHandler.PersistRepoHandler) // 持久化仓库
r.echo.GET("/api/refreshToken", r.managerHandler.RefreshToken) // 刷新默认token
r.echo.POST("/api/execWaitTask", r.managerHandler.ExecWaitTaskHandler) // 执行等待中的缓存下载任务和挂载模型任务
r.repositoryRouter() // repository接口
r.cacheJobRouter() // 模型缓存
}

func (r *HttpRouter) repositoryRouter() {
Expand Down
12 changes: 7 additions & 5 deletions internal/service/cache_job_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (c *CacheJobService) ListCacheJob(instanceId, datatype string, page, pageSi
}
jobIds := make([]int64, 0)
for _, job := range cacheJobs {
if job.Status == consts.StatusCacheJobIng {
if job.Status == consts.RunningStatusJobIng {
jobIds = append(jobIds, job.ID)
}
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func (c *CacheJobService) StopCacheJob(jobStatusReq *query.JobStatusReq) error {
if cacheJob == nil {
return myerr.New(fmt.Sprintf("任务不存在。"))
}
if cacheJob.Status != consts.StatusCacheJobIng {
if cacheJob.Status != consts.RunningStatusJobIng {
return myerr.New(fmt.Sprintf("job is not running, Can't be stopped.%d", cacheJob.Status))
}
entity, err := c.dingospeedDao.GetEntity(jobStatusReq.InstanceId, true)
Expand All @@ -166,7 +166,7 @@ func (c *CacheJobService) StopCacheJob(jobStatusReq *query.JobStatusReq) error {
if entity == nil {
return myerr.New("该区域dingspeed未注册。")
}
err = c.cacheJobDao.UpdateCacheStatus(&query.UpdateJobStatusReq{Id: jobStatusReq.Id, Status: consts.StatusCacheJobStopping})
err = c.cacheJobDao.UpdateCacheStatus(&query.UpdateJobStatusReq{Id: jobStatusReq.Id, Status: consts.RunningStatusJobStopping})
if err != nil {
return err
}
Expand All @@ -193,7 +193,9 @@ func (c *CacheJobService) ResumeCacheJob(resumeCacheJobReq *query.ResumeCacheJob
if cacheJob == nil {
return myerr.New(fmt.Sprintf("job is not exist.jobId:%d", resumeCacheJobReq.Id))
}
if cacheJob.Status != consts.StatusCacheJobBreak {
if cacheJob.Status != consts.RunningStatusJobBreak &&
cacheJob.Status != consts.RunningStatusJobStop &&
cacheJob.Status != consts.RunningStatusJobWait {
return myerr.New("当前状态不可执行该操作。")
}
entity, err := c.dingospeedDao.GetEntity(resumeCacheJobReq.InstanceId, true)
Expand Down Expand Up @@ -235,7 +237,7 @@ func (c *CacheJobService) DeleteCacheJob(id int64) error {
if cacheJob == nil {
return myerr.New(fmt.Sprintf("记录不存在。"))
}
if cacheJob.Status == consts.StatusCacheJobIng || cacheJob.Status == consts.StatusCacheJobComplete {
if cacheJob.Status == consts.RunningStatusJobIng || cacheJob.Status == consts.RunningStatusJobComplete {
return myerr.New(fmt.Sprintf("当前缓存任务不能删除。"))
}
return c.cacheJobDao.Delete(id)
Expand Down
61 changes: 61 additions & 0 deletions internal/service/mamager_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package service

import (
"fmt"

"dingoscheduler/internal/dao"
"dingoscheduler/internal/model/query"
"dingoscheduler/pkg/consts"
)

type ManagerService struct {
repositoryDao *dao.RepositoryDao
repositoryService *RepositoryService
cacheJobDao *dao.CacheJobDao
cacheJobService *CacheJobService
}

func NewManagerService(repositoryDao *dao.RepositoryDao, repositoryService *RepositoryService, cacheJobDao *dao.CacheJobDao,
cacheJobService *CacheJobService) *ManagerService {
return &ManagerService{
repositoryService: repositoryService,
repositoryDao: repositoryDao,
cacheJobDao: cacheJobDao,
cacheJobService: cacheJobService,
}
}

func (s *ManagerService) ExecWaitTask(waitTaskReq *query.WaitTaskReq) error {
execStatus := []int32{consts.RunningStatusJobBreak, consts.RunningStatusJobWait}
if waitTaskReq.Type == consts.CacheTypePreheat {
unCacheJobs, err := s.cacheJobDao.GetUnCacheJob(waitTaskReq.InstanceId, waitTaskReq.Ids, execStatus, waitTaskReq.Limit)
if err != nil {
return err
}
for _, i := range unCacheJobs {
err = s.cacheJobService.ResumeCacheJob(&query.ResumeCacheJobReq{
Id: i.ID,
InstanceId: waitTaskReq.InstanceId,
})
if err != nil {
return err
}
}
} else if waitTaskReq.Type == consts.CacheTypeMount {
repositories, err := s.repositoryDao.GetUnmountRepository(waitTaskReq.InstanceId, waitTaskReq.Ids, execStatus, waitTaskReq.Limit)
if err != nil {
return err
}
for _, i := range repositories {
err = s.repositoryService.MountRepository(&query.RepositoryReq{
Id: i.ID,
})
if err != nil {
return err
}
}
} else {
return fmt.Errorf("type is invalid")
}
return nil
}
4 changes: 2 additions & 2 deletions internal/service/repository_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func (s *RepositoryService) MountRepository(repoReq *query.RepositoryReq) error
if repository == nil {
return myerr.New(fmt.Sprintf("记录不存在。编号:%d", repoReq.Id))
}
if repository.Status == consts.StatusCacheJobIng || repository.Status == consts.StatusCacheJobComplete {
if repository.Status == consts.RunningStatusJobIng || repository.Status == consts.RunningStatusJobComplete {
return myerr.New("当前状态不可执行该操作。")
}
entity, err := s.dingospeedDao.GetEntity(repository.InstanceId, false) // 挂载到公共目录,通过离线模式处理
Expand All @@ -240,7 +240,7 @@ func (s *RepositoryService) MountRepository(repoReq *query.RepositoryReq) error
speedDomain := fmt.Sprintf("http://%s:%d", entity.Host, entity.Port)
if err = s.repositoryDao.UpdateRepositoryMountStatus(&query.UpdateMountStatusReq{
Id: repository.ID,
Status: consts.StatusCacheJobIng,
Status: consts.RunningStatusJobIng,
}); err != nil {
zap.S().Errorf("UpdateRepositoryMountStatus err.%v", err)
return myerr.New("更新状态错误。")
Expand Down
2 changes: 1 addition & 1 deletion internal/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ package service
import "github.com/google/wire"

var ServiceProvider = wire.NewSet(NewSchedulerService, NewSysService, NewCacheJobService, NewRepositoryService,
NewTagService, NewOrganizationService, NewHfTokenService)
NewTagService, NewOrganizationService, NewHfTokenService, NewManagerService)
Loading