Skip to content
Merged
17 changes: 12 additions & 5 deletions backend/modules/evaluation/domain/service/expt_run_item_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type ExptItemEvalCtxExecutor struct {
evalSetItemSvc EvaluationSetItemService
}

const exptRunLogPersistTimeout = 5 * time.Second

func (e *ExptItemEvalCtxExecutor) Eval(ctx context.Context, eiec *entity.ExptItemEvalCtx) error {
event := eiec.Event

Expand Down Expand Up @@ -130,6 +132,8 @@ func (e *ExptItemEvalCtxExecutor) storeTurnRunResult(ctx context.Context, etec *
if result == nil {
return fmt.Errorf("StoreTurnRunResult with nil result")
}
persistCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), exptRunLogPersistTimeout)
defer cancel()

turn := etec.Turn
turnResultLog := etec.GetExistTurnResultLogs()[turn.ID]
Expand Down Expand Up @@ -173,7 +177,7 @@ func (e *ExptItemEvalCtxExecutor) storeTurnRunResult(ctx context.Context, etec *
if se, ok := errorx.FromStatusError(evalErr); ok && (se.Code() == errno.CustomEvalTargetInvokeFailCode || se.Code() == errno.CustomRPCEvaluatorRunFailedCode) {
errMsg = errorx.ErrorWithoutStack(evalErr)
} else {
errMsg = e.Configer.GetErrCtrl(ctx).ConvertErrMsg(evalErr.Error())
errMsg = e.Configer.GetErrCtrl(persistCtx).ConvertErrMsg(evalErr.Error())
}

logs.CtxWarn(ctx, "[ExptTurnEval] store turn run err, before: %v, after: %v", evalErr, errMsg)
Expand All @@ -195,7 +199,7 @@ func (e *ExptItemEvalCtxExecutor) storeTurnRunResult(ctx context.Context, etec *

result.SetEvalErr(evalErr)

if err := e.TurnResultRepo.SaveTurnRunLogs(ctx, []*entity.ExptTurnResultRunLog{clone}); err != nil {
if err := e.TurnResultRepo.SaveTurnRunLogs(persistCtx, []*entity.ExptTurnResultRunLog{clone}); err != nil {
return err
}

Expand Down Expand Up @@ -275,8 +279,11 @@ func (e *ExptItemEvalCtxExecutor) buildExptTurnEvalCtx(ctx context.Context, turn
}

func (e *ExptItemEvalCtxExecutor) CompleteItemRun(ctx context.Context, event *entity.ExptItemEvalEvent, evalErr error) error {
persistCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), exptRunLogPersistTimeout)
defer cancel()

if evalErr != nil {
if retry, _ := e.evalErrNeedRetry(ctx, event, evalErr); retry {
if retry, _ := e.evalErrNeedRetry(persistCtx, event, evalErr); retry {
return evalErr
}
}
Expand All @@ -292,11 +299,11 @@ func (e *ExptItemEvalCtxExecutor) CompleteItemRun(ctx context.Context, event *en
ufields["status"] = int32(entity.ItemRunState_Success)
}

if err := e.ItemResultRepo.UpdateItemRunLog(ctx, event.ExptID, event.ExptRunID, []int64{event.EvalSetItemID}, ufields, event.SpaceID); err != nil {
if err := e.ItemResultRepo.UpdateItemRunLog(persistCtx, event.ExptID, event.ExptRunID, []int64{event.EvalSetItemID}, ufields, event.SpaceID); err != nil {
return err
}

if e.evalErrNeedTerminateExpt(ctx, event.SpaceID, evalErr) {
if e.evalErrNeedTerminateExpt(persistCtx, event.SpaceID, evalErr) {
logs.CtxWarn(ctx, "[ExptTurnEval] found error which should terminate expt, expt_id: %v, expt_run_id: %v, item_id: %v, err: %v", event.ExptID, event.ExptRunID, event.EvalSetItemID, evalErr)
return evalErr
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/bytedance/gg/gptr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"

"github.com/coze-dev/coze-loop/backend/infra/external/benefit"
Expand Down Expand Up @@ -371,6 +372,23 @@ func Test_ExptItemEvalCtxExecutor_CompleteSetItemRun(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "mock updateitemrunlog error")
})

t.Run("ctx取消后仍落item失败状态", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

mockConfiger.EXPECT().GetErrRetryConf(gomock.Any(), int64(4), gomock.Any()).AnyTimes().Return(&entity.RetryConf{IsInDebt: false})
mockItemResultRepo.EXPECT().UpdateItemRunLog(gomock.Any(), int64(1), int64(2), []int64{3}, gomock.Any(), int64(4)).
DoAndReturn(func(ctx context.Context, _, _ int64, _ []int64, ufields map[string]any, _ int64) error {
require.NoError(t, ctx.Err())
assert.Equal(t, int32(entity.ItemRunState_Fail), ufields["status"])
return nil
})

event := &entity.ExptItemEvalEvent{ExptID: 1, ExptRunID: 2, EvalSetItemID: 3, SpaceID: 4, RetryTimes: 1}
err := executor.CompleteItemRun(ctx, event, errors.New("target timeout"))
assert.NoError(t, err)
})
}

func Test_ExptItemEvalCtxExecutor_storeTurnRunResult(t *testing.T) {
Expand Down Expand Up @@ -438,6 +456,37 @@ func Test_ExptItemEvalCtxExecutor_storeTurnRunResult(t *testing.T) {
err := executor.storeTurnRunResult(context.Background(), etec, result)
assert.NoError(t, err)
})

t.Run("ctx取消后仍落turn失败状态", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

turnResultLog := &entity.ExptTurnResultRunLog{ID: 1, TurnID: 1}
etec := &entity.ExptTurnEvalCtx{
Turn: &entity.Turn{ID: 1},
ExptItemEvalCtx: &entity.ExptItemEvalCtx{
Expt: &entity.Experiment{ID: 1, SourceID: "src", SpaceID: 2},
Event: &entity.ExptItemEvalEvent{ExptRunID: 3},
EvalSetItem: &entity.EvaluationSetItem{ItemID: 2},
ExistItemEvalResult: &entity.ExptItemEvalResult{TurnResultRunLogs: map[int64]*entity.ExptTurnResultRunLog{1: turnResultLog}},
},
}
result := &entity.ExptTurnRunResult{EvalErr: errors.New("target timeout")}

mockConfiger.EXPECT().GetErrCtrl(gomock.Any()).DoAndReturn(func(ctx context.Context) *entity.ExptErrCtrl {
require.NoError(t, ctx.Err())
return entity.DefaultExptErrCtrl()
})
mockTurnResultRepo.EXPECT().SaveTurnRunLogs(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, logs []*entity.ExptTurnResultRunLog) error {
require.NoError(t, ctx.Err())
require.Len(t, logs, 1)
assert.Equal(t, entity.TurnRunState_Fail, logs[0].Status)
return nil
})

err := executor.storeTurnRunResult(ctx, etec, result)
assert.NoError(t, err)
})
}

func Test_buildExptTurnEvalCtx(t *testing.T) {
Expand Down
13 changes: 9 additions & 4 deletions backend/modules/evaluation/domain/service/target_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type EvalTargetServiceImpl struct {
configer component.IConfiger
}

const evalTargetRecordPersistTimeout = 5 * time.Second
Comment thread
caijialin0626 marked this conversation as resolved.

func NewEvalTargetServiceImpl(evalTargetRepo repo.IEvalTargetRepo,
idgen idgen.IIDGenerator,
metric metrics.EvalTargetMetrics,
Expand Down Expand Up @@ -354,12 +356,15 @@ func (e *EvalTargetServiceImpl) ExecuteTarget(ctx context.Context, spaceID, targ
}
}

recordID, err1 := e.idgen.GenID(ctx)
recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), evalTargetRecordPersistTimeout)
defer recordCancel()

recordID, err1 := e.idgen.GenID(recordCtx)
if err1 != nil {
err = err1
return
}
logID := logs.GetLogID(ctx)
logID := logs.GetLogID(recordCtx)

record = &entity.EvalTargetRecord{
ID: recordID,
Expand All @@ -385,9 +390,9 @@ func (e *EvalTargetServiceImpl) ExecuteTarget(ctx context.Context, spaceID, targ
UpdatedAt: gptr.Of(time.Now().UnixMilli()),
},
}
e.convEvalTargetRunErr(ctx, record)
e.convEvalTargetRunErr(recordCtx, record)

_, errCreate := e.evalTargetRepo.CreateEvalTargetRecord(ctx, record, nil)
_, errCreate := e.evalTargetRepo.CreateEvalTargetRecord(recordCtx, record, nil)
if errCreate != nil {
return
}
Expand Down
96 changes: 92 additions & 4 deletions backend/modules/evaluation/domain/service/target_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,10 +607,14 @@ func TestEvalTargetServiceImpl_ExecuteTarget(t *testing.T) {
deps.configer.EXPECT().GetTargetTrajectoryConf(gomock.Any()).AnyTimes().Return(&entity.TargetTrajectoryConf{})
// convEvalTargetRunErr (in ExecuteTarget defer) may call GetErrCtrl when record has EvalTargetRunError
deps.configer.EXPECT().GetErrCtrl(gomock.Any()).AnyTimes().Return(entity.DefaultExptErrCtrl())
deps.idgen.EXPECT().GenID(ctx).Return(int64(9999), nil)
deps.idgen.EXPECT().GenID(gomock.Any()).DoAndReturn(func(ctx context.Context) (int64, error) {
require.NoError(t, ctx.Err())
return int64(9999), nil
})

var savedRecord *entity.EvalTargetRecord
deps.repo.EXPECT().CreateEvalTargetRecord(ctx, gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, rec *entity.EvalTargetRecord, _ *bool) (int64, error) {
deps.repo.EXPECT().CreateEvalTargetRecord(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, rec *entity.EvalTargetRecord, _ *bool) (int64, error) {
require.NoError(t, ctx.Err())
savedRecord = rec
return rec.ID, nil
})
Expand Down Expand Up @@ -651,6 +655,86 @@ func TestEvalTargetServiceImpl_ExecuteTarget(t *testing.T) {
}
}

func TestEvalTargetServiceImpl_ExecuteTarget_PersistsFailRecordAfterContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
ctrl := gomock.NewController(t)
defer ctrl.Finish()

deps := &evalTargetServiceTestDeps{
repo: repomocks.NewMockIEvalTargetRepo(ctrl),
idgen: idgenmocks.NewMockIIDGenerator(ctrl),
metric: metricsmocks.NewMockEvalTargetMetrics(ctrl),
operator: servicemocks.NewMockISourceEvalTargetOperateService(ctrl),
configer: componentmocks.NewMockIConfiger(ctrl),
}

evalTarget := &entity.EvalTarget{
ID: 200,
SpaceID: 100,
SourceTargetID: "src-id",
EvalTargetType: entity.EvalTargetTypeLoopPrompt,
EvalTargetVersion: &entity.EvalTargetVersion{
ID: 300,
SourceTargetVersion: "v1",
InputSchema: []*entity.ArgsSchema{
{Key: gptr.Of("field")},
},
},
}
input := &entity.EvalTargetInputData{
InputFields: map[string]*entity.Content{
"field": {
ContentType: gptr.Of(entity.ContentTypeText),
Text: gptr.Of("hello"),
},
},
}
param := &entity.ExecuteTargetCtx{
ExperimentRunID: gptr.Of(int64(555)),
ItemID: 777,
TurnID: 888,
}

deps.repo.EXPECT().GetEvalTargetVersion(ctx, evalTarget.SpaceID, evalTarget.EvalTargetVersion.ID).Return(evalTarget, nil)
deps.metric.EXPECT().EmitRun(evalTarget.SpaceID, gomock.Any(), gomock.Any()).Times(1)
deps.configer.EXPECT().GetTargetTrajectoryConf(gomock.Any()).AnyTimes().Return(&entity.TargetTrajectoryConf{})
deps.configer.EXPECT().GetErrCtrl(gomock.Any()).AnyTimes().DoAndReturn(func(ctx context.Context) *entity.ExptErrCtrl {
require.NoError(t, ctx.Err())
return entity.DefaultExptErrCtrl()
})
deps.operator.EXPECT().ValidateInput(gomock.Any(), evalTarget.SpaceID, evalTarget.EvalTargetVersion.InputSchema, input).Return(nil)
deps.operator.EXPECT().Execute(gomock.Any(), evalTarget.SpaceID, gomock.Any()).DoAndReturn(func(context.Context, int64, *entity.ExecuteEvalTargetParam) (*entity.EvalTargetOutputData, entity.EvalTargetRunStatus, error) {
cancel()
return nil, entity.EvalTargetRunStatusFail, context.Canceled
})
deps.idgen.EXPECT().GenID(gomock.Any()).DoAndReturn(func(ctx context.Context) (int64, error) {
require.NoError(t, ctx.Err())
return int64(9999), nil
})
deps.repo.EXPECT().CreateEvalTargetRecord(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, rec *entity.EvalTargetRecord, _ *bool) (int64, error) {
require.NoError(t, ctx.Err())
require.NotNil(t, rec)
assert.Equal(t, entity.EvalTargetRunStatusFail, gptr.Indirect(rec.Status))
return rec.ID, nil
})

svc := &EvalTargetServiceImpl{
evalTargetRepo: deps.repo,
idgen: deps.idgen,
metric: deps.metric,
typedOperators: map[entity.EvalTargetType]ISourceEvalTargetOperateService{
evalTarget.EvalTargetType: deps.operator,
},
configer: deps.configer,
}

record, err := svc.ExecuteTarget(ctx, evalTarget.SpaceID, evalTarget.ID, evalTarget.EvalTargetVersion.ID, param, input)
require.NoError(t, err)
require.NotNil(t, record)
assert.Equal(t, int64(9999), record.ID)
assert.Equal(t, entity.EvalTargetRunStatusFail, gptr.Indirect(record.Status))
}

func TestEvalTargetServiceImpl_ExecuteTarget_TrajectoryExtraction(t *testing.T) {
// do not run in parallel, this test involves time.Sleep

Expand Down Expand Up @@ -752,7 +836,10 @@ func TestEvalTargetServiceImpl_ExecuteTarget_TrajectoryExtraction(t *testing.T)
spaceID: 1,
},
})
deps.idgen.EXPECT().GenID(ctx).Return(int64(9999), nil)
deps.idgen.EXPECT().GenID(gomock.Any()).DoAndReturn(func(ctx context.Context) (int64, error) {
require.NoError(t, ctx.Err())
return int64(9999), nil
})

trajectoryAdapter.EXPECT().
ListTrajectory(gomock.Any(), spaceID, gomock.Any(), gomock.AssignableToTypeOf((*int64)(nil))).
Expand All @@ -770,7 +857,8 @@ func TestEvalTargetServiceImpl_ExecuteTarget_TrajectoryExtraction(t *testing.T)
Return(outputData, entity.EvalTargetRunStatusSuccess, nil)

var savedRecord *entity.EvalTargetRecord
deps.repo.EXPECT().CreateEvalTargetRecord(ctx, gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, rec *entity.EvalTargetRecord, _ *bool) (int64, error) {
deps.repo.EXPECT().CreateEvalTargetRecord(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, rec *entity.EvalTargetRecord, _ *bool) (int64, error) {
require.NoError(t, ctx.Err())
savedRecord = rec
return rec.ID, nil
})
Expand Down
Loading