-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
160 lines (136 loc) · 7.85 KB
/
app.py
File metadata and controls
160 lines (136 loc) · 7.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from flask import Flask, request, jsonify, render_template, url_for
import requests # To call the inference endpoint
import os
import logging
import sys
app = Flask(__name__)
# Configure logging
app.logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler(sys.stdout)
log_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
app.logger.addHandler(log_handler)
# --- Configuration for your Hugging Face Inference Endpoint ---
# These should be set as environment variables
INFERENCE_ENDPOINT_URL = os.environ.get("INFERENCE_ENDPOINT_URL")
INFERENCE_ENDPOINT_API_KEY = os.environ.get("INFERENCE_ENDPOINT_API_KEY") # Your HF API Token
if not INFERENCE_ENDPOINT_URL:
app.logger.warning("INFERENCE_ENDPOINT_URL environment variable is not set. "/
"The /generate endpoint will likely fail.")
if not INFERENCE_ENDPOINT_API_KEY:
app.logger.warning("INFERENCE_ENDPOINT_API_KEY environment variable is not set. "/
"Requests to a secured HF Inference Endpoint may fail.")
# CheXpert classes for the frontend (as in your original app.py and handler.py)
# This list should ideally match what your model/handler.py on the HF endpoint expects.
CHEXPERT_CLASSES = [
"No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity",
"Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis",
"Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"
]
@app.route('/')
def home():
"""Serves the main HTML page."""
return render_template('index.html', chexpert_classes=CHEXPERT_CLASSES)
@app.route('/generate', methods=['POST'])
def generate():
"""
Receives parameters from the frontend, constructs a payload,
and calls the external Hugging Face Inference Endpoint.
"""
if not INFERENCE_ENDPOINT_URL:
app.logger.error("INFERENCE_ENDPOINT_URL is not configured.")
return jsonify({"error": "Application is not configured to call the inference service."}), 500
data = request.get_json()
if not data:
app.logger.warning("Received empty or invalid JSON payload for /generate.")
return jsonify({"error": "Invalid JSON payload"}), 400
app.logger.info(f"Received request for /generate with data: {data}")
# --- Parameters from the frontend, matching index.html and handler.py expectations ---
model_variant = data.get('model_variant', 'conditional')
# Default pipeline_type to model_variant if not provided or empty by frontend
pipeline_type_from_frontend = data.get('pipeline_type')
pipeline_type = pipeline_type_from_frontend if pipeline_type_from_frontend else model_variant
conditions = data.get('conditions', ["No Finding"])
guidance_scale = float(data.get('guidance_scale', 3.0))
seed = int(data.get('seed', 42))
num_inference_steps = int(data.get('num_inference_steps', 50))
p_mask = float(data.get('p_mask', 0.9)) # For ambient pipeline, if applicable
# --- Prepare payload for your Hugging Face Inference Endpoint ---
# This structure should match what your handler.py (running on HF) expects in its __call__ method.
# Based on your handler.py, it expects an "inputs" dictionary.
payload = {
"inputs": {
"model_variant": model_variant,
"pipeline_type": pipeline_type,
"conditions": conditions,
"guidance_scale": guidance_scale,
"seed": seed,
"num_inference_steps": num_inference_steps,
"p_mask": p_mask
# Add any other parameters your HF endpoint's handler.py might require
}
}
app.logger.info(f"Prepared payload for HF Inference Endpoint: {payload}")
headers = {}
if INFERENCE_ENDPOINT_API_KEY:
headers["Authorization"] = f"Bearer {INFERENCE_ENDPOINT_API_KEY}"
headers["Content-Type"] = "application/json" # Standard for HF Inference API
try:
app.logger.info(f"Sending request to HF Inference Endpoint: {INFERENCE_ENDPOINT_URL}")
response = requests.post(INFERENCE_ENDPOINT_URL, json=payload, headers=headers, timeout=120) # 2 min timeout
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
# Assuming the HF endpoint returns JSON.
# If it returns a raw image, response.content would be used.
# Your handler.py returns a dict: {"generated_image_base64": ..., ...}
result_data = response.json()
# Check if the response is a list (some HF endpoints return a list with one item)
if isinstance(result_data, list) and len(result_data) > 0:
result_data = result_data[0] # Take the first item if it's a list response
image_base64 = result_data.get("generated_image_base64")
if not image_base64:
app.logger.error(f"No 'generated_image_base64' in response from inference service. Response: {result_data}")
return jsonify({"error": "No image data in response from inference service. Check endpoint logs."}), 500
# Also pass back model_variant_used and pipeline_used if the endpoint returns them
model_variant_used = result_data.get("model_variant_used", model_variant)
pipeline_used = result_data.get("pipeline_used", pipeline_type)
app.logger.info("Successfully received response from HF Inference Endpoint.")
return jsonify({
"image_base64": image_base64,
"model_variant_used": model_variant_used,
"pipeline_used": pipeline_used
})
except requests.exceptions.Timeout:
app.logger.error(f"Timeout when calling inference endpoint: {INFERENCE_ENDPOINT_URL}")
return jsonify({"error": "Request to inference service timed out."}), 504 # Gateway Timeout
except requests.exceptions.HTTPError as e:
app.logger.error(f"HTTP error calling inference endpoint: {e.response.status_code} - {e.response.text}")
error_detail = f"Inference service returned HTTP error {e.response.status_code}."
try:
# Try to get more specific error from HF endpoint response
error_payload = e.response.json()
if "error" in error_payload:
error_detail += f" Message: {error_payload['error']}"
except ValueError: # Not JSON
error_detail += f" Response: {e.response.text[:200]}" # First 200 chars
return jsonify({"error": error_detail}), e.response.status_code
except requests.exceptions.RequestException as e:
app.logger.error(f"Error calling inference endpoint: {e}")
return jsonify({"error": f"Error calling inference service: {str(e)}"}), 500
except Exception as e:
app.logger.error(f"An unexpected error occurred: {e}", exc_info=True)
return jsonify({"error": "An unexpected server error occurred."}), 500
if __name__ == '__main__':
# Ensure the 'templates' directory exists for render_template
if not os.path.exists("templates"):
os.makedirs("templates")
app.logger.info("Created 'templates' directory.")
# Create a dummy index.html if it doesn't exist to prevent startup error
# The actual index.html should be more comprehensive (provided separately by me previously)
dummy_index_path = os.path.join("templates", "index.html")
if not os.path.exists(dummy_index_path):
with open(dummy_index_path, "w") as f:
f.write("<h1>ML Model Demo (Basic Placeholder)</h1><p>This is a placeholder index.html. Replace it with the full frontend code.</p>")
app.logger.info(f"Created dummy '{dummy_index_path}'. Please replace it with the full UI.")
port = int(os.environ.get('PORT', 8080))
debug_mode = os.environ.get('FLASK_DEBUG', 'True').lower() in ['true', '1', 't']
app.logger.info(f"Starting Flask app on host 0.0.0.0, port {port} with debug mode {'on' if debug_mode else 'off'}")
app.run(debug=debug_mode, host='0.0.0.0', port=port)