Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ server = [
"Flask",
"Werkzeug",
"gunicorn",
"Flask-Cors",
"Flask-Compress",
"Flask-Caching",
]

[tool.ruff]
Expand Down
58 changes: 50 additions & 8 deletions sign_language_segmentation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,42 @@
from pathlib import Path

from flask import Flask, request, abort, make_response, jsonify
from flask_caching import Cache
from flask_compress import Compress
from flask_cors import CORS
from pose_format import Pose

from sign_language_segmentation.bin import segment_pose

app = Flask(__name__)
compress = Compress(app)
cache = Cache(app, config={"CACHE_TYPE": "SimpleCache"})
CORS(app, resources={r"/": {"methods": ["GET", "OPTIONS"]}})

CACHE_TTL = 86400 # 1 day in seconds


def resolve_path(uri: str):
# Map gs:// URIs to the gcsfuse mount point, or return as-is
return uri.replace("gs://", "/mnt/")


def load_pose(uri: str) -> Pose:
pose_file_path = Path(resolve_path(uri))
if not pose_file_path.exists():
raise FileNotFoundError(f"File does not exist: {uri}")
with pose_file_path.open("rb") as f:
return Pose.read(f)


def tiers_to_seconds(tiers: dict, fps: float) -> dict:
return {
tier: [{"start": round(seg["start"] / fps, 3), "end": round(seg["end"] / fps, 3)}
for seg in segments]
for tier, segments in tiers.items()
}


@app.errorhandler(Exception)
def handle_exception(e):
print("Exception", e)
Expand All @@ -38,6 +62,30 @@ def health_check():
return make_response(jsonify(body), 200)


@app.route("/", methods=['GET', 'OPTIONS'])
@compress.compressed()
@cache.cached(timeout=CACHE_TTL, query_string=True)
def get_segments():
if request.method == 'OPTIONS':
return make_response("", 204)

pose_uri = request.args.get("pose")
if not pose_uri:
abort(make_response(jsonify(message="Missing `pose` query parameter"), 400))

pose = load_pose(pose_uri)

if len(pose.body.data) == 1:
return jsonify(sign=[], sentence=[])

_eaf, tiers = segment_pose(pose)
result = tiers_to_seconds(tiers, pose.body.fps)

response = make_response(jsonify(sign=result["SIGN"], sentence=result["SENTENCE"]))
response.headers["Cache-Control"] = f"public, max-age={CACHE_TTL}"
return response


@app.route("/", methods=['POST'])
def pose_segmentation():
body = request.get_json()
Expand All @@ -49,18 +97,12 @@ def pose_segmentation():
if output_file_path.exists():
return make_response(jsonify(message="Output file already exists", path=body["output"]), 208)

pose_file_path = Path(resolve_path(body["input"]))
if not pose_file_path.exists():
raise FileNotFoundError("File does not exist")

with pose_file_path.open("rb") as f:
pose = Pose.read(f)
pose = load_pose(body["input"])

if len(pose.body.data) == 1:
# segment_pose would error on a single-frame pose
return make_response(jsonify(message="Pose has only one frame, no segmentation needed", path=body["output"]), 200)

eaf, tiers = segment_pose(pose)
eaf, _tiers = segment_pose(pose)

output_file_path.parent.mkdir(parents=True, exist_ok=True)
print("Saving .eaf to disk ...")
Expand Down
Loading