diff --git a/src/semgrep_mcp/semgrep.py b/src/semgrep_mcp/semgrep.py index f762947..ce6fb6d 100644 --- a/src/semgrep_mcp/semgrep.py +++ b/src/semgrep_mcp/semgrep.py @@ -179,7 +179,9 @@ def shutdown(self) -> None: ################################################################################ -async def run_semgrep(args: list[str]) -> asyncio.subprocess.Process: +async def run_semgrep( + top_level_span: trace.Span | None, args: list[str] +) -> asyncio.subprocess.Process: """ Runs semgrep with the given arguments as a subprocess, without waiting for it to finish. """ @@ -190,6 +192,13 @@ async def run_semgrep(args: list[str]) -> asyncio.subprocess.Process: # Just so we get the debug logs for the MCP server env = os.environ.copy() env["SEMGREP_LOG_SRCS"] = "mcp" + if top_level_span: + env["SEMGREP_TRACE_PARENT_SPAN_ID"] = trace.format_span_id( + top_level_span.get_span_context().span_id + ) + env["SEMGREP_TRACE_PARENT_TRACE_ID"] = trace.format_trace_id( + top_level_span.get_span_context().trace_id + ) # Execute semgrep command process = await asyncio.create_subprocess_exec( @@ -212,7 +221,7 @@ async def run_semgrep_daemon(top_level_span: trace.Span) -> SemgrepContext | Non Returns None if the user doesn't have the Pro Engine installed. """ - resp = await run_semgrep(["--pro", "--version"]) + resp = await run_semgrep(top_level_span, ["--pro", "--version"]) # wait for the command to exit so the exit code is set await resp.communicate() @@ -227,15 +236,15 @@ async def run_semgrep_daemon(top_level_span: trace.Span) -> SemgrepContext | Non return None else: - process = await run_semgrep(["mcp", "--pro"]) + process = await run_semgrep(top_level_span, ["mcp", "--pro", "--trace"]) return SemgrepContext(process=process, top_level_span=top_level_span) -async def run_semgrep_output(args: list[str]) -> str: +async def run_semgrep_output(top_level_span: trace.Span | None, args: list[str]) -> str: """ Runs `semgrep` with the given arguments and returns the stdout. """ - process = await run_semgrep(args) + process = await run_semgrep(top_level_span, args) stdout, stderr = await process.communicate() if process.returncode != 0: diff --git a/src/semgrep_mcp/server.py b/src/semgrep_mcp/server.py index b4a4c81..f721ea6 100755 --- a/src/semgrep_mcp/server.py +++ b/src/semgrep_mcp/server.py @@ -351,7 +351,7 @@ async def get_supported_languages() -> list[str]: args = ["show", "supported-languages", "--experimental"] # Parse output and return list of languages - languages = await run_semgrep_output(args) + languages = await run_semgrep_output(top_level_span=None, args=args) return [lang.strip() for lang in languages.strip().split("\n") if lang.strip()] @@ -587,7 +587,7 @@ async def semgrep_scan_with_custom_rule( # Run semgrep scan with custom rule args = get_semgrep_scan_args(temp_dir, rule_file_path) - output = await run_semgrep_output(args) + output = await run_semgrep_output(top_level_span=None, args=args) results: SemgrepScanResult = SemgrepScanResult.model_validate_json(output) remove_temp_dir_from_results(results, temp_dir) return results @@ -632,7 +632,7 @@ async def semgrep_scan( # Create temporary files from code content temp_dir = create_temp_files_from_code_content(code_files) args = get_semgrep_scan_args(temp_dir, config) - output = await run_semgrep_output(args) + output = await run_semgrep_output(top_level_span=None, args=args) results: SemgrepScanResult = SemgrepScanResult.model_validate_json(output) remove_temp_dir_from_results(results, temp_dir) return results @@ -740,7 +740,7 @@ async def semgrep_scan_local( results = [] for cf in code_files: args = get_semgrep_scan_args(cf.path, config) - output = await run_semgrep_output(args) + output = await run_semgrep_output(top_level_span=None, args=args) result: SemgrepScanResult = SemgrepScanResult.model_validate_json(output) results.append(result) return results @@ -796,7 +796,7 @@ async def security_check( # Create temporary files from code content temp_dir = create_temp_files_from_code_content(code_files) args = get_semgrep_scan_args(temp_dir) - output = await run_semgrep_output(args) + output = await run_semgrep_output(top_level_span=None, args=args) results: SemgrepScanResult = SemgrepScanResult.model_validate_json(output) remove_temp_dir_from_results(results, temp_dir) @@ -859,7 +859,7 @@ async def get_abstract_syntax_tree( "--json", temp_file_path, ] - return await run_semgrep_output(args) + return await run_semgrep_output(top_level_span=None, args=args) except McpError as e: raise e