diff --git a/cmd/generate/context.go b/cmd/generate/context.go index c20a7023..85c4a318 100644 --- a/cmd/generate/context.go +++ b/cmd/generate/context.go @@ -26,7 +26,7 @@ func (h *generateCommandHandler) CreateContextFromPrompt() (*PromptPexContext, e } runID := fmt.Sprintf("run_%d", time.Now().Unix()) - context := &PromptPexContext{ + promptContext := &PromptPexContext{ // Unique identifier for the run RunID: runID, // The prompt content and metadata @@ -50,21 +50,21 @@ func (h *generateCommandHandler) CreateContextFromPrompt() (*PromptPexContext, e } else { sessionInfo = fmt.Sprintf("reloading session file at %s", *h.sessionFile) // Check if prompt hashes match - if existingContext.PromptHash != context.PromptHash { + if existingContext.PromptHash != promptContext.PromptHash { return nil, fmt.Errorf("prompt changed unable to reuse session file") } // Merge existing context data if existingContext != nil { - context = mergeContexts(existingContext, context) + promptContext = mergeContexts(existingContext, promptContext) } } } - h.WriteToParagraph(RenderMessagesToString(context.Prompt.Messages)) + h.WriteToParagraph(RenderMessagesToString(promptContext.Prompt.Messages)) h.WriteEndBox(sessionInfo) - return context, nil + return promptContext, nil } // loadContextFromFile loads a PromptPexContext from a JSON file diff --git a/cmd/generate/generate.go b/cmd/generate/generate.go index be2cf91f..6610bbd2 100644 --- a/cmd/generate/generate.go +++ b/cmd/generate/generate.go @@ -76,14 +76,14 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { sessionFile: util.Ptr(sessionFile), } - // Create context - context, err := handler.CreateContextFromPrompt() + // Create prompt context + promptContext, err := handler.CreateContextFromPrompt() if err != nil { return fmt.Errorf("failed to create context: %w", err) } // Run the PromptPex pipeline - if err := handler.RunTestGenerationPipeline(context); err != nil { + if err := handler.RunTestGenerationPipeline(promptContext); err != nil { // Disable usage help for pipeline failures cmd.SilenceUsage = true return fmt.Errorf("pipeline failed: %w", err) diff --git a/cmd/generate/llm.go b/cmd/generate/llm.go index c539bfc8..f679f397 100644 --- a/cmd/generate/llm.go +++ b/cmd/generate/llm.go @@ -28,11 +28,10 @@ func (h *generateCommandHandler) callModelWithRetry(step string, req azuremodels for attempt := 0; attempt <= maxRetries; attempt++ { sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(h.cfg.ErrOut)) sp.Start() - //nolint:gocritic,revive // TODO - defer sp.Stop() resp, err := h.client.GetChatCompletionStream(ctx, req, h.org) if err != nil { + sp.Stop() var rateLimitErr *azuremodels.RateLimitError if errors.As(err, &rateLimitErr) { if attempt < maxRetries { @@ -53,8 +52,6 @@ func (h *generateCommandHandler) callModelWithRetry(step string, req azuremodels return "", err } reader := resp.Reader - //nolint:gocritic,revive // TODO - defer reader.Close() var content strings.Builder for { @@ -63,6 +60,11 @@ func (h *generateCommandHandler) callModelWithRetry(step string, req azuremodels if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") { break } + if closeErr := reader.Close(); closeErr != nil { + // Log close error but don't override the original error + fmt.Fprintf(h.cfg.ErrOut, "Warning: failed to close reader: %v\n", closeErr) + } + sp.Stop() return "", err } for _, choice := range completion.Choices { @@ -75,6 +77,13 @@ func (h *generateCommandHandler) callModelWithRetry(step string, req azuremodels } } + // Properly close reader and stop spinner before returning success + err = reader.Close() + sp.Stop() + if err != nil { + return "", fmt.Errorf("failed to close reader: %w", err) + } + res := strings.TrimSpace(content.String()) h.LogLLMResponse(res) return res, nil