diff --git a/app/config/config.py b/app/config/config.py index 3aaf8e00..d150c10c 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -95,6 +95,7 @@ def __init__(self): # AI self.openai_api_key = self.load("OPENAI_API_KEY") self.deepseek_api_key = self.load("DEEPSEEK_API_KEY") + self.akash_chat_api_key = self.load("AKASH_CHAT_API_KEY") self.xai_api_key = self.load("XAI_API_KEY") self.eternal_api_key = self.load("ETERNAL_API_KEY") self.system_prompt = self.load("SYSTEM_PROMPT") diff --git a/app/core/engine.py b/app/core/engine.py index 18518db4..466e1aea 100644 --- a/app/core/engine.py +++ b/app/core/engine.py @@ -157,6 +157,19 @@ async def initialize_agent(aid, is_private=False): ) if input_token_limit > 60000: input_token_limit = 60000 + elif agent.model == "akashchat": + agent.model = "Meta-Llama-3-3-70B-Instruct" + llm = ChatOpenAI( + model_name=agent.model, + openai_api_key=config.akash_chat_api_key, + openai_api_base="https://chatapi.akash.network/api/v1", + frequency_penalty=agent.frequency_penalty, + presence_penalty=agent.presence_penalty, + temperature=agent.temperature, + timeout=1000, + ) + if input_token_limit > 60000: + input_token_limit = 60000 else: llm = ChatOpenAI( model_name=agent.model, diff --git a/example.env b/example.env index dfd9708e..d3358b94 100644 --- a/example.env +++ b/example.env @@ -6,6 +6,8 @@ OPENAI_API_KEY= DEEPSEEK_API_KEY= +AKASH_CHAT_API_KEY= + DB_HOST= DB_PORT= DB_USERNAME= diff --git a/models/agent.py b/models/agent.py index 7f4d0f64..b6977da0 100644 --- a/models/agent.py +++ b/models/agent.py @@ -685,6 +685,7 @@ class AgentUpdate(BaseModel): "deepseek-reasoner", "grok-2", "eternalai", + "akashchat", ], PydanticField( default="gpt-4o-mini", diff --git a/skills/akashchat/__init__.py b/skills/akashchat/__init__.py new file mode 100644 index 00000000..056c7a03 --- /dev/null +++ b/skills/akashchat/__init__.py @@ -0,0 +1,83 @@ +"""AkashChat skills.""" + +import logging +from typing import TypedDict + +from abstracts.skill import SkillStoreABC +from skills.akashchat.akashgen_image_generation import AkashGenImageGeneration +from skills.akashchat.base import AkashChatBaseTool +from skills.base import SkillConfig, SkillState + +# Cache skills at the system level, because they are stateless +_cache: dict[str, AkashChatBaseTool] = {} + +logger = logging.getLogger(__name__) + + +class SkillStates(TypedDict): + akashgen_image_generation: SkillState + + +class Config(SkillConfig): + """Configuration for AkashChat skills.""" + + states: SkillStates + api_key: str + + +async def get_skills( + config: "Config", + is_private: bool, + store: SkillStoreABC, + **_, +) -> list[AkashChatBaseTool]: + """Get all AkashChat skills. + + Args: + config: The configuration for AkashChat skills. + is_private: Whether to include private skills. + store: The skill store for persisting data. + + Returns: + A list of AkashChat skills. + """ + available_skills = [] + + # Include skills based on their state + for skill_name, state in config["states"].items(): + if state == "disabled": + continue + elif state == "public" or (state == "private" and is_private): + available_skills.append(skill_name) + + # Get each skill using the cached getter + result = [] + for name in available_skills: + skill = get_akashchat_skill(name, store) + if skill: + result.append(skill) + return result + + +def get_akashchat_skill( + name: str, + store: SkillStoreABC, +) -> AkashChatBaseTool: + """Get an AkashChat skill by name. + + Args: + name: The name of the skill to get + store: The skill store for persisting data + + Returns: + The requested AkashChat skill + """ + if name == "akashgen_image_generation": + if name not in _cache: + _cache[name] = AkashGenImageGeneration( + skill_store=store, + ) + return _cache[name] + else: + logger.warning(f"Unknown AkashChat skill: {name}") + return None diff --git a/skills/akashchat/akashchat.png b/skills/akashchat/akashchat.png new file mode 100644 index 00000000..a1389627 Binary files /dev/null and b/skills/akashchat/akashchat.png differ diff --git a/skills/akashchat/akashgen_image_generation.py b/skills/akashchat/akashgen_image_generation.py new file mode 100644 index 00000000..ea3cf0d4 --- /dev/null +++ b/skills/akashchat/akashgen_image_generation.py @@ -0,0 +1,172 @@ +"""AkashGen image generation skill for AkashChat.""" + +import asyncio +import json +import logging +from typing import Optional, Type + +import httpx +from epyxid import XID +from langchain_core.runnables import RunnableConfig +from pydantic import BaseModel, Field, field_validator + +from skills.akashchat.base import AkashChatBaseTool + +logger = logging.getLogger(__name__) + + +class AkashGenImageGenerationInput(BaseModel): + """Input for AkashGenImageGeneration tool.""" + + prompt: str = Field(description="Text prompt describing the image to generate.") + negative: str = Field( + default="", + description="Negative prompt to exclude undesirable elements from the image.", + ) + sampler: str = Field( + default="dpmpp_2m", + description="Sampling method to use for image generation. Default: dpmpp_2m.", + ) + scheduler: str = Field( + default="sgm_uniform", + description="Scheduler to use for image generation. Default: sgm_uniform.", + ) + preferred_gpu: list[str] = Field( + default_factory=lambda: ["RTX4090", "A10", "A100", "V100-32Gi", "H100"], + description="List of preferred GPU types for image generation.", + ) + + @field_validator("preferred_gpu") + @classmethod + def ensure_list(cls, v): + if isinstance(v, str): + try: + parsed = json.loads(v) + if isinstance(parsed, list): + return parsed + except Exception: + raise ValueError("preferred_gpu must be a list or a JSON array string") + return v + + +class AkashGenImageGeneration(AkashChatBaseTool): + """Tool for generating high-quality images using AkashGen model. + + This tool takes a text prompt and uses AkashChat's API to generate + an image based on the description using the AkashGen model. + + Attributes: + name: The name of the tool. + description: A description of what the tool does. + args_schema: The schema for the tool's input arguments. + """ + + name: str = "akashgen_image_generation" + description: str = ( + "Generate images using AkashGen model.\n" + "Provide a text prompt describing the image you want to generate.\n" + "AkashGen is a powerful image generation model capable of creating detailed, " + "high-quality images from text descriptions.\n" + "You can specify size, quality, and style parameters for more control.\n" + ) + args_schema: Type[BaseModel] = AkashGenImageGenerationInput + + async def _arun( + self, + prompt: str, + negative: Optional[str] = "", + sampler: Optional[str] = "dpmpp_2m", + scheduler: Optional[str] = "sgm_uniform", + preferred_gpu: Optional[list[str]] = [ + "RTX4090", + "A10", + "A100", + "V100-32Gi", + "H100", + ], + config: RunnableConfig = None, + **kwargs, + ) -> str: + """Implementation of the tool to generate images using AkashChat's AkashGen model. + + Args: + prompt: Text prompt describing the image to generate. + negative: Negative prompt to exclude undesirable elements from the image. + sampler: Sampling method to use for image generation. Default: dpmpp_2m. + scheduler: Scheduler to use for image generation. Default: sgm_uniform. + preferred_gpu: List of preferred GPU types for image generation. + config: Configuration for the runnable. + + Returns: + str: URL of the generated image. + + Raises: + Exception: If the image generation fails. + """ + context = self.context_from_config(config) + + # Get the AkashChat API key from the skill store + api_key = context.config.get("api_key") + + # Generate a unique job ID + job_id = str(XID()) + + try: + # Prepare the payload for AkashGen API + payload = { + "prompt": prompt, + "negative": negative, + "sampler": sampler, + "scheduler": scheduler, + "preferred_gpu": preferred_gpu, + } + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + # Send the image generation request + async with httpx.AsyncClient(timeout=60) as client: + gen_response = await client.post( + "https://gen.akash.network/api/generate", + json=payload, + headers=headers, + ) + gen_response.raise_for_status() + gen_data = gen_response.json() + job_id = gen_data.get("job_id") + if not job_id: + raise Exception(f"No job_id returned from AkashGen: {gen_data}") + + # Poll for status every 3 seconds + status_url = f"https://gen.akash.network/api/status?ids={job_id}" + for _ in range(40): # ~2 minutes max + status_response = await client.get(status_url, headers=headers) + status_response.raise_for_status() + status_data = status_response.json() + status_entry = status_data.get(job_id) or ( + status_data["statuses"][0] + if "statuses" in status_data and status_data["statuses"] + else None + ) + if not status_entry: + raise Exception(f"Malformed status response: {status_data}") + if status_entry["status"] == "completed": + result = status_entry.get("result") + if not result: + raise Exception( + f"No result found in completed status: {status_entry}" + ) + return result + elif status_entry["status"] == "failed": + raise Exception(f"Image generation failed: {status_entry}") + await asyncio.sleep(3) + raise Exception(f"Image generation timed out for job_id {job_id}") + except httpx.HTTPError as e: + error_message = f"HTTP error during AkashGen image generation: {str(e)}" + logger.error(error_message) + raise Exception(error_message) + except Exception as e: + error_message = f"Error generating image with AkashGen: {str(e)}" + logger.error(error_message) + raise Exception(error_message) diff --git a/skills/akashchat/base.py b/skills/akashchat/base.py new file mode 100644 index 00000000..bf6838c8 --- /dev/null +++ b/skills/akashchat/base.py @@ -0,0 +1,26 @@ +"""Base class for AkashChat skills.""" + +from typing import Type + +from pydantic import BaseModel, Field + +from abstracts.skill import SkillStoreABC +from skills.base import IntentKitSkill + + +class AkashChatBaseTool(IntentKitSkill): + """Base class for all AkashChat skills. + + This class provides common functionality for all AkashChat skills. + """ + + name: str = Field(description="The name of the tool") + description: str = Field(description="A description of what the tool does") + args_schema: Type[BaseModel] + skill_store: SkillStoreABC = Field( + description="The skill store for persisting data" + ) + + @property + def category(self) -> str: + return "akashchat" diff --git a/skills/akashchat/schema.json b/skills/akashchat/schema.json new file mode 100644 index 00000000..7b45b91e --- /dev/null +++ b/skills/akashchat/schema.json @@ -0,0 +1,62 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "title": "AkashChat Skills", + "description": "Skills for interacting with Akash Chat API services, including image generation and other AI capabilities", + "x-icon": "https://ai.service.crestal.dev/skills/akashchat/akashchat.png", + "x-tags": [ + "AI", + "Image Generation" + ], + "x-api-key": "required", + "properties": { + "enabled": { + "type": "boolean", + "title": "Enabled", + "description": "Whether this skill is enabled", + "default": false + }, + "states": { + "type": "object", + "properties": { + "akashgen_image_generation": { + "type": "string", + "title": "Image Generation by AkashGen", + "enum": [ + "disabled", + "public", + "private" + ], + "x-enum-title": [ + "Disabled", + "Agent Owner + All Users", + "Agent Owner Only" + ], + "description": "Generate images using AkashGen model based on text prompts", + "default": "disabled" + } + }, + "description": "States for each AkashChat skill (disabled, public, or private)" + }, + "api_key": { + "type": "string", + "title": "API Key", + "x-link": "[Get your API key](https://chatapi.akash.network)", + "x-sensitive": true, + "description": "Akash Chat API key for authentication" + } + }, + "required": [ + "enabled", + "states" + ], + "if": { + "properties": { + "enabled": { "const": true } + } + }, + "then": { + "required": ["api_key"] + }, + "additionalProperties": true +} \ No newline at end of file