diff --git a/pyproject.toml b/pyproject.toml index d3ffacb..71a5a56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ server = [ "Flask", "Werkzeug", "gunicorn", + "Flask-Cors", + "Flask-Compress", + "Flask-Caching", ] [tool.ruff] diff --git a/sign_language_segmentation/server.py b/sign_language_segmentation/server.py index 97e0d5c..8223762 100644 --- a/sign_language_segmentation/server.py +++ b/sign_language_segmentation/server.py @@ -4,11 +4,19 @@ from pathlib import Path from flask import Flask, request, abort, make_response, jsonify +from flask_caching import Cache +from flask_compress import Compress +from flask_cors import CORS from pose_format import Pose from sign_language_segmentation.bin import segment_pose app = Flask(__name__) +compress = Compress(app) +cache = Cache(app, config={"CACHE_TYPE": "SimpleCache"}) +CORS(app, resources={r"/": {"methods": ["GET", "OPTIONS"]}}) + +CACHE_TTL = 86400 # 1 day in seconds def resolve_path(uri: str): @@ -16,6 +24,22 @@ def resolve_path(uri: str): return uri.replace("gs://", "/mnt/") +def load_pose(uri: str) -> Pose: + pose_file_path = Path(resolve_path(uri)) + if not pose_file_path.exists(): + raise FileNotFoundError(f"File does not exist: {uri}") + with pose_file_path.open("rb") as f: + return Pose.read(f) + + +def tiers_to_seconds(tiers: dict, fps: float) -> dict: + return { + tier: [{"start": round(seg["start"] / fps, 3), "end": round(seg["end"] / fps, 3)} + for seg in segments] + for tier, segments in tiers.items() + } + + @app.errorhandler(Exception) def handle_exception(e): print("Exception", e) @@ -38,6 +62,30 @@ def health_check(): return make_response(jsonify(body), 200) +@app.route("/", methods=['GET', 'OPTIONS']) +@compress.compressed() +@cache.cached(timeout=CACHE_TTL, query_string=True) +def get_segments(): + if request.method == 'OPTIONS': + return make_response("", 204) + + pose_uri = request.args.get("pose") + if not pose_uri: + abort(make_response(jsonify(message="Missing `pose` query parameter"), 400)) + + pose = load_pose(pose_uri) + + if len(pose.body.data) == 1: + return jsonify(sign=[], sentence=[]) + + _eaf, tiers = segment_pose(pose) + result = tiers_to_seconds(tiers, pose.body.fps) + + response = make_response(jsonify(sign=result["SIGN"], sentence=result["SENTENCE"])) + response.headers["Cache-Control"] = f"public, max-age={CACHE_TTL}" + return response + + @app.route("/", methods=['POST']) def pose_segmentation(): body = request.get_json() @@ -49,18 +97,12 @@ def pose_segmentation(): if output_file_path.exists(): return make_response(jsonify(message="Output file already exists", path=body["output"]), 208) - pose_file_path = Path(resolve_path(body["input"])) - if not pose_file_path.exists(): - raise FileNotFoundError("File does not exist") - - with pose_file_path.open("rb") as f: - pose = Pose.read(f) + pose = load_pose(body["input"]) if len(pose.body.data) == 1: - # segment_pose would error on a single-frame pose return make_response(jsonify(message="Pose has only one frame, no segmentation needed", path=body["output"]), 200) - eaf, tiers = segment_pose(pose) + eaf, _tiers = segment_pose(pose) output_file_path.parent.mkdir(parents=True, exist_ok=True) print("Saving .eaf to disk ...")