diff --git a/devops/internal/service/call_option.go b/devops/internal/service/call_option.go index c8134f737..66297edc0 100644 --- a/devops/internal/service/call_option.go +++ b/devops/internal/service/call_option.go @@ -140,7 +140,7 @@ func (c *callbackHandler) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou ext := c.ConvCallbackOutput(output) if ext != nil && ext.TokenUsage != nil { state.Metrics.PromptTokens = int64(ext.TokenUsage.PromptTokens) - state.Metrics.CompletionTokens = int64(ext.TokenUsage.PromptTokens) + state.Metrics.CompletionTokens = int64(ext.TokenUsage.CompletionTokens) } // append result diff --git a/devops/internal/service/call_option_test.go b/devops/internal/service/call_option_test.go index 0be116668..91f6b7bb7 100644 --- a/devops/internal/service/call_option_test.go +++ b/devops/internal/service/call_option_test.go @@ -148,6 +148,28 @@ func Test_OnEnd(t *testing.T) { }) assert.Equal(t, actualCtx, ctx) }) + PatchConvey("Test token usage is correctly assigned to metrics", t, func() { + cb := &callbackHandler{ + nodeKey: "nodeKey", + threadID: "threadID", + } + cb.stateCh = make(chan *model.NodeDebugState, 1) + Mock(getNodeDebugStateCtx).Return(&nodeDebugStateCtxValue{invokeTimeMS: int64(1728630000), callbackInput: "input"}, true).Build() + output := &einomodel.CallbackOutput{ + Message: &schema.Message{}, + TokenUsage: &einomodel.TokenUsage{ + PromptTokens: 42, + CompletionTokens: 420, + }, + } + actualCtx := cb.OnEnd(ctx, info, output) + safego.Go(ctx, func() { + res, _ := <-cb.stateCh + assert.Equal(t, int64(42), res.Metrics.PromptTokens) + assert.Equal(t, int64(420), res.Metrics.CompletionTokens) + }) + assert.Equal(t, actualCtx, ctx) + }) } func Test_OnError(t *testing.T) {