diff --git a/Dockerfile b/Dockerfile index 728a1a9..9ddd2f7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,8 @@ ARG COG_VERSION FROM r8.im/${COG_REPO}/${COG_MODEL}@sha256:${COG_VERSION} +ENV REQUEST_TIMEOUT=600 + # Install necessary packages and Python 3.10 RUN apt-get update && apt-get upgrade -y && \ apt-get install -y --no-install-recommends software-properties-common curl git openssh-server && \ diff --git a/src/handler.py b/src/handler.py index 669f2e0..4a73942 100644 --- a/src/handler.py +++ b/src/handler.py @@ -1,15 +1,18 @@ import time import subprocess +import os import runpod import requests from requests.adapters import HTTPAdapter, Retry +TIMEOUT = int(os.environ.get("_RUNPOD_REQUEST_TIMEOUT", "600")) + LOCAL_URL = "http://127.0.0.1:5000" cog_session = requests.Session() retries = Retry(total=10, backoff_factor=0.1, status_forcelist=[502, 503, 504]) -cog_session.mount('http://', HTTPAdapter(max_retries=retries)) +cog_session.mount("http://", HTTPAdapter(max_retries=retries)) # ----------------------------- Start API Service ---------------------------- # @@ -21,9 +24,9 @@ # Automatic Functions # # ---------------------------------------------------------------------------- # def wait_for_service(url): - ''' + """ Check if the service is ready to receive requests. - ''' + """ while True: try: health = requests.get(url, timeout=120) @@ -42,11 +45,18 @@ def wait_for_service(url): def run_inference(inference_request): - ''' + """ Run inference on a request. - ''' - response = cog_session.post(url=f'{LOCAL_URL}/predictions', - json=inference_request, timeout=600) + """ + response = cog_session.post( + url=f"{LOCAL_URL}/predictions", + json=inference_request, + timeout=TIMEOUT, + ) + if response.status_code != 200: + print(response.status_code) + print(response.text) + print(response.text) return response.json() @@ -54,9 +64,9 @@ def run_inference(inference_request): # RunPod Handler # # ---------------------------------------------------------------------------- # def handler(event): - ''' + """ This is the handler function that will be called by the serverless. - ''' + """ json = run_inference({"input": event["input"]}) @@ -64,8 +74,9 @@ def handler(event): if __name__ == "__main__": - wait_for_service(url=f'{LOCAL_URL}/health-check') + wait_for_service(url=f"{LOCAL_URL}/health-check") print("Cog API Service is ready. Starting RunPod serverless handler...") + print(f"Using request timeout of {TIMEOUT}s") runpod.serverless.start({"handler": handler})