forked from belalanne/tabpfn-cloud-function
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
144 lines (120 loc) · 4.85 KB
/
main.py
File metadata and controls
144 lines (120 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import functions_framework
import os
from datetime import datetime
import json
import logging
from dotenv import load_dotenv
from predictor import TransactionPredictor
from google.cloud import storage
from google.api_core import retry
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Constants for GCS
GCS_BUCKET = os.getenv('GCS_BUCKET', 'your-bucket-name')
MODEL_PATH = "models/tabpfn-client"
# Global predictor instance
predictor = None
def initialize_predictor(request_id="init"):
"""Initialize the global predictor instance."""
global predictor
if predictor is None:
try:
# Initialize predictor with GCS configuration
raw_use_mock = os.getenv('USE_MOCK', '')
raw_use_gcs = os.getenv('USE_GCS', '')
# Log the actual environment variables for debugging
logger.info(f"[{request_id}] Environment variables - USE_MOCK: '{raw_use_mock}', USE_GCS: '{raw_use_gcs}'")
use_mock = raw_use_mock.lower() == 'true'
use_gcs = raw_use_gcs.lower() == 'true'
logger.info(f"[{request_id}] Using mock: {use_mock}, Using GCS: {use_gcs}")
predictor = TransactionPredictor(
model_dir=MODEL_PATH,
use_mock=use_mock,
use_gcs=use_gcs,
gcs_bucket=GCS_BUCKET
)
predictor.initialize()
logger.info("Predictor initialization completed successfully")
except Exception as e:
logger.error(f"Failed to initialize predictor: {str(e)}")
raise
@functions_framework.http
def infer_category(request):
"""HTTP Cloud Function to infer transaction category."""
request_id = datetime.utcnow().strftime('%Y%m%d_%H%M%S_%f')
logger.info(f"Processing request {request_id}")
# Set CORS headers for the preflight request
if request.method == 'OPTIONS':
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'POST',
'Access-Control-Allow-Headers': 'Content-Type',
'Access-Control-Max-Age': '3600'
}
return ('', 204, headers)
# Set CORS headers for the main request
headers = {
'Access-Control-Allow-Origin': '*',
'Content-Type': 'application/json'
}
try:
# Initialize predictor if needed
if predictor is None:
logger.info(f"[{request_id}] Initializing predictor...")
initialize_predictor(request_id)
# Get request data
request_json = request.get_json()
if not request_json:
logger.warning(f"[{request_id}] No JSON data in request")
return (json.dumps({
'error': 'No JSON data provided',
'success': False,
'request_id': request_id
}), 400, headers)
if 'transactions' not in request_json:
logger.warning(f"[{request_id}] No transactions in request data")
return (json.dumps({
'error': 'No transactions provided',
'success': False,
'request_id': request_id
}), 400, headers)
transactions = request_json['transactions']
if not transactions:
logger.warning(f"[{request_id}] Empty transactions list")
return (json.dumps({
'error': 'Empty transactions list',
'success': False,
'request_id': request_id
}), 400, headers)
logger.info(f"[{request_id}] Processing {len(transactions)} transactions")
# Get predictions
try:
results = predictor.predict(transactions)
response_data = {
'success': True,
'results': results,
'request_id': request_id,
'mode': 'mock' if predictor.use_mock else 'smart-categories'
}
logger.info(f"[{request_id}] Successfully processed {len(results)} transactions")
return (json.dumps(response_data), 200, headers)
except Exception as e:
logger.error(f"[{request_id}] Error during prediction: {str(e)}")
return (json.dumps({
'error': str(e),
'success': False,
'request_id': request_id
}), 500, headers)
except Exception as e:
logger.error(f"[{request_id}] Error in infer_category: {str(e)}")
return (json.dumps({
'error': str(e),
'success': False,
'request_id': request_id
}), 500, headers)