diff --git a/inference-ollama-gradio-chat/gradio_chat.py b/inference-ollama-gradio-chat/gradio_chat.py index 3e4ca79..980cfbd 100644 --- a/inference-ollama-gradio-chat/gradio_chat.py +++ b/inference-ollama-gradio-chat/gradio_chat.py @@ -3,25 +3,77 @@ import os import gradio as gr +import requests from openai import OpenAI +from sshtunnel import SSHTunnelForwarder OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "http://127.0.0.1:11434") MODEL_NAME = os.environ.get("MODEL_NAME", "llama2") +# SSH tunneling setup if remote Ollama +SSH_HOST = os.environ.get("SSH_HOST") +SSH_USER = os.environ.get("SSH_USER") +SSH_KEY_PATH = os.environ.get("SSH_KEY_PATH") +SSH_PORT = int(os.environ.get("SSH_PORT", 22)) + +tunnel = None +if SSH_HOST: + tunnel = SSHTunnelForwarder( + (SSH_HOST, SSH_PORT), + ssh_username=SSH_USER, + ssh_pkey=SSH_KEY_PATH, + remote_bind_address=("127.0.0.1", 11434), + local_bind_address=("127.0.0.1", 0), # 0 means auto-assign port + ) + tunnel.start() + local_port = tunnel.local_bind_port + OLLAMA_BASE_URL = f"http://127.0.0.1:{local_port}" + client = OpenAI(base_url=f"{OLLAMA_BASE_URL}/v1", api_key="ollama") def chat(message: str, history: list[dict]) -> str: messages = history + [{"role": "user", "content": message}] response = client.chat.completions.create(model=MODEL_NAME, messages=messages) - return response.choices[0].message.content + return response.choices[0].message.content or "" + + +def tokenize_text(text: str) -> str: + try: + response = requests.post( + f"{OLLAMA_BASE_URL}/api/tokenize", + json={"model": MODEL_NAME, "prompt": text}, + ) + if response.status_code == 200: + tokens = response.json()["tokens"] + return f"Tokens: {tokens}\n\nCount: {len(tokens)}" + else: + return f"Error: {response.status_code} - {response.text}" + except Exception as e: + return f"Error: {str(e)}" -demo = gr.ChatInterface( +chat_interface = gr.ChatInterface( fn=chat, title="Ollama Chat", description=f"Chatting with **{MODEL_NAME}** via Ollama", ) +tokenization_interface = gr.Interface( + fn=tokenize_text, + inputs=gr.Textbox(label="Input Text", placeholder="Enter text to tokenize"), + outputs=gr.Textbox(label="Tokenization Result"), + title="Tokenization Preview", + description=f"Preview how **{MODEL_NAME}** tokenizes text via Ollama", +) + +demo = gr.TabbedInterface( + [chat_interface, tokenization_interface], ["Chat", "Tokenization"] +) + if __name__ == "__main__": - demo.launch(server_name="0.0.0.0", server_port=7860) + try: + demo.launch(server_name="0.0.0.0", server_port=7860) + finally: + if tunnel: + tunnel.close() diff --git a/inference-ollama-gradio-chat/task.yaml b/inference-ollama-gradio-chat/task.yaml index 298d62d..4fd212c 100644 --- a/inference-ollama-gradio-chat/task.yaml +++ b/inference-ollama-gradio-chat/task.yaml @@ -5,7 +5,7 @@ resources: cpus: 2 memory: 4 setup: | - uv pip install --upgrade gradio openai + uv pip install --upgrade gradio openai sshtunnel requests python -c "import gradio; print('gradio installed at:', gradio.__file__)" if ! command -v ollama >/dev/null 2>&1; then curl -fsSL https://ollama.com/install.sh | sh; fi if [ "$(uname)" != "Darwin" ]; then $SUDO apt-get update && $SUDO apt-get install -y pciutils lshw; fi @@ -13,9 +13,7 @@ run: | export OLLAMA_HOST=0.0.0.0:11434 ollama serve > /tmp/ollama.log 2>&1 & sleep 3 - ollama pull $MODEL_NAME > /tmp/ollama-pull.log 2>&1 & - sleep 5 - python -c "import gradio; print('gradio found at:', gradio.__file__)" 2>&1 || echo "gradio NOT found in run python" + ollama pull $MODEL_NAME OLLAMA_BASE_URL=http://127.0.0.1:11434 MODEL_NAME=$MODEL_NAME python ~/inference-ollama-gradio-chat/gradio_chat.py > /tmp/gradio.log 2>&1 & sleep 5 tail -f /tmp/ollama.log /tmp/ollama-pull.log /tmp/gradio.log /tmp/ngrok.log