diff --git a/main.py b/main.py index 8b797d7..7a15dc3 100644 --- a/main.py +++ b/main.py @@ -2,43 +2,16 @@ import argparse import asyncio -import logging import shutil -from pathlib import Path from dotenv import load_dotenv +from src.utils.logging import configure_logging + load_dotenv() from src.masksql import MaskSQL # noqa: E402 -from src.utils.logging import configure_logging # noqa: E402 - - -logger = logging.getLogger(__name__) - - -def clean_cache_directory(cache_dir: str) -> None: - """Clean intermediate files from the data directory. - - Removes files matching the pattern [0-9]*_* but excludes files starting with 1_*. - This is used to clean up intermediate pipeline output files while preserving - the initial input files. - - Parameters - ---------- - cache_dir : str - Path to the cache directory to clean. - """ - cache_path = Path(cache_dir) - - if not cache_path.exists(): - logger.error(f"Data directory does not exist: {cache_dir}") - return - - shutil.rmtree(cache_path) - - logger.info("Cleanup complete") async def main() -> None: @@ -58,7 +31,7 @@ async def main() -> None: mask_sql = MaskSQL.from_config(args.config) if args.clean: - clean_cache_directory(mask_sql.conf.cache_dir) + shutil.rmtree(mask_sql.conf.cache_dir, ignore_errors=True) else: await mask_sql.evaluate() diff --git a/pyproject.toml b/pyproject.toml index 72e5896..9cac292 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ packages = ["src"] [dependency-groups] dev = [ "codecov>=2.1.13", + "loguru>=0.7.3", "mypy>=1.14.1", "nbqa>=1.9.1", "pip>=25.3", # Pinning version to address vulnerability GHSA-4xh5-x5gv-qwph diff --git a/src/data_models/masksql_input.py b/src/data_models/masksql_input.py index e59b49a..5351913 100644 --- a/src/data_models/masksql_input.py +++ b/src/data_models/masksql_input.py @@ -20,10 +20,10 @@ class MaskSqlInput(BaseObject): db_id: Identifier of the database the question is about question: Natural language question text query: Optional SQL query (may be empty for new inputs) - annotated_links: Dictionary of annotations for the question + gold_schema_links: Dictionary of annotations for the question """ db_id: str question: str query: str - annotated_links: dict[str, Any] + gold_schema_links: dict[str, Any] diff --git a/src/masksql.py b/src/masksql.py index 6821b72..05ecea3 100644 --- a/src/masksql.py +++ b/src/masksql.py @@ -21,10 +21,7 @@ from src.pipeline.exec_conc_sql import ExecuteConcreteSql from src.pipeline.gen_sql.gen_masked_sql import GenerateSymbolicSql from src.pipeline.init_data import InitData -from src.pipeline.link_schema.link_schema import ( - FilterSchemaLinksModel, - LinkSchema, -) +from src.pipeline.link_schema.link_schema import FilterSchemaLinksModel, LinkSchema from src.pipeline.link_values.link_values import FilterValueLinksModel, LinkValues from src.pipeline.pipeline import Pipeline from src.pipeline.rank_schema import RankSchemaResd @@ -74,10 +71,8 @@ def create_pipeline_stages(conf: MaskSqlConfig) -> list[JsonListProcessor]: LimitJson(), InitData(), *rank_schema, - # ResdItemCount(), AddFilteredSchema(conf.tables_path), AddSymbolTable(conf.tables_path), - # SlmSQL("slm_sql", conf.openai, model=conf.slm), DetectValues(conf.openai, model=conf.slm), LinkValues(conf.openai, model=conf.slm), CopyTransformer("value_links", "filtered_value_links", FilterValueLinksModel), @@ -94,7 +89,6 @@ def create_pipeline_stages(conf: MaskSqlConfig) -> list[JsonListProcessor]: RepairSQL(conf.openai, model=conf.slm), CalcExecAcc(conf.db_path, conf.policy), AddInferenceAttack(conf.openai, model=conf.llm), - # # PrintProps(['question', 'symbolic.question', 'attack']) Results(), ] @@ -175,7 +169,7 @@ async def query(self, db_id: str, question: str) -> MaskSqlOutput: db_id=db_id, question=question, query="", - annotated_links={}, + gold_schema_links={}, ) results = await self.pipeline.run([data]) return results[0] diff --git a/src/pipeline/add_symbolic_question.py b/src/pipeline/add_symbolic_question.py index d83759b..3833f53 100644 --- a/src/pipeline/add_symbolic_question.py +++ b/src/pipeline/add_symbolic_question.py @@ -1,15 +1,12 @@ """Deterministic masking of terms in questions.""" -import logging +from loguru import logger from src.pipeline.add_symb_schema import AddSymbolicSchema, SymbolicSchema from src.pipeline.base_processor.list_processor import JsonListProcessor from src.utils.strings import replace_str -logger = logging.getLogger(__name__) - - class SymbolicQuestion(SymbolicSchema): """ Data model for questions with symbolic representations. diff --git a/src/pipeline/base_processor/list_processor.py b/src/pipeline/base_processor/list_processor.py index bb06477..bc1f390 100644 --- a/src/pipeline/base_processor/list_processor.py +++ b/src/pipeline/base_processor/list_processor.py @@ -4,9 +4,12 @@ from abc import ABC, abstractmethod from typing import Generic, Type, TypeVar +from loguru import logger + from src.data_cache.json_cache import JsonCache from src.data_models.base_object import BaseObject from src.utils.async_utils import apply_async +from src.utils.logging import log T = TypeVar("T", bound=BaseObject) @@ -57,6 +60,7 @@ def get_cache_file_path(self, cache_dir: str, sequence: int) -> str: """ return os.path.join(cache_dir, f"{sequence}_{self.name}.json") + @logger.catch(message="Failed to process row", reraise=True) async def __process_row_internal(self, row: T) -> U: if self.cache and not self.force and row.idx in self.cache: return self.cache[row.idx] @@ -89,6 +93,7 @@ def _pre_run(self) -> None: # noqa: B027 def _post_run(self) -> None: # noqa: B027 """Override to add post-processing logic after run.""" + @log("Processor completed: {0}") async def run(self, input_data: list[T]) -> list[U]: """ Process input file and return output_data. @@ -111,3 +116,7 @@ async def run(self, input_data: list[T]) -> list[U]: self._post_run() return output_data + + def __repr__(self) -> str: + """Name of the processor.""" + return self.name diff --git a/src/pipeline/base_processor/list_transformer.py b/src/pipeline/base_processor/list_transformer.py index a899a76..6275ccb 100644 --- a/src/pipeline/base_processor/list_transformer.py +++ b/src/pipeline/base_processor/list_transformer.py @@ -1,6 +1,5 @@ """List transformation base classes.""" -import logging import os from abc import ABC @@ -8,9 +7,6 @@ from src.pipeline.base_processor.list_processor import JsonListProcessor -logger = logging.getLogger(__name__) - - FORCE = int(os.environ.get("FORCE", "0")) > 0 diff --git a/src/pipeline/base_processor/prompt_processor.py b/src/pipeline/base_processor/prompt_processor.py index ee5c9e6..f28bdf8 100644 --- a/src/pipeline/base_processor/prompt_processor.py +++ b/src/pipeline/base_processor/prompt_processor.py @@ -1,14 +1,16 @@ """Base base_processor for LLM-based value detection.""" +import uuid from abc import ABC, abstractmethod from json import JSONDecodeError from typing import Any, Generic, Type, TypeVar +from loguru import logger + from src.config import OpenAIConfig from src.pipeline.base_processor.list_processor import JsonListProcessor from src.pipeline.init_data import InitData from src.utils.llm_util import send_prompt -from src.utils.logging import logger from src.utils.timer import Timer @@ -48,10 +50,14 @@ def __init__( self.openai_config = openai_config self.model = model self.include_stats = include_stats + self.prompt_logger = logger.bind(type="prompt", name=self.name) async def _prompt_llm(self, row: T, prompt: str) -> tuple[Any, str]: + prompt_logger = self.prompt_logger.bind(prompt_id=uuid.uuid4()) try: + prompt_logger.bind(is_req=True).debug(prompt) res, toks = await send_prompt(prompt, self.openai_config, model=self.model) + prompt_logger.bind(is_req=False).debug(res) except JSONDecodeError as e: logger.error(f"Sending prompts failed: {e}") return "", "0" diff --git a/src/pipeline/exec_conc_sql.py b/src/pipeline/exec_conc_sql.py index 634a877..daf2538 100644 --- a/src/pipeline/exec_conc_sql.py +++ b/src/pipeline/exec_conc_sql.py @@ -2,11 +2,11 @@ from typing import Any, Optional +from loguru import logger from pydantic import BaseModel from src.pipeline.base_processor.list_processor import JsonListProcessor from src.pipeline.unmask import AddConcreteSql -from src.utils.logging import logger from src.utils.sqlite_facade import SqliteFacade diff --git a/src/pipeline/link_schema/link_schema.py b/src/pipeline/link_schema/link_schema.py index f2b9803..1919937 100644 --- a/src/pipeline/link_schema/link_schema.py +++ b/src/pipeline/link_schema/link_schema.py @@ -1,11 +1,12 @@ """Schema linking from questions to database schemas.""" +from loguru import logger + from src.config import OpenAIConfig from src.pipeline.base_processor.prompt_processor import PromptProcessor from src.pipeline.link_schema.prompts.v4 import SCHEMA_LINK_PROMPT_V4 from src.pipeline.link_values.link_values import FilterValueLinksModel from src.utils.llm_util import extract_object -from src.utils.logging import logger class LinkSchema(PromptProcessor[FilterValueLinksModel, "LinkSchema.Model"]): diff --git a/src/pipeline/pipeline.py b/src/pipeline/pipeline.py index 0bca29c..2b5afd4 100644 --- a/src/pipeline/pipeline.py +++ b/src/pipeline/pipeline.py @@ -7,9 +7,6 @@ from src.pipeline.base_processor.list_processor import JsonListProcessor from src.utils.logging import ( log_pipeline_summary, - log_stage_complete, - log_stage_start, - reset_stage_timings, ) from src.utils.mem import track_memory_async from src.utils.timer import Timer @@ -40,22 +37,9 @@ def __init__( stage.set_cache_file(pipeline_cache_dir, i + 1) async def __run_internal(self, input_data: list[T]) -> list[Any]: - # tmp_file: Any = input_file tmp_data = input_data - timer: Timer = Timer() - timer.start() - last_lap_time = 0.0 - for stage in self.stages: - log_stage_start(stage.name) tmp_data = await stage.run(tmp_data) - - # Get cumulative time and calculate stage time - cumulative_time = timer.lap() - stage_time = cumulative_time - last_lap_time - last_lap_time = cumulative_time - - log_stage_complete(stage.name, stage_time) return tmp_data async def run(self, input_data: list[T]) -> list[Any]: @@ -72,9 +56,6 @@ async def run(self, input_data: list[T]) -> list[Any]: tuple[Any, float, float] Tuple of (result, average_memory_mb, peak_memory_mb) """ - # Reset timing tracker for new pipeline run - reset_stage_timings() - timer: Timer = Timer.start() result, avg_mem, peak_mem = await track_memory_async( self.__run_internal, input_data diff --git a/src/pipeline/repair_sql/prompts/v3.py b/src/pipeline/repair_sql/prompts/v3.py index f1490ff..923ecdd 100644 --- a/src/pipeline/repair_sql/prompts/v3.py +++ b/src/pipeline/repair_sql/prompts/v3.py @@ -117,7 +117,7 @@ The execution result: no such table: store_locations -======= Your task ======= +******* Your task ******* ************************** Database schema: {schema} diff --git a/src/pipeline/repair_sql/prompts/v4.py b/src/pipeline/repair_sql/prompts/v4.py index f836f27..bad74f2 100644 --- a/src/pipeline/repair_sql/prompts/v4.py +++ b/src/pipeline/repair_sql/prompts/v4.py @@ -117,7 +117,7 @@ The execution result: no such table: store_locations -======= Your task ======= +******* Your task ******* ************************** Database schema: {schema} diff --git a/src/pipeline/repair_sql/repair_sql.py b/src/pipeline/repair_sql/repair_sql.py index f0a152f..19fa9e4 100644 --- a/src/pipeline/repair_sql/repair_sql.py +++ b/src/pipeline/repair_sql/repair_sql.py @@ -2,11 +2,12 @@ from typing import Any +from loguru import logger + from src.config import OpenAIConfig from src.pipeline.base_processor.prompt_processor import PromptProcessor from src.pipeline.exec_conc_sql import ExecuteConcreteSql from src.pipeline.repair_sql.prompts.v3 import REPAIR_SQL_PROMPT_V3 -from src.utils.logging import logger from src.utils.strings import extract_sql diff --git a/src/pipeline/repair_symb_sql/prompts/v1.py b/src/pipeline/repair_symb_sql/prompts/v1.py index 677f569..55803b4 100644 --- a/src/pipeline/repair_symb_sql/prompts/v1.py +++ b/src/pipeline/repair_symb_sql/prompts/v1.py @@ -117,7 +117,7 @@ The execution result: no such table: store_locations -======= Your task ======= +******* Your task ******* ************************** Database schema: {schema} diff --git a/src/pipeline/repair_symb_sql/prompts/v2.py b/src/pipeline/repair_symb_sql/prompts/v2.py index 10a864c..14787d2 100644 --- a/src/pipeline/repair_symb_sql/prompts/v2.py +++ b/src/pipeline/repair_symb_sql/prompts/v2.py @@ -108,7 +108,7 @@ The corrected SQL query is: SELECT COUNT(*) FROM [T1] WHERE [T1].[C1] = '[V1]' AND [C2] = 0 -======= Your task ======= +******* Your task ******* ************************** Database schema: {schema} diff --git a/src/pipeline/repair_symb_sql/raw_v2.py b/src/pipeline/repair_symb_sql/raw_v2.py index b916173..c043fa9 100644 --- a/src/pipeline/repair_symb_sql/raw_v2.py +++ b/src/pipeline/repair_symb_sql/raw_v2.py @@ -108,7 +108,7 @@ The corrected SQL query is: SELECT COUNT(*) FROM [T1] WHERE [T1].[C1] = '[V1]' AND [C2] = 0 -======= Your task ======= +******* Your task ******* ************************** NL Question and Database Schema: {symbolic_raw} diff --git a/src/pipeline/resd/run_resdsql.py b/src/pipeline/resd/run_resdsql.py index 53d81ad..b923043 100644 --- a/src/pipeline/resd/run_resdsql.py +++ b/src/pipeline/resd/run_resdsql.py @@ -6,10 +6,12 @@ from pathlib import Path from typing import Any +from loguru import logger + from src.pipeline.base_processor.list_processor import JsonListProcessor from src.pipeline.init_data import InitData from src.utils.json_io import write_json -from src.utils.logging import console, log_error, log_success, logger +from src.utils.logging import console class RunResdsql(JsonListProcessor[InitData.Model, InitData.Model]): @@ -85,7 +87,7 @@ async def run(self, input_data: list[InitData.Model]) -> list[InitData.Model]: # Skip if RESDSQL output already exists and force is not enabled if not self.force and self.resd_output_path.exists(): logger.info( - f"[dim]⏭ Skipping RESDSQL pipeline - output exists:[/dim] " + f"Skipping RESDSQL pipeline - output exists:" f"{self.resd_output_path.name}" ) return input_data @@ -107,13 +109,10 @@ async def run(self, input_data: list[InitData.Model]) -> list[InitData.Model]: self._generate_text2sql_data() self._add_question_ids() - log_success( - "RESDSQL pipeline completed", - output_file=str(self.resd_output_path), - ) + logger.info(f"RESDSQL pipeline completed: {self.resd_output_path}") except Exception as e: - log_error(f"RESDSQL pipeline failed: {e}") + logger.error(f"RESDSQL pipeline failed: {e}") raise return input_data @@ -147,7 +146,7 @@ def _run_step(self, step_name: str, script: str, args: list[str]) -> None: ) if result.returncode != 0: - log_error( + logger.error( f"RESDSQL step failed: {step_name}", script=script, exit_code=result.returncode, diff --git a/src/pipeline/results.py b/src/pipeline/results.py index f5a8da5..afde9bb 100644 --- a/src/pipeline/results.py +++ b/src/pipeline/results.py @@ -45,7 +45,7 @@ async def _process_row( # if "attack" in row and "annotated_links" in row: masked_terms = row.symbolic.masked_terms attack = row.attack - a_links = row.annotated_links + a_links = row.gold_schema_links ri_terms = 0 num_masks = len(masked_terms) diff --git a/src/utils/json_io.py b/src/utils/json_io.py index cc9c7d3..76e6abe 100644 --- a/src/utils/json_io.py +++ b/src/utils/json_io.py @@ -3,7 +3,8 @@ import json from typing import Any, Type, TypeVar -from pydantic import BaseModel +from loguru import logger +from pydantic import BaseModel, ValidationError T = TypeVar("T", bound=BaseModel) @@ -26,6 +27,9 @@ def read_json_raw(path: str) -> Any: return json.load(f) +@logger.catch( + message="Failed to validate data", reraise=True, exception=ValidationError +) def read_json(path: str, cls: Type[T]) -> list[T]: """ Read and parse a JSON file. diff --git a/src/utils/llm_util.py b/src/utils/llm_util.py index b6aa868..92ca004 100644 --- a/src/utils/llm_util.py +++ b/src/utils/llm_util.py @@ -2,18 +2,16 @@ import ast import json -import logging import os import re from typing import Any +from loguru import logger from openai import AsyncClient from src.config import OpenAIConfig -logger = logging.getLogger(__name__) - VLM_ARCH = os.environ.get("VLM_ARCH") MAX_COMPLETION_TOKENS = os.environ.get("MAX_COMPLETION_TOKENS") @@ -72,12 +70,6 @@ async def send_prompt( timeout=openai_config.timeout, ) - # Concise logging with rich markup - logger.debug( - f"[cyan]LLM Request[/cyan] → [bold]{model}[/bold] " - f"([dim]{len(prompt)} chars[/dim])" - ) - response = await client.chat.completions.create( model=model, messages=[ @@ -89,18 +81,12 @@ async def send_prompt( max_completion_tokens=openai_config.max_completion_tokens, ) if response.choices is None: - print(prompt) raise Exception(f"LM prompts failed: {response.model_extra}") usage = "0" if response.usage: usage = str(response.usage.total_tokens) content = response.choices[0].message.content or "" - logger.debug( - f"[green]LLM Response[/green] ← [bold]{usage}[/bold] tokens " - f"([dim]{len(content)} chars[/dim])" - ) - return content, usage diff --git a/src/utils/logging.py b/src/utils/logging.py index 3fc7dbf..7f19063 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,11 +1,12 @@ """Logging configuration utilities using rich library.""" -import logging import os -from typing import Any +import sys +from functools import wraps +from typing import Any, Awaitable, Callable, ParamSpec, TypeVar +from loguru import logger from rich.console import Console -from rich.logging import RichHandler from rich.panel import Panel from rich.table import Table from rich.text import Text @@ -25,52 +26,7 @@ install_rich_traceback(console=console, show_locals=False, width=100, word_wrap=True) -def configure_logging() -> None: - """Configure Python logging with rich formatting and custom handlers. - - This function sets up logging with: - - Rich colored console output - - Concise timestamp format - - Different colors for different log levels - - Enhanced traceback formatting for exceptions - """ - # Remove existing handlers to avoid duplicates - root_logger = logging.getLogger() - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - - # Create rich handler with custom formatting - rich_handler = RichHandler( - console=console, - show_time=True, - show_path=False, - show_level=True, - rich_tracebacks=True, - tracebacks_show_locals=False, - markup=True, - log_time_format="[%H:%M:%S]", - omit_repeated_times=False, - ) - - # Minimal formatter for cleaner output - rich_handler.setFormatter( - logging.Formatter( - fmt="%(message)s", - datefmt="[%X]", - ) - ) - - # Configure root logger - root_logger.addHandler(rich_handler) - root_logger.setLevel(LOG_LEVEL) - - # Silence verbose HTTP request logs from third-party libraries - logging.getLogger("httpx").setLevel(logging.WARNING) - logging.getLogger("openai").setLevel(logging.WARNING) - logging.getLogger("urllib3").setLevel(logging.WARNING) - - -def log_error(message: str, **kwargs: Any) -> None: +def log_panel(message: str, **kwargs: Any) -> None: """Log an error message in a styled box. Parameters @@ -101,143 +57,29 @@ def log_error(message: str, **kwargs: Any) -> None: console.print(panel) -def log_warning(message: str, **kwargs: Any) -> None: - """Log a warning message in a styled box. - - Parameters - ---------- - message : str - The warning message to display - **kwargs : Any - Additional context information to include in the warning box - """ - # Build warning content - warning_text = Text() - warning_text.append(message, style="bold yellow") - - if kwargs: - warning_text.append("\n\n", style="") - warning_text.append("Context:\n", style="bold cyan") - for key, value in kwargs.items(): - warning_text.append(f" {key}: ", style="cyan") - warning_text.append(f"{value}\n", style="white") - - # Display in a yellow panel - panel = Panel( - warning_text, - title="[bold yellow]WARNING", - border_style="yellow", - padding=(1, 2), - ) - console.print(panel) - - -def log_success(message: str, **kwargs: Any) -> None: - """Log a success message in a styled box. - - Parameters - ---------- - message : str - The success message to display - **kwargs : Any - Additional context information to include in the success box - """ - # Build success content - success_text = Text() - success_text.append(message, style="bold green") - - if kwargs: - success_text.append("\n\n", style="") - success_text.append("Details:\n", style="bold cyan") - for key, value in kwargs.items(): - success_text.append(f" {key}: ", style="cyan") - success_text.append(f"{value}\n", style="white") - - # Display in a green panel - panel = Panel( - success_text, - title="[bold green]SUCCESS", - border_style="green", - padding=(1, 2), - ) - console.print(panel) - - -def log_info(message: str, **kwargs: Any) -> None: - """Log an info message in a styled box. - - Parameters - ---------- - message : str - The info message to display - **kwargs : Any - Additional context information to include in the info box - """ - # Build info content - info_text = Text() - info_text.append(message, style="bold blue") - - if kwargs: - info_text.append("\n\n", style="") - info_text.append("Details:\n", style="bold cyan") - for key, value in kwargs.items(): - info_text.append(f" {key}: ", style="cyan") - info_text.append(f"{value}\n", style="white") - - # Display in a blue panel - panel = Panel( - info_text, - title="[bold blue]INFO", - border_style="blue", - padding=(1, 2), +def configure_logging() -> None: + """Configure Python logging with rich formatting and custom handlers.""" + logger.remove() + logger.add( + sys.stdout, + level=LOG_LEVEL, + colorize=True, + backtrace=False, + catch=False, + format="[{level:>7}]: {message}", ) - console.print(panel) - - -# Create a logger instance that can be imported -logger = logging.getLogger("masksql") - - -def log_stage_start(stage_name: str) -> None: - """Log the start of a pipeline stage. - - Parameters - ---------- - stage_name : str - Name of the pipeline stage starting - """ - console.print( - f"\n[bold cyan]▶ Starting Stage:[/bold cyan] [bold white]{stage_name}[/bold white]" + logger.add( + "logs/debug-{time:MMMD-HH-mm}.log", + level="DEBUG", + format="[{time:HH:mm:ss}]-[{level:<7}]-[{name:>20} | {function:<25}:{line:<3}]: {message}", ) - -def log_stage_complete(stage_name: str, elapsed_time: float) -> None: - """Log the completion of a pipeline stage with timing. - - Parameters - ---------- - stage_name : str - Name of the pipeline stage completed - elapsed_time : float - Time taken to complete the stage in seconds - """ - # Store timing for summary - _stage_timings.append((stage_name, elapsed_time)) - - # Format time with appropriate precision and color - if elapsed_time < 1.0: - time_str = f"{elapsed_time:.3f}s" - time_color = "green" - elif elapsed_time < 10.0: - time_str = f"{elapsed_time:.2f}s" - time_color = "yellow" - else: - time_str = f"{elapsed_time:.2f}s" - time_color = "red" - - console.print( - f"[bold green]✓ Done Stage:[/bold green] [bold white]{stage_name}[/bold white] " - f"[dim]│[/dim] [{time_color}]{time_str}[/{time_color}]" + logger.add( + "logs/prompts-{time:MMMD-HH-mm}.jsonl", + level="DEBUG", + filter=lambda record: record["extra"].get("type") == "prompt", + format="{message}", + serialize=True, ) @@ -339,10 +181,25 @@ def log_pipeline_summary( console.print("\n") -def reset_stage_timings() -> None: - """Reset the stage timings tracker. +P = ParamSpec("P") +R = TypeVar("R") +AR = Awaitable[R] - This should be called at the start of a pipeline run. - """ - global _stage_timings # noqa: PLW0603 - _stage_timings = [] + +def log( + message: str = "", before: str | None = None +) -> Callable[[Callable[P, AR]], Callable[P, AR]]: + """Log messages before and after an async function execution.""" + + def decorator(func: Callable[P, AR]) -> Callable[P, AR]: + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> AR: + if before is not None: + logger.info(before.format(*args, **kwargs)) + result = await func(*args, **kwargs) + logger.info(message.format(*args, **kwargs)) + return result + + return wrapper + + return decorator diff --git a/src/utils/mem.py b/src/utils/mem.py index 39c0241..e10585b 100644 --- a/src/utils/mem.py +++ b/src/utils/mem.py @@ -1,6 +1,5 @@ """Memory usage monitoring utilities.""" -import logging import os import threading import time @@ -8,9 +7,7 @@ from typing import Any import psutil - - -logger = logging.getLogger(__name__) +from loguru import logger def _monitor_memory( diff --git a/src/utils/sqlite_facade.py b/src/utils/sqlite_facade.py index ce947e6..36896d1 100644 --- a/src/utils/sqlite_facade.py +++ b/src/utils/sqlite_facade.py @@ -8,7 +8,7 @@ from sqlite3 import Connection from typing import Any -from src.utils.logging import logger +from loguru import logger DB_TIMEOUT = 10000 diff --git a/src/utils/strings.py b/src/utils/strings.py index 4d1637b..7adb4f7 100644 --- a/src/utils/strings.py +++ b/src/utils/strings.py @@ -4,7 +4,7 @@ from difflib import SequenceMatcher from enum import Enum -from src.utils.logging import logger +from loguru import logger def delete_whitespace(content: str) -> str: diff --git a/tests/e2e/test_data/1_input.json b/tests/e2e/test_data/1_input.json index 9d01a16..0cd5a28 100644 --- a/tests/e2e/test_data/1_input.json +++ b/tests/e2e/test_data/1_input.json @@ -14,11 +14,6 @@ "pets": "TABLE:pets", "weight": "COLUMN:pets.weight" }, - "annotated_links": { - "10": "VALUE:pets.weight", - "pets": "TABLE:pets", - "weight": "COLUMN:pets.weight" - }, "idx": "spider_45" } ] diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py deleted file mode 100644 index c5114fe..0000000 --- a/tests/utils/test_logging.py +++ /dev/null @@ -1,356 +0,0 @@ -"""Tests for logging configuration utilities.""" - -from unittest.mock import MagicMock, call, patch - -from rich.console import Console -from rich.panel import Panel -from rich.table import Table - -from src.utils import logging as logging_module -from src.utils.logging import ( - configure_logging, - log_error, - log_info, - log_pipeline_summary, - log_stage_complete, - log_stage_start, - log_success, - log_warning, - reset_stage_timings, -) - - -class TestConfigureLogging: - """Test suite for configure_logging function.""" - - def test_configure_logging_default_level(self, monkeypatch): - """Test logging configuration with default INFO level.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Should add handler to root logger - mock_logger.addHandler.assert_called_once() - # Check that root logger setLevel was called with INFO - assert call("INFO") in mock_logger.setLevel.call_args_list - - def test_configure_logging_custom_level(self, monkeypatch): - """Test logging configuration with custom LOG_LEVEL.""" - monkeypatch.setattr(logging_module, "LOG_LEVEL", "DEBUG") - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Should set logger to DEBUG level - assert call("DEBUG") in mock_logger.setLevel.call_args_list - - def test_configure_logging_error_level(self, monkeypatch): - """Test logging configuration with ERROR level.""" - monkeypatch.setattr(logging_module, "LOG_LEVEL", "ERROR") - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - assert call("ERROR") in mock_logger.setLevel.call_args_list - - def test_configure_logging_warning_level(self, monkeypatch): - """Test logging configuration with WARNING level.""" - monkeypatch.setattr(logging_module, "LOG_LEVEL", "WARNING") - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - assert call("WARNING") in mock_logger.setLevel.call_args_list - - def test_configure_logging_removes_existing_handlers(self, monkeypatch): - """Test that existing handlers are removed before adding new ones.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_handler1 = MagicMock() - mock_handler2 = MagicMock() - mock_logger.handlers = [mock_handler1, mock_handler2] - mock_get_logger.return_value = mock_logger - - configure_logging() - - # Should remove existing handlers - assert mock_logger.removeHandler.call_count == 2 - mock_logger.removeHandler.assert_any_call(mock_handler1) - mock_logger.removeHandler.assert_any_call(mock_handler2) - - def test_configure_logging_adds_rich_handler(self, monkeypatch): - """Test that RichHandler is added with correct configuration.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Should add a handler - mock_logger.addHandler.assert_called_once() - added_handler = mock_logger.addHandler.call_args[0][0] - - # Verify it's a RichHandler by checking its type name - assert added_handler.__class__.__name__ == "RichHandler" - - def test_configure_logging_formatter(self, monkeypatch): - """Test that the formatter is configured correctly.""" - monkeypatch.delenv("LOG_LEVEL", raising=False) - - with patch("logging.getLogger") as mock_get_logger: - mock_logger = MagicMock() - mock_get_logger.return_value = mock_logger - mock_logger.handlers = [] - - configure_logging() - - # Get the handler that was added - added_handler = mock_logger.addHandler.call_args[0][0] - - # Check the formatter is minimal - formatter = added_handler.formatter - assert formatter._fmt == "%(message)s" - - -class TestLogError: - """Test suite for log_error function.""" - - def test_log_error_simple_message(self): - """Test logging a simple error message.""" - with patch.object(Console, "print") as mock_print: - log_error("Test error message") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "red" - - def test_log_error_with_context(self): - """Test logging an error message with context.""" - with patch.object(Console, "print") as mock_print: - log_error("Test error", file="test.py", line=42) - - # Should print a Panel with context - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogWarning: - """Test suite for log_warning function.""" - - def test_log_warning_simple_message(self): - """Test logging a simple warning message.""" - with patch.object(Console, "print") as mock_print: - log_warning("Test warning message") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "yellow" - - def test_log_warning_with_context(self): - """Test logging a warning message with context.""" - with patch.object(Console, "print") as mock_print: - log_warning("Deprecated function", function="old_func") - - # Should print a Panel with context - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogSuccess: - """Test suite for log_success function.""" - - def test_log_success_simple_message(self): - """Test logging a simple success message.""" - with patch.object(Console, "print") as mock_print: - log_success("Operation completed") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "green" - - def test_log_success_with_details(self): - """Test logging a success message with details.""" - with patch.object(Console, "print") as mock_print: - log_success("File saved", path="/tmp/file.txt", size=1024) - - # Should print a Panel with details - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogInfo: - """Test suite for log_info function.""" - - def test_log_info_simple_message(self): - """Test logging a simple info message.""" - with patch.object(Console, "print") as mock_print: - log_info("Processing data") - - # Should print a Panel - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - assert panel.border_style == "blue" - - def test_log_info_with_details(self): - """Test logging an info message with details.""" - with patch.object(Console, "print") as mock_print: - log_info("Processing", items=100, status="active") - - # Should print a Panel with details - mock_print.assert_called_once() - panel = mock_print.call_args[0][0] - assert isinstance(panel, Panel) - - -class TestLogStageStart: - """Test suite for log_stage_start function.""" - - def test_log_stage_start(self): - """Test logging stage start.""" - with patch.object(Console, "print") as mock_print: - log_stage_start("TestStage") - - # Should print formatted output - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "Starting Stage" in call_args - - -class TestLogStageComplete: - """Test suite for log_stage_complete function.""" - - def test_log_stage_complete_fast(self): - """Test logging stage completion with fast time (< 1s).""" - with patch.object(Console, "print") as mock_print: - log_stage_complete("TestStage", 0.5) - - # Should print formatted output with green timing - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "Done Stage" in call_args - assert "0.500s" in call_args - - def test_log_stage_complete_medium(self): - """Test logging stage completion with medium time (1-10s).""" - with patch.object(Console, "print") as mock_print: - log_stage_complete("TestStage", 5.0) - - # Should print formatted output with yellow timing - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "5.00s" in call_args - - def test_log_stage_complete_slow(self): - """Test logging stage completion with slow time (> 10s).""" - with patch.object(Console, "print") as mock_print: - log_stage_complete("TestStage", 15.5) - - # Should print formatted output with red timing - mock_print.assert_called_once() - call_args = mock_print.call_args[0][0] - assert "TestStage" in call_args - assert "15.50s" in call_args - - -class TestResetStageTimings: - """Test suite for reset_stage_timings function.""" - - def test_reset_stage_timings(self): - """Test resetting stage timings.""" - # Add some timings - with patch.object(Console, "print"): - log_stage_complete("Stage1", 1.0) - log_stage_complete("Stage2", 2.0) - - # Reset timings - reset_stage_timings() - - # Verify timings are empty by checking summary output - with patch.object(Console, "print") as mock_print: - log_pipeline_summary(10.0, 100.0, 150.0) - - # Should have been called multiple times (for table, memory, etc.) - assert mock_print.call_count > 0 - - -class TestLogPipelineSummary: - """Test suite for log_pipeline_summary function.""" - - def test_log_pipeline_summary_basic(self): - """Test logging pipeline summary without results.""" - reset_stage_timings() - - with patch.object(Console, "print") as mock_print: - # Add some stage timings first - log_stage_complete("Stage1", 1.0) - log_stage_complete("Stage2", 2.0) - - mock_print.reset_mock() - - # Log summary - log_pipeline_summary(3.0, 100.0, 150.0) - - # Should print multiple times (table, memory, etc.) - assert mock_print.call_count >= 2 - - # Check that at least one call contains a Table - has_table = False - for call_args in mock_print.call_args_list: - if call_args[0]: # Check positional args - arg = call_args[0][0] - if isinstance(arg, Table): - has_table = True - break - assert has_table - - def test_log_pipeline_summary_with_results(self): - """Test logging pipeline summary with results.""" - reset_stage_timings() - - with patch.object(Console, "print") as mock_print: - # Add some stage timings - log_stage_complete("Stage1", 1.0) - - mock_print.reset_mock() - - # Log summary with results - results = {"accuracy": 0.95, "latency": 2.5, "count": 100} - log_pipeline_summary(3.0, 100.0, 150.0, results) - - # Should print multiple times - assert mock_print.call_count >= 2 diff --git a/uv.lock b/uv.lock index daebed6..e663513 100644 --- a/uv.lock +++ b/uv.lock @@ -1337,6 +1337,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/40/791891d4c0c4dab4c5e187c17261cedc26285fd41541577f900470a45a4d/license_expression-30.4.4-py3-none-any.whl", hash = "sha256:421788fdcadb41f049d2dc934ce666626265aeccefddd25e162a26f23bcbf8a4", size = 120615 }, ] +[[package]] +name = "loguru" +version = "0.7.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "win32-setctime", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559, upload-time = "2024-12-06T11:20:56.608Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c", size = 61595, upload-time = "2024-12-06T11:20:54.538Z" }, +] + [[package]] name = "markdown" version = "3.10" @@ -1481,6 +1494,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "codecov" }, + { name = "loguru" }, { name = "mypy" }, { name = "nbqa" }, { name = "pip" }, @@ -1548,6 +1562,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "codecov", specifier = ">=2.1.13" }, + { name = "loguru", specifier = ">=0.7.3" }, { name = "mypy", specifier = ">=1.14.1" }, { name = "nbqa", specifier = ">=1.9.1" }, { name = "pip", specifier = ">=25.3" }, @@ -3943,9 +3958,9 @@ wheels = [ name = "urllib3" version = "2.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556 } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584 }, + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, ] [[package]] @@ -4055,6 +4070,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025 }, ] +[[package]] +name = "win32-setctime" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/8f/705086c9d734d3b663af0e9bb3d4de6578d08f46b1b101c2442fd9aecaa2/win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0", size = 4867, upload-time = "2024-12-07T15:28:28.314Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083, upload-time = "2024-12-07T15:28:26.465Z" }, +] + [[package]] name = "wrapt" version = "2.0.1"