diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00b4daa..6abc6a3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,6 +47,9 @@ jobs: - name: Check code style with ruff run: uv run ruff format --check --diff + - name: Check type hints with ty + run: uv run ty check think/ + - name: Test with pytest run: uv run pytest --cov=think --cov-report=lcov timeout-minutes: 5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 803deb1..96efb43 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,15 @@ repos: args: [--fix] # Run the formatter. - id: ruff-format + - repo: local + hooks: + - id: ty + name: ty check + stages: [pre-commit] + types: [python] + entry: uv run ty check think/ + language: python + pass_filenames: false - repo: local hooks: # Run the tests diff --git a/README.md b/README.md index e1cf845..3c36862 100644 --- a/README.md +++ b/README.md @@ -14,40 +14,48 @@ LLM API provider. Ask a question: ```python +# example: ask.py from asyncio import run from think import LLM, ask llm = LLM.from_url("anthropic:///claude-3-haiku-20240307") + async def haiku(topic): return await ask(llm, "Write a haiku about {{ topic }}", topic=topic) + print(run(haiku("computers"))) ``` Get answers as structured data: ```python +# example: structured.py from asyncio import run from think import LLM, LLMQuery llm = LLM.from_url("openai:///gpt-4o-mini") + class CityInfo(LLMQuery): """ Give me basic information about {{ city }}. """ + name: str country: str population: int latitude: float longitude: float + async def city_info(city): return await CityInfo.run(llm, city=city) + info = run(city_info("Paris")) print(f"{info.name} is a city in {info.country} with {info.population} inhabitants.") ``` @@ -55,6 +63,7 @@ print(f"{info.name} is a city in {info.country} with {info.population} inhabitan Integrate AI with custom tools: ```python +# example: tools.py from asyncio import run from datetime import date @@ -62,6 +71,7 @@ from think import LLM, Chat llm = LLM.from_url("openai:///gpt-4o-mini") + def current_date() -> str: """ Get the current date. @@ -70,50 +80,59 @@ def current_date() -> str: """ return date.today().isoformat() + async def days_to_xmas() -> str: chat = Chat("How many days are left until Christmas?") return await llm(chat, tools=[current_date]) + print(run(days_to_xmas())) ``` Use vision (with models that support it): ```python +# example: vision.py from asyncio import run from think import LLM, Chat llm = LLM.from_url("openai:///gpt-4o-mini") + async def describe_image(path): image_data = open(path, "rb").read() chat = Chat().user("Describe the image in detail", images=[image_data]) return await llm(chat) + print(run(describe_image("path/to/image.jpg"))) ``` This also works with PDF documents (with models that support PDFs): ```python +# example: pdf.py from asyncio import run from think import LLM, Chat llm = LLM.from_url("google:///gemini-2.0-flash") + async def read_pdf(path): pdf_data = open(path, "rb").read() chat = Chat().user("Read the document", documents=[pdf_data]) return await llm(chat) + print(run(read_pdf("path/to/document.pdf"))) ``` Use Pydantic or custom parsers for structured data: ```python +# example: parsing.py from asyncio import run from ast import parse @@ -123,6 +142,7 @@ from think.prompt import JinjaStringTemplate llm = LLM.from_url("openai:///gpt-4o-mini") + def parse_python(text): # extract code block from the text block_parser = CodeBlockParser() @@ -134,6 +154,7 @@ def parse_python(text): except SyntaxError as err: raise ValueError(f"Invalid Python code: {err}") from err + async def generate_python_script(task): system = "You always output the requested code in a single Markdown code block" prompt = "Write a Python script for the following task: {{ task }}" @@ -141,6 +162,7 @@ async def generate_python_script(task): chat = Chat(system).user(tpl(prompt, task=task)) return await llm(chat, parser=parse_python) + print(run(generate_python_script("sort a list of numbers"))) ``` @@ -155,6 +177,7 @@ provides scaffolding to integrate other RAG providers. Example usage: ```python +# example: rag.py from asyncio import run from think import LLM @@ -163,6 +186,7 @@ from think.rag.base import RAG, RagDocument llm = LLM.from_url("openai:///gpt-4o-mini") rag = RAG.for_provider("txtai")(llm) + async def index_documents(): data = [ RagDocument(id="a", text="Titanic: A sweeping romantic epic"), @@ -171,6 +195,7 @@ async def index_documents(): ] await rag.add_documents(data) + run(index_documents()) query = "A movie about a ship that sinks" result = run(rag(query)) @@ -191,6 +216,7 @@ use tools, and integrate with RAG. Example: ```python +# example: agent.py from asyncio import run from datetime import datetime @@ -199,6 +225,7 @@ from think.agent import BaseAgent, tool llm = LLM.from_url("openai:///gpt-4o-mini") + class Chatbot(BaseAgent): """You are a helpful assistant. Today is {{today}}.""" @@ -211,6 +238,7 @@ class Chatbot(BaseAgent): print(response) return input("> ").strip() + agent = Chatbot(llm, today=datetime.today()) run(agent.run()) ``` diff --git a/examples/agent.py b/examples/agent.py new file mode 100644 index 0000000..bef0243 --- /dev/null +++ b/examples/agent.py @@ -0,0 +1,25 @@ +# example: agent.py +from asyncio import run +from datetime import datetime + +from think import LLM +from think.agent import BaseAgent, tool + +llm = LLM.from_url("openai:///gpt-4o-mini") + + +class Chatbot(BaseAgent): + """You are a helpful assistant. Today is {{today}}.""" + + @tool + def get_time(self) -> str: + """Get the current time.""" + return datetime.now().strftime("%H:%M") + + async def interact(self, response: str) -> str: + print(response) + return input("> ").strip() + + +agent = Chatbot(llm, today=datetime.today()) +run(agent.run()) diff --git a/examples/ai_functions.py b/examples/ai_functions.py deleted file mode 100644 index 9dc7228..0000000 --- a/examples/ai_functions.py +++ /dev/null @@ -1,31 +0,0 @@ -import sys -import click - -sys.path.append(".") - -from think.llm.openai import ChatGPT # noqa E402 -from think.ai import ai # noqa E402 - - -@ai -def haiku(topic: str) -> str: - """ - Write a haiku about {{ topic }} - """ - - -@click.command() -@click.option("--api-key", "-k", default=None) -@click.argument("topic") -def main(topic, api_key=None): - """ - Write and output a haiku about a given TOPIC using GPT-4. - - API key, if not provided, will be read from OPENAI_API_KEY environment variable. - """ - llm = ChatGPT(api_key=api_key) - print(haiku(llm, topic=topic)) - - -if __name__ == "__main__": - main() diff --git a/examples/ask.py b/examples/ask.py new file mode 100644 index 0000000..b617b97 --- /dev/null +++ b/examples/ask.py @@ -0,0 +1,13 @@ +# example: ask.py +from asyncio import run + +from think import LLM, ask + +llm = LLM.from_url("anthropic:///claude-3-haiku-20240307") + + +async def haiku(topic): + return await ask(llm, "Write a haiku about {{ topic }}", topic=topic) + + +print(run(haiku("computers"))) diff --git a/examples/chatbot.py b/examples/chatbot.py deleted file mode 100644 index ddc3206..0000000 --- a/examples/chatbot.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python3 -from asyncio import run -import sys -import os -from typing import Optional - -import click -from dotenv import load_dotenv - -# Add parent directory to path for easy local development -sys.path.append(".") - -from think import LLM -from think.agent import BaseAgent - -load_dotenv() - - -class Chatbot(BaseAgent): - """You are a helpful assistant. Be concise and friendly in your responses.""" - - async def interact(self, response: str) -> str: - print("AI: ", response, "\n") - - # Wait for user input - while True: - user_input = input("> ").strip() - if user_input: - break - - # Check for exit command - if user_input.lower() in ("exit", "quit", "bye"): - print("Goodbye!") - return None - - return user_input - - -@click.command() -@click.option( - "--model-url", - "-m", - default=None, - help="LLM URL (e.g., 'openai:///gpt-4o-mini'). Defaults to LLM_URL env variable.", -) -@click.option( - "--system", - "-s", - default="You are a helpful assistant. Be concise and friendly in your responses.", - help="System prompt to initialize the chat.", -) -def main(model_url: Optional[str], system: str): - """ - Interactive chatbot using the Think library. - - Start a conversation with an LLM in your terminal. Type your messages - and receive AI responses. Use Ctrl+C or type 'exit' to end the conversation. - """ - # Get model URL from argument or environment - model_url = model_url or os.environ.get("LLM_URL") - if not model_url: - print( - "Error: Model URL not provided. Use --model-url option or set LLM_URL environment variable." - ) - sys.exit(1) - - try: - # Initialize LLM from URL - llm = LLM.from_url(model_url) - print(f"Connected to {model_url}") - print("Type your messages (type 'exit' to quit)") - print("-" * 50) - - agent = Chatbot(llm=llm, system=system) - run(agent.run("Hello!")) - - except KeyboardInterrupt: - print("\nGoodbye!") - except Exception as e: - print(f"\nError: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/examples/extract.py b/examples/extract.py new file mode 100644 index 0000000..623fb14 --- /dev/null +++ b/examples/extract.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# +# Python script to extract examples from README.md and store them in +# `examples` directory. + +import os +import re +import pathlib +from typing import Dict, List + + +def extract_code_blocks(markdown_content: str) -> List[str]: + """ + Extract Python code blocks from markdown content. + + Args: + markdown_content: Markdown text containing code blocks + + Returns: + List of extracted code blocks + """ + # Pattern to match python code blocks: ```python ... ``` + pattern = r"```python\n(.*?)```" + + # Find all matches using re.DOTALL to match across multiple lines + matches = re.findall(pattern, markdown_content, re.DOTALL) + + return matches + + +def parse_example_files(code_blocks: List[str]) -> Dict[str, str]: + """ + Parse code blocks to find examples with filename comments. + + Args: + code_blocks: List of code blocks extracted from markdown + + Returns: + Dictionary mapping filenames to code content + """ + examples = {} + + for block in code_blocks: + match = re.search(r"# example: ([a-zA-Z0-9_\-]+\.py)", block) + if match: + filename = match.group(1) + examples[filename] = block.strip() + "\n" + + return examples + + +def save_examples(examples: Dict[str, str], output_dir: pathlib.Path) -> None: + """ + Save examples to files in the output directory. + + Args: + examples: Dictionary mapping filenames to code content + output_dir: Directory where examples will be saved + """ + os.makedirs(output_dir, exist_ok=True) + + for filename, content in examples.items(): + file_path = output_dir / filename + with open(file_path, "w") as f: + f.write(content) + print(f"Saved example: {file_path}") + + +def main(): + examples_dir = pathlib.Path(__file__).parent + readme_path = examples_dir.parent / "README.md" + + try: + with open(readme_path, "r") as f: + readme_content = f.read() + except FileNotFoundError: + print(f"Error: README.md not found at {readme_path}") + return + + code_blocks = extract_code_blocks(readme_content) + examples = parse_example_files(code_blocks) + + if examples: + save_examples(examples, examples_dir) + print(f"Extracted {len(examples)} examples to {examples_dir}") + else: + print("No examples found in README.md") + + +if __name__ == "__main__": + main() diff --git a/examples/parsing.py b/examples/parsing.py new file mode 100644 index 0000000..653c79b --- /dev/null +++ b/examples/parsing.py @@ -0,0 +1,32 @@ +# example: parsing.py +from asyncio import run +from ast import parse + +from think import LLM, Chat +from think.parser import CodeBlockParser +from think.prompt import JinjaStringTemplate + +llm = LLM.from_url("openai:///gpt-4o-mini") + + +def parse_python(text): + # extract code block from the text + block_parser = CodeBlockParser() + code = block_parser(text) + # check if the code is valid Python syntax + try: + parse(code) + return code + except SyntaxError as err: + raise ValueError(f"Invalid Python code: {err}") from err + + +async def generate_python_script(task): + system = "You always output the requested code in a single Markdown code block" + prompt = "Write a Python script for the following task: {{ task }}" + tpl = JinjaStringTemplate() + chat = Chat(system).user(tpl(prompt, task=task)) + return await llm(chat, parser=parse_python) + + +print(run(generate_python_script("sort a list of numbers"))) diff --git a/examples/parsing_output.py b/examples/parsing_output.py deleted file mode 100644 index 6855762..0000000 --- a/examples/parsing_output.py +++ /dev/null @@ -1,48 +0,0 @@ -import json -import sys -from pydantic import BaseModel -import click - -sys.path.append(".") - -from think.llm.openai import ChatGPT # noqa E402 -from think.chat import Chat # noqa E402 -from think.parser import JSONParser # noqa E402 - - -class CityInfo(BaseModel): - name: str - country: str - population: int - latitude: float - longitude: float - - -@click.command() -@click.option("--api-key", "-k", default=None) -@click.argument("city", default="Zagreb", required=False) -def main(city, api_key=None): - """ - Ask GPT-3 to answer information about a city in a structured format. - - API key, if not provided, will be read from OPENAI_API_KEY environment variable. - """ - llm = ChatGPT(model="gpt-3.5-turbo-16k") - parser = JSONParser(spec=CityInfo) - chat = Chat( - "You are a hepful assistant. Your task is to answer questions about cities, " - "to the best of your knowledge. Your output must be a valid JSON conforming to " - "this JSON schema:\n" + json.dumps(parser.schema) - ).user(city) - - answer = llm(chat, parser=parser) - print( - f"{answer.name} is a city in {answer.country} with {answer.population} inhabitants." - ) - print( - f"It is located at {answer.latitude:.2f}° latitude and {answer.longitude:.2f}° longitude." - ) - - -if __name__ == "__main__": - main() diff --git a/examples/pdf.py b/examples/pdf.py new file mode 100644 index 0000000..ae23212 --- /dev/null +++ b/examples/pdf.py @@ -0,0 +1,15 @@ +# example: pdf.py +from asyncio import run + +from think import LLM, Chat + +llm = LLM.from_url("google:///gemini-2.0-flash") + + +async def read_pdf(path): + pdf_data = open(path, "rb").read() + chat = Chat().user("Read the document", documents=[pdf_data]) + return await llm(chat) + + +print(run(read_pdf("path/to/document.pdf"))) diff --git a/examples/rag.py b/examples/rag.py new file mode 100644 index 0000000..45622fb --- /dev/null +++ b/examples/rag.py @@ -0,0 +1,23 @@ +# example: rag.py +from asyncio import run + +from think import LLM +from think.rag.base import RAG, RagDocument + +llm = LLM.from_url("openai:///gpt-4o-mini") +rag = RAG.for_provider("txtai")(llm) + + +async def index_documents(): + data = [ + RagDocument(id="a", text="Titanic: A sweeping romantic epic"), + RagDocument(id="b", text="The Godfather: A gripping mafia saga"), + RagDocument(id="c", text="Forrest Gump: A heartwarming tale of a simple man"), + ] + await rag.add_documents(data) + + +run(index_documents()) +query = "A movie about a ship that sinks" +result = run(rag(query)) +print(result) diff --git a/examples/structured.py b/examples/structured.py new file mode 100644 index 0000000..21c5a82 --- /dev/null +++ b/examples/structured.py @@ -0,0 +1,26 @@ +# example: structured.py +from asyncio import run + +from think import LLM, LLMQuery + +llm = LLM.from_url("openai:///gpt-4o-mini") + + +class CityInfo(LLMQuery): + """ + Give me basic information about {{ city }}. + """ + + name: str + country: str + population: int + latitude: float + longitude: float + + +async def city_info(city): + return await CityInfo.run(llm, city=city) + + +info = run(city_info("Paris")) +print(f"{info.name} is a city in {info.country} with {info.population} inhabitants.") diff --git a/examples/tool_usage.py b/examples/tool_usage.py deleted file mode 100644 index cb8c25b..0000000 --- a/examples/tool_usage.py +++ /dev/null @@ -1,37 +0,0 @@ -from datetime import date -import click -import sys - -sys.path.append(".") - -from think.llm.openai import ChatGPT # noqa E402 -from think.chat import Chat # noqa E402 -from think.tool import tool # noqa E402 - - -@tool -def current_date() -> str: - """ - Get the current date. - - :returns: current date in YYYY-MM-DD format - """ - return date.today().isoformat() - - -@click.command() -@click.option("--api-key", "-k", default=None) -def main(api_key=None): - """ - Ask GPT-4 how old it is, providing the current date as a tool. - - API key, if not provided, will be read from OPENAI_API_KEY environment variable. - """ - llm = ChatGPT() - chat = Chat("You are a helpful assistant.") - chat.user("How old are you (in days since your knowledge cutoff)?") - print(llm(chat, tools=[current_date])) - - -if __name__ == "__main__": - main() diff --git a/examples/tools.py b/examples/tools.py new file mode 100644 index 0000000..d0beb3e --- /dev/null +++ b/examples/tools.py @@ -0,0 +1,24 @@ +# example: tools.py +from asyncio import run +from datetime import date + +from think import LLM, Chat + +llm = LLM.from_url("openai:///gpt-4o-mini") + + +def current_date() -> str: + """ + Get the current date. + + :returns: current date in YYYY-MM-DD format + """ + return date.today().isoformat() + + +async def days_to_xmas() -> str: + chat = Chat("How many days are left until Christmas?") + return await llm(chat, tools=[current_date]) + + +print(run(days_to_xmas())) diff --git a/examples/vision.py b/examples/vision.py new file mode 100644 index 0000000..f273d67 --- /dev/null +++ b/examples/vision.py @@ -0,0 +1,15 @@ +# example: vision.py +from asyncio import run + +from think import LLM, Chat + +llm = LLM.from_url("openai:///gpt-4o-mini") + + +async def describe_image(path): + image_data = open(path, "rb").read() + chat = Chat().user("Describe the image in detail", images=[image_data]) + return await llm(chat) + + +print(run(describe_image("path/to/image.jpg"))) diff --git a/pyproject.toml b/pyproject.toml index 0e742f0..64a7c58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,8 @@ dev = [ "chromadb>=0.6.2", "pinecone>=5.4.2", "pinecone-client>=4.1.2", + "aioboto3>=13.2.0", + "ty>=0.0.1a1", ] [build-system] diff --git a/think/llm/anthropic.py b/think/llm/anthropic.py index a3a3779..d761f5d 100644 --- a/think/llm/anthropic.py +++ b/think/llm/anthropic.py @@ -7,6 +7,7 @@ try: from anthropic import ( NOT_GIVEN, + NotGiven, AsyncAnthropic, AsyncStream, AuthenticationError, @@ -22,7 +23,7 @@ from .base import LLM, BaseAdapter, ConfigError, BadRequestError, PydanticResultT -from .chat import Chat, ContentPart, ContentType, Message, Role +from .chat import Chat, ContentPart, ContentType, Message, Role, image_url, document_url from .tool import ToolCall, ToolDefinition, ToolResponse log = getLogger(__name__) @@ -45,42 +46,42 @@ def dump_role(self, role: Role) -> Literal["user", "assistant"]: def dump_content_part(self, part: ContentPart) -> dict: match part: case ContentPart(type=ContentType.text, text=text): - return dict( - type="text", - text=text, - ) + return { + "type": "text", + "text": text, + } case ContentPart(type=ContentType.image): - return dict( - type="image", - source=dict( - type="base64", - data=part.image_data, - media_type=part.image_mime_type, - ), - ) + return { + "type": "image", + "source": { + "type": "base64", + "data": part.image_data, + "media_type": part.image_mime_type, + }, + } case ContentPart(type=ContentType.document): if part.is_document_url: - source = dict( - type="url", - url=part.document_url, - ) + source = { + "type": "url", + "url": part.document_url, + } else: - source = dict( - type="base64", - data=part.document_data, - media_type=part.document_mime_type, - ) - return dict(type="document", source=source) + source = { + "type": "base64", + "data": part.document_data, + "media_type": part.document_mime_type, + } + return {"type": "document", "source": source} case ContentPart( type=ContentType.tool_call, tool_call=ToolCall(id=id, name=name, arguments=arguments), ): - return dict( - type="tool_use", - id=id, - name=name, - input=arguments, - ) + return { + "type": "tool_use", + "id": id, + "name": name, + "input": arguments, + } case ContentPart( type=ContentType.tool_response, tool_response=ToolResponse( @@ -89,11 +90,11 @@ def dump_content_part(self, part: ContentPart) -> dict: error=error, ), ): - return dict( - type="tool_result", - tool_use_id=id, - content=response if response is not None else (error or ""), - ) + return { + "type": "tool_result", + "tool_use_id": id, + "content": response if response is not None else (error or ""), + } case _: raise ValueError(f"Unknown content type for: {part}") @@ -104,12 +105,12 @@ def parse_content_part(self, part: dict) -> ContentPart: case {"type": "image", "source": {"data": data}}: return ContentPart( type=ContentType.image, - image=b64decode(data.encode("ascii")), + image=image_url(b64decode(data.encode("ascii"))), ) case {"type": "document", "source": {"data": data}}: return ContentPart( type=ContentType.document, - document=b64decode(data.encode("ascii")), + document=document_url(b64decode(data.encode("ascii"))), ) case {"type": "document", "source": {"url": url}}: return ContentPart( @@ -130,7 +131,7 @@ def parse_content_part(self, part: dict) -> ContentPart: ), ) case _: - raise ValueError(f"Unknown content type: {part.type}") + raise ValueError(f"Unknown content type: {part}") def dump_message(self, message: Message) -> dict: if len(message.content) == 1 and message.content[0].type == ContentType.text: @@ -138,10 +139,10 @@ def dump_message(self, message: Message) -> dict: else: content = [self.dump_content_part(part) for part in message.content] - return dict( - role=self.dump_role(message.role), - content=content, - ) + return { + "role": self.dump_role(message.role), + "content": content, + } def parse_message(self, message: dict | AnthropicMessage) -> Message: if isinstance(message, AnthropicMessage): @@ -156,15 +157,17 @@ def parse_message(self, message: dict | AnthropicMessage) -> Message: ContentPart(type=ContentType.text, text=content), ], ) + elif isinstance(content, list): + parts = [self.parse_content_part(part) for part in content] + if any(part.type == ContentType.tool_response for part in parts): + role = Role.tool + return Message(role=role, content=parts) + else: + raise ValueError(f"Cannot handle message content: {content}") - parts = [self.parse_content_part(part) for part in content] - if any(part.type == ContentType.tool_response for part in parts): - role = Role.tool - return Message(role=role, content=parts) - - def dump_chat(self, chat: Chat) -> tuple[str, list[dict]]: + def dump_chat(self, chat: Chat) -> tuple[str | NotGiven, list[dict]]: system_messages = [] - other_messages = [] + other_messages: list[dict] = [] offset = 0 # If the first message is a system one, extract it as a separate diff --git a/think/llm/base.py b/think/llm/base.py index b866e6a..710a29a 100644 --- a/think/llm/base.py +++ b/think/llm/base.py @@ -4,7 +4,7 @@ from json import JSONDecodeError from logging import getLogger from time import time -from typing import TYPE_CHECKING, AsyncGenerator, Callable, TypeVar, overload +from typing import TYPE_CHECKING, AsyncGenerator, Callable, TypeVar, overload, cast from urllib.parse import parse_qs, urlparse from pydantic import BaseModel, ValidationError @@ -12,7 +12,7 @@ from think.parser import JSONParser from .chat import Chat, ContentPart, ContentType, Message, Role -from .tool import ToolDefinition, ToolKit, ToolResponse +from .tool import ToolDefinition, ToolKit, ToolCall, ToolResponse CustomParserResultT = TypeVar("CustomParserResultT") PydanticResultT = TypeVar("PydanticResultT", bound=BaseModel) @@ -473,12 +473,14 @@ async def _process_message( response_list = [] for part in message.content: if part.type == ContentType.text: - text += part.text + text += cast(str, part.text) elif part.type == ContentType.tool_call: if toolkit is None: log.warning("Tool call with no toolkit defined, ignoring") continue - tool_response = await toolkit.execute_tool_call(part.tool_call) + tool_response = await toolkit.execute_tool_call( + cast(ToolCall, part.tool_call) + ) response_list.append(tool_response) if response_list: diff --git a/think/llm/bedrock.py b/think/llm/bedrock.py index c805df7..f704920 100644 --- a/think/llm/bedrock.py +++ b/think/llm/bedrock.py @@ -19,7 +19,7 @@ from .base import LLM, BadRequestError, BaseAdapter, ConfigError, PydanticResultT -from .chat import Chat, ContentPart, ContentType, Message, Role +from .chat import Chat, ContentPart, ContentType, Message, Role, image_url from .tool import ToolCall, ToolDefinition, ToolResponse log = getLogger(__name__) @@ -71,29 +71,29 @@ def dump_role(self, role: Role) -> Literal["user", "assistant"]: def dump_content_part(self, part: ContentPart) -> dict: match part: case ContentPart(type=ContentType.text, text=text): - return dict( - text=text, - ) + return { + "text": text, + } case ContentPart(type=ContentType.image): - return dict( - image=dict( - source=dict( - bytes=part.image_bytes, - ), - format=part.image_mime_type.split("/")[1], - ) - ) + return { + "image": { + "source": { + "bytes": part.image_bytes, + }, + "format": part.image_mime_type.split("/")[1], + } + } case ContentPart( type=ContentType.tool_call, tool_call=ToolCall(id=id, name=name, arguments=arguments), ): - return dict( - toolUse=dict( - toolUseId=id, - name=name, - input=arguments, - ), - ) + return { + "toolUse": { + "toolUseId": id, + "name": name, + "input": arguments, + }, + } case ContentPart( type=ContentType.tool_response, tool_response=ToolResponse( @@ -102,20 +102,20 @@ def dump_content_part(self, part: ContentPart) -> dict: error=error, ), ): - return dict( - toolResult=dict( - toolUseId=id, - content=[ - dict( - text=response + return { + "toolResult": { + "toolUseId": id, + "content": [ + { + "text": response if response is not None else (error or ""), - ) + } ], - ), - ) + }, + } case _: - raise ValueError(f"Unknown content type: {part.type}") + raise ValueError(f"Unknown content type: {part}") def parse_content_part(self, part: dict) -> ContentPart: match part: @@ -124,7 +124,7 @@ def parse_content_part(self, part: dict) -> ContentPart: case {"type": "image", "source": {"data": data}}: return ContentPart( type=ContentType.image, - image=b64decode(data.encode("ascii")), + image=image_url(b64decode(data.encode("ascii"))), ) case {"toolUse": {"toolUseId": id, "name": name, "input": input}}: return ContentPart( @@ -143,10 +143,10 @@ def parse_content_part(self, part: dict) -> ContentPart: raise ValueError(f"Unknown content type for {part}") def dump_message(self, message: Message) -> dict: - return dict( - role=self.dump_role(message.role), - content=[self.dump_content_part(part) for part in message.content], - ) + return { + "role": self.dump_role(message.role), + "content": [self.dump_content_part(part) for part in message.content], + } def parse_message(self, message: dict) -> Message: role = Role.assistant if message.get("role") == "assistant" else Role.user @@ -158,13 +158,15 @@ def parse_message(self, message: dict) -> Message: ContentPart(type=ContentType.text, text=content), ], ) + elif isinstance(content, list): + parts = [self.parse_content_part(part) for part in content] + if any(part.type == ContentType.tool_response for part in parts): + role = Role.tool + return Message(role=role, content=parts) + else: + raise ValueError(f"Unknown message content: {content}") - parts = [self.parse_content_part(part) for part in content] - if any(part.type == ContentType.tool_response for part in parts): - role = Role.tool - return Message(role=role, content=parts) - - def dump_chat(self, chat: Chat) -> tuple[str, list[dict]]: + def dump_chat(self, chat: Chat) -> tuple[str | None, list[dict]]: system_messages = [] other_messages = [] offset = 0 @@ -241,11 +243,11 @@ async def _internal_call( async with self.session.client("bedrock-runtime") as client: try: - kwargs = dict( - modelId=self.model, - messages=messages, - system=system_block, - ) + kwargs = { + "modelId": self.model, + "messages": messages, + "system": system_block, + } if cfg: cfg["inferenceConfig"] = cfg if adapter.spec: diff --git a/think/llm/chat.py b/think/llm/chat.py index b18ce03..c92c7b0 100644 --- a/think/llm/chat.py +++ b/think/llm/chat.py @@ -45,7 +45,11 @@ class ContentType(str, Enum): tool_response = "tool_response" -def _validate_file(value: Any, type_desc: str, magic_bytes: dict[str, bytes]) -> str: +def _validate_file( + value: Any, + type_desc: str, + magic_bytes: dict[str, bytes], +) -> str | None: """Generic file validator/converter for image/document fields.""" if not value: @@ -78,6 +82,26 @@ def _validate_file(value: Any, type_desc: str, magic_bytes: dict[str, bytes]) -> return f"data:{mime_type};base64,{b64encode(value).decode('ascii')}" +def image_url(value: Any) -> str | None: + """ + Converts raw image data to a data URL. + + :param value: The raw image data or URL to be converted. + :return: A data URL representing the image. + """ + return _validate_file(value, "image", IMAGE_MAGIC_BYTES) + + +def document_url(value: Any) -> str | None: + """ + Converts raw document data to a data URL. + + :param value: The raw document data or URL to be converted. + :return: A data URL representing the document. + """ + return _validate_file(value, "document", DOCUMENT_MAGIC_BYTES) + + def _get_file_b64(data: str) -> str | None: """ Return base64-encoded file data if possible. @@ -148,7 +172,7 @@ class ContentPart(BaseModel): @classmethod def validate_image(cls, v): """Pydantic validator/converter for the image field.""" - return _validate_file(v, "image", IMAGE_MAGIC_BYTES) + return image_url(v) @property def is_image_url(self) -> bool: @@ -195,7 +219,7 @@ def image_mime_type(self) -> str | None: @classmethod def validate_document(cls, v): """Pydantic validator/converter for the document field.""" - return _validate_file(v, "document", DOCUMENT_MAGIC_BYTES) + return document_url(v) @property def is_document_url(self) -> bool: @@ -255,7 +279,7 @@ class Message(BaseModel): """ role: Role - content: list[ContentPart] | None = None + content: list[ContentPart] parsed: Any | None = None @classmethod @@ -505,7 +529,7 @@ def load(cls, data: list[dict[str, Any]]) -> "Chat": :return: Chat instance. """ c = cls() - c.messages = [Message(**m) for m in data] + c.messages = [Message.model_validate(m) for m in data] return c def clone(self) -> "Chat": diff --git a/think/llm/openai.py b/think/llm/openai.py index 5422d16..770146c 100644 --- a/think/llm/openai.py +++ b/think/llm/openai.py @@ -23,7 +23,7 @@ ) from err from .base import LLM, BadRequestError, BaseAdapter, ConfigError, PydanticResultT -from .chat import Chat, ContentPart, ContentType, Message, Role +from .chat import Chat, ContentPart, ContentType, Message, Role, image_url, document_url from .tool import ToolCall, ToolDefinition, ToolResponse log = getLogger(__name__) @@ -51,14 +51,14 @@ def dump_message(self, message: Message) -> list[dict]: match part: case ContentPart(type=ContentType.tool_call, tool_call=tool_call): tool_calls.append( - dict( - id=tool_call.id, - type="function", - function=dict( - name=tool_call.name, - arguments=json.dumps(tool_call.arguments), - ), - ) + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.arguments), + }, + } ) case ContentPart( type=ContentType.tool_response, @@ -73,17 +73,17 @@ def dump_message(self, message: Message) -> list[dict]: ) case ContentPart(type=ContentType.text, text=text): text_parts.append( - dict( - type="text", - text=text, - ) + { + "type": "text", + "text": text, + } ) case ContentPart(type=ContentType.image, image=image): image_parts.append( - dict( - type="image_url", - image_url=dict(url=image), - ) + { + "type": "image_url", + "image_url": {"url": image}, + } ) case ContentPart(type=ContentType.document): @@ -95,20 +95,20 @@ def dump_message(self, message: Message) -> list[dict]: raise ValueError(f"Unsupported document MIME type: {mime_type}") doc_parts.append( - dict( - type="input_file", - file_name="document.pdf", - file_data=part.document_data, - ) + { + "type": "input_file", + "file_name": "document.pdf", + "file_data": part.document_data, + } ) if tool_responses: return [ - dict( - role="tool", - tool_call_id=call_id, - content=response, - ) + { + "role": "tool", + "tool_call_id": call_id, + "content": response, + } for call_id, response in tool_responses.items() ] @@ -116,27 +116,27 @@ def dump_message(self, message: Message) -> list[dict]: if len(text_parts) == 1: text_parts = text_parts[0]["text"] return [ - dict( - role="assistant", - content=text_parts or None, - tool_calls=tool_calls or None, - ) + { + "role": "assistant", + "content": text_parts or None, + "tool_calls": tool_calls or None, + } ] if message.role == Role.system: if len(text_parts) == 1: text_parts = text_parts[0]["text"] - return [dict(role="system", content=text_parts)] + return [{"role": "system", "content": text_parts}] if message.role == Role.user: content = text_parts + image_parts + doc_parts if len(content) == 1 and content[0]["type"] == "text": content = content[0]["text"] return [ - dict( - role="user", - content=content, - ) + { + "role": "user", + "content": content, + } ] raise ValueError(f"Unsupported message role: {message.role}") @@ -263,14 +263,14 @@ def parse_message(self, message: dict[str, Any]) -> Message: content.append( ContentPart( type=ContentType.image, - image=part.get("image_url", {}).get("url"), + image=image_url(part.get("image_url", {}).get("url")), ) ) elif part_type == "input_file": content.append( ContentPart( type=ContentType.document, - document=part.get("file_data"), + document=document_url(part.get("file_data")), ) ) else: diff --git a/think/llm/tool.py b/think/llm/tool.py index ea65004..495e5de 100644 --- a/think/llm/tool.py +++ b/think/llm/tool.py @@ -39,7 +39,7 @@ def __init__(self, func: Callable, name: str | None = None): :param name: The name of the tool, exposed to the LLM. Defaults to the function name. """ - self.name = name or func.__name__ + self.name = name or getattr(func, "__name__", "tool") self.func = func self.model = self.create_model_from_function(func) @@ -93,8 +93,9 @@ def create_model_from_function(cls, func: Callable) -> type[BaseModel]: else: fields[name] = (annotation, default) + func_name = getattr(func, "__name__", "tool") model_name = ( - "".join(part.capitalize() for part in func.__name__.split("_")) + "Args" + "".join(part.capitalize() for part in func_name.split("_")) + "Args" ) model = create_model( @@ -131,8 +132,8 @@ class ToolResponse: """ call: ToolCall - response: str = None - error: str = None + response: str | None = None + error: str | None = None class ToolError(Exception): @@ -231,7 +232,7 @@ async def execute_tool_call(self, call: ToolCall) -> ToolResponse: call=call, error=f"ERROR: Error running tool {call.name}: {err}" ) - def add_tool(self, func: Callable, name: str = None) -> None: + def add_tool(self, func: Callable, name: str | None = None) -> None: """ Add a single tool to the toolkit. diff --git a/think/parser.py b/think/parser.py index 3413c8c..b35ee16 100644 --- a/think/parser.py +++ b/think/parser.py @@ -1,7 +1,7 @@ import json import re from enum import Enum -from typing import Optional, Union, Type +from typing import Optional, Union, Type, overload from pydantic import BaseModel @@ -91,6 +91,15 @@ def __init__(self, spec: Optional[Type[BaseModel]] = None, strict: bool = True): def schema(self): return self.spec.model_json_schema() if self.spec else None + @overload + def __call__(self, text: str) -> BaseModel: ... + + @overload + def __call__(self, text: str) -> dict: ... + + @overload + def __call__(self, text: str) -> None: ... + def __call__(self, text: str) -> Union[BaseModel, dict, None]: text = text.strip() if text.startswith("```"): @@ -132,7 +141,7 @@ class EnumParser: any of the Enum values. """ - def __init__(self, spec: Enum, ignore_case: bool = True): + def __init__(self, spec: Type[Enum], ignore_case: bool = True): self.spec = spec self.ignore_case = ignore_case