-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp_fastapi.py
More file actions
106 lines (82 loc) · 3.81 KB
/
app_fastapi.py
File metadata and controls
106 lines (82 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import shutil
from pathlib import Path
from typing import Annotated
from urllib.parse import quote
import matplotlib
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
matplotlib.use("Agg")
# Import from the new shared webapp package
from webapp import config, logic, state, utils
utils.download_sample_images()
# --- FastAPI App Initialization ---
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# --- Web App Routes ---
@app.get("/", response_class=HTMLResponse)
async def home(request: Request, message: str = None, category: str = "success"):
logic.check_training_status()
sample_images = [f.name for f in config.SAMPLES_FOLDER.glob("*.jpg")]
return templates.TemplateResponse(
"index.html",
{
"request": request,
"models": logic.get_available_models(),
"training_status": state.training_log,
"device_info": logic.get_device_info(),
"message": message,
"category": category,
"sample_images": sample_images,
},
)
@app.get("/status", response_class=JSONResponse)
async def status():
just_finished, message = logic.check_training_status()
if just_finished:
state.training_log["message"] = message
else:
state.training_log.pop("message", None)
return JSONResponse(content=state.training_log)
@app.post("/train")
async def train(
model: Annotated[str, Form()] = ...,
data_name: Annotated[str, Form()] = ...,
epochs: Annotated[int, Form()] = ...,
learning_rate: Annotated[float, Form()] = ...,
batch_size: Annotated[int, Form()] = ...,
):
success, message = logic.start_training_process(model, data_name, epochs, learning_rate, batch_size)
# Redirect with a message for the user
category = "success" if success else "error"
return RedirectResponse(url=f"/?message={quote(message)}&category={category}", status_code=303)
@app.post("/cancel_training")
async def cancel_training():
success, message = logic.cancel_current_training()
category = "success" if success else "error"
return RedirectResponse(url=f"/?message={quote(message)}&category={category}", status_code=303)
@app.get("/predict_sample", response_class=HTMLResponse)
async def predict_sample(request: Request, model_path: str, image_name: str):
image_path = config.SAMPLES_FOLDER / image_name
image_base64, error = logic.perform_prediction(model_path_str=model_path, image_path=image_path)
return templates.TemplateResponse("result.html", {"request": request, "result_image": image_base64, "error": error})
@app.post("/predict", response_class=HTMLResponse)
async def predict(
request: Request,
model_path: Annotated[str, Form()] = ...,
image_file: Annotated[UploadFile, File()] = ...,
):
if not image_file.filename:
return templates.TemplateResponse("result.html", {"request": request, "error": "Please upload an image file."})
filename = image_file.filename
# --- Image Format Validation ---
if not logic.is_allowed_file(filename):
error_msg = "Invalid file type. Please upload a PNG, JPG, JPEG, or GIF image."
return templates.TemplateResponse("result.html", {"request": request, "error": error_msg})
image_path = config.UPLOAD_FOLDER / image_file.filename
with Path.open(image_path, "wb") as buffer:
shutil.copyfileobj(image_file.file, buffer)
image_base64, error = logic.perform_prediction(model_path_str=model_path, image_path=image_path)
return templates.TemplateResponse("result.html", {"request": request, "result_image": image_base64, "error": error})