-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_example.py
More file actions
63 lines (50 loc) · 1.87 KB
/
inference_example.py
File metadata and controls
63 lines (50 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
NeoLLM - Developer API Integration Example
This script shows how to load your trained model and use it programmatically
in your own applications (like a Discord bot, Discord backend, FastAPI backend, etc.).
"""
from neollm.inference.engine import InferenceEngine
# 1. Load the engine perfectly configured directly from your checkpoint
# Make sure to point this to the checkpoint you want to use (e.g. your small model)
CHECKPOINT_PATH = "checkpoints/small/latest.pt"
print(f"Loading model into VRAM from {CHECKPOINT_PATH}...")
engine = InferenceEngine.from_checkpoint(
checkpoint_path=CHECKPOINT_PATH,
tokenizer_path="data/tokenizer.json"
)
print("Model loaded successfully!\n")
# 2. Generate text programmatically
prompt = "The future of artificial intelligence is"
print(f"Prompt: {prompt}")
print("Generating response...")
# Generate parameters:
# - max_new_tokens: How many tokens to generate
# - temperature: How random the output is (0.0 = deterministic, 1.0 = creative)
# - top_k / top_p: Controls vocabulary sampling (top_p=0.9 is standard)
response = engine.generate(
prompt=prompt,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
greedy=False
)
print(f"\nResponse:\n{response}")
# =====================================================================
# IF USING FASTAPI (Web Backend Example)
# =====================================================================
"""
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
# Load the engine ONCE when the app starts
engine = InferenceEngine.from_checkpoint("checkpoints/small/latest.pt", "data/tokenizer.json")
class ChatRequest(BaseModel):
prompt: str
max_tokens: int = 100
@app.post("/generate")
def generate_text(request: ChatRequest):
return {
"prompt": request.prompt,
"response": engine.generate(request.prompt, max_new_tokens=request.max_tokens)
}
"""