diff --git a/simple_typing_application/sentence_generator/openai_sentence_generator.py b/simple_typing_application/sentence_generator/openai_sentence_generator.py index 73b9d99..e110b33 100644 --- a/simple_typing_application/sentence_generator/openai_sentence_generator.py +++ b/simple_typing_application/sentence_generator/openai_sentence_generator.py @@ -1,6 +1,6 @@ from __future__ import annotations from datetime import datetime as dt -from logging import getLogger, Logger +from logging import DEBUG, Logger, getLogger from typing import Any, Callable, cast from langchain.agents import create_agent @@ -12,6 +12,7 @@ from .utils import split_hiraganas_alphabets_symbols, splitted_hiraganas_alphabets_symbols_to_typing_target # noqa from ..utils.japanese_string_utils import delete_space_between_hiraganas from ..utils.rerun import rerun_deco +from ..utils.stopwatch import stopwatch class _OutputSchema(BaseModel): @@ -52,11 +53,8 @@ def __init__( model=model, temperature=temperature, max_retries=max_retry, - api_key=( - (lambda: openai_api_key) - if isinstance(openai_api_key, str) - else openai_api_key - ), + api_key=((lambda: openai_api_key) if isinstance(openai_api_key, str) else openai_api_key), + reasoning_effort="minimal" if model.startswith("gpt-5") else None, seed=seed, ) self._agent = create_agent( @@ -97,11 +95,12 @@ async def generate( "content": self._user_prompt, }, ] - self._logger.debug(f'agent input messages: {messages}') - ret: dict[str, Any] = await self._agent.ainvoke( - {"messages": messages}, # type: ignore - ) - self._logger.debug(f'agent response: {ret}') + self._logger.debug(f"agent input messages: {messages}") + with stopwatch(level=DEBUG, logger=self._logger, prefix="OpenAI agent invocation"): + ret: dict[str, Any] = await self._agent.ainvoke( + {"messages": messages}, + ) # type: ignore + self._logger.debug(f"agent response: {ret}") # store to memory output = cast(_OutputSchema, ret["structured_response"]) diff --git a/simple_typing_application/utils/stopwatch.py b/simple_typing_application/utils/stopwatch.py new file mode 100644 index 0000000..2dee7f4 --- /dev/null +++ b/simple_typing_application/utils/stopwatch.py @@ -0,0 +1,121 @@ +import logging +import time +from contextlib import contextmanager +from functools import wraps, partial +from logging import Logger, getLogger +from typing import Callable, Generator, TypeVar, ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + +logger: Logger = getLogger(__name__) + + +@contextmanager +def stopwatch( + level: int = logging.INFO, + prefix: str | None = None, + postfix: str = "", + logger: Logger = logger, +) -> Generator[None, None, None]: + """Context manager to measure the execution time of a code block. + + Args: + level (int, optional): log level. Defaults to logging.INFO. + Must be one of logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL. + prefix (str | None, optional): prefix of the log message. Defaults to None. + postfix (str, optional): postfix of the log message. Defaults to "". + logger (Logger, optional): logger. Defaults to logger. + + Yields: + None: None + """ # noqa + + # validation + if level not in { + logging.DEBUG, + logging.INFO, + logging.WARNING, + logging.ERROR, + logging.CRITICAL, + }: + raise ValueError( + f"Invalid log level: {level}. " + "Must be one of logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL.", # noqa + ) + + # preparation + if prefix is None: + prefix = "Execution time" + + log_msg_fmt: str = f"{prefix}: {{elapsed_time:.6f}} seconds{postfix}" + + log_func = { + logging.DEBUG: logger.debug, + logging.INFO: logger.info, + logging.WARNING: logger.warning, + logging.ERROR: logger.error, + logging.CRITICAL: logger.critical, + }[level] + + # execution + start_time = time.perf_counter() + try: + yield + finally: + end_time = time.perf_counter() + log_func(log_msg_fmt.format(elapsed_time=end_time - start_time)) + + +def stopwatch_deco( + func: Callable[P, T] | None = None, + *, + level: int = logging.INFO, + prefix: str | None = None, + postfix: str = "", + logger: Logger = logger, +) -> Callable[P, T]: + """Decorator to measure the execution time of a function. + + Args: + func (Callable[P, T], optional): function to be decorated. Defaults to None. + *, + level (int, optional): log level. Defaults to logging.INFO. + Must be one of logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL. + prefix (str | None, optional): prefix of the log message. Defaults to None. + postfix (str, optional): postfix of the log message. Defaults to "". + logger (Logger, optional): logger. Defaults to logger. + + Returns: + Callable[P, T]: decorated function. + + Note: + If func is None, return a decorator. Otherwise, return a wrapper. + """ # noqa + + # if func is None, return a decorator + if func is None: + return partial( + stopwatch_deco, + level=level, + prefix=prefix, + postfix=postfix, + logger=logger, + ) + + if prefix is None: + prefix = f"Execution time of {getattr(func, '__name__', str(func))}" + + # if func is not None, return a wrapper + @wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + with stopwatch( + level=level, + prefix=prefix, + postfix=postfix, + logger=logger, + ): + result = func(*args, **kwargs) + return result + + return wrapped diff --git a/tests/utils/test_stopwatch.py b/tests/utils/test_stopwatch.py new file mode 100644 index 0000000..1225879 --- /dev/null +++ b/tests/utils/test_stopwatch.py @@ -0,0 +1,100 @@ +import logging +import re + +import pytest + +from simple_typing_application.utils.stopwatch import stopwatch, stopwatch_deco + + +LOGGER_NAME = "simple_typing_application.utils.stopwatch" + + +def _latest_record(caplog): + assert caplog.records, "Expected stopwatch to emit a log record" + return caplog.records[-1] + + +def test_stopwatch_context_manager_logs_default_prefix(caplog): + caplog.set_level(logging.INFO, logger=LOGGER_NAME) + + with stopwatch(): + sum(range(5)) + + record = _latest_record(caplog) + assert record.levelno == logging.INFO + assert re.match(r"Execution time: \d+\.\d{6} seconds", record.message) + + +def test_stopwatch_context_manager_custom_prefix_postfix_and_level(caplog): + custom_logger = logging.getLogger("tests.utils.stopwatch.ctx") + caplog.set_level(logging.WARNING, logger=custom_logger.name) + + with stopwatch( + level=logging.WARNING, + prefix="Block", + postfix=" !!!", + logger=custom_logger, + ): + sum(range(10)) + + record = _latest_record(caplog) + assert record.levelno == logging.WARNING + assert record.message.startswith("Block: ") + assert record.message.endswith(" seconds !!!") + + +def test_stopwatch_context_manager_invalid_level_raises_value_error(): + with pytest.raises(ValueError): + with stopwatch(level=123): + pass + + +def test_stopwatch_decorator_without_parentheses_uses_func_name_in_prefix(caplog): + caplog.set_level(logging.INFO, logger=LOGGER_NAME) + + @stopwatch_deco + def greet(name: str) -> str: + return f"hello {name}" + + assert greet("world") == "hello world" + record = _latest_record(caplog) + assert record.levelno == logging.INFO + assert re.match(r"Execution time of greet: \d+\.\d{6} seconds", record.message) + + +def test_stopwatch_decorator_called_with_parentheses(caplog): + caplog.set_level(logging.INFO, logger=LOGGER_NAME) + + @stopwatch_deco() + def add(a: int, b: int) -> int: + return a + b + + assert add(1, 2) == 3 + record = _latest_record(caplog) + assert re.match(r"Execution time of add: \d+\.\d{6} seconds", record.message) + + +def test_stopwatch_decorator_custom_prefix_and_invalid_level(caplog): + caplog.set_level(logging.ERROR, logger=LOGGER_NAME) + + @stopwatch_deco( + prefix="Manual", + postfix=" !!!", + level=logging.ERROR, + ) + def work() -> None: + return None + + work() + + record = _latest_record(caplog) + assert record.levelno == logging.ERROR + assert record.message.startswith("Manual: ") + assert record.message.endswith(" seconds !!!") + + @stopwatch_deco(level=123) + def broken(): + return None + + with pytest.raises(ValueError): + broken()