Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import datetime as dt
from logging import getLogger, Logger
from logging import DEBUG, Logger, getLogger
from typing import Any, Callable, cast

from langchain.agents import create_agent
Expand All @@ -12,6 +12,7 @@
from .utils import split_hiraganas_alphabets_symbols, splitted_hiraganas_alphabets_symbols_to_typing_target # noqa
from ..utils.japanese_string_utils import delete_space_between_hiraganas
from ..utils.rerun import rerun_deco
from ..utils.stopwatch import stopwatch


class _OutputSchema(BaseModel):
Expand Down Expand Up @@ -52,11 +53,8 @@ def __init__(
model=model,
temperature=temperature,
max_retries=max_retry,
api_key=(
(lambda: openai_api_key)
if isinstance(openai_api_key, str)
else openai_api_key
),
api_key=((lambda: openai_api_key) if isinstance(openai_api_key, str) else openai_api_key),
reasoning_effort="minimal" if model.startswith("gpt-5") else None,
seed=seed,
)
self._agent = create_agent(
Expand Down Expand Up @@ -97,11 +95,12 @@ async def generate(
"content": self._user_prompt,
},
]
self._logger.debug(f'agent input messages: {messages}')
ret: dict[str, Any] = await self._agent.ainvoke(
{"messages": messages}, # type: ignore
)
self._logger.debug(f'agent response: {ret}')
self._logger.debug(f"agent input messages: {messages}")
with stopwatch(level=DEBUG, logger=self._logger, prefix="OpenAI agent invocation"):
Comment thread
hmasdev marked this conversation as resolved.
ret: dict[str, Any] = await self._agent.ainvoke(
{"messages": messages},
) # type: ignore
self._logger.debug(f"agent response: {ret}")

# store to memory
output = cast(_OutputSchema, ret["structured_response"])
Expand Down
121 changes: 121 additions & 0 deletions simple_typing_application/utils/stopwatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging
Comment thread
hmasdev marked this conversation as resolved.
import time
from contextlib import contextmanager
from functools import wraps, partial
from logging import Logger, getLogger
from typing import Callable, Generator, TypeVar, ParamSpec

T = TypeVar("T")
P = ParamSpec("P")

logger: Logger = getLogger(__name__)


@contextmanager
def stopwatch(
level: int = logging.INFO,
prefix: str | None = None,
postfix: str = "",
logger: Logger = logger,
) -> Generator[None, None, None]:
"""Context manager to measure the execution time of a code block.

Args:
level (int, optional): log level. Defaults to logging.INFO.
Must be one of logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL.
prefix (str | None, optional): prefix of the log message. Defaults to None.
postfix (str, optional): postfix of the log message. Defaults to "".
logger (Logger, optional): logger. Defaults to logger.

Yields:
None: None
""" # noqa

# validation
if level not in {
logging.DEBUG,
logging.INFO,
logging.WARNING,
logging.ERROR,
logging.CRITICAL,
}:
raise ValueError(
f"Invalid log level: {level}. "
"Must be one of logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL.", # noqa
)

# preparation
if prefix is None:
prefix = "Execution time"

log_msg_fmt: str = f"{prefix}: {{elapsed_time:.6f}} seconds{postfix}"

log_func = {
logging.DEBUG: logger.debug,
logging.INFO: logger.info,
logging.WARNING: logger.warning,
logging.ERROR: logger.error,
logging.CRITICAL: logger.critical,
}[level]

# execution
start_time = time.perf_counter()
try:
yield
finally:
end_time = time.perf_counter()
log_func(log_msg_fmt.format(elapsed_time=end_time - start_time))


def stopwatch_deco(
func: Callable[P, T] | None = None,
*,
level: int = logging.INFO,
prefix: str | None = None,
postfix: str = "",
logger: Logger = logger,
) -> Callable[P, T]:
"""Decorator to measure the execution time of a function.

Args:
func (Callable[P, T], optional): function to be decorated. Defaults to None.
*,
level (int, optional): log level. Defaults to logging.INFO.
Must be one of logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL.
prefix (str | None, optional): prefix of the log message. Defaults to None.
postfix (str, optional): postfix of the log message. Defaults to "".
logger (Logger, optional): logger. Defaults to logger.

Returns:
Callable[P, T]: decorated function.

Note:
If func is None, return a decorator. Otherwise, return a wrapper.
""" # noqa

# if func is None, return a decorator
if func is None:
return partial(
stopwatch_deco,
level=level,
prefix=prefix,
postfix=postfix,
logger=logger,
)

if prefix is None:
prefix = f"Execution time of {getattr(func, '__name__', str(func))}"

# if func is not None, return a wrapper
@wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
with stopwatch(
level=level,
prefix=prefix,
postfix=postfix,
logger=logger,
):
result = func(*args, **kwargs)
return result

return wrapped
100 changes: 100 additions & 0 deletions tests/utils/test_stopwatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
import re

import pytest

from simple_typing_application.utils.stopwatch import stopwatch, stopwatch_deco


LOGGER_NAME = "simple_typing_application.utils.stopwatch"


def _latest_record(caplog):
assert caplog.records, "Expected stopwatch to emit a log record"
return caplog.records[-1]


def test_stopwatch_context_manager_logs_default_prefix(caplog):
caplog.set_level(logging.INFO, logger=LOGGER_NAME)

with stopwatch():
sum(range(5))

record = _latest_record(caplog)
assert record.levelno == logging.INFO
assert re.match(r"Execution time: \d+\.\d{6} seconds", record.message)


def test_stopwatch_context_manager_custom_prefix_postfix_and_level(caplog):
custom_logger = logging.getLogger("tests.utils.stopwatch.ctx")
caplog.set_level(logging.WARNING, logger=custom_logger.name)

with stopwatch(
level=logging.WARNING,
prefix="Block",
postfix=" !!!",
logger=custom_logger,
):
sum(range(10))

record = _latest_record(caplog)
assert record.levelno == logging.WARNING
assert record.message.startswith("Block: ")
assert record.message.endswith(" seconds !!!")


def test_stopwatch_context_manager_invalid_level_raises_value_error():
with pytest.raises(ValueError):
with stopwatch(level=123):
pass


def test_stopwatch_decorator_without_parentheses_uses_func_name_in_prefix(caplog):
caplog.set_level(logging.INFO, logger=LOGGER_NAME)

@stopwatch_deco
def greet(name: str) -> str:
return f"hello {name}"

assert greet("world") == "hello world"
record = _latest_record(caplog)
assert record.levelno == logging.INFO
assert re.match(r"Execution time of greet: \d+\.\d{6} seconds", record.message)


def test_stopwatch_decorator_called_with_parentheses(caplog):
caplog.set_level(logging.INFO, logger=LOGGER_NAME)

@stopwatch_deco()
def add(a: int, b: int) -> int:
return a + b

assert add(1, 2) == 3
record = _latest_record(caplog)
assert re.match(r"Execution time of add: \d+\.\d{6} seconds", record.message)


def test_stopwatch_decorator_custom_prefix_and_invalid_level(caplog):
caplog.set_level(logging.ERROR, logger=LOGGER_NAME)

@stopwatch_deco(
prefix="Manual",
postfix=" !!!",
level=logging.ERROR,
)
def work() -> None:
return None

work()

record = _latest_record(caplog)
assert record.levelno == logging.ERROR
assert record.message.startswith("Manual: ")
assert record.message.endswith(" seconds !!!")

@stopwatch_deco(level=123)
def broken():
return None

with pytest.raises(ValueError):
broken()
Loading