-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·63 lines (48 loc) · 2.06 KB
/
main.py
File metadata and controls
executable file
·63 lines (48 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tempfile
import shutil
from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
import os
import io
from google.cloud import storage
import gcs_util
app = FastAPI()
#http response codes: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Status
@app.post("/train_yolo")
def train_yolo(dataset: UploadFile, model: str, epochs: int = 10, batch: int = 16, user_id: str = "default"):
if not dataset.filename:
raise HTTPException(status_code=400, detail="No dataset uploaded")
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as temp_file:
temp_path = temp_file.name
shutil.copyfileobj(dataset.file, temp_file)
try:
blob_path = f"{user_id}/{model}.zip"
if gcs_util.check_gcs_unique_name(f"{user_id}/{model}"):
os.remove(temp_path)
raise HTTPException(status_code=409, detail="Model name already in use")
gcs_path = gcs_util.upload_to_gcs(temp_path, blob_path)
gcs_util.submit_training_job(blob_path, model, epochs, batch)
finally:
try:
os.remove(temp_path)
except Exception:
pass
@app.get("/get_models")
def get_models(user_id: str):
return gcs_util.get_user_models(user_id)
@app.get("/download_model")
def download_model(user_id: str, model_name: str):
return gcs_util.get_user_models(f"{user_id}/{model_name}")
def stream_blob(bucket_name: str, blob_name: str):
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)
if not blob.exists():
raise HTTPException(status_code=404, detail="Model not found")
stream = io.BytesIO()
blob.download_to_file(stream)
stream.seek(0)
return StreamingResponse(stream, media_type="application/octet-stream", headers={"Content-Disposition": f"attachment; filename={os.path.basename(blob_name)}"})
@app.get("/download_model_file")
def download_model_file(user_id: str, model: str, artifact: str = "final_model.quant.onnx"):
blob_name = f"{user_id}/{model}/{artifact}"