From a22e475a7b98492a1e653f440587ad52bf0b8701 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 14:03:51 -0500 Subject: [PATCH 01/34] Fix prediction pipeline robustness and checkpoint compatibility --- catpred/data/cache_utils.py | 9 ++- catpred/data/esm_utils.py | 7 +- catpred/data/utils.py | 76 ++++++++++++++++++---- catpred/models/model.py | 18 ++++-- catpred/utils.py | 20 ++++-- demo_run.py | 122 +++++++++++++++++++++++++++-------- predict.sh | 12 ++-- scripts/create_pdbrecords.py | 90 ++++++++++++++++++-------- 8 files changed, 271 insertions(+), 83 deletions(-) diff --git a/catpred/data/cache_utils.py b/catpred/data/cache_utils.py index 7ee9ede..90cac3a 100644 --- a/catpred/data/cache_utils.py +++ b/catpred/data/cache_utils.py @@ -27,6 +27,13 @@ def md5_hash_fn(s): encoded = s.encode('utf-8') return hashlib.md5(encoded).hexdigest() + +def _torch_load_compat(path): + try: + return torch.load(path, weights_only=False) + except TypeError: + return torch.load(path) + # run once function GLOBAL_RUN_RECORDS = dict() @@ -92,7 +99,7 @@ def inner(t, *args, __cache_key = None, **kwargs): if entry_path.exists(): log(f'cache hit: fetching {t} from {str(entry_path)}') - return torch.load(str(entry_path)) + return _torch_load_compat(str(entry_path)) out = fn(t, *args, **kwargs) diff --git a/catpred/data/esm_utils.py b/catpred/data/esm_utils.py index f45bd65..690c106 100644 --- a/catpred/data/esm_utils.py +++ b/catpred/data/esm_utils.py @@ -157,8 +157,11 @@ def get_single_esm_repr(protein_str): def get_esm_repr(proteins, name, device): if isinstance(proteins, torch.Tensor): proteins = tensor_to_aa_str(proteins) - - get_protein_repr_fn = cache_fn(get_single_esm_repr, path = 'esm/proteins', name = name) + + # Cache by sequence content to avoid collisions when different proteins + # are accidentally given the same pdb/name identifier. + _ = name + get_protein_repr_fn = cache_fn(get_single_esm_repr, path = 'esm/proteins') return calc_protein_representations_with_subunits([proteins], get_protein_repr_fn, device = device) diff --git a/catpred/data/utils.py b/catpred/data/utils.py index a2baaee..341791a 100644 --- a/catpred/data/utils.py +++ b/catpred/data/utils.py @@ -27,6 +27,33 @@ # Increase maximum size of field in the csv processing for the current architecture csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2)) + +def _load_protein_records(protein_records_path: str): + """ + Load protein records JSON from gzip, plain JSON, or accidentally double-gzipped files. + """ + try: + with gzip.open(protein_records_path, "rt", encoding="utf-8") as handle: + return json.load(handle) + except Exception: + try: + with gzip.open(protein_records_path, "rb") as handle: + payload = handle.read() + if payload[:2] == b"\x1f\x8b": + payload = gzip.decompress(payload) + return json.loads(payload.decode("utf-8")) + except Exception: + pass + + try: + with open(protein_records_path, "rt", encoding="utf-8") as handle: + return json.load(handle) + except Exception as plain_err: + raise ValueError( + f'Failed to load protein records from "{protein_records_path}". ' + "Expected a JSON mapping in .json or .json.gz format." + ) from plain_err + def get_header(path: str) -> List[str]: """ Returns the header of a data CSV file. @@ -400,8 +427,7 @@ def get_data(path: str, if protein_records_path is None: protein_records = None else: - with gzip.open(protein_records_path, "rt", encoding="utf-8") as f: - protein_records = json.load(f) + protein_records = _load_protein_records(protein_records_path) if args is not None: # Prefer explicit function arguments but default to args if not provided @@ -498,16 +524,42 @@ def get_data(path: str, if args.smoke_test: if smoke_test_counter>100: break smiles = [row[c] for c in smiles_columns] - pdbpath = row['pdbpath'] - pdbname = pdbpath.split("/")[-1] - protein_record = protein_records[pdbname] - if not protein_records is None: - sequence_features, _ = sequence_feat_getter(protein_record['seq'], - name = pdbname, - device = 'cpu') - protein_record['esm2_feats'] = sequence_features[0] #batch dim - else: - protein_record = None + protein_record = None + + if protein_records is not None: + if 'pdbpath' not in row: + raise ValueError( + f'Missing required "pdbpath" column in {path} at row {i + 2}.' + ) + + pdbpath = (row['pdbpath'] or '').strip() + if pdbpath == '': + raise ValueError( + f'Empty pdbpath found in {path} at row {i + 2}.' + ) + + pdbname = os.path.basename(pdbpath) + protein_record = protein_records.get(pdbname) or protein_records.get(pdbpath) + if protein_record is None: + raise KeyError( + f'No protein record found for pdbpath "{pdbpath}" (basename "{pdbname}") ' + f'in row {i + 2}. Ensure create_pdbrecords.py was run on the same input file.' + ) + + row_sequence = row.get('sequence') + if row_sequence is not None and protein_record.get('seq') != row_sequence: + raise ValueError( + f'Sequence mismatch for pdbpath "{pdbpath}" in row {i + 2}. ' + 'This usually means multiple sequences share the same pdbpath identifier.' + ) + + if 'esm2_feats' not in protein_record: + sequence_features, _ = sequence_feat_getter( + protein_record['seq'], + name=pdbname, + device='cpu' + ) + protein_record['esm2_feats'] = sequence_features[0] # batch dim targets, atom_targets, bond_targets = [], [], [] for column in target_columns: diff --git a/catpred/models/model.py b/catpred/models/model.py index 58dabe6..d9e8a88 100644 --- a/catpred/models/model.py +++ b/catpred/models/model.py @@ -24,6 +24,14 @@ from torch import nn from torch import einsum + +def _torch_load_compat(path): + try: + return torch.load(path, weights_only=False) + except TypeError: + return torch.load(path) + + def exists(val): return val is not None @@ -131,7 +139,7 @@ def create_protein_model(self, args: TrainArgs) -> None: self.seq_embedder = nn.Embedding(21, args.seq_embed_dim, padding_idx=20) #last index is for padding if self.args.add_pretrained_egnn_feats: - self.pretrained_egnn_feats_dict = torch.load(self.args.pretrained_egnn_feats_path) + self.pretrained_egnn_feats_dict = _torch_load_compat(self.args.pretrained_egnn_feats_path) x = list(self.pretrained_egnn_feats_dict.values()) self.pretrained_egnn_feats_avg = torch.stack(x).mean(dim=0) @@ -417,8 +425,10 @@ def seq_to_tensor(seq): esm_feature_arr = [each['esm2_feats'] for each in protein_records] esm_feature_arr = pad_sequence(esm_feature_arr, batch_first=True).to(self.device) - if seq_arr.shape[1]!=esm_feature_arr.shape[1]: - seq_arr = seq_arr[:,:esm_feature_arr.shape[1]:] + if seq_arr.shape[1] != esm_feature_arr.shape[1]: + common_len = min(seq_arr.shape[1], esm_feature_arr.shape[1]) + seq_arr = seq_arr[:, :common_len] + esm_feature_arr = esm_feature_arr[:, :common_len] # project sequence to embed dim seq_outs = self.seq_embedder(seq_arr) @@ -522,4 +532,4 @@ def seq_to_tensor(seq): else: output = nn.functional.softplus(output) + 1 - return output \ No newline at end of file + return output diff --git a/catpred/utils.py b/catpred/utils.py index 7aaa9cf..ecde744 100644 --- a/catpred/utils.py +++ b/catpred/utils.py @@ -25,6 +25,18 @@ from catpred.models.ffn import MultiReadout +def _torch_load_compat(path, map_location=None): + """ + Explicitly disable weights-only loading for backward compatibility with + CatPred checkpoints that store non-tensor objects (e.g., argparse.Namespace). + """ + try: + return torch.load(path, map_location=map_location, weights_only=False) + except TypeError: + # For older torch versions that do not support weights_only. + return torch.load(path, map_location=map_location) + + def makedirs(path: str, isfile: bool = False) -> None: """ Creates a directory given a path to either a directory or file. @@ -110,7 +122,7 @@ def load_checkpoint( debug = info = print # Load model and args - state = torch.load(path, map_location=lambda storage, loc: storage) + state = _torch_load_compat(path, map_location=lambda storage, loc: storage) args = TrainArgs() args.from_dict(vars(state["args"]), skip_unsettable=True) if not pretrained_egnn_feats_path is None: @@ -221,7 +233,7 @@ def load_frzn_model( """ debug = logger.debug if logger is not None else print - loaded_mpnn_model = torch.load(path, map_location=lambda storage, loc: storage) + loaded_mpnn_model = _torch_load_compat(path, map_location=lambda storage, loc: storage) loaded_state_dict = loaded_mpnn_model["state_dict"] loaded_args = loaded_mpnn_model["args"] @@ -443,7 +455,7 @@ def load_scalers( :return: A tuple with the data :class:`~catpred.data.scaler.StandardScaler` and features :class:`~catpred.data.scaler.StandardScaler`. """ - state = torch.load(path, map_location=lambda storage, loc: storage) + state = _torch_load_compat(path, map_location=lambda storage, loc: storage) if state["data_scaler"] is not None: scaler = StandardScaler(state["data_scaler"]["means"], state["data_scaler"]["stds"]) @@ -498,7 +510,7 @@ def load_args(path: str) -> TrainArgs: """ args = TrainArgs() args.from_dict( - vars(torch.load(path, map_location=lambda storage, loc: storage)["args"]), + vars(_torch_load_compat(path, map_location=lambda storage, loc: storage)["args"]), skip_unsettable=True, ) diff --git a/demo_run.py b/demo_run.py index 8a78739..5b77e38 100644 --- a/demo_run.py +++ b/demo_run.py @@ -13,6 +13,7 @@ import time import os +import subprocess import pandas as pd import numpy as np from IPython.display import Image, display @@ -20,8 +21,30 @@ from IPython.display import display, Latex, Math import argparse -def create_csv_sh(parameter, input_file_path, checkpoint_dir): + +def prepare_prediction_inputs(parameter, input_file_path): df = pd.read_csv(input_file_path) + required_columns = {"SMILES", "sequence", "pdbpath"} + missing_columns = required_columns.difference(df.columns) + if missing_columns: + print( + f'Missing required column(s) in input file: {", ".join(sorted(missing_columns))}.' + ) + return None + + conflicting_pdbpaths = ( + df.groupby("pdbpath")["sequence"] + .nunique(dropna=False) + .loc[lambda value: value > 1] + ) + if len(conflicting_pdbpaths) > 0: + preview = ", ".join(conflicting_pdbpaths.index.astype(str).tolist()[:5]) + print( + "Found pdbpath values mapped to multiple sequences. " + f"Each unique sequence must have a unique pdbpath. Examples: {preview}" + ) + return None + smiles_list = df.SMILES seq_list = df.sequence smiles_list_new = [] @@ -45,21 +68,17 @@ def create_csv_sh(parameter, input_file_path, checkpoint_dir): print('Correct your input! Exiting..') return None - input_file_new_path = f'{input_file_path[:-4]}_input.csv' + input_file_base, _ = os.path.splitext(input_file_path) + input_file_new_path = f'{input_file_base}_input.csv' df['SMILES'] = smiles_list_new - df.to_csv(input_file_new_path) - - with open('predict.sh', 'w') as f: - f.write(f''' - TEST_FILE_PREFIX={input_file_new_path[:-4]} - RECORDS_FILE=${{TEST_FILE_PREFIX}}.json - CHECKPOINT_DIR={checkpoint_dir} - - python ./scripts/create_pdbrecords.py --data_file ${{TEST_FILE_PREFIX}}.csv --out_file ${{RECORDS_FILE}} - python predict.py --test_path ${{TEST_FILE_PREFIX}}.csv --preds_path ${{TEST_FILE_PREFIX}}_output.csv --checkpoint_dir $CHECKPOINT_DIR --uncertainty_method mve --smiles_column SMILES --individual_ensemble_predictions --protein_records_path $RECORDS_FILE - ''') + df.to_csv(input_file_new_path, index=False) - return input_file_new_path[:-4]+'_output.csv' + test_file_prefix = input_file_new_path[:-4] + return { + "input_csv": input_file_new_path, + "records_file": f"{test_file_prefix}.json.gz", + "output_csv": f"{test_file_prefix}_output.csv", + } def get_predictions(parameter, outfile): """ @@ -86,6 +105,12 @@ def get_predictions(parameter, outfile): unc_col = f'{target_col}_mve_uncal_var' + missing_cols = [col for col in [target_col, unc_col] if col not in df.columns] + if missing_cols: + raise ValueError( + f'Prediction output is missing required column(s): {", ".join(missing_cols)}' + ) + for _, row in df.iterrows(): model_cols = [col for col in row.index if col.startswith(target_col) and 'model_' in col] @@ -93,12 +118,15 @@ def get_predictions(parameter, outfile): prediction = row[target_col] prediction_linear = np.power(10, prediction) - model_outs = np.array([row[col] for col in model_cols]) - epi_unc = np.var(model_outs) - alea_unc = unc - epi_unc + if model_cols: + model_outs = np.array([row[col] for col in model_cols]) + epi_unc = np.var(model_outs) + else: + epi_unc = 0.0 + alea_unc = max(unc - epi_unc, 0.0) epi_unc = np.sqrt(epi_unc) alea_unc = np.sqrt(alea_unc) - unc = np.sqrt(unc) + unc = np.sqrt(max(unc, 0.0)) pred_col.append(prediction_linear) pred_logcol.append(prediction) @@ -115,22 +143,62 @@ def get_predictions(parameter, outfile): return df def main(args): - print(os.getcwd()) - - outfile = create_csv_sh(args.parameter, args.input_file, args.checkpoint_dir) - if outfile is None: + run_paths = prepare_prediction_inputs(args.parameter, args.input_file) + if run_paths is None: return + outfile = run_paths["output_csv"] print('Predicting.. This will take a while..') - if args.use_gpu: - os.system("export PROTEIN_EMBED_USE_CPU=0;./predict.sh") - else: - os.system("export PROTEIN_EMBED_USE_CPU=1;./predict.sh") + env = os.environ.copy() + env["PROTEIN_EMBED_USE_CPU"] = "0" if args.use_gpu else "1" + + create_records_cmd = [ + "python", + "./scripts/create_pdbrecords.py", + "--data_file", + run_paths["input_csv"], + "--out_file", + run_paths["records_file"], + ] + predict_cmd = [ + "python", + "predict.py", + "--test_path", + run_paths["input_csv"], + "--preds_path", + outfile, + "--checkpoint_dir", + args.checkpoint_dir, + "--uncertainty_method", + "mve", + "--smiles_column", + "SMILES", + "--individual_ensemble_predictions", + "--protein_records_path", + run_paths["records_file"], + ] + + create_records_result = subprocess.run(create_records_cmd, env=env) + if create_records_result.returncode != 0: + print( + f"Protein record generation failed with exit code {create_records_result.returncode}." + ) + return + + predict_result = subprocess.run(predict_cmd, env=env) + if predict_result.returncode != 0: + print(f"Prediction command failed with exit code {predict_result.returncode}.") + return + + if not os.path.exists(outfile): + print(f'Prediction output file was not generated: {outfile}') + return output_final = get_predictions(args.parameter, outfile) filename = outfile.split('/')[-1] - output_final.to_csv(f'../results/{filename}') + os.makedirs('../results', exist_ok=True) + output_final.to_csv(f'../results/{filename}', index=False) print('Output saved to results/', filename) if __name__ == "__main__": diff --git a/predict.sh b/predict.sh index 2caad26..93d8036 100644 --- a/predict.sh +++ b/predict.sh @@ -1,7 +1,9 @@ -TEST_FILE_PREFIX=./demo/batch_kcat_input -RECORDS_FILE=${TEST_FILE_PREFIX}.json -CHECKPOINT_DIR=../data/pretrained/production/kcat/ +#!/usr/bin/env bash +set -euo pipefail -python ./scripts/create_pdbrecords.py --data_file ${TEST_FILE_PREFIX}.csv --out_file ${RECORDS_FILE}.gz +TEST_FILE_PREFIX="${TEST_FILE_PREFIX:-./demo/batch_kcat_input}" +RECORDS_FILE="${RECORDS_FILE:-${TEST_FILE_PREFIX}.json.gz}" +CHECKPOINT_DIR="${CHECKPOINT_DIR:-../data/pretrained/production/kcat}" -python predict.py --test_path ${TEST_FILE_PREFIX}.csv --preds_path ${TEST_FILE_PREFIX}_output.csv --checkpoint_dir $CHECKPOINT_DIR --uncertainty_method mve --smiles_column SMILES --individual_ensemble_predictions --protein_records_path ${RECORDS_FILE}.gz +python ./scripts/create_pdbrecords.py --data_file "${TEST_FILE_PREFIX}.csv" --out_file "${RECORDS_FILE}" +python predict.py --test_path "${TEST_FILE_PREFIX}.csv" --preds_path "${TEST_FILE_PREFIX}_output.csv" --checkpoint_dir "${CHECKPOINT_DIR}" --uncertainty_method mve --smiles_column SMILES --individual_ensemble_predictions --protein_records_path "${RECORDS_FILE}" diff --git a/scripts/create_pdbrecords.py b/scripts/create_pdbrecords.py index 0360b81..54c2736 100644 --- a/scripts/create_pdbrecords.py +++ b/scripts/create_pdbrecords.py @@ -1,42 +1,76 @@ -import pandas as pd import argparse +import gzip import json -def parse_args(): - """Prepare argument parser. +import os - Args: +import pandas as pd - Return: - """ +def parse_args(): parser = argparse.ArgumentParser( - description="Generate json records for test file with only sequences" + description="Generate gzipped JSON protein records from an input CSV file." ) + parser.add_argument("--data_file", required=True, help="Path to CSV file") parser.add_argument( - "--data_file", - help="Path to csv file", + "--out_file", required=True, + help="Output path for gzipped JSON protein records", ) - - parser.add_argument("--out_file", help="output file for json records") + return parser.parse_args() - args = parser.parse_args() - return args -args = parse_args() -df = pd.read_csv(args.data_file) -assert('pdbpath' in df.columns) -assert('sequence' in df.columns) +def _as_clean_str(value): + return value.strip() if isinstance(value, str) else value -import json -dic_full = {} -for ind, row in df.iterrows(): - dic = {} - dic['name'] = row.pdbpath - dic['seq'] = row.sequence - dic_full[row.pdbpath] = dic -import gzip -# Writing the dictionary to a gzipped file -with gzip.open(args.out_file, 'wb') as f: - f.write(json.dumps(dic_full).encode('utf-8')) +def build_records(df: pd.DataFrame, data_file: str) -> dict: + required = {"pdbpath", "sequence"} + missing = required.difference(df.columns) + if missing: + raise ValueError( + f'Missing required column(s) in "{data_file}": {", ".join(sorted(missing))}' + ) + + records = {} + conflicts = [] + + for index, row in df.iterrows(): + row_num = index + 2 # account for header row + pdbpath = _as_clean_str(row["pdbpath"]) + sequence = _as_clean_str(row["sequence"]) + + if not pdbpath: + raise ValueError(f'Empty "pdbpath" in row {row_num} of "{data_file}".') + if not sequence: + raise ValueError(f'Empty "sequence" in row {row_num} of "{data_file}".') + + key = os.path.basename(pdbpath) + existing = records.get(key) + if existing is not None and existing["seq"] != sequence: + conflicts.append((row_num, key)) + continue + + records[key] = {"name": key, "seq": sequence} + + if conflicts: + preview = ", ".join( + [f'{key} (row {row_num})' for row_num, key in conflicts[:5]] + ) + raise ValueError( + "Found pdbpath basenames reused for different sequences. " + f"Each unique sequence must have a unique pdbpath. Examples: {preview}" + ) + + return records + + +def main(): + args = parse_args() + df = pd.read_csv(args.data_file) + records = build_records(df, args.data_file) + with gzip.open(args.out_file, "wt", encoding="utf-8") as handle: + json.dump(records, handle) + + +if __name__ == "__main__": + main() From 00416bb80c0e1f23ef8968ce1b0eaa8d7efbeb25 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 14:04:06 -0500 Subject: [PATCH 02/34] Update setup and usage docs for resilient local workflows --- README.md | 58 ++++++++++++++++++++++++++++++++++++++++++++++++- environment.yml | 1 - 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1933ecd..2a92ed8 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,8 @@ Then proceed to either option below to complete the installation. If installing ```bash mkdir catpred_pipeline catpred_pipeline/results cd catpred_pipeline -wget https://catpred.s3.us-east-1.amazonaws.com/capsule_data_update.tar.gz +wget -c --tries=5 --timeout=30 https://catpred.s3.us-east-1.amazonaws.com/capsule_data_update.tar.gz || \ +wget -c --tries=5 --timeout=30 https://catpred.s3.amazonaws.com/capsule_data_update.tar.gz tar -xzf capsule_data_update.tar.gz git clone https://github.com/maranasgroup/catpred.git cd catpred @@ -74,10 +75,65 @@ conda activate catpred pip install -e . ```` +`stride` is Linux-only and optional for the default demos. If needed for your workflow, install it separately on Linux: + +```bash +conda install -c kimlab stride +``` + ### ๐Ÿ”ฎ Prediction The Jupyter Notebook `batch_demo.ipynb` and the Python script `demo_run.py` show the usage of pre-trained models for prediction. +Input CSV requirements for `demo_run.py` and batch prediction: +- Required columns: `SMILES`, `sequence`, `pdbpath`. +- `pdbpath` must be unique per unique sequence. Reusing the same `pdbpath` for different sequences can produce incorrect cached embeddings. +- Reusing the same `pdbpath` for repeated measurements of the same sequence is supported. + +The helper script used to build protein records is: + +```bash +python ./scripts/create_pdbrecords.py --data_file --out_file +``` + +CatPred currently expects one sequence per row. Multi-protein complexes (e.g., heteromers/homodimers) are not explicitly modeled as separate chains in the default prediction workflow. + +For released benchmark datasets, the number of entries with 3D structure can be smaller than the total sequence/substrate pairs; 3D-derived artifacts are available only for the subset with valid structure mapping. + +### ๐Ÿงช Fine-Tuning On Custom Data + +You can fine-tune CatPred on your own regression targets using `train.py`. + +1. Prepare train/val/test CSVs with at least: +- `SMILES` +- `sequence` +- `pdbpath` (unique per unique sequence) +- one numeric target column (for example: `log10kcat_max`) + +2. Build a protein-records file that covers all `pdbpath` values in your splits: + +```bash +python ./scripts/create_pdbrecords.py --data_file --out_file +``` + +3. Train: + +```bash +python train.py \ + --protein_records_path \ + --data_path \ + --separate_val_path \ + --separate_test_path \ + --dataset_type regression \ + --smiles_columns SMILES \ + --target_columns \ + --add_esm_feats \ + --loss_function mve \ + --save_dir +``` + +For working end-to-end examples, see the training commands in scripts such as `scripts/reproduce_figS10_catpred.sh`. + ### ๐Ÿ”„ Reproducing Publication Results We provide three separate ways for reproducing the results of the publication. diff --git a/environment.yml b/environment.yml index 082e55c..890460c 100644 --- a/environment.yml +++ b/environment.yml @@ -20,7 +20,6 @@ dependencies: - faiss-cpu - pytorch-scatter - pyg - - kimlab::stride - pip: - ipdb - fair-esm From 56c15d44f781ec40584fd00d13deed1dc9f1ff2c Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 14:20:00 -0500 Subject: [PATCH 03/34] Phase 0: harden inference runtime and packaging entrypoints --- .gitignore | 12 ++ catpred/data/esm_utils.py | 57 +++++---- catpred/data/utils.py | 1 - catpred/features/__init__.py | 6 +- catpred/features/featurization.py | 28 +++++ catpred/models/model.py | 26 ++-- catpred/train/make_predictions.py | 196 +++++++++++++++--------------- catpred/train/run_training.py | 2 +- catpred/utils.py | 5 +- demo_run.py | 8 +- setup.cfg | 5 - setup.py | 1 + 12 files changed, 194 insertions(+), 153 deletions(-) diff --git a/.gitignore b/.gitignore index b155533..a3b307a 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,15 @@ catpred/data/__pycache__/esm_utils.cpython-312.pyc catpred/data/__pycache__/data.cpython-312.pyc catpred/data/__pycache__/cache_utils.cpython-312.pyc catpred/data/__pycache__/__init__.cpython-312.pyc + +# Generic Python/OS artifacts +__pycache__/ +*.pyo +*.pyd +.DS_Store +*.egg-info/ +.ipynb_checkpoints/ + +# Local validation artifacts +.e2e-assets/ +.e2e-tests/ diff --git a/catpred/data/esm_utils.py b/catpred/data/esm_utils.py index 690c106..359613a 100644 --- a/catpred/data/esm_utils.py +++ b/catpred/data/esm_utils.py @@ -1,12 +1,9 @@ import torch import os -import re -from pathlib import Path from functools import partial import esm from torch.nn.utils.rnn import pad_sequence -from .cache_utils import cache_fn, run_once, md5_hash_fn -# import esm.inverse_folding as esm_if +from .cache_utils import cache_fn, run_once def exists(val): return val is not None @@ -20,7 +17,12 @@ def to_device(t, *, device): def cast_tuple(t): return (t,) if not isinstance(t, tuple) else t -PROTEIN_EMBED_USE_CPU = os.getenv('PROTEIN_EMBED_USE_CPU', None) is not None +def _env_flag(name: str, default: str = "0") -> bool: + raw = os.getenv(name, default) + return str(raw).strip().lower() in {"1", "true", "yes", "on"} + + +PROTEIN_EMBED_USE_CPU = _env_flag("PROTEIN_EMBED_USE_CPU", "0") if PROTEIN_EMBED_USE_CPU: print('calculating protein embed only on cpu') @@ -32,9 +34,6 @@ def cast_tuple(t): 'tokenizer': None } -# general helper functions -import ipdb - def calc_protein_representations_with_subunits(proteins, get_repr_fn, *, device): representations = [] for subunits in proteins: @@ -102,8 +101,6 @@ def calc_protein_representations_with_subunits(proteins, get_repr_fn, *, device) } AA_STR_TO_INT_MAP = {v:k for k,v in INT_TO_AA_STR_MAP.items()} -import ipdb - def tensor_to_aa_str(t): str_seqs = [] #ipdb.set_trace() @@ -116,6 +113,7 @@ def tensor_to_aa_str(t): def init_esm(): model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter() + model.eval() if not PROTEIN_EMBED_USE_CPU: model = model.cuda() @@ -124,6 +122,8 @@ def init_esm(): @run_once('init_esm_if') def init_esm_if(): + import esm.inverse_folding as esm_if + model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() batch_converter = esm_if.util.CoordBatchConverter(alphabet, 2048) @@ -145,7 +145,7 @@ def get_single_esm_repr(protein_str): batch_tokens = batch_tokens[:, :ESM_MAX_LENGTH] if not PROTEIN_EMBED_USE_CPU: - batch_tokens = batch_tokens.cuda() + batch_tokens = batch_tokens.to(next(model.parameters()).device) with torch.no_grad(): results = model(batch_tokens, repr_layers=[33]) @@ -161,20 +161,30 @@ def get_esm_repr(proteins, name, device): # Cache by sequence content to avoid collisions when different proteins # are accidentally given the same pdb/name identifier. _ = name - get_protein_repr_fn = cache_fn(get_single_esm_repr, path = 'esm/proteins') + get_protein_repr_fn = cache_fn(get_single_esm_repr, path='esm/proteins') + + return calc_protein_representations_with_subunits([proteins], get_protein_repr_fn, device=device) - return calc_protein_representations_with_subunits([proteins], get_protein_repr_fn, device = device) +def get_coords(pdbpath: str, chain_id: str = "A"): + try: + import esm.inverse_folding as esm_if + except ImportError as exc: + raise ImportError( + "ESM inverse folding is not installed. Install optional esm inverse-folding dependencies " + "to use get_coords()." + ) from exc -def get_coords(pdbpath): - #init_esm_if() - #model, batch_converter = GLOBAL_VARIABLES['esmif_model'] - addpath = '/home/ubuntu/CatPred-DB/CatPred-DB/' - coords = esm_if.util.load_coords(addpath+pdbpath, 'A') - return coords + if not os.path.exists(pdbpath): + raise FileNotFoundError(f'PDB file not found: "{pdbpath}"') + + return esm_if.util.load_coords(pdbpath, chain_id) def get_esm_tokens(protein_str, device): if isinstance(protein_str, torch.Tensor): - proteins = tensor_to_aa_str(proteins) + protein_str = tensor_to_aa_str(protein_str) + if len(protein_str) != 1: + raise ValueError("get_esm_tokens expects a single protein sequence.") + protein_str = protein_str[0] init_esm() model, batch_converter = GLOBAL_VARIABLES['model'] @@ -187,8 +197,8 @@ def get_esm_tokens(protein_str, device): batch_tokens = batch_tokens[:, :ESM_MAX_LENGTH] - if device!='cpu': - batch_tokens = batch_tokens.cuda() + if device != 'cpu': + batch_tokens = batch_tokens.to(device) return batch_tokens @@ -204,7 +214,8 @@ def get_esm_tokens(protein_str, device): def get_protein_embedder(name): allowed_protein_embedders = list(PROTEIN_REPR_CONFIG.keys()) - assert name in allowed_protein_embedders, f"must be one of {', '.join(allowed_protein_embedders)}" + if name not in allowed_protein_embedders: + raise ValueError(f"Unsupported protein embedder '{name}'. Must be one of {', '.join(allowed_protein_embedders)}") config = PROTEIN_REPR_CONFIG[name] return config diff --git a/catpred/data/utils.py b/catpred/data/utils.py index 341791a..34f03be 100644 --- a/catpred/data/utils.py +++ b/catpred/data/utils.py @@ -15,7 +15,6 @@ import numpy as np import pandas as pd from tqdm import tqdm -import ipdb from .esm_utils import get_protein_embedder, get_coords from .data import MoleculeDatapoint, MoleculeDataset, make_mols diff --git a/catpred/features/__init__.py b/catpred/features/__init__.py index 596e02f..94220bf 100644 --- a/catpred/features/__init__.py +++ b/catpred/features/__init__.py @@ -4,7 +4,8 @@ from .featurization import atom_features, bond_features, BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph, \ MolGraph, onek_encoding_unk, set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h, \ - set_adding_hs, set_keeping_atom_map, is_reaction, is_explicit_h, is_adding_hs, is_keeping_atom_map, is_mol, reset_featurization_parameters + set_adding_hs, set_keeping_atom_map, is_reaction, is_explicit_h, is_adding_hs, is_keeping_atom_map, is_mol, \ + reset_featurization_parameters, featurization_session from .utils import load_features, save_features, load_valid_atom_or_bond_features __all__ = [ @@ -37,5 +38,6 @@ 'load_features', 'save_features', 'load_valid_atom_or_bond_features', - 'reset_featurization_parameters' + 'reset_featurization_parameters', + 'featurization_session' ] diff --git a/catpred/features/featurization.py b/catpred/features/featurization.py index 7a9b5d2..b272475 100644 --- a/catpred/features/featurization.py +++ b/catpred/features/featurization.py @@ -1,6 +1,9 @@ from typing import List, Tuple, Union from itertools import zip_longest import logging +from contextlib import contextmanager +from copy import deepcopy +import threading from rdkit import Chem import torch @@ -50,6 +53,31 @@ def __init__(self) -> None: # Create a global parameter object for reference throughout this module PARAMS = Featurization_parameters() +_FEATURIZATION_LOCK = threading.RLock() + + +def _clone_featurization_parameters(params: Featurization_parameters) -> Featurization_parameters: + cloned = Featurization_parameters() + cloned.__dict__.update(deepcopy(params.__dict__)) + return cloned + + +@contextmanager +def featurization_session(): + """ + Protects global featurization state for one train/predict transaction. + + This is an interim safety mechanism for multi-request serving environments + where concurrent requests can otherwise mutate shared featurization globals. + """ + global PARAMS + + with _FEATURIZATION_LOCK: + snapshot = _clone_featurization_parameters(PARAMS) + try: + yield + finally: + PARAMS = snapshot def reset_featurization_parameters(logger: logging.Logger = None) -> None: diff --git a/catpred/models/model.py b/catpred/models/model.py index d9e8a88..5b50099 100644 --- a/catpred/models/model.py +++ b/catpred/models/model.py @@ -1,28 +1,18 @@ from typing import List, Union, Tuple -from rotary_embedding_torch import RotaryEmbedding +import os import numpy as np from rdkit import Chem import torch import torch.nn as nn +from rotary_embedding_torch import RotaryEmbedding +from torch.nn.utils.rnn import pad_sequence + from .mpn import MPN from .ffn import build_ffn, MultiReadout from catpred.args import TrainArgs from catpred.features import BatchMolGraph from catpred.nn_utils import initialize_weights -from torch.nn.utils.rnn import pad_sequence - -from collections import OrderedDict -import ipdb -import os - -import torch -import torch.nn as nn - -import ipdb -import torch -from torch import nn -from torch import einsum def _torch_load_compat(path): @@ -34,9 +24,6 @@ def _torch_load_compat(path): def exists(val): return val is not None - -def default(val, d): - return val if exists(val) else d class AttentivePooling(nn.Module): def __init__(self, input_size=1280, hidden_size=1280): @@ -238,7 +225,10 @@ def create_ffn(self, args: TrainArgs) -> None: first_linear_dim_now += 1280 if args.add_pretrained_egnn_feats: first_linear_dim_now+=128 - assert(os.path.exists(args.pretrained_egnn_feats_path)) + if not os.path.exists(args.pretrained_egnn_feats_path): + raise FileNotFoundError( + f'Pretrained EGNN features file not found: "{args.pretrained_egnn_feats_path}"' + ) self.readout = build_ffn( first_linear_dim=first_linear_dim_now, diff --git a/catpred/train/make_predictions.py b/catpred/train/make_predictions.py index 9e285ad..419507c 100644 --- a/catpred/train/make_predictions.py +++ b/catpred/train/make_predictions.py @@ -7,7 +7,7 @@ from catpred.args import PredictArgs, TrainArgs from catpred.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset, StandardScaler, AtomBondScaler from catpred.utils import load_args, load_checkpoint, load_scalers, makedirs, timeit, update_prediction_args -from catpred.features import set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h, set_adding_hs, set_keeping_atom_map, reset_featurization_parameters +from catpred.features import set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h, set_adding_hs, set_keeping_atom_map, reset_featurization_parameters, featurization_session from catpred.models import MoleculeModel from catpred.uncertainty import UncertaintyCalibrator, build_uncertainty_calibrator, UncertaintyEstimator, build_uncertainty_evaluator from catpred.multitask_utils import reshape_values @@ -240,8 +240,14 @@ def predict_and_save( # Save results if save_results: print(f"Saving predictions to {args.preds_path}") - assert len(test_data) == len(preds) - assert len(test_data) == len(unc) + if len(test_data) != len(preds): + raise RuntimeError( + f"Prediction count mismatch: expected {len(test_data)} rows, got {len(preds)} predictions." + ) + if len(test_data) != len(unc): + raise RuntimeError( + f"Uncertainty count mismatch: expected {len(test_data)} rows, got {len(unc)} rows." + ) makedirs(args.preds_path, isfile=True) @@ -401,110 +407,108 @@ def make_predictions( num_models = len(args.checkpoint_paths) - set_features(args, train_args) + with featurization_session(): + set_features(args, train_args) - # Note: to get the invalid SMILES for your data, use the get_invalid_smiles_from_file or get_invalid_smiles_from_list functions from data/utils.py - full_data, test_data, test_data_loader, full_to_valid_indices = load_data( - args, smiles - ) - - if args.uncertainty_method is None and (args.calibration_method is not None or args.evaluation_methods is not None): - if args.dataset_type in ['classification', 'multiclass']: - args.uncertainty_method = 'classification' - else: - raise ValueError('Cannot calibrate or evaluate uncertainty without selection of an uncertainty method.') + # Note: to get the invalid SMILES for your data, use get_invalid_smiles_from_file/get_invalid_smiles_from_list. + full_data, test_data, test_data_loader, full_to_valid_indices = load_data( + args, smiles + ) + if args.uncertainty_method is None and (args.calibration_method is not None or args.evaluation_methods is not None): + if args.dataset_type in ['classification', 'multiclass']: + args.uncertainty_method = 'classification' + else: + raise ValueError('Cannot calibrate or evaluate uncertainty without selection of an uncertainty method.') + + if calibrator is None and args.calibration_path is not None: + + calibration_data = get_data( + protein_records_path=args.protein_records_path, + path=args.calibration_path, + smiles_columns=args.smiles_columns, + target_columns=task_names, + args=args, + features_path=args.calibration_features_path, + features_generator=args.features_generator, + phase_features_path=args.calibration_phase_features_path, + atom_descriptors_path=args.calibration_atom_descriptors_path, + bond_descriptors_path=args.calibration_bond_descriptors_path, + max_data_size=args.max_data_size, + loss_function=args.loss_function, + ) - if calibrator is None and args.calibration_path is not None: + calibration_data_loader = MoleculeDataLoader( + dataset=calibration_data, + batch_size=args.batch_size, + num_workers=args.num_workers, + ) - calibration_data = get_data( - protein_records_path=args.protein_records_path, - path=args.calibration_path, - smiles_columns=args.smiles_columns, - target_columns=task_names, - args=args, - features_path=args.calibration_features_path, - features_generator=args.features_generator, - phase_features_path=args.calibration_phase_features_path, - atom_descriptors_path=args.calibration_atom_descriptors_path, - bond_descriptors_path=args.calibration_bond_descriptors_path, - max_data_size=args.max_data_size, - loss_function=args.loss_function, - ) + if isinstance(models, list) and isinstance(scalers, list): + calibration_models = models + calibration_scalers = scalers + else: + calibration_model_objects = load_model(args, generator=True) + calibration_models = calibration_model_objects[2] + calibration_scalers = calibration_model_objects[3] - calibration_data_loader = MoleculeDataLoader( - dataset=calibration_data, - batch_size=args.batch_size, - num_workers=args.num_workers, - ) + calibrator = build_uncertainty_calibrator( + calibration_method=args.calibration_method, + uncertainty_method=args.uncertainty_method, + interval_percentile=args.calibration_interval_percentile, + regression_calibrator_metric=args.regression_calibrator_metric, + calibration_data=calibration_data, + calibration_data_loader=calibration_data_loader, + models=calibration_models, + scalers=calibration_scalers, + num_models=num_models, + dataset_type=args.dataset_type, + loss_function=args.loss_function, + uncertainty_dropout_p=args.uncertainty_dropout_p, + dropout_sampling_size=args.dropout_sampling_size, + spectra_phase_mask=getattr(train_args, "spectra_phase_mask", None), + ) - if isinstance(models, List) and isinstance(scalers, List): - calibration_models = models - calibration_scalers = scalers + # Edge case if empty list of smiles is provided + if len(test_data) == 0: + preds = [None] * len(full_data) + unc = [None] * len(full_data) else: - calibration_model_objects = load_model(args, generator=True) - calibration_models = calibration_model_objects[2] - calibration_scalers = calibration_model_objects[3] - - calibrator = build_uncertainty_calibrator( - calibration_method=args.calibration_method, - uncertainty_method=args.uncertainty_method, - interval_percentile=args.calibration_interval_percentile, - regression_calibrator_metric=args.regression_calibrator_metric, - calibration_data=calibration_data, - calibration_data_loader=calibration_data_loader, - models=calibration_models, - scalers=calibration_scalers, - num_models=num_models, - dataset_type=args.dataset_type, - loss_function=args.loss_function, - uncertainty_dropout_p=args.uncertainty_dropout_p, - dropout_sampling_size=args.dropout_sampling_size, - spectra_phase_mask=getattr(train_args, "spectra_phase_mask", None), - ) - - # Edge case if empty list of smiles is provided - if len(test_data) == 0: - preds = [None] * len(full_data) - unc = [None] * len(full_data) - else: - preds, unc = predict_and_save( - args=args, - train_args=train_args, - test_data=test_data, - task_names=task_names, - num_tasks=num_tasks, - test_data_loader=test_data_loader, - full_data=full_data, - full_to_valid_indices=full_to_valid_indices, - models=models, - scalers=scalers, - num_models=num_models, - calibrator=calibrator, - return_invalid_smiles=return_invalid_smiles, - ) + preds, unc = predict_and_save( + args=args, + train_args=train_args, + test_data=test_data, + task_names=task_names, + num_tasks=num_tasks, + test_data_loader=test_data_loader, + full_data=full_data, + full_to_valid_indices=full_to_valid_indices, + models=models, + scalers=scalers, + num_models=num_models, + calibrator=calibrator, + return_invalid_smiles=return_invalid_smiles, + ) - if return_index_dict: - preds_dict = {} - unc_dict = {} - for i in range(len(full_data)): - if return_invalid_smiles: - preds_dict[i] = preds[i] - unc_dict[i] = unc[i] - else: - valid_index = full_to_valid_indices.get(i, None) - if valid_index is not None: - preds_dict[i] = preds[valid_index] - unc_dict[i] = unc[valid_index] - if return_uncertainty: - return preds_dict, unc_dict - else: + if return_index_dict: + preds_dict = {} + unc_dict = {} + for i in range(len(full_data)): + if return_invalid_smiles: + preds_dict[i] = preds[i] + unc_dict[i] = unc[i] + else: + valid_index = full_to_valid_indices.get(i, None) + if valid_index is not None: + preds_dict[i] = preds[valid_index] + unc_dict[i] = unc[valid_index] + if return_uncertainty: + return preds_dict, unc_dict return preds_dict - else: + if return_uncertainty: return preds, unc - else: - return preds + return preds def catpred_predict() -> None: diff --git a/catpred/train/run_training.py b/catpred/train/run_training.py index 1e52394..aa0b762 100644 --- a/catpred/train/run_training.py +++ b/catpred/train/run_training.py @@ -237,7 +237,7 @@ def run_training(args: TrainArgs, makedirs(save_dir) try: writer = SummaryWriter(log_dir=save_dir) - except: + except TypeError: writer = SummaryWriter(logdir=save_dir) # Load/build model diff --git a/catpred/utils.py b/catpred/utils.py index ecde744..3f70e62 100644 --- a/catpred/utils.py +++ b/catpred/utils.py @@ -134,7 +134,7 @@ def load_checkpoint( # Build model model = MoleculeModel(args) - print(model) + debug(model) model_state_dict = model.state_dict() # Skip missing parameters and parameters of mismatched size @@ -172,9 +172,6 @@ def load_checkpoint( model = model.to(args.device) return model - -import ipdb - def overwrite_state_dict( loaded_param_name: str, model_param_name: str, diff --git a/demo_run.py b/demo_run.py index 5b77e38..4bb9b57 100644 --- a/demo_run.py +++ b/demo_run.py @@ -52,18 +52,20 @@ def prepare_prediction_inputs(parameter, input_file_path): for i, smi in enumerate(smiles_list): try: mol = Chem.MolFromSmiles(smi) + if mol is None: + raise ValueError("RDKit could not parse SMILES") smi = Chem.MolToSmiles(mol) if parameter == 'kcat' and '.' in smi: smi = '.'.join(sorted(smi.split('.'))) smiles_list_new.append(smi) - except: - print(f'Invalid SMILES input in input row {i}') + except Exception as exc: + print(f'Invalid SMILES input in input row {i}: {exc}') print('Correct your input! Exiting..') return None valid_aas = set('ACDEFGHIKLMNPQRSTVWY') for i, seq in enumerate(seq_list): - if not set(seq).issubset(valid_aas): + if not isinstance(seq, str) or not set(seq).issubset(valid_aas): print(f'Invalid Enzyme sequence input in row {i}!') print('Correct your input! Exiting..') return None diff --git a/setup.cfg b/setup.cfg index 168dee3..d75df80 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,11 +51,6 @@ console_scripts = catpred_train=catpred.train:catpred_train catpred_predict=catpred.train:catpred_predict catpred_fingerprint=catpred.train:catpred_fingerprint - catpred_hyperopt=catpred.hyperparameter_optimization:catpred_hyperopt - catpred_interpret=catpred.interpret:catpred_interpret - catpred_web=catpred.web.run:catpred_web - sklearn_train=catpred.sklearn_train:sklearn_train - sklearn_predict=catpred.sklearn_predict:sklearn_predict [options.extras_require] test = pytest>=6.2.2; parameterized>=0.8.1 diff --git a/setup.py b/setup.py index 5de8c5d..e7d62bf 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "console_scripts": [ "catpred_train=catpred.train:catpred_train", "catpred_predict=catpred.train:catpred_predict", + "catpred_fingerprint=catpred.train:catpred_fingerprint", ] }, install_requires=[ From 89d686662f331ac6aee15c2300ac8bbf46b11101 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 14:29:42 -0500 Subject: [PATCH 04/34] Phase 1: extract reusable inference service core --- catpred/__init__.py | 1 + catpred/data/esm_utils.py | 2 +- catpred/inference/__init__.py | 17 +++ catpred/inference/service.py | 240 ++++++++++++++++++++++++++++++++ demo_run.py | 250 +++++++--------------------------- 5 files changed, 306 insertions(+), 204 deletions(-) create mode 100644 catpred/inference/__init__.py create mode 100644 catpred/inference/service.py diff --git a/catpred/__init__.py b/catpred/__init__.py index e27ba71..c4b3702 100644 --- a/catpred/__init__.py +++ b/catpred/__init__.py @@ -3,6 +3,7 @@ import catpred.models import catpred.train import catpred.uncertainty +import catpred.inference import catpred.args import catpred.constants diff --git a/catpred/data/esm_utils.py b/catpred/data/esm_utils.py index 359613a..ff527e9 100644 --- a/catpred/data/esm_utils.py +++ b/catpred/data/esm_utils.py @@ -22,7 +22,7 @@ def _env_flag(name: str, default: str = "0") -> bool: return str(raw).strip().lower() in {"1", "true", "yes", "on"} -PROTEIN_EMBED_USE_CPU = _env_flag("PROTEIN_EMBED_USE_CPU", "0") +PROTEIN_EMBED_USE_CPU = _env_flag("PROTEIN_EMBED_USE_CPU", "0") or not torch.cuda.is_available() if PROTEIN_EMBED_USE_CPU: print('calculating protein embed only on cpu') diff --git a/catpred/inference/__init__.py b/catpred/inference/__init__.py new file mode 100644 index 0000000..c4b814e --- /dev/null +++ b/catpred/inference/__init__.py @@ -0,0 +1,17 @@ +from .service import ( + PredictionRequest, + PreparedInputPaths, + prepare_prediction_inputs, + run_raw_prediction, + postprocess_predictions, + run_prediction_pipeline, +) + +__all__ = [ + "PredictionRequest", + "PreparedInputPaths", + "prepare_prediction_inputs", + "run_raw_prediction", + "postprocess_predictions", + "run_prediction_pipeline", +] diff --git a/catpred/inference/service.py b/catpred/inference/service.py new file mode 100644 index 0000000..99f768d --- /dev/null +++ b/catpred/inference/service.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import os +import subprocess +from typing import Tuple + +import numpy as np +import pandas as pd +from rdkit import Chem + + +_VALID_PARAMETERS = {"kcat", "km", "ki"} +_TARGET_COLUMNS = { + "kcat": ("log10kcat_max", "s^(-1)"), + "km": ("log10km_mean", "mM"), + "ki": ("log10ki_mean", "mM"), +} +_VALID_AAS = set("ACDEFGHIKLMNPQRSTVWY") + + +@dataclass(frozen=True) +class PredictionRequest: + parameter: str + input_file: str + checkpoint_dir: str + use_gpu: bool = False + repo_root: str | None = None + python_executable: str = "python" + + +@dataclass(frozen=True) +class PreparedInputPaths: + input_csv: str + records_file: str + output_csv: str + + +def _validate_parameter(parameter: str) -> str: + parameter = parameter.lower() + if parameter not in _VALID_PARAMETERS: + raise ValueError(f"Unsupported parameter '{parameter}'. Must be one of: kcat, km, ki.") + return parameter + + +def _resolve_repo_root(repo_root: str | None) -> Path: + root = Path(repo_root) if repo_root else Path.cwd() + root = root.resolve() + if not root.exists(): + raise FileNotFoundError(f'Repository root does not exist: "{root}"') + return root + + +def _resolve_input_path(input_file: str, repo_root: Path) -> Path: + path = Path(input_file) + if not path.is_absolute(): + path = (repo_root / path).resolve() + if not path.exists(): + raise FileNotFoundError(f'Input CSV not found: "{path}"') + return path + + +def _validate_and_prepare_dataframe(parameter: str, df: pd.DataFrame, input_csv: Path) -> pd.DataFrame: + required_columns = {"SMILES", "sequence", "pdbpath"} + missing = required_columns.difference(df.columns) + if missing: + raise ValueError( + f'Missing required column(s) in "{input_csv}": {", ".join(sorted(missing))}.' + ) + + conflicting_pdbpaths = ( + df.groupby("pdbpath")["sequence"] + .nunique(dropna=False) + .loc[lambda value: value > 1] + ) + if len(conflicting_pdbpaths) > 0: + preview = ", ".join(conflicting_pdbpaths.index.astype(str).tolist()[:5]) + raise ValueError( + "Found pdbpath values mapped to multiple sequences. " + f"Each unique sequence must have a unique pdbpath. Examples: {preview}" + ) + + canonical_smiles = [] + for i, raw_smiles in enumerate(df["SMILES"]): + mol = Chem.MolFromSmiles(raw_smiles) + if mol is None: + raise ValueError(f'Invalid SMILES input in row {i + 2}: "{raw_smiles}"') + smiles = Chem.MolToSmiles(mol) + if parameter == "kcat" and "." in smiles: + smiles = ".".join(sorted(smiles.split("."))) + canonical_smiles.append(smiles) + + for i, sequence in enumerate(df["sequence"]): + if not isinstance(sequence, str) or not set(sequence).issubset(_VALID_AAS): + raise ValueError(f'Invalid enzyme sequence in row {i + 2}: "{sequence}"') + + prepared = df.copy() + prepared["SMILES"] = canonical_smiles + return prepared + + +def prepare_prediction_inputs(parameter: str, input_file: str, repo_root: str | None = None) -> PreparedInputPaths: + parameter = _validate_parameter(parameter) + root = _resolve_repo_root(repo_root) + input_csv = _resolve_input_path(input_file, root) + + df = pd.read_csv(input_csv) + prepared_df = _validate_and_prepare_dataframe(parameter, df, input_csv) + + input_base = input_csv.with_suffix("") + prepared_input_csv = Path(f"{input_base}_input.csv") + prepared_df.to_csv(prepared_input_csv, index=False) + + test_prefix = prepared_input_csv.with_suffix("") + records_file = Path(f"{test_prefix}.json.gz") + output_csv = Path(f"{test_prefix}_output.csv") + + return PreparedInputPaths( + input_csv=str(prepared_input_csv), + records_file=str(records_file), + output_csv=str(output_csv), + ) + + +def _build_prediction_commands( + python_executable: str, + repo_root: Path, + paths: PreparedInputPaths, + checkpoint_dir: str, +) -> Tuple[list[str], list[str]]: + create_records_cmd = [ + python_executable, + str(repo_root / "scripts" / "create_pdbrecords.py"), + "--data_file", + paths.input_csv, + "--out_file", + paths.records_file, + ] + predict_cmd = [ + python_executable, + str(repo_root / "predict.py"), + "--test_path", + paths.input_csv, + "--preds_path", + paths.output_csv, + "--checkpoint_dir", + checkpoint_dir, + "--uncertainty_method", + "mve", + "--smiles_column", + "SMILES", + "--individual_ensemble_predictions", + "--protein_records_path", + paths.records_file, + ] + return create_records_cmd, predict_cmd + + +def run_raw_prediction(request: PredictionRequest, paths: PreparedInputPaths) -> None: + root = _resolve_repo_root(request.repo_root) + create_records_cmd, predict_cmd = _build_prediction_commands( + python_executable=request.python_executable, + repo_root=root, + paths=paths, + checkpoint_dir=request.checkpoint_dir, + ) + + env = os.environ.copy() + env["PROTEIN_EMBED_USE_CPU"] = "0" if request.use_gpu else "1" + + subprocess.run(create_records_cmd, cwd=str(root), env=env, check=True) + subprocess.run(predict_cmd, cwd=str(root), env=env, check=True) + + if not os.path.exists(paths.output_csv): + raise FileNotFoundError(f'Prediction output file was not generated: "{paths.output_csv}"') + + +def postprocess_predictions(parameter: str, output_csv: str) -> pd.DataFrame: + parameter = _validate_parameter(parameter) + target_col, unit = _TARGET_COLUMNS[parameter] + unc_col = f"{target_col}_mve_uncal_var" + + df = pd.read_csv(output_csv) + missing_cols = [col for col in [target_col, unc_col] if col not in df.columns] + if missing_cols: + raise ValueError( + f'Prediction output is missing required column(s): {", ".join(missing_cols)}' + ) + + pred_col, pred_logcol, pred_sd_tot, pred_sd_alea, pred_sd_epi = [], [], [], [], [] + + for _, row in df.iterrows(): + model_cols = [col for col in row.index if col.startswith(target_col) and "model_" in col] + + unc = row[unc_col] + prediction_log = row[target_col] + prediction_linear = np.power(10, prediction_log) + + if model_cols: + model_outs = np.array([row[col] for col in model_cols]) + epi_unc_var = np.var(model_outs) + else: + epi_unc_var = 0.0 + + alea_unc_var = max(unc - epi_unc_var, 0.0) + epi_unc = np.sqrt(epi_unc_var) + alea_unc = np.sqrt(alea_unc_var) + total_unc = np.sqrt(max(unc, 0.0)) + + pred_col.append(prediction_linear) + pred_logcol.append(prediction_log) + pred_sd_tot.append(total_unc) + pred_sd_alea.append(alea_unc) + pred_sd_epi.append(epi_unc) + + df[f"Prediction_({unit})"] = pred_col + df["Prediction_log10"] = pred_logcol + df["SD_total"] = pred_sd_tot + df["SD_aleatoric"] = pred_sd_alea + df["SD_epistemic"] = pred_sd_epi + return df + + +def run_prediction_pipeline(request: PredictionRequest, results_dir: str = "../results") -> str: + parameter = _validate_parameter(request.parameter) + paths = prepare_prediction_inputs(parameter, request.input_file, request.repo_root) + run_raw_prediction(request, paths) + + output_final = postprocess_predictions(parameter, paths.output_csv) + + results_path = Path(results_dir) + if not results_path.is_absolute(): + results_path = (_resolve_repo_root(request.repo_root) / results_path).resolve() + results_path.mkdir(parents=True, exist_ok=True) + + out_name = Path(paths.output_csv).name + final_output = results_path / out_name + output_final.to_csv(final_output, index=False) + return str(final_output) diff --git a/demo_run.py b/demo_run.py index 4bb9b57..e6c3d08 100644 --- a/demo_run.py +++ b/demo_run.py @@ -1,220 +1,64 @@ """ -Enzyme Kinetics Parameter Prediction Script - -This script predicts enzyme kinetics parameters (kcat, Km, or Ki) using a pre-trained model. -It processes input data, generates predictions, and saves the results. +Enzyme kinetics parameter prediction CLI for local/demo usage. Usage: - python demo_run.py --parameter --input_file --checkpoint_dir [--use_gpu] - -Dependencies: - pandas, numpy, rdkit, IPython, argparse + python demo_run.py --parameter --input_file --checkpoint_dir [--use_gpu] """ -import time -import os -import subprocess -import pandas as pd -import numpy as np -from IPython.display import Image, display -from rdkit import Chem -from IPython.display import display, Latex, Math import argparse +import subprocess +from catpred.inference import PredictionRequest, run_prediction_pipeline -def prepare_prediction_inputs(parameter, input_file_path): - df = pd.read_csv(input_file_path) - required_columns = {"SMILES", "sequence", "pdbpath"} - missing_columns = required_columns.difference(df.columns) - if missing_columns: - print( - f'Missing required column(s) in input file: {", ".join(sorted(missing_columns))}.' - ) - return None - conflicting_pdbpaths = ( - df.groupby("pdbpath")["sequence"] - .nunique(dropna=False) - .loc[lambda value: value > 1] +def main(args: argparse.Namespace) -> int: + request = PredictionRequest( + parameter=args.parameter.lower(), + input_file=args.input_file, + checkpoint_dir=args.checkpoint_dir, + use_gpu=args.use_gpu, + repo_root=".", ) - if len(conflicting_pdbpaths) > 0: - preview = ", ".join(conflicting_pdbpaths.index.astype(str).tolist()[:5]) - print( - "Found pdbpath values mapped to multiple sequences. " - f"Each unique sequence must have a unique pdbpath. Examples: {preview}" - ) - return None - - smiles_list = df.SMILES - seq_list = df.sequence - smiles_list_new = [] - - for i, smi in enumerate(smiles_list): - try: - mol = Chem.MolFromSmiles(smi) - if mol is None: - raise ValueError("RDKit could not parse SMILES") - smi = Chem.MolToSmiles(mol) - if parameter == 'kcat' and '.' in smi: - smi = '.'.join(sorted(smi.split('.'))) - smiles_list_new.append(smi) - except Exception as exc: - print(f'Invalid SMILES input in input row {i}: {exc}') - print('Correct your input! Exiting..') - return None - - valid_aas = set('ACDEFGHIKLMNPQRSTVWY') - for i, seq in enumerate(seq_list): - if not isinstance(seq, str) or not set(seq).issubset(valid_aas): - print(f'Invalid Enzyme sequence input in row {i}!') - print('Correct your input! Exiting..') - return None - - input_file_base, _ = os.path.splitext(input_file_path) - input_file_new_path = f'{input_file_base}_input.csv' - df['SMILES'] = smiles_list_new - df.to_csv(input_file_new_path, index=False) - - test_file_prefix = input_file_new_path[:-4] - return { - "input_csv": input_file_new_path, - "records_file": f"{test_file_prefix}.json.gz", - "output_csv": f"{test_file_prefix}_output.csv", - } - -def get_predictions(parameter, outfile): - """ - Process prediction results and add additional metrics. - - Args: - parameter (str): The kinetics parameter that was predicted. - outfile (str): Path to the output CSV file from the prediction. - - Returns: - pandas.DataFrame: Processed predictions with additional metrics. - """ - df = pd.read_csv(outfile) - pred_col, pred_logcol, pred_sd_totcol, pred_sd_aleacol, pred_sd_epicol = [], [], [], [], [] - - unit = 'mM' - if parameter == 'kcat': - target_col = 'log10kcat_max' - unit = 's^(-1)' - elif parameter == 'km': - target_col = 'log10km_mean' - else: - target_col = 'log10ki_mean' - - unc_col = f'{target_col}_mve_uncal_var' - - missing_cols = [col for col in [target_col, unc_col] if col not in df.columns] - if missing_cols: - raise ValueError( - f'Prediction output is missing required column(s): {", ".join(missing_cols)}' - ) - - for _, row in df.iterrows(): - model_cols = [col for col in row.index if col.startswith(target_col) and 'model_' in col] - - unc = row[unc_col] - prediction = row[target_col] - prediction_linear = np.power(10, prediction) - if model_cols: - model_outs = np.array([row[col] for col in model_cols]) - epi_unc = np.var(model_outs) - else: - epi_unc = 0.0 - alea_unc = max(unc - epi_unc, 0.0) - epi_unc = np.sqrt(epi_unc) - alea_unc = np.sqrt(alea_unc) - unc = np.sqrt(max(unc, 0.0)) + print("Predicting.. This will take a while..") + try: + final_output = run_prediction_pipeline(request=request, results_dir="../results") + except (ValueError, FileNotFoundError) as exc: + print(str(exc)) + return 1 + except subprocess.CalledProcessError as exc: + print(f"Prediction command failed with exit code {exc.returncode}.") + return exc.returncode if exc.returncode is not None else 1 - pred_col.append(prediction_linear) - pred_logcol.append(prediction) - pred_sd_totcol.append(unc) - pred_sd_aleacol.append(alea_unc) - pred_sd_epicol.append(epi_unc) + print(f"Output saved to {final_output}") + return 0 - df[f'Prediction_({unit})'] = pred_col - df['Prediction_log10'] = pred_logcol - df['SD_total'] = pred_sd_totcol - df['SD_aleatoric'] = pred_sd_aleacol - df['SD_epistemic'] = pred_sd_epicol - - return df - -def main(args): - run_paths = prepare_prediction_inputs(args.parameter, args.input_file) - if run_paths is None: - return - - outfile = run_paths["output_csv"] - print('Predicting.. This will take a while..') - - env = os.environ.copy() - env["PROTEIN_EMBED_USE_CPU"] = "0" if args.use_gpu else "1" - - create_records_cmd = [ - "python", - "./scripts/create_pdbrecords.py", - "--data_file", - run_paths["input_csv"], - "--out_file", - run_paths["records_file"], - ] - predict_cmd = [ - "python", - "predict.py", - "--test_path", - run_paths["input_csv"], - "--preds_path", - outfile, - "--checkpoint_dir", - args.checkpoint_dir, - "--uncertainty_method", - "mve", - "--smiles_column", - "SMILES", - "--individual_ensemble_predictions", - "--protein_records_path", - run_paths["records_file"], - ] - - create_records_result = subprocess.run(create_records_cmd, env=env) - if create_records_result.returncode != 0: - print( - f"Protein record generation failed with exit code {create_records_result.returncode}." - ) - return - - predict_result = subprocess.run(predict_cmd, env=env) - if predict_result.returncode != 0: - print(f"Prediction command failed with exit code {predict_result.returncode}.") - return - - if not os.path.exists(outfile): - print(f'Prediction output file was not generated: {outfile}') - return - - output_final = get_predictions(args.parameter, outfile) - filename = outfile.split('/')[-1] - os.makedirs('../results', exist_ok=True) - output_final.to_csv(f'../results/{filename}', index=False) - print('Output saved to results/', filename) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Predict enzyme kinetics parameters.") - parser.add_argument("--parameter", type=str, choices=["kcat", "km", "ki"], required=True, - help="Kinetics parameter to predict (kcat, km, or ki)") - parser.add_argument("--input_file", type=str, required=True, - help="Path to the input CSV file") - parser.add_argument("--use_gpu", action="store_true", - help="Use GPU for prediction (default is CPU)") - parser.add_argument("--checkpoint_dir", type=str, required=True, - help="Path to the model checkpoint directory") - - args = parser.parse_args() - args.parameter = args.parameter.lower() + parser.add_argument( + "--parameter", + type=str, + choices=["kcat", "km", "ki"], + required=True, + help="Kinetics parameter to predict (kcat, km, or ki)", + ) + parser.add_argument( + "--input_file", + type=str, + required=True, + help="Path to the input CSV file", + ) + parser.add_argument( + "--use_gpu", + action="store_true", + help="Use GPU for prediction (default is CPU)", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + required=True, + help="Path to the model checkpoint directory", + ) - main(args) + raise SystemExit(main(parser.parse_args())) From 968b54918a0265c1acba079f17a7ccd45570c461 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 16:04:12 -0500 Subject: [PATCH 05/34] Phase 2: add web API and optional modal backend routing --- .gitignore | 1 + README.md | 46 +++++ catpred/__init__.py | 39 ++-- catpred/inference/__init__.py | 40 ++++- catpred/inference/backends.py | 325 ++++++++++++++++++++++++++++++++++ catpred/inference/service.py | 19 +- catpred/inference/types.py | 20 +++ catpred/web/__init__.py | 3 + catpred/web/app.py | 160 +++++++++++++++++ catpred/web/run.py | 28 +++ setup.cfg | 4 + setup.py | 7 + 12 files changed, 655 insertions(+), 37 deletions(-) create mode 100644 catpred/inference/backends.py create mode 100644 catpred/inference/types.py create mode 100644 catpred/web/__init__.py create mode 100644 catpred/web/app.py create mode 100644 catpred/web/run.py diff --git a/.gitignore b/.gitignore index a3b307a..4edc983 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ __pycache__/ *.pyd .DS_Store *.egg-info/ +.venv/ .ipynb_checkpoints/ # Local validation artifacts diff --git a/README.md b/README.md index 2a92ed8..0e9d6d9 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ - [System Requirements](#requirements) - [Installation](#installing) - [Prediction](#predict) + - [Web API (Optional)](#web-api-optional) - [Reproducibility](#reproduce) - [Acknowledgements](#acknw) - [License](#license) @@ -100,6 +101,51 @@ CatPred currently expects one sequence per row. Multi-protein complexes (e.g., h For released benchmark datasets, the number of entries with 3D structure can be smaller than the total sequence/substrate pairs; 3D-derived artifacts are available only for the subset with valid structure mapping. +### ๐ŸŒ Web API (Optional) + +CatPred also provides an optional FastAPI service for prediction workflows. + +Install web dependencies: + +```bash +pip install -e ".[web]" +``` + +Run the API: + +```bash +catpred_web --host 0.0.0.0 --port 8000 +``` + +Endpoints: +- `GET /health` โ€” liveness check. +- `GET /ready` โ€” backend configuration/readiness. +- `POST /predict` โ€” run inference. + +Minimal `POST /predict` example for local inference: + +```bash +curl -X POST http://127.0.0.1:8000/predict \ + -H "Content-Type: application/json" \ + -d '{ + "parameter": "kcat", + "checkpoint_dir": "../data/pretrained/reproduce_checkpoints/kcat", + "input_file": "./demo/batch_kcat_pred.csv", + "backend": "local" + }' +``` + +You can keep local inference as default and optionally enable Modal as another backend: + +```bash +export CATPRED_DEFAULT_BACKEND=local +export CATPRED_MODAL_ENDPOINT="https://" +export CATPRED_MODAL_TOKEN="" +export CATPRED_MODAL_FALLBACK_TO_LOCAL=1 +``` + +Use `"backend": "modal"` in `/predict` requests to route through Modal. If fallback is enabled (env var above or request field `fallback_to_local`), failed modal requests can automatically reroute to local inference. + ### ๐Ÿงช Fine-Tuning On Custom Data You can fine-tune CatPred on your own regression targets using `train.py`. diff --git a/catpred/__init__.py b/catpred/__init__.py index c4b3702..92cfc40 100644 --- a/catpred/__init__.py +++ b/catpred/__init__.py @@ -1,14 +1,29 @@ -import catpred.data -import catpred.features -import catpred.models -import catpred.train -import catpred.uncertainty -import catpred.inference - -import catpred.args -import catpred.constants -import catpred.nn_utils -import catpred.utils -import catpred.rdkit +from __future__ import annotations + +import importlib __version__ = "0.0.1" + +_LAZY_SUBMODULES = { + "args", + "constants", + "data", + "features", + "inference", + "models", + "nn_utils", + "rdkit", + "train", + "uncertainty", + "utils", +} + +__all__ = sorted(_LAZY_SUBMODULES) + ["__version__"] + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + module = importlib.import_module(f"catpred.{name}") + globals()[name] = module + return module + raise AttributeError(f"module 'catpred' has no attribute '{name}'") diff --git a/catpred/inference/__init__.py b/catpred/inference/__init__.py index c4b814e..103d8ed 100644 --- a/catpred/inference/__init__.py +++ b/catpred/inference/__init__.py @@ -1,13 +1,24 @@ -from .service import ( - PredictionRequest, - PreparedInputPaths, - prepare_prediction_inputs, - run_raw_prediction, - postprocess_predictions, - run_prediction_pipeline, +from __future__ import annotations + +from .backends import ( + BackendPredictionResult, + BackendRouterSettings, + InferenceBackend, + InferenceBackendError, + InferenceBackendRouter, + LocalInferenceBackend, + ModalHTTPInferenceBackend, ) +from .types import PreparedInputPaths, PredictionRequest __all__ = [ + "BackendPredictionResult", + "BackendRouterSettings", + "InferenceBackend", + "InferenceBackendError", + "InferenceBackendRouter", + "LocalInferenceBackend", + "ModalHTTPInferenceBackend", "PredictionRequest", "PreparedInputPaths", "prepare_prediction_inputs", @@ -15,3 +26,18 @@ "postprocess_predictions", "run_prediction_pipeline", ] + + +def __getattr__(name: str): + if name in { + "prepare_prediction_inputs", + "run_raw_prediction", + "postprocess_predictions", + "run_prediction_pipeline", + }: + from . import service + + value = getattr(service, name) + globals()[name] = value + return value + raise AttributeError(f"module 'catpred.inference' has no attribute '{name}'") diff --git a/catpred/inference/backends.py b/catpred/inference/backends.py new file mode 100644 index 0000000..916dfe9 --- /dev/null +++ b/catpred/inference/backends.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from pathlib import Path +import csv +import json +import os +from typing import Any +from urllib import error, request as urllib_request + +import pandas as pd + +from .types import PredictionRequest + + +class InferenceBackendError(RuntimeError): + """Raised when a backend cannot satisfy an inference request.""" + + +@dataclass(frozen=True) +class BackendPredictionResult: + backend_name: str + output_file: str + metadata: dict[str, Any] = field(default_factory=dict) + + +class InferenceBackend: + name = "base" + + def readiness(self) -> dict[str, Any]: + raise NotImplementedError + + def predict(self, request_obj: PredictionRequest, results_dir: str) -> BackendPredictionResult: + raise NotImplementedError + + +class LocalInferenceBackend(InferenceBackend): + name = "local" + + def __init__(self, repo_root: str | None = None) -> None: + self._repo_root = repo_root + + def readiness(self) -> dict[str, Any]: + root = Path(self._repo_root) if self._repo_root else Path.cwd() + root = root.resolve() + required = [ + root / "predict.py", + root / "scripts" / "create_pdbrecords.py", + ] + missing = [str(path) for path in required if not path.exists()] + return { + "configured": True, + "ready": len(missing) == 0, + "missing_files": missing, + "repo_root": str(root), + } + + def predict(self, request_obj: PredictionRequest, results_dir: str) -> BackendPredictionResult: + from .service import run_prediction_pipeline + + effective_request = request_obj + if not request_obj.repo_root and self._repo_root: + effective_request = replace(request_obj, repo_root=self._repo_root) + + output_file = run_prediction_pipeline(effective_request, results_dir=results_dir) + return BackendPredictionResult(backend_name=self.name, output_file=output_file) + + +class ModalHTTPInferenceBackend(InferenceBackend): + name = "modal" + + def __init__( + self, + endpoint: str | None, + token: str | None = None, + timeout_seconds: int = 900, + repo_root: str | None = None, + ) -> None: + self._endpoint = endpoint + self._token = token + self._timeout_seconds = timeout_seconds + self._repo_root = repo_root + + def readiness(self) -> dict[str, Any]: + configured = bool(self._endpoint) + return { + "configured": configured, + "ready": configured, + "endpoint": self._endpoint, + "timeout_seconds": self._timeout_seconds, + } + + def _resolve_input_file(self, request_obj: PredictionRequest) -> Path: + input_path = Path(request_obj.input_file) + if not input_path.is_absolute(): + root = Path(request_obj.repo_root or self._repo_root or Path.cwd()).resolve() + input_path = (root / input_path).resolve() + if not input_path.exists(): + raise FileNotFoundError(f'Input CSV not found for modal backend: "{input_path}"') + return input_path + + def _resolve_results_dir(self, results_dir: str, request_obj: PredictionRequest) -> Path: + out_dir = Path(results_dir) + if not out_dir.is_absolute(): + root = Path(request_obj.repo_root or self._repo_root or Path.cwd()).resolve() + out_dir = (root / out_dir).resolve() + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir + + @staticmethod + def _load_rows(csv_path: Path) -> list[dict[str, Any]]: + with csv_path.open(newline="", encoding="utf-8") as handle: + return list(csv.DictReader(handle)) + + def _post_json(self, payload: dict[str, Any]) -> dict[str, Any]: + if not self._endpoint: + raise InferenceBackendError( + "Modal backend is not configured. Set CATPRED_MODAL_ENDPOINT." + ) + + encoded_payload = json.dumps(payload).encode("utf-8") + headers = {"Content-Type": "application/json"} + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + + req = urllib_request.Request( + url=self._endpoint, + method="POST", + data=encoded_payload, + headers=headers, + ) + try: + with urllib_request.urlopen(req, timeout=self._timeout_seconds) as resp: + raw = resp.read() + except error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + raise InferenceBackendError( + f"Modal backend request failed with HTTP {exc.code}: {body}" + ) from exc + except error.URLError as exc: + raise InferenceBackendError( + f"Modal backend request failed: {exc.reason}" + ) from exc + + try: + decoded = json.loads(raw.decode("utf-8")) + except json.JSONDecodeError as exc: + raise InferenceBackendError("Modal backend returned non-JSON output.") from exc + + if not isinstance(decoded, dict): + raise InferenceBackendError( + "Modal backend response must be a JSON object." + ) + return decoded + + def _materialize_output( + self, + response: dict[str, Any], + input_path: Path, + results_dir: str, + request_obj: PredictionRequest, + ) -> BackendPredictionResult: + output_file = response.get("output_file") + if isinstance(output_file, str) and output_file: + resolved = Path(output_file).resolve() + if resolved.exists(): + return BackendPredictionResult( + backend_name=self.name, + output_file=str(resolved), + metadata={"endpoint": self._endpoint, "mode": "output_file"}, + ) + + output_rows = response.get("output_rows") + if isinstance(output_rows, list): + out_dir = self._resolve_results_dir(results_dir, request_obj) + out_name = response.get("output_filename") + if not isinstance(out_name, str) or not out_name: + out_name = f"{input_path.stem}_modal_output.csv" + if not out_name.endswith(".csv"): + out_name = f"{out_name}.csv" + + final_output = out_dir / out_name + pd.DataFrame(output_rows).to_csv(final_output, index=False) + return BackendPredictionResult( + backend_name=self.name, + output_file=str(final_output), + metadata={"endpoint": self._endpoint, "mode": "output_rows"}, + ) + + output_csv_text = response.get("output_csv_text") + if isinstance(output_csv_text, str) and output_csv_text: + out_dir = self._resolve_results_dir(results_dir, request_obj) + out_name = response.get("output_filename") + if not isinstance(out_name, str) or not out_name: + out_name = f"{input_path.stem}_modal_output.csv" + if not out_name.endswith(".csv"): + out_name = f"{out_name}.csv" + final_output = out_dir / out_name + final_output.write_text(output_csv_text, encoding="utf-8") + return BackendPredictionResult( + backend_name=self.name, + output_file=str(final_output), + metadata={"endpoint": self._endpoint, "mode": "output_csv_text"}, + ) + + raise InferenceBackendError( + "Modal backend response must include one of: output_file, output_rows, output_csv_text." + ) + + def predict(self, request_obj: PredictionRequest, results_dir: str) -> BackendPredictionResult: + input_path = self._resolve_input_file(request_obj) + payload = { + "parameter": request_obj.parameter, + "checkpoint_dir": request_obj.checkpoint_dir, + "use_gpu": request_obj.use_gpu, + "input_rows": self._load_rows(input_path), + "input_filename": input_path.name, + } + response = self._post_json(payload) + return self._materialize_output( + response=response, + input_path=input_path, + results_dir=results_dir, + request_obj=request_obj, + ) + + +def _env_flag(name: str, default: bool = False) -> bool: + value = os.environ.get(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "y", "on"} + + +@dataclass(frozen=True) +class BackendRouterSettings: + default_backend: str = "local" + modal_endpoint: str | None = None + modal_token: str | None = None + modal_timeout_seconds: int = 900 + repo_root: str | None = None + + @classmethod + def from_env(cls) -> "BackendRouterSettings": + timeout = int(os.environ.get("CATPRED_MODAL_TIMEOUT_SECONDS", "900")) + return cls( + default_backend=os.environ.get("CATPRED_DEFAULT_BACKEND", "local").lower(), + modal_endpoint=os.environ.get("CATPRED_MODAL_ENDPOINT"), + modal_token=os.environ.get("CATPRED_MODAL_TOKEN"), + modal_timeout_seconds=timeout, + repo_root=os.environ.get("CATPRED_REPO_ROOT"), + ) + + +class InferenceBackendRouter: + def __init__(self, settings: BackendRouterSettings | None = None) -> None: + self.settings = settings or BackendRouterSettings.from_env() + self._backends: dict[str, InferenceBackend] = { + "local": LocalInferenceBackend(repo_root=self.settings.repo_root), + "modal": ModalHTTPInferenceBackend( + endpoint=self.settings.modal_endpoint, + token=self.settings.modal_token, + timeout_seconds=self.settings.modal_timeout_seconds, + repo_root=self.settings.repo_root, + ), + } + if self.settings.default_backend not in self._backends: + raise ValueError( + f"Unsupported CATPRED_DEFAULT_BACKEND '{self.settings.default_backend}'. " + "Use one of: local, modal." + ) + + def available_backends(self) -> list[str]: + return sorted(self._backends.keys()) + + def resolve_backend(self, backend_name: str | None = None) -> InferenceBackend: + selected = (backend_name or self.settings.default_backend).lower() + if selected not in self._backends: + raise ValueError( + f"Unsupported backend '{selected}'. Use one of: {', '.join(self.available_backends())}." + ) + + backend = self._backends[selected] + state = backend.readiness() + if not state.get("configured", False): + raise InferenceBackendError( + f"Backend '{selected}' is not configured. Readiness: {state}" + ) + return backend + + def predict( + self, + request_obj: PredictionRequest, + results_dir: str, + backend_name: str | None = None, + fallback_to_local: bool = False, + ) -> BackendPredictionResult: + selected_name = (backend_name or self.settings.default_backend).lower() + backend: InferenceBackend | None = None + try: + backend = self.resolve_backend(selected_name) + return backend.predict(request_obj=request_obj, results_dir=results_dir) + except Exception as exc: + if fallback_to_local and selected_name != "local": + local_backend = self._backends["local"] + local_result = local_backend.predict(request_obj=request_obj, results_dir=results_dir) + metadata = dict(local_result.metadata) + metadata["fallback_from"] = selected_name + metadata["fallback_reason"] = str(exc) + return BackendPredictionResult( + backend_name=local_result.backend_name, + output_file=local_result.output_file, + metadata=metadata, + ) + raise + + def readiness(self) -> dict[str, Any]: + backends = {name: backend.readiness() for name, backend in self._backends.items()} + default_state = backends[self.settings.default_backend] + return { + "default_backend": self.settings.default_backend, + "ready": bool(default_state.get("ready", False)), + "backends": backends, + "fallback_to_local_enabled": _env_flag("CATPRED_MODAL_FALLBACK_TO_LOCAL", default=False), + } diff --git a/catpred/inference/service.py b/catpred/inference/service.py index 99f768d..ae8a875 100644 --- a/catpred/inference/service.py +++ b/catpred/inference/service.py @@ -1,6 +1,5 @@ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path import os import subprocess @@ -10,6 +9,7 @@ import pandas as pd from rdkit import Chem +from .types import PreparedInputPaths, PredictionRequest _VALID_PARAMETERS = {"kcat", "km", "ki"} _TARGET_COLUMNS = { @@ -20,23 +20,6 @@ _VALID_AAS = set("ACDEFGHIKLMNPQRSTVWY") -@dataclass(frozen=True) -class PredictionRequest: - parameter: str - input_file: str - checkpoint_dir: str - use_gpu: bool = False - repo_root: str | None = None - python_executable: str = "python" - - -@dataclass(frozen=True) -class PreparedInputPaths: - input_csv: str - records_file: str - output_csv: str - - def _validate_parameter(parameter: str) -> str: parameter = parameter.lower() if parameter not in _VALID_PARAMETERS: diff --git a/catpred/inference/types.py b/catpred/inference/types.py new file mode 100644 index 0000000..51c4857 --- /dev/null +++ b/catpred/inference/types.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class PredictionRequest: + parameter: str + input_file: str + checkpoint_dir: str + use_gpu: bool = False + repo_root: str | None = None + python_executable: str = "python" + + +@dataclass(frozen=True) +class PreparedInputPaths: + input_csv: str + records_file: str + output_csv: str diff --git a/catpred/web/__init__.py b/catpred/web/__init__.py new file mode 100644 index 0000000..e4dd907 --- /dev/null +++ b/catpred/web/__init__.py @@ -0,0 +1,3 @@ +"""Web API interface for CatPred inference.""" + +__all__ = ["app", "run"] diff --git a/catpred/web/app.py b/catpred/web/app.py new file mode 100644 index 0000000..01cbc0a --- /dev/null +++ b/catpred/web/app.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from pathlib import Path +import os +import subprocess +import tempfile +from typing import Any, Optional + +import pandas as pd + +try: + from fastapi import FastAPI, HTTPException + from pydantic import BaseModel, Field, root_validator +except ImportError as exc: # pragma: no cover - import guard for optional dependency + raise ImportError( + "catpred.web requires optional dependencies. Install with `pip install .[web]`." + ) from exc + +from catpred.inference import ( + InferenceBackendError, + InferenceBackendRouter, + PredictionRequest, +) + + +def _env_flag(name: str, default: bool = False) -> bool: + value = os.environ.get(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "y", "on"} + + +class PredictRequest(BaseModel): + parameter: str = Field(..., description="One of: kcat, km, ki") + checkpoint_dir: str = Field(..., description="Path to the checkpoint directory") + input_file: Optional[str] = Field(default=None, description="Path to input CSV") + input_rows: Optional[list[dict[str, Any]]] = Field( + default=None, + description="Optional in-request CSV rows. Use this for remote clients.", + ) + use_gpu: bool = False + backend: Optional[str] = Field(default=None, description="Override backend (local|modal)") + fallback_to_local: Optional[bool] = Field(default=None, description="Fallback if modal fails") + results_dir: str = Field(default="../results", description="Directory for final predictions") + repo_root: Optional[str] = Field(default=None, description="Repository root path") + python_executable: str = Field(default="python", description="Python executable to run scripts") + + @root_validator + def _validate_input_source(cls, values: dict[str, Any]) -> dict[str, Any]: + has_file = bool(values.get("input_file")) + has_rows = bool(values.get("input_rows")) + if has_file == has_rows: + raise ValueError("Provide exactly one of `input_file` or `input_rows`.") + return values + + +class PredictResponse(BaseModel): + backend: str + output_file: str + row_count: int + preview_rows: list[dict[str, Any]] + metadata: dict[str, Any] = Field(default_factory=dict) + + +def _resolve_repo_root(repo_root: Optional[str]) -> Path: + if repo_root: + return Path(repo_root).resolve() + env_repo_root = os.environ.get("CATPRED_REPO_ROOT") + if env_repo_root: + return Path(env_repo_root).resolve() + return Path.cwd().resolve() + + +def _write_rows_to_temp_csv(rows: list[dict[str, Any]], repo_root: Path) -> tuple[str, str]: + if len(rows) == 0: + raise ValueError("`input_rows` cannot be empty.") + + tmp_dir = repo_root / ".e2e-tests" / "api_tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp(prefix="api_input_", suffix=".csv", dir=str(tmp_dir)) + os.close(fd) + pd.DataFrame(rows).to_csv(tmp_path, index=False) + return tmp_path, tmp_path + + +def _resolve_input_file(payload: PredictRequest) -> tuple[str, Optional[str]]: + if payload.input_file: + return payload.input_file, None + + repo_root = _resolve_repo_root(payload.repo_root) + return _write_rows_to_temp_csv(payload.input_rows or [], repo_root) + + +def _preview_output(output_file: str, preview_limit: int = 5) -> tuple[int, list[dict[str, Any]]]: + df = pd.read_csv(output_file) + return len(df), df.head(preview_limit).to_dict(orient="records") + + +def create_app(router: Optional[InferenceBackendRouter] = None) -> FastAPI: + app = FastAPI(title="CatPred API", version="0.1.0") + backend_router = router or InferenceBackendRouter() + default_fallback = _env_flag("CATPRED_MODAL_FALLBACK_TO_LOCAL", default=False) + + @app.get("/health") + def health() -> dict[str, str]: + return {"status": "ok"} + + @app.get("/ready") + def ready() -> dict[str, Any]: + return backend_router.readiness() + + @app.post("/predict", response_model=PredictResponse) + def predict(payload: PredictRequest) -> PredictResponse: + input_file, temp_file = _resolve_input_file(payload) + try: + request_obj = PredictionRequest( + parameter=payload.parameter.lower(), + input_file=input_file, + checkpoint_dir=payload.checkpoint_dir, + use_gpu=payload.use_gpu, + repo_root=payload.repo_root, + python_executable=payload.python_executable, + ) + + fallback = payload.fallback_to_local + if fallback is None: + fallback = default_fallback + + result = backend_router.predict( + request_obj=request_obj, + results_dir=payload.results_dir, + backend_name=payload.backend, + fallback_to_local=fallback, + ) + + row_count, preview_rows = _preview_output(result.output_file) + return PredictResponse( + backend=result.backend_name, + output_file=result.output_file, + row_count=row_count, + preview_rows=preview_rows, + metadata=result.metadata, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except (FileNotFoundError, InferenceBackendError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except subprocess.CalledProcessError as exc: + raise HTTPException( + status_code=500, + detail=f"Prediction command failed with exit code {exc.returncode}.", + ) from exc + finally: + if temp_file and Path(temp_file).exists(): + Path(temp_file).unlink() + + return app + + +app = create_app() diff --git a/catpred/web/run.py b/catpred/web/run.py new file mode 100644 index 0000000..41fc261 --- /dev/null +++ b/catpred/web/run.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import argparse + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Run CatPred web API server.") + parser.add_argument("--host", default="0.0.0.0", help="Host interface to bind") + parser.add_argument("--port", type=int, default=8000, help="Port to bind") + parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development") + parser.add_argument("--workers", type=int, default=1, help="Number of worker processes") + args = parser.parse_args(argv) + + try: + import uvicorn + except ImportError: + print("uvicorn is not installed. Install optional web dependencies with `pip install .[web]`.") + return 1 + + from .app import create_app + + app = create_app() + uvicorn.run(app, host=args.host, port=args.port, reload=args.reload, workers=args.workers) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/setup.cfg b/setup.cfg index d75df80..4557b02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,9 +51,13 @@ console_scripts = catpred_train=catpred.train:catpred_train catpred_predict=catpred.train:catpred_predict catpred_fingerprint=catpred.train:catpred_fingerprint + catpred_web=catpred.web.run:main [options.extras_require] test = pytest>=6.2.2; parameterized>=0.8.1 +web = + fastapi>=0.95,<1.0 + uvicorn>=0.22,<1.0 [options.package_data] catpred = py.typed diff --git a/setup.py b/setup.py index e7d62bf..5502b99 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "catpred_train=catpred.train:catpred_train", "catpred_predict=catpred.train:catpred_predict", "catpred_fingerprint=catpred.train:catpred_fingerprint", + "catpred_web=catpred.web.run:main", ] }, install_requires=[ @@ -46,6 +47,12 @@ "scipy<1.11 ; python_version=='3.7'", "scipy>=1.9 ; python_version=='3.8'", ], + extras_require={ + "web": [ + "fastapi>=0.95,<1.0", + "uvicorn>=0.22,<1.0", + ] + }, python_requires=">=3.7", classifiers=[ "Programming Language :: Python :: 3", From 7d1167289485a3688e67603ade082a44a27116cc Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 16:12:41 -0500 Subject: [PATCH 06/34] Phase 3.1: harden API request surface and path guardrails --- README.md | 29 ++++++- catpred/web/app.py | 192 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 195 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 0e9d6d9..c185b4b 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,12 @@ Endpoints: - `GET /ready` โ€” backend configuration/readiness. - `POST /predict` โ€” run inference. -Minimal `POST /predict` example for local inference: +By default, the API is hardened for service use: +- `input_file` requests are disabled (use `input_rows` instead). +- request-time overrides of `repo_root` / `python_executable` are disabled. +- `results_dir` is constrained under `CATPRED_API_RESULTS_ROOT`. + +Minimal `POST /predict` example for local inference using `input_rows`: ```bash curl -X POST http://127.0.0.1:8000/predict \ @@ -130,7 +135,11 @@ curl -X POST http://127.0.0.1:8000/predict \ -d '{ "parameter": "kcat", "checkpoint_dir": "../data/pretrained/reproduce_checkpoints/kcat", - "input_file": "./demo/batch_kcat_pred.csv", + "input_rows": [ + {"SMILES": "CCO", "sequence": "ACDEFGHIK", "pdbpath": "seq_a"}, + {"SMILES": "CCN", "sequence": "LMNPQRSTV", "pdbpath": "seq_b"} + ], + "results_dir": "batch1", "backend": "local" }' ``` @@ -146,6 +155,22 @@ export CATPRED_MODAL_FALLBACK_TO_LOCAL=1 Use `"backend": "modal"` in `/predict` requests to route through Modal. If fallback is enabled (env var above or request field `fallback_to_local`), failed modal requests can automatically reroute to local inference. +Optional API environment variables: + +```bash +# Root directories used by API path constraints +export CATPRED_API_INPUT_ROOT="/absolute/path/for/input-csvs" +export CATPRED_API_RESULTS_ROOT="/absolute/path/for/results" + +# Enable only for trusted local workflows (not recommended for public deployments) +export CATPRED_API_ALLOW_INPUT_FILE=1 +export CATPRED_API_ALLOW_UNSAFE_OVERRIDES=1 + +# Request limits +export CATPRED_API_MAX_INPUT_ROWS=1000 +export CATPRED_API_MAX_INPUT_FILE_BYTES=5000000 +``` + ### ๐Ÿงช Fine-Tuning On Custom Data You can fine-tune CatPred on your own regression targets using `train.py`. diff --git a/catpred/web/app.py b/catpred/web/app.py index 01cbc0a..ed1081f 100644 --- a/catpred/web/app.py +++ b/catpred/web/app.py @@ -1,8 +1,10 @@ from __future__ import annotations +from dataclasses import dataclass from pathlib import Path import os import subprocess +import sys import tempfile from typing import Any, Optional @@ -30,6 +32,38 @@ def _env_flag(name: str, default: bool = False) -> bool: return value.strip().lower() in {"1", "true", "yes", "y", "on"} +@dataclass(frozen=True) +class APISettings: + repo_root: str + python_executable: str + input_root: str + results_root: str + allow_input_file: bool = False + allow_unsafe_request_overrides: bool = False + max_input_rows: int = 1000 + max_input_file_bytes: int = 5_000_000 + preview_rows: int = 5 + + @classmethod + def from_env(cls) -> "APISettings": + env_repo_root = os.environ.get("CATPRED_REPO_ROOT") + repo_root = str(Path(env_repo_root).resolve()) if env_repo_root else str(Path.cwd().resolve()) + input_root = os.environ.get("CATPRED_API_INPUT_ROOT") + results_root = os.environ.get("CATPRED_API_RESULTS_ROOT") + + return cls( + repo_root=repo_root, + python_executable=os.environ.get("CATPRED_PYTHON_EXECUTABLE", sys.executable or "python"), + input_root=str(Path(input_root).resolve()) if input_root else str((Path(repo_root) / "inputs").resolve()), + results_root=str(Path(results_root).resolve()) if results_root else str((Path(repo_root) / "results").resolve()), + allow_input_file=_env_flag("CATPRED_API_ALLOW_INPUT_FILE", default=False), + allow_unsafe_request_overrides=_env_flag("CATPRED_API_ALLOW_UNSAFE_OVERRIDES", default=False), + max_input_rows=max(int(os.environ.get("CATPRED_API_MAX_INPUT_ROWS", "1000")), 1), + max_input_file_bytes=max(int(os.environ.get("CATPRED_API_MAX_INPUT_FILE_BYTES", "5000000")), 1024), + preview_rows=max(int(os.environ.get("CATPRED_API_PREVIEW_ROWS", "5")), 1), + ) + + class PredictRequest(BaseModel): parameter: str = Field(..., description="One of: kcat, km, ki") checkpoint_dir: str = Field(..., description="Path to the checkpoint directory") @@ -41,9 +75,18 @@ class PredictRequest(BaseModel): use_gpu: bool = False backend: Optional[str] = Field(default=None, description="Override backend (local|modal)") fallback_to_local: Optional[bool] = Field(default=None, description="Fallback if modal fails") - results_dir: str = Field(default="../results", description="Directory for final predictions") - repo_root: Optional[str] = Field(default=None, description="Repository root path") - python_executable: str = Field(default="python", description="Python executable to run scripts") + results_dir: str = Field( + default="results", + description="Results subdirectory under CATPRED_API_RESULTS_ROOT.", + ) + repo_root: Optional[str] = Field( + default=None, + description="Unsafe override. Disabled by default; only for trusted local workflows.", + ) + python_executable: Optional[str] = Field( + default=None, + description="Unsafe override. Disabled by default; only for trusted local workflows.", + ) @root_validator def _validate_input_source(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -62,20 +105,96 @@ class PredictResponse(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict) -def _resolve_repo_root(repo_root: Optional[str]) -> Path: - if repo_root: +def _is_subpath(path: Path, root: Path) -> bool: + try: + path.relative_to(root) + return True + except ValueError: + return False + + +def _resolve_repo_root(repo_root: Optional[str], settings: APISettings) -> Path: + if repo_root is not None: + if not settings.allow_unsafe_request_overrides: + raise ValueError( + "Request field `repo_root` is disabled. " + "Set CATPRED_API_ALLOW_UNSAFE_OVERRIDES=1 for trusted local use." + ) return Path(repo_root).resolve() - env_repo_root = os.environ.get("CATPRED_REPO_ROOT") - if env_repo_root: - return Path(env_repo_root).resolve() - return Path.cwd().resolve() + return Path(settings.repo_root).resolve() + +def _resolve_python_executable(python_executable: Optional[str], settings: APISettings) -> str: + if python_executable is not None: + if not settings.allow_unsafe_request_overrides: + raise ValueError( + "Request field `python_executable` is disabled. " + "Set CATPRED_API_ALLOW_UNSAFE_OVERRIDES=1 for trusted local use." + ) + return python_executable + return settings.python_executable -def _write_rows_to_temp_csv(rows: list[dict[str, Any]], repo_root: Path) -> tuple[str, str]: + +def _resolve_and_validate_path_under_root(raw_path: str, root: Path, purpose: str) -> Path: + candidate = Path(raw_path) + resolved = candidate.resolve() if candidate.is_absolute() else (root / candidate).resolve() + if not _is_subpath(resolved, root): + raise ValueError(f"{purpose} must stay under configured root: {root}") + return resolved + + +def _resolve_results_dir(raw_results_dir: str, settings: APISettings) -> str: + results_root = Path(settings.results_root).resolve() + results_root.mkdir(parents=True, exist_ok=True) + resolved = _resolve_and_validate_path_under_root( + raw_path=raw_results_dir, + root=results_root, + purpose="results_dir", + ) + resolved.mkdir(parents=True, exist_ok=True) + return str(resolved) + + +def _resolve_input_file_path(input_file: str, settings: APISettings) -> Path: + if not settings.allow_input_file: + raise ValueError( + "Request field `input_file` is disabled. Submit `input_rows` instead, " + "or set CATPRED_API_ALLOW_INPUT_FILE=1 for trusted local use." + ) + + input_root = Path(settings.input_root).resolve() + input_root.mkdir(parents=True, exist_ok=True) + resolved = _resolve_and_validate_path_under_root( + raw_path=input_file, + root=input_root, + purpose="input_file", + ) + + if not resolved.exists(): + raise FileNotFoundError(f'Input CSV not found: "{resolved}"') + if not resolved.is_file(): + raise ValueError(f'Input CSV path is not a file: "{resolved}"') + if resolved.stat().st_size > settings.max_input_file_bytes: + raise ValueError( + f'Input file exceeds CATPRED_API_MAX_INPUT_FILE_BYTES ({settings.max_input_file_bytes}).' + ) + + return resolved + + +def _write_rows_to_temp_csv( + rows: list[dict[str, Any]], + settings: APISettings, + repo_root: Path, +) -> tuple[str, str]: if len(rows) == 0: raise ValueError("`input_rows` cannot be empty.") + if len(rows) > settings.max_input_rows: + raise ValueError( + f"`input_rows` exceeds CATPRED_API_MAX_INPUT_ROWS ({settings.max_input_rows})." + ) - tmp_dir = repo_root / ".e2e-tests" / "api_tmp" + tmp_dir = (repo_root.resolve() / ".e2e-tests" / "api_tmp").resolve() tmp_dir.mkdir(parents=True, exist_ok=True) fd, tmp_path = tempfile.mkstemp(prefix="api_input_", suffix=".csv", dir=str(tmp_dir)) os.close(fd) @@ -83,21 +202,29 @@ def _write_rows_to_temp_csv(rows: list[dict[str, Any]], repo_root: Path) -> tupl return tmp_path, tmp_path -def _resolve_input_file(payload: PredictRequest) -> tuple[str, Optional[str]]: +def _resolve_input_file( + payload: PredictRequest, + settings: APISettings, + repo_root: Path, +) -> tuple[str, Optional[str]]: if payload.input_file: - return payload.input_file, None + resolved = _resolve_input_file_path(payload.input_file, settings) + return str(resolved), None - repo_root = _resolve_repo_root(payload.repo_root) - return _write_rows_to_temp_csv(payload.input_rows or [], repo_root) + return _write_rows_to_temp_csv(payload.input_rows or [], settings=settings, repo_root=repo_root) -def _preview_output(output_file: str, preview_limit: int = 5) -> tuple[int, list[dict[str, Any]]]: +def _preview_output(output_file: str, preview_limit: int) -> tuple[int, list[dict[str, Any]]]: df = pd.read_csv(output_file) return len(df), df.head(preview_limit).to_dict(orient="records") -def create_app(router: Optional[InferenceBackendRouter] = None) -> FastAPI: - app = FastAPI(title="CatPred API", version="0.1.0") +def create_app( + router: Optional[InferenceBackendRouter] = None, + settings: Optional[APISettings] = None, +) -> FastAPI: + api_settings = settings or APISettings.from_env() + app = FastAPI(title="CatPred API", version="0.2.0") backend_router = router or InferenceBackendRouter() default_fallback = _env_flag("CATPRED_MODAL_FALLBACK_TO_LOCAL", default=False) @@ -107,19 +234,33 @@ def health() -> dict[str, str]: @app.get("/ready") def ready() -> dict[str, Any]: - return backend_router.readiness() + readiness = backend_router.readiness() + readiness["api"] = { + "allow_input_file": api_settings.allow_input_file, + "allow_unsafe_request_overrides": api_settings.allow_unsafe_request_overrides, + "max_input_rows": api_settings.max_input_rows, + "max_input_file_bytes": api_settings.max_input_file_bytes, + "input_root": str(Path(api_settings.input_root).resolve()), + "results_root": str(Path(api_settings.results_root).resolve()), + } + return readiness @app.post("/predict", response_model=PredictResponse) def predict(payload: PredictRequest) -> PredictResponse: - input_file, temp_file = _resolve_input_file(payload) + temp_file: Optional[str] = None try: + repo_root = _resolve_repo_root(payload.repo_root, api_settings) + python_executable = _resolve_python_executable(payload.python_executable, api_settings) + input_file, temp_file = _resolve_input_file(payload, settings=api_settings, repo_root=repo_root) + safe_results_dir = _resolve_results_dir(payload.results_dir, api_settings) + request_obj = PredictionRequest( parameter=payload.parameter.lower(), input_file=input_file, checkpoint_dir=payload.checkpoint_dir, use_gpu=payload.use_gpu, - repo_root=payload.repo_root, - python_executable=payload.python_executable, + repo_root=str(repo_root), + python_executable=python_executable, ) fallback = payload.fallback_to_local @@ -128,12 +269,15 @@ def predict(payload: PredictRequest) -> PredictResponse: result = backend_router.predict( request_obj=request_obj, - results_dir=payload.results_dir, + results_dir=safe_results_dir, backend_name=payload.backend, fallback_to_local=fallback, ) - row_count, preview_rows = _preview_output(result.output_file) + row_count, preview_rows = _preview_output( + result.output_file, + preview_limit=api_settings.preview_rows, + ) return PredictResponse( backend=result.backend_name, output_file=result.output_file, From cd6f22a0f413626916ed58c55ce8ed275cdc28b4 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 16:13:41 -0500 Subject: [PATCH 07/34] Phase 3.2: unify packaging metadata via setup.cfg --- setup.cfg | 8 +++---- setup.py | 72 ++++--------------------------------------------------- 2 files changed, 8 insertions(+), 72 deletions(-) diff --git a/setup.cfg b/setup.cfg index 4557b02..3db702f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,10 +1,10 @@ [metadata] name = catpred version = 0.0.1 -author = -author_email = +author = Veda Sheersh Boorla, Costas D. Maranas +author_email = mailforveda@gmail.com license = MIT -description = +description = A comprehensive framework for deep learning in vitro enzyme kinetic parameters kcat, Km and Ki keywords = protein language model machine learning @@ -21,7 +21,7 @@ classifiers = License :: OSI Approved :: MIT License Operating System :: OS Independent project_urls = - Documentation = + Documentation = https://github.com/maranasgroup/catpred/ Source = https://github.com/maranasgroup/catpred PyPi = Demo = http://tiny.cc/catpred diff --git a/setup.py b/setup.py index 5502b99..b357122 100644 --- a/setup.py +++ b/setup.py @@ -1,70 +1,6 @@ -from setuptools import find_packages, setup +from setuptools import setup -__version__ = "0.0.1" -# Load README -with open("README.md", encoding="utf-8") as f: - long_description = f.read() - -setup( - name="catpred", - author="Veda Sheersh Boorla, Costas D. Maranas", - author_email="mailforveda@gmail.com", - description="A comprehensive framework for deep learning in vitro enzyme kinetic parameters kcat, Km and Ki", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/maranasgroup/catpred", - download_url=f"https://github.com/maranasgroup/catpred/v_{__version__}.tar.gz", - project_urls={ - "Documentation": "https://github.com/maranasgroup/catpred/", - "Source": "https://github.com/maranasgroup/catpred", - "PyPi": "", - "Demo": "https://tiny.cc/catpred", - }, - license="MIT", - packages=find_packages(), - package_data={"catpred": ["py.typed"]}, - entry_points={ - "console_scripts": [ - "catpred_train=catpred.train:catpred_train", - "catpred_predict=catpred.train:catpred_predict", - "catpred_fingerprint=catpred.train:catpred_fingerprint", - "catpred_web=catpred.web.run:main", - ] - }, - install_requires=[ - "matplotlib>=3.1.3", - "numpy>=1.18.1", - "pandas>=1.0.3", - "pandas-flavor>=0.2.0", - "scikit-learn>=0.22.2.post1", - "tensorboardX>=2.0", - "sphinx>=3.1.2", - "torch>=1.4.0", - "tqdm>=4.45.0", - "typed-argument-parser>=1.6.1", - "rdkit>=2020.03.1.0", - "scipy<1.11 ; python_version=='3.7'", - "scipy>=1.9 ; python_version=='3.8'", - ], - extras_require={ - "web": [ - "fastapi>=0.95,<1.0", - "uvicorn>=0.22,<1.0", - ] - }, - python_requires=">=3.7", - classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - keywords=[ - "bioinformatics", - "machine learning", - "enzyme function prediction", - "message passing neural network", - ], -) +if __name__ == "__main__": + # Keep setup.py as a thin shim so setup.cfg is the single source of packaging metadata. + setup() From 84f3e424ea9aa2414ce73b9cd1d708ce330b4bdd Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 17:18:53 -0500 Subject: [PATCH 08/34] Phase 3.3: centralize deserialization policy and enforce trusted roots --- README.md | 12 ++ catpred/__init__.py | 1 + catpred/args.py | 8 +- catpred/data/cache_utils.py | 14 +-- catpred/data/utils.py | 20 ++-- catpred/features/utils.py | 6 +- catpred/models/model.py | 13 +-- catpred/security/__init__.py | 17 +++ catpred/security/deserialization.py | 168 ++++++++++++++++++++++++++++ catpred/utils.py | 39 ++++--- catpred/web/app.py | 37 +++++- scripts/baseline_analysis.py | 10 +- 12 files changed, 288 insertions(+), 57 deletions(-) create mode 100644 catpred/security/__init__.py create mode 100644 catpred/security/deserialization.py diff --git a/README.md b/README.md index c185b4b..c0ac109 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,7 @@ Optional API environment variables: # Root directories used by API path constraints export CATPRED_API_INPUT_ROOT="/absolute/path/for/input-csvs" export CATPRED_API_RESULTS_ROOT="/absolute/path/for/results" +export CATPRED_API_CHECKPOINT_ROOT="/absolute/path/for/checkpoints" # Enable only for trusted local workflows (not recommended for public deployments) export CATPRED_API_ALLOW_INPUT_FILE=1 @@ -171,6 +172,17 @@ export CATPRED_API_MAX_INPUT_ROWS=1000 export CATPRED_API_MAX_INPUT_FILE_BYTES=5000000 ``` +Deserialization hardening controls: + +```bash +# Trusted roots used by secure loaders (colon-separated list on Unix) +export CATPRED_TRUSTED_DESERIALIZATION_ROOTS="/srv/catpred:/srv/catpred-data" + +# Backward-compatible default is enabled (1). Set to 0 to block unsafe pickle-based loading. +# Use 0 only after validating your artifacts are safe-load compatible. +export CATPRED_ALLOW_UNSAFE_DESERIALIZATION=1 +``` + ### ๐Ÿงช Fine-Tuning On Custom Data You can fine-tune CatPred on your own regression targets using `train.py`. diff --git a/catpred/__init__.py b/catpred/__init__.py index 92cfc40..75b676b 100644 --- a/catpred/__init__.py +++ b/catpred/__init__.py @@ -13,6 +13,7 @@ "models", "nn_utils", "rdkit", + "security", "train", "uncertainty", "utils", diff --git a/catpred/args.py b/catpred/args.py index d52b198..99cc1b3 100644 --- a/catpred/args.py +++ b/catpred/args.py @@ -1,7 +1,6 @@ import json import os from tempfile import TemporaryDirectory -import pickle from typing import List, Optional from typing_extensions import Literal from packaging import version @@ -15,6 +14,7 @@ import catpred.data.utils from catpred.data import set_cache_mol, empty_cache from catpred.features import get_available_features_generators +from catpred.security import load_index_artifact Metric = Literal['auc', 'prc-auc', 'rmse', 'mae', 'mse', 'r2', 'accuracy', 'cross_entropy', 'binary_cross_entropy', 'sid', 'wasserstein', 'f1', 'mcc', 'bounded_rmse', 'bounded_mae', 'bounded_mse'] @@ -815,8 +815,10 @@ def process_args(self) -> None: raise ValueError('When using crossval or index_predetermined split type, must provide crossval_index_file.') if self.split_type in ['crossval', 'index_predetermined']: - with open(self.crossval_index_file, 'rb') as rf: - self._crossval_index_sets = pickle.load(rf) + self._crossval_index_sets = load_index_artifact( + self.crossval_index_file, + purpose="cross-validation index file", + ) self.num_folds = len(self.crossval_index_sets) self.seed = 0 diff --git a/catpred/data/cache_utils.py b/catpred/data/cache_utils.py index 90cac3a..a72812d 100644 --- a/catpred/data/cache_utils.py +++ b/catpred/data/cache_utils.py @@ -4,6 +4,7 @@ import hashlib from functools import wraps from pathlib import Path +from catpred.security import load_torch_artifact def exists(val): return val is not None @@ -27,13 +28,6 @@ def md5_hash_fn(s): encoded = s.encode('utf-8') return hashlib.md5(encoded).hexdigest() - -def _torch_load_compat(path): - try: - return torch.load(path, weights_only=False) - except TypeError: - return torch.load(path) - # run once function GLOBAL_RUN_RECORDS = dict() @@ -99,7 +93,11 @@ def inner(t, *args, __cache_key = None, **kwargs): if entry_path.exists(): log(f'cache hit: fetching {t} from {str(entry_path)}') - return _torch_load_compat(str(entry_path)) + return load_torch_artifact( + str(entry_path), + purpose="esm cache entry", + roots=[CACHE_PATH], + ) out = fn(t, *args, **kwargs) diff --git a/catpred/data/utils.py b/catpred/data/utils.py index 34f03be..ad23098 100644 --- a/catpred/data/utils.py +++ b/catpred/data/utils.py @@ -3,7 +3,6 @@ import csv import ctypes from logging import Logger -import pickle from random import Random from typing import List, Set, Tuple, Union import os @@ -22,6 +21,7 @@ from catpred.args import PredictArgs, TrainArgs from catpred.features import load_features, load_valid_atom_or_bond_features, is_mol from catpred.rdkit import make_mol +from catpred.security import load_index_artifact, load_pickle_artifact # Increase maximum size of field in the csv processing for the current architecture csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2)) @@ -791,8 +791,12 @@ def split_data(data: MoleculeDataset, for split in range(3): split_indices = [] for index in index_set[split]: - with open(os.path.join(args.crossval_index_dir, f'{index}.pkl'), 'rb') as rf: - split_indices.extend(pickle.load(rf)) + split_indices.extend( + load_index_artifact( + os.path.join(args.crossval_index_dir, f"{index}.pkl"), + purpose="cross-validation fold index file", + ) + ) data_split.append([data[i] for i in split_indices]) train, val, test = tuple(data_split) return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test) @@ -841,12 +845,10 @@ def split_data(data: MoleculeDataset, if test_fold_index is None: raise ValueError('arg "test_fold_index" can not be None!') - try: - with open(folds_file, 'rb') as f: - all_fold_indices = pickle.load(f) - except UnicodeDecodeError: - with open(folds_file, 'rb') as f: - all_fold_indices = pickle.load(f, encoding='latin1') # in case we're loading indices from python2 + all_fold_indices = load_pickle_artifact( + folds_file, + purpose="predetermined folds file", + ) log_scaffold_stats(data, all_fold_indices, logger=logger) diff --git a/catpred/features/utils.py b/catpred/features/utils.py index 617af86..c0e24dd 100644 --- a/catpred/features/utils.py +++ b/catpred/features/utils.py @@ -1,11 +1,11 @@ import csv import os -import pickle from typing import List import numpy as np import pandas as pd from rdkit.Chem import PandasTools +from catpred.security import load_pickle_artifact def save_features(path: str, features: List[np.ndarray]) -> None: @@ -49,8 +49,8 @@ def load_features(path: str) -> np.ndarray: next(reader) # skip header features = np.array([[float(value) for value in row] for row in reader]) elif extension in ['.pkl', '.pckl', '.pickle']: - with open(path, 'rb') as f: - features = np.array([np.squeeze(np.array(feat.todense())) for feat in pickle.load(f)]) + payload = load_pickle_artifact(path, purpose="feature matrix pickle") + features = np.array([np.squeeze(np.array(feat.todense())) for feat in payload]) else: raise ValueError(f'Features path extension {extension} not supported.') diff --git a/catpred/models/model.py b/catpred/models/model.py index 5b50099..16e5c6f 100644 --- a/catpred/models/model.py +++ b/catpred/models/model.py @@ -13,13 +13,7 @@ from catpred.args import TrainArgs from catpred.features import BatchMolGraph from catpred.nn_utils import initialize_weights - - -def _torch_load_compat(path): - try: - return torch.load(path, weights_only=False) - except TypeError: - return torch.load(path) +from catpred.security import load_torch_artifact def exists(val): @@ -126,7 +120,10 @@ def create_protein_model(self, args: TrainArgs) -> None: self.seq_embedder = nn.Embedding(21, args.seq_embed_dim, padding_idx=20) #last index is for padding if self.args.add_pretrained_egnn_feats: - self.pretrained_egnn_feats_dict = _torch_load_compat(self.args.pretrained_egnn_feats_path) + self.pretrained_egnn_feats_dict = load_torch_artifact( + self.args.pretrained_egnn_feats_path, + purpose="pretrained EGNN features", + ) x = list(self.pretrained_egnn_feats_dict.values()) self.pretrained_egnn_feats_avg = torch.stack(x).mean(dim=0) diff --git a/catpred/security/__init__.py b/catpred/security/__init__.py new file mode 100644 index 0000000..8c5f371 --- /dev/null +++ b/catpred/security/__init__.py @@ -0,0 +1,17 @@ +from .deserialization import ( + DeserializationSecurityError, + ensure_trusted_path, + load_index_artifact, + load_pickle_artifact, + load_torch_artifact, + unsafe_deserialization_enabled, +) + +__all__ = [ + "DeserializationSecurityError", + "ensure_trusted_path", + "load_index_artifact", + "load_pickle_artifact", + "load_torch_artifact", + "unsafe_deserialization_enabled", +] diff --git a/catpred/security/deserialization.py b/catpred/security/deserialization.py new file mode 100644 index 0000000..a3fc8f9 --- /dev/null +++ b/catpred/security/deserialization.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import gzip +import json +import os +import pickle +from pathlib import Path +from typing import Any, Iterable + + +class DeserializationSecurityError(RuntimeError): + """Raised when deserialization is blocked by policy.""" + + +_ALLOW_UNSAFE_ENV = "CATPRED_ALLOW_UNSAFE_DESERIALIZATION" +_TRUSTED_ROOTS_ENV = "CATPRED_TRUSTED_DESERIALIZATION_ROOTS" + + +def _env_flag(name: str, default: bool = False) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + return raw.strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _dedupe_paths(paths: list[Path]) -> list[Path]: + unique: list[Path] = [] + seen = set() + for path in paths: + resolved = path.resolve() + key = str(resolved) + if key not in seen: + unique.append(resolved) + seen.add(key) + return unique + + +def _default_trusted_roots() -> list[Path]: + raw = os.environ.get(_TRUSTED_ROOTS_ENV) + if raw: + candidates = [Path(item) for item in raw.split(os.pathsep) if item.strip()] + else: + cwd = Path.cwd().resolve() + candidates = [cwd, cwd.parent] + return _dedupe_paths(candidates) + + +def trusted_roots(extra_roots: Iterable[str | Path] | None = None) -> list[Path]: + roots = _default_trusted_roots() + if extra_roots: + roots.extend(Path(path) for path in extra_roots) + return _dedupe_paths(roots) + + +def is_trusted_path(path: str | Path, roots: Iterable[str | Path] | None = None) -> bool: + resolved = Path(path).resolve() + candidate_roots = trusted_roots(roots) + for root in candidate_roots: + try: + resolved.relative_to(root) + return True + except ValueError: + continue + return False + + +def ensure_trusted_path( + path: str | Path, + *, + purpose: str, + roots: Iterable[str | Path] | None = None, +) -> Path: + resolved = Path(path).resolve() + candidate_roots = trusted_roots(roots) + if is_trusted_path(resolved, candidate_roots): + return resolved + roots_display = ", ".join(str(item) for item in candidate_roots) + raise DeserializationSecurityError( + f'Refusing to load untrusted {purpose} from "{resolved}". ' + f"Allowed roots: {roots_display}. " + f"Use {_TRUSTED_ROOTS_ENV} to expand trusted roots." + ) + + +def unsafe_deserialization_enabled(default: bool = True) -> bool: + return _env_flag(_ALLOW_UNSAFE_ENV, default=default) + + +def load_pickle_artifact( + path: str | Path, + *, + purpose: str, + roots: Iterable[str | Path] | None = None, + allow_unsafe: bool | None = None, + encoding: str | None = None, +) -> Any: + resolved = ensure_trusted_path(path, purpose=purpose, roots=roots) + unsafe = unsafe_deserialization_enabled() if allow_unsafe is None else allow_unsafe + if not unsafe: + raise DeserializationSecurityError( + f"Pickle deserialization is disabled for {purpose}. " + f"Set {_ALLOW_UNSAFE_ENV}=1 only for trusted artifacts." + ) + + with resolved.open("rb") as handle: + if encoding is None: + try: + return pickle.load(handle) + except UnicodeDecodeError: + handle.seek(0) + return pickle.load(handle, encoding="latin1") + return pickle.load(handle, encoding=encoding) + + +def load_index_artifact( + path: str | Path, + *, + purpose: str, + roots: Iterable[str | Path] | None = None, + allow_unsafe: bool | None = None, +) -> Any: + resolved = ensure_trusted_path(path, purpose=purpose, roots=roots) + suffixes = tuple(s.lower() for s in resolved.suffixes) + if suffixes and suffixes[-1] == ".json": + with resolved.open("rt", encoding="utf-8") as handle: + return json.load(handle) + if suffixes[-2:] == (".json", ".gz"): + with gzip.open(resolved, "rt", encoding="utf-8") as handle: + return json.load(handle) + return load_pickle_artifact( + resolved, + purpose=purpose, + roots=roots, + allow_unsafe=allow_unsafe, + ) + + +def load_torch_artifact( + path: str | Path, + *, + purpose: str, + map_location=None, + roots: Iterable[str | Path] | None = None, + allow_unsafe: bool | None = None, +) -> Any: + resolved = ensure_trusted_path(path, purpose=purpose, roots=roots) + unsafe = unsafe_deserialization_enabled() if allow_unsafe is None else allow_unsafe + + import torch + + if unsafe: + try: + return torch.load(str(resolved), map_location=map_location, weights_only=False) + except TypeError: + return torch.load(str(resolved), map_location=map_location) + + try: + return torch.load(str(resolved), map_location=map_location, weights_only=True) + except TypeError as exc: + raise DeserializationSecurityError( + "Safe torch deserialization requires a torch version that supports weights_only loading. " + f"Set {_ALLOW_UNSAFE_ENV}=1 for trusted legacy checkpoints." + ) from exc + except Exception as exc: + raise DeserializationSecurityError( + f"Safe torch deserialization rejected {purpose}. " + f"If this checkpoint is trusted, set {_ALLOW_UNSAFE_ENV}=1." + ) from exc diff --git a/catpred/utils.py b/catpred/utils.py index 3f70e62..102aeda 100644 --- a/catpred/utils.py +++ b/catpred/utils.py @@ -23,18 +23,7 @@ from catpred.models import MoleculeModel from catpred.nn_utils import NoamLR from catpred.models.ffn import MultiReadout - - -def _torch_load_compat(path, map_location=None): - """ - Explicitly disable weights-only loading for backward compatibility with - CatPred checkpoints that store non-tensor objects (e.g., argparse.Namespace). - """ - try: - return torch.load(path, map_location=map_location, weights_only=False) - except TypeError: - # For older torch versions that do not support weights_only. - return torch.load(path, map_location=map_location) +from catpred.security import load_torch_artifact def makedirs(path: str, isfile: bool = False) -> None: @@ -122,7 +111,11 @@ def load_checkpoint( debug = info = print # Load model and args - state = _torch_load_compat(path, map_location=lambda storage, loc: storage) + state = load_torch_artifact( + path, + purpose="model checkpoint", + map_location=lambda storage, loc: storage, + ) args = TrainArgs() args.from_dict(vars(state["args"]), skip_unsettable=True) if not pretrained_egnn_feats_path is None: @@ -230,7 +223,11 @@ def load_frzn_model( """ debug = logger.debug if logger is not None else print - loaded_mpnn_model = _torch_load_compat(path, map_location=lambda storage, loc: storage) + loaded_mpnn_model = load_torch_artifact( + path, + purpose="frozen model checkpoint", + map_location=lambda storage, loc: storage, + ) loaded_state_dict = loaded_mpnn_model["state_dict"] loaded_args = loaded_mpnn_model["args"] @@ -452,7 +449,11 @@ def load_scalers( :return: A tuple with the data :class:`~catpred.data.scaler.StandardScaler` and features :class:`~catpred.data.scaler.StandardScaler`. """ - state = _torch_load_compat(path, map_location=lambda storage, loc: storage) + state = load_torch_artifact( + path, + purpose="scaler checkpoint", + map_location=lambda storage, loc: storage, + ) if state["data_scaler"] is not None: scaler = StandardScaler(state["data_scaler"]["means"], state["data_scaler"]["stds"]) @@ -507,7 +508,13 @@ def load_args(path: str) -> TrainArgs: """ args = TrainArgs() args.from_dict( - vars(_torch_load_compat(path, map_location=lambda storage, loc: storage)["args"]), + vars( + load_torch_artifact( + path, + purpose="args checkpoint", + map_location=lambda storage, loc: storage, + )["args"] + ), skip_unsettable=True, ) diff --git a/catpred/web/app.py b/catpred/web/app.py index ed1081f..62c38d3 100644 --- a/catpred/web/app.py +++ b/catpred/web/app.py @@ -38,6 +38,7 @@ class APISettings: python_executable: str input_root: str results_root: str + checkpoint_root: str allow_input_file: bool = False allow_unsafe_request_overrides: bool = False max_input_rows: int = 1000 @@ -50,12 +51,18 @@ def from_env(cls) -> "APISettings": repo_root = str(Path(env_repo_root).resolve()) if env_repo_root else str(Path.cwd().resolve()) input_root = os.environ.get("CATPRED_API_INPUT_ROOT") results_root = os.environ.get("CATPRED_API_RESULTS_ROOT") + checkpoint_root = os.environ.get("CATPRED_API_CHECKPOINT_ROOT") return cls( repo_root=repo_root, python_executable=os.environ.get("CATPRED_PYTHON_EXECUTABLE", sys.executable or "python"), input_root=str(Path(input_root).resolve()) if input_root else str((Path(repo_root) / "inputs").resolve()), results_root=str(Path(results_root).resolve()) if results_root else str((Path(repo_root) / "results").resolve()), + checkpoint_root=( + str(Path(checkpoint_root).resolve()) + if checkpoint_root + else str((Path(repo_root) / "checkpoints").resolve()) + ), allow_input_file=_env_flag("CATPRED_API_ALLOW_INPUT_FILE", default=False), allow_unsafe_request_overrides=_env_flag("CATPRED_API_ALLOW_UNSAFE_OVERRIDES", default=False), max_input_rows=max(int(os.environ.get("CATPRED_API_MAX_INPUT_ROWS", "1000")), 1), @@ -155,6 +162,21 @@ def _resolve_results_dir(raw_results_dir: str, settings: APISettings) -> str: return str(resolved) +def _resolve_checkpoint_dir(raw_checkpoint_dir: str, settings: APISettings) -> str: + checkpoint_root = Path(settings.checkpoint_root).resolve() + checkpoint_root.mkdir(parents=True, exist_ok=True) + resolved = _resolve_and_validate_path_under_root( + raw_path=raw_checkpoint_dir, + root=checkpoint_root, + purpose="checkpoint_dir", + ) + if not resolved.exists(): + raise FileNotFoundError(f'Checkpoint directory not found: "{resolved}"') + if not resolved.is_dir(): + raise ValueError(f'checkpoint_dir must be a directory: "{resolved}"') + return str(resolved) + + def _resolve_input_file_path(input_file: str, settings: APISettings) -> Path: if not settings.allow_input_file: raise ValueError( @@ -242,6 +264,7 @@ def ready() -> dict[str, Any]: "max_input_file_bytes": api_settings.max_input_file_bytes, "input_root": str(Path(api_settings.input_root).resolve()), "results_root": str(Path(api_settings.results_root).resolve()), + "checkpoint_root": str(Path(api_settings.checkpoint_root).resolve()), } return readiness @@ -253,20 +276,22 @@ def predict(payload: PredictRequest) -> PredictResponse: python_executable = _resolve_python_executable(payload.python_executable, api_settings) input_file, temp_file = _resolve_input_file(payload, settings=api_settings, repo_root=repo_root) safe_results_dir = _resolve_results_dir(payload.results_dir, api_settings) - + fallback = payload.fallback_to_local + if fallback is None: + fallback = default_fallback + selected_backend = (payload.backend or backend_router.settings.default_backend).lower() + checkpoint_dir = payload.checkpoint_dir + if selected_backend == "local" or fallback: + checkpoint_dir = _resolve_checkpoint_dir(payload.checkpoint_dir, api_settings) request_obj = PredictionRequest( parameter=payload.parameter.lower(), input_file=input_file, - checkpoint_dir=payload.checkpoint_dir, + checkpoint_dir=checkpoint_dir, use_gpu=payload.use_gpu, repo_root=str(repo_root), python_executable=python_executable, ) - fallback = payload.fallback_to_local - if fallback is None: - fallback = default_fallback - result = backend_router.predict( request_obj=request_obj, results_dir=safe_results_dir, diff --git a/scripts/baseline_analysis.py b/scripts/baseline_analysis.py index e2076a1..e87507d 100644 --- a/scripts/baseline_analysis.py +++ b/scripts/baseline_analysis.py @@ -1,6 +1,5 @@ import os import sys -import pickle import pandas as pd from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor @@ -8,14 +7,17 @@ from sklearn.metrics import r2_score, mean_absolute_error import ipdb import csv +from catpred.security import load_pickle_artifact OUTPUT_DIR="../results/reproduce_results" DATA_DIR = "../data/external/Baseline/" def load_identity_data(parameter): """Load pre-calculated identity dictionary and mappings.""" - with open(f'{DATA_DIR}/{parameter}/{parameter}_test_train_identities_updated.pkl', 'rb') as f: - data = pickle.load(f) + data = load_pickle_artifact( + f"{DATA_DIR}/{parameter}/{parameter}_test_train_identities_updated.pkl", + purpose="baseline identity pickle", + ) train_seqs_dict = {val: key for key, val in data['train_seq_mapping'].items()} test_seqs_dict = {val: key for key, val in data['test_seq_mapping'].items()} return data, train_seqs_dict, test_seqs_dict @@ -142,4 +144,4 @@ def main(PARAMETER, OUTPUT_FILE, recalculate=False): if sys.argv[1]=='recalculate': main(PARAMETER, OUTPUT_FILE, True) else: - main(PARAMETER, OUTPUT_FILE, False) \ No newline at end of file + main(PARAMETER, OUTPUT_FILE, False) From 919a8635ae6f6fb2074ec8ec4e57232ac93ebda3 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 17:51:41 -0500 Subject: [PATCH 09/34] docs: align API checkpoint_dir guidance with current guardrails --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c0ac109..baec9cd 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ By default, the API is hardened for service use: - `input_file` requests are disabled (use `input_rows` instead). - request-time overrides of `repo_root` / `python_executable` are disabled. - `results_dir` is constrained under `CATPRED_API_RESULTS_ROOT`. +- for local backend (and modal requests with fallback enabled), `checkpoint_dir` must resolve under `CATPRED_API_CHECKPOINT_ROOT`. Minimal `POST /predict` example for local inference using `input_rows`: @@ -134,7 +135,7 @@ curl -X POST http://127.0.0.1:8000/predict \ -H "Content-Type: application/json" \ -d '{ "parameter": "kcat", - "checkpoint_dir": "../data/pretrained/reproduce_checkpoints/kcat", + "checkpoint_dir": "kcat", "input_rows": [ {"SMILES": "CCO", "sequence": "ACDEFGHIK", "pdbpath": "seq_a"}, {"SMILES": "CCN", "sequence": "LMNPQRSTV", "pdbpath": "seq_b"} @@ -154,6 +155,7 @@ export CATPRED_MODAL_FALLBACK_TO_LOCAL=1 ``` Use `"backend": "modal"` in `/predict` requests to route through Modal. If fallback is enabled (env var above or request field `fallback_to_local`), failed modal requests can automatically reroute to local inference. +For local backend requests, place local checkpoints under `CATPRED_API_CHECKPOINT_ROOT` and pass a path relative to that root (for example, `"checkpoint_dir": "kcat"`). Optional API environment variables: From 135b9e4e97ac1c9ea3162c25b89e7f046f1a6229 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 23:43:04 -0500 Subject: [PATCH 10/34] feat: ship web UI refresh and vercel deployment support --- .vercelignore | 8 + README.md | 22 + api/index.py | 21 + catpred/data/utils.py | 8 +- catpred/features/features_generators.py | 1 - catpred/models/model.py | 7 +- catpred/web/app.py | 103 ++- catpred/web/static/catpred.css | 1099 +++++++++++++++++++++++ catpred/web/static/catpred.js | 730 +++++++++++++++ catpred/web/static/index.html | 214 +++++ requirements.txt | 3 + setup.cfg | 6 +- vercel.json | 21 + 13 files changed, 2227 insertions(+), 16 deletions(-) create mode 100644 .vercelignore create mode 100644 api/index.py create mode 100644 catpred/web/static/catpred.css create mode 100644 catpred/web/static/catpred.js create mode 100644 catpred/web/static/index.html create mode 100644 requirements.txt create mode 100644 vercel.json diff --git a/.vercelignore b/.vercelignore new file mode 100644 index 0000000..e3a5762 --- /dev/null +++ b/.vercelignore @@ -0,0 +1,8 @@ +.venv/ +.e2e-assets/ +.e2e-tests/ +results/ +output/ +checkpoints/ +external/ +*.ipynb diff --git a/README.md b/README.md index baec9cd..00c08c2 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ - [Installation](#installing) - [Prediction](#predict) - [Web API (Optional)](#web-api-optional) + - [Vercel Deployment (Optional)](#vercel-deployment-optional) - [Reproducibility](#reproduce) - [Acknowledgements](#acknw) - [License](#license) @@ -185,6 +186,27 @@ export CATPRED_TRUSTED_DESERIALIZATION_ROOTS="/srv/catpred:/srv/catpred-data" export CATPRED_ALLOW_UNSAFE_DESERIALIZATION=1 ``` +### โ–ฒ Vercel Deployment (Optional) + +This repository includes a Vercel-ready ASGI entrypoint at `api/index.py` and a `vercel.json` route config. + +1. Push this repository to GitHub. +2. In Vercel, create a new project from that repo. +3. Set Environment Variables in Vercel Project Settings: + +```bash +# Use remote inference backend in serverless deployments +CATPRED_DEFAULT_BACKEND=modal +CATPRED_MODAL_ENDPOINT=https:// +CATPRED_MODAL_TOKEN= +CATPRED_MODAL_FALLBACK_TO_LOCAL=0 +``` + +Notes: +- Serverless filesystems are ephemeral/read-only except `/tmp`; this app auto-uses `/tmp/catpred` on Vercel. +- Local checkpoint-based inference is not recommended on Vercel serverless due runtime/dependency limits. +- If `CATPRED_MODAL_ENDPOINT` is not configured, the UI still loads but prediction requests will be limited by backend readiness. + ### ๐Ÿงช Fine-Tuning On Custom Data You can fine-tune CatPred on your own regression targets using `train.py`. diff --git a/api/index.py b/api/index.py new file mode 100644 index 0000000..962060d --- /dev/null +++ b/api/index.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from pathlib import Path +import os +import sys + + +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +if os.environ.get("VERCEL"): + os.environ.setdefault("CATPRED_API_RUNTIME_ROOT", "/tmp/catpred") + os.environ.setdefault("CATPRED_MODAL_FALLBACK_TO_LOCAL", "0") + if os.environ.get("CATPRED_MODAL_ENDPOINT"): + os.environ.setdefault("CATPRED_DEFAULT_BACKEND", "modal") + +from catpred.web.app import app + + +__all__ = ["app"] diff --git a/catpred/data/utils.py b/catpred/data/utils.py index ad23098..02f6d62 100644 --- a/catpred/data/utils.py +++ b/catpred/data/utils.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from collections import OrderedDict, defaultdict import sys import csv import ctypes from logging import Logger from random import Random -from typing import List, Set, Tuple, Union +from typing import List, Set, Tuple, Union, TYPE_CHECKING import os import json import torch @@ -18,11 +20,13 @@ from .esm_utils import get_protein_embedder, get_coords from .data import MoleculeDatapoint, MoleculeDataset, make_mols from .scaffold import log_scaffold_stats, scaffold_split -from catpred.args import PredictArgs, TrainArgs from catpred.features import load_features, load_valid_atom_or_bond_features, is_mol from catpred.rdkit import make_mol from catpred.security import load_index_artifact, load_pickle_artifact +if TYPE_CHECKING: + from catpred.args import PredictArgs, TrainArgs + # Increase maximum size of field in the csv processing for the current architecture csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2)) diff --git a/catpred/features/features_generators.py b/catpred/features/features_generators.py index 40226d2..26ec91f 100644 --- a/catpred/features/features_generators.py +++ b/catpred/features/features_generators.py @@ -85,7 +85,6 @@ def morgan_binary_features_generator(mol: Molecule, # return features -import ipdb @register_features_generator('morgan_diff_fp') def morgan_difference_features_generator(rxn: Reaction) -> np.ndarray: """ diff --git a/catpred/models/model.py b/catpred/models/model.py index 16e5c6f..916cd81 100644 --- a/catpred/models/model.py +++ b/catpred/models/model.py @@ -127,8 +127,11 @@ def create_protein_model(self, args: TrainArgs) -> None: x = list(self.pretrained_egnn_feats_dict.values()) self.pretrained_egnn_feats_avg = torch.stack(x).mean(dim=0) - # For rotary positional embeddings - self.rotary_embedder = RotaryEmbedding(dim=args.seq_embed_dim//4) + # Rotary embeddings require an even feature dimension. + rotary_dim = max(2, args.seq_embed_dim // 4) + if rotary_dim % 2 != 0: + rotary_dim -= 1 + self.rotary_embedder = RotaryEmbedding(dim=rotary_dim) # For self-attention self.multihead_attn = nn.MultiheadAttention(args.seq_embed_dim, diff --git a/catpred/web/app.py b/catpred/web/app.py index 62c38d3..8158f24 100644 --- a/catpred/web/app.py +++ b/catpred/web/app.py @@ -12,6 +12,8 @@ try: from fastapi import FastAPI, HTTPException + from fastapi.responses import FileResponse + from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field, root_validator except ImportError as exc: # pragma: no cover - import guard for optional dependency raise ImportError( @@ -32,12 +34,56 @@ def _env_flag(name: str, default: bool = False) -> bool: return value.strip().lower() in {"1", "true", "yes", "y", "on"} +_SUPPORTED_PARAMETERS = ("kcat", "km", "ki") + + +def _contains_model_checkpoints(path: Path) -> bool: + if not path.exists() or not path.is_dir(): + return False + return any(path.rglob("model.pt")) + + +def _discover_default_checkpoint_root(repo_root: Path) -> Path: + default_root = (repo_root / "checkpoints").resolve() + production_root = (repo_root / ".e2e-assets" / "pretrained" / "production").resolve() + candidates = [default_root, production_root] + + best_root: Optional[Path] = None + best_score = -1 + for candidate in candidates: + if not candidate.exists() or not candidate.is_dir(): + continue + score = sum(int((candidate / parameter).is_dir()) for parameter in _SUPPORTED_PARAMETERS) + if score > best_score: + best_score = score + best_root = candidate + + if best_root and best_score > 0: + return best_root + + for candidate in candidates: + if _contains_model_checkpoints(candidate): + return candidate + + return default_root + + +def _discover_available_checkpoints(checkpoint_root: Path) -> dict[str, str]: + available: dict[str, str] = {} + for parameter in _SUPPORTED_PARAMETERS: + param_dir = (checkpoint_root / parameter).resolve() + if param_dir.is_dir() and _contains_model_checkpoints(param_dir): + available[parameter] = parameter + return available + + @dataclass(frozen=True) class APISettings: repo_root: str python_executable: str input_root: str results_root: str + temp_root: str checkpoint_root: str allow_input_file: bool = False allow_unsafe_request_overrides: bool = False @@ -49,19 +95,41 @@ class APISettings: def from_env(cls) -> "APISettings": env_repo_root = os.environ.get("CATPRED_REPO_ROOT") repo_root = str(Path(env_repo_root).resolve()) if env_repo_root else str(Path.cwd().resolve()) + repo_root_path = Path(repo_root).resolve() + env_runtime_root = os.environ.get("CATPRED_API_RUNTIME_ROOT") + if env_runtime_root: + default_runtime_root = Path(env_runtime_root).resolve() + elif os.environ.get("VERCEL"): + default_runtime_root = Path("/tmp/catpred").resolve() + else: + default_runtime_root = repo_root_path input_root = os.environ.get("CATPRED_API_INPUT_ROOT") results_root = os.environ.get("CATPRED_API_RESULTS_ROOT") + temp_root = os.environ.get("CATPRED_API_TEMP_ROOT") checkpoint_root = os.environ.get("CATPRED_API_CHECKPOINT_ROOT") return cls( repo_root=repo_root, python_executable=os.environ.get("CATPRED_PYTHON_EXECUTABLE", sys.executable or "python"), - input_root=str(Path(input_root).resolve()) if input_root else str((Path(repo_root) / "inputs").resolve()), - results_root=str(Path(results_root).resolve()) if results_root else str((Path(repo_root) / "results").resolve()), + input_root=( + str(Path(input_root).resolve()) + if input_root + else str((default_runtime_root / "inputs").resolve()) + ), + results_root=( + str(Path(results_root).resolve()) + if results_root + else str((default_runtime_root / "results").resolve()) + ), + temp_root=( + str(Path(temp_root).resolve()) + if temp_root + else str((default_runtime_root / "tmp").resolve()) + ), checkpoint_root=( str(Path(checkpoint_root).resolve()) if checkpoint_root - else str((Path(repo_root) / "checkpoints").resolve()) + else str(_discover_default_checkpoint_root(repo_root_path)) ), allow_input_file=_env_flag("CATPRED_API_ALLOW_INPUT_FILE", default=False), allow_unsafe_request_overrides=_env_flag("CATPRED_API_ALLOW_UNSAFE_OVERRIDES", default=False), @@ -164,7 +232,6 @@ def _resolve_results_dir(raw_results_dir: str, settings: APISettings) -> str: def _resolve_checkpoint_dir(raw_checkpoint_dir: str, settings: APISettings) -> str: checkpoint_root = Path(settings.checkpoint_root).resolve() - checkpoint_root.mkdir(parents=True, exist_ok=True) resolved = _resolve_and_validate_path_under_root( raw_path=raw_checkpoint_dir, root=checkpoint_root, @@ -207,7 +274,6 @@ def _resolve_input_file_path(input_file: str, settings: APISettings) -> Path: def _write_rows_to_temp_csv( rows: list[dict[str, Any]], settings: APISettings, - repo_root: Path, ) -> tuple[str, str]: if len(rows) == 0: raise ValueError("`input_rows` cannot be empty.") @@ -216,7 +282,7 @@ def _write_rows_to_temp_csv( f"`input_rows` exceeds CATPRED_API_MAX_INPUT_ROWS ({settings.max_input_rows})." ) - tmp_dir = (repo_root.resolve() / ".e2e-tests" / "api_tmp").resolve() + tmp_dir = Path(settings.temp_root).resolve() tmp_dir.mkdir(parents=True, exist_ok=True) fd, tmp_path = tempfile.mkstemp(prefix="api_input_", suffix=".csv", dir=str(tmp_dir)) os.close(fd) @@ -227,13 +293,12 @@ def _write_rows_to_temp_csv( def _resolve_input_file( payload: PredictRequest, settings: APISettings, - repo_root: Path, ) -> tuple[str, Optional[str]]: if payload.input_file: resolved = _resolve_input_file_path(payload.input_file, settings) return str(resolved), None - return _write_rows_to_temp_csv(payload.input_rows or [], settings=settings, repo_root=repo_root) + return _write_rows_to_temp_csv(payload.input_rows or [], settings=settings) def _preview_output(output_file: str, preview_limit: int) -> tuple[int, list[dict[str, Any]]]: @@ -249,6 +314,17 @@ def create_app( app = FastAPI(title="CatPred API", version="0.2.0") backend_router = router or InferenceBackendRouter() default_fallback = _env_flag("CATPRED_MODAL_FALLBACK_TO_LOCAL", default=False) + static_root = (Path(__file__).resolve().parent / "static").resolve() + + if static_root.exists(): + app.mount("/static", StaticFiles(directory=str(static_root)), name="static") + + @app.get("/", include_in_schema=False) + def root() -> FileResponse: + index_path = static_root / "index.html" + if not index_path.exists(): + raise HTTPException(status_code=404, detail="Landing page not found.") + return FileResponse(index_path) @app.get("/health") def health() -> dict[str, str]: @@ -257,6 +333,8 @@ def health() -> dict[str, str]: @app.get("/ready") def ready() -> dict[str, Any]: readiness = backend_router.readiness() + checkpoint_root = Path(api_settings.checkpoint_root).resolve() + available_checkpoints = _discover_available_checkpoints(checkpoint_root) readiness["api"] = { "allow_input_file": api_settings.allow_input_file, "allow_unsafe_request_overrides": api_settings.allow_unsafe_request_overrides, @@ -264,7 +342,12 @@ def ready() -> dict[str, Any]: "max_input_file_bytes": api_settings.max_input_file_bytes, "input_root": str(Path(api_settings.input_root).resolve()), "results_root": str(Path(api_settings.results_root).resolve()), - "checkpoint_root": str(Path(api_settings.checkpoint_root).resolve()), + "temp_root": str(Path(api_settings.temp_root).resolve()), + "checkpoint_root": str(checkpoint_root), + "available_checkpoints": available_checkpoints, + "missing_checkpoints": [ + parameter for parameter in _SUPPORTED_PARAMETERS if parameter not in available_checkpoints + ], } return readiness @@ -274,7 +357,7 @@ def predict(payload: PredictRequest) -> PredictResponse: try: repo_root = _resolve_repo_root(payload.repo_root, api_settings) python_executable = _resolve_python_executable(payload.python_executable, api_settings) - input_file, temp_file = _resolve_input_file(payload, settings=api_settings, repo_root=repo_root) + input_file, temp_file = _resolve_input_file(payload, settings=api_settings) safe_results_dir = _resolve_results_dir(payload.results_dir, api_settings) fallback = payload.fallback_to_local if fallback is None: diff --git a/catpred/web/static/catpred.css b/catpred/web/static/catpred.css new file mode 100644 index 0000000..90bb76f --- /dev/null +++ b/catpred/web/static/catpred.css @@ -0,0 +1,1099 @@ +:root { + --bg: #f6f4f0; + --bg-warm: #efecea; + --surface: #ffffff; + --border: rgba(0, 0, 0, 0.06); + --border-soft: rgba(0, 0, 0, 0.04); + + --text: #1a1816; + --text-secondary: #6b6560; + --text-tertiary: #a09890; + --text-muted: #c4bdb5; + + --protein-rose: #c4897a; + --protein-sage: #7a9e8e; + --protein-slate: #7a8a9e; + --protein-lavender: #9a8aae; + + --focus: rgba(122, 158, 142, 0.2); + --danger: #9b4542; + --ok: #3d6c55; + + --radius-lg: 22px; + --radius-md: 14px; + --shadow-card: 0 12px 36px rgba(17, 15, 12, 0.07); + --shadow-float: 0 16px 32px rgba(17, 15, 12, 0.11); +} + +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +html { + scroll-behavior: smooth; + height: 100%; + scroll-snap-type: y mandatory; + overscroll-behavior-y: contain; +} + +body { + min-height: 100%; + background: var(--bg); + color: var(--text); + font-family: "Outfit", "Segoe UI", sans-serif; + font-weight: 300; + line-height: 1.6; + -webkit-font-smoothing: antialiased; + overflow-x: hidden; + position: relative; +} + +.texture { + position: fixed; + inset: 0; + pointer-events: none; + z-index: -2; + background: + radial-gradient(circle at 12% 14%, rgba(122, 158, 142, 0.12) 0%, transparent 32%), + radial-gradient(circle at 88% 22%, rgba(196, 137, 122, 0.11) 0%, transparent 33%), + radial-gradient(circle at 32% 80%, rgba(154, 138, 174, 0.08) 0%, transparent 36%); +} + +.texture::before { + content: ""; + position: absolute; + inset: 0; + background-image: + linear-gradient(to right, rgba(255, 255, 255, 0.16) 1px, transparent 1px), + linear-gradient(to bottom, rgba(255, 255, 255, 0.11) 1px, transparent 1px); + background-size: 28px 28px; + mix-blend-mode: soft-light; + opacity: 0.85; +} + +::selection { + background: var(--protein-sage); + color: #fff; +} + +.container { + width: min(1280px, calc(100% - 2.8rem)); + margin-inline: auto; +} + +.top-nav { + position: sticky; + top: 0; + z-index: 90; + backdrop-filter: blur(18px); + -webkit-backdrop-filter: blur(18px); + background: rgba(246, 244, 240, 0.84); + border-bottom: 1px solid var(--border-soft); +} + +.nav-content { + min-height: 78px; + display: flex; + align-items: center; + justify-content: space-between; + gap: 1rem; +} + +.brand { + display: inline-flex; + align-items: center; + gap: 0.72rem; + text-decoration: none; + color: var(--text); +} + +.brand-mark { + width: 30px; + height: 30px; + border-radius: 999px; + display: grid; + place-items: center; + color: var(--protein-sage); + background: rgba(122, 158, 142, 0.1); +} + +.brand-word { + font-family: "Newsreader", serif; + font-size: 1.32rem; + font-weight: 400; + letter-spacing: -0.02em; +} + +.desktop-nav { + display: flex; + align-items: center; + gap: 1.65rem; +} + +.desktop-nav a { + text-decoration: none; + color: var(--text-secondary); + font-size: 0.9rem; + font-weight: 300; + letter-spacing: 0.01em; + transition: color 0.25s ease; +} + +.desktop-nav a:hover { + color: var(--text); +} + +.hero { + min-height: calc(100svh - 78px); + display: grid; + grid-template-columns: 0.96fr 1.04fr; + gap: clamp(1.3rem, 4vw, 4.2rem); + align-items: center; + padding-block: clamp(2.2rem, 7.4vw, 5.8rem); +} + +.hero-copy { + position: relative; + z-index: 1; +} + +.eyebrow { + font-family: "IBM Plex Mono", monospace; + font-size: 0.68rem; + font-weight: 300; + letter-spacing: 0.13em; + text-transform: uppercase; + color: var(--text-tertiary); + margin-bottom: 1.2rem; +} + +.hero h1 { + font-family: "Newsreader", serif; + font-size: clamp(2.4rem, 5.4vw, 4.8rem); + font-weight: 300; + line-height: 1.05; + letter-spacing: -0.03em; + margin-bottom: 1.25rem; +} + +.hero h1 em { + font-style: italic; + color: var(--protein-sage); +} + +.lead { + max-width: 60ch; + font-size: clamp(0.98rem, 1.9vw, 1.07rem); + color: var(--text-secondary); + line-height: 1.75; + margin-bottom: 1.8rem; +} + +.lead code, +.muted code, +.helper code, +.faq-item code, +.steps code { + font-family: "IBM Plex Mono", monospace; + font-size: 0.82em; + padding: 0.08rem 0.28rem; + border-radius: 6px; + border: 1px solid var(--border); + background: rgba(255, 255, 255, 0.66); +} + +.hero-actions { + display: flex; + flex-wrap: wrap; + gap: 0.8rem; + margin-bottom: 1.75rem; +} + +.btn { + display: inline-flex; + align-items: center; + justify-content: center; + gap: 0.45rem; + border-radius: 999px; + border: 1px solid var(--border); + padding: 0.76rem 1.25rem; + text-decoration: none; + cursor: pointer; + font-family: "Outfit", sans-serif; + font-size: 0.87rem; + font-weight: 400; + letter-spacing: 0.01em; + transition: transform 0.26s ease, background-color 0.26s ease, color 0.26s ease, border-color 0.26s ease, + box-shadow 0.26s ease; +} + +.btn:hover { + transform: translateY(-1px); +} + +.btn.is-running { + position: relative; + cursor: wait; + transform: none; +} + +.btn.is-running::before { + content: ""; + width: 0.78rem; + height: 0.78rem; + border-radius: 999px; + border: 2px solid rgba(255, 255, 255, 0.45); + border-top-color: rgba(255, 255, 255, 0.96); + animation: spin 0.75s linear infinite; +} + +.btn.is-running:hover { + transform: none; +} + +.btn-dark { + background: var(--text); + border-color: var(--text); + color: var(--bg); + box-shadow: var(--shadow-float); +} + +.btn-dark:hover { + background: #2c2925; + border-color: #2c2925; +} + +.btn-ghost { + background: transparent; + color: var(--text-secondary); +} + +.btn-ghost:hover { + color: var(--text); + border-color: rgba(0, 0, 0, 0.12); +} + +.btn-outline { + background: transparent; + color: var(--text); +} + +.btn-outline:hover { + background: var(--text); + color: var(--bg); + border-color: var(--text); +} + +.nav-cta { + padding-inline: 1.1rem; +} + +.hero-stats { + list-style: none; + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 0.7rem; + max-width: 720px; +} + +.hero-stats li { + border: 1px solid var(--border); + background: rgba(255, 255, 255, 0.68); + border-radius: var(--radius-md); + padding: 0.76rem 0.78rem; +} + +.hero-stats strong { + display: block; + font-family: "IBM Plex Mono", monospace; + font-size: 0.94rem; + font-weight: 400; + color: var(--text); +} + +.hero-stats span { + display: block; + margin-top: 0.25rem; + font-size: 0.84rem; + color: var(--text-secondary); +} + +.hero-visual-wrap { + position: relative; + isolation: isolate; +} + +.hero-visual { + position: relative; + width: min(560px, 100%); + aspect-ratio: 1 / 1; + margin-inline: auto; +} + +.hero-visual::before { + content: ""; + position: absolute; + inset: 10%; + border-radius: 999px; + background: radial-gradient(circle, rgba(122, 158, 142, 0.16) 0%, rgba(196, 137, 122, 0.08) 42%, transparent 74%); + filter: blur(24px); + animation: breathe 9s ease-in-out infinite; +} + +@keyframes breathe { + 0%, + 100% { + transform: scale(1); + opacity: 0.85; + } + + 50% { + transform: scale(1.06); + opacity: 1; + } +} + +.protein-structure { + width: 100%; + height: 100%; + position: relative; + z-index: 1; + animation: proteinFloat 18s ease-in-out infinite; +} + +@keyframes proteinFloat { + 0%, + 100% { + transform: translateY(0) rotate(0deg); + } + + 30% { + transform: translateY(-8px) rotate(0.4deg); + } + + 60% { + transform: translateY(3px) rotate(-0.25deg); + } + + 85% { + transform: translateY(-4px) rotate(0.18deg); + } +} + +.section { + padding-block: clamp(2.8rem, 7vw, 5.4rem); +} + +.snap-section { + min-height: 100svh; + scroll-snap-align: start; + scroll-snap-stop: always; + scroll-margin-top: 84px; +} + +.section-alt { + background: var(--bg-warm); + border-top: 1px solid var(--border); + border-bottom: 1px solid var(--border); +} + +.section-head { + margin-bottom: 1.2rem; +} + +.section-head.centered { + text-align: center; + max-width: 900px; + margin-inline: auto; +} + +.section-head h2 { + font-family: "Newsreader", serif; + font-size: clamp(1.95rem, 3.5vw, 3.1rem); + font-weight: 300; + line-height: 1.08; + letter-spacing: -0.02em; + margin-bottom: 0.56rem; +} + +.muted { + color: var(--text-secondary); + font-size: 0.98rem; + line-height: 1.72; +} + +.service-strip { + margin-top: 0.9rem; + display: inline-flex; + align-items: center; + gap: 0.6rem; + border: 1px solid var(--border); + border-radius: 999px; + background: rgba(255, 255, 255, 0.62); + padding: 0.36rem 0.62rem; +} + +.service-label { + font-family: "IBM Plex Mono", monospace; + font-size: 0.66rem; + letter-spacing: 0.08em; + text-transform: uppercase; + color: var(--text-tertiary); +} + +.badge { + border-radius: 999px; + border: 1px solid var(--border); + padding: 0.16rem 0.52rem; + font-family: "IBM Plex Mono", monospace; + font-size: 0.68rem; + font-weight: 400; + color: var(--text-secondary); + background: #fff; +} + +.badge.ok { + color: var(--ok); + border-color: rgba(61, 108, 85, 0.38); + background: rgba(122, 158, 142, 0.1); +} + +.badge.error { + color: var(--danger); + border-color: rgba(155, 69, 66, 0.38); + background: rgba(196, 137, 122, 0.1); +} + +.service-hint { + color: var(--text-secondary); + font-size: 0.82rem; +} + +.predict-layout { + display: grid; + grid-template-columns: 0.96fr 1.04fr; + gap: 1rem; + align-items: start; +} + +.card { + border: 1px solid var(--border); + border-radius: var(--radius-lg); + background: rgba(255, 255, 255, 0.8); + box-shadow: var(--shadow-card); + padding: 1rem; +} + +.card h3 { + font-family: "Newsreader", serif; + font-size: 1.38rem; + font-weight: 400; + line-height: 1.2; +} + +.input-card { + position: sticky; + top: 98px; +} + +.preset-row { + margin-top: 0.72rem; + display: flex; + gap: 0.45rem; + flex-wrap: wrap; +} + +.preset-btn { + border: 1px solid var(--border); + border-radius: 999px; + padding: 0.42rem 0.78rem; + background: rgba(255, 255, 255, 0.7); + color: var(--text-secondary); + cursor: pointer; + font-family: "IBM Plex Mono", monospace; + font-size: 0.68rem; + letter-spacing: 0.05em; + text-transform: uppercase; + transition: all 0.24s ease; +} + +.preset-btn:hover, +.preset-btn.active { + color: var(--protein-sage); + border-color: rgba(122, 158, 142, 0.4); + background: rgba(122, 158, 142, 0.08); +} + +.preset-row + .row-container { + margin-top: 0.78rem; +} + +.field-label { + display: block; + margin-bottom: 0.3rem; + font-family: "IBM Plex Mono", monospace; + font-size: 0.66rem; + letter-spacing: 0.11em; + text-transform: uppercase; + color: var(--text-tertiary); +} + +input, +textarea, +select, +button { + font: inherit; +} + +input, +textarea, +select { + width: 100%; + border: 1px solid var(--border); + border-radius: 10px; + background: rgba(255, 255, 255, 0.9); + color: var(--text); + padding: 0.62rem 0.68rem; + font-size: 0.9rem; + font-weight: 300; +} + +textarea { + line-height: 1.46; + resize: vertical; +} + +input:focus, +textarea:focus, +select:focus, +button:focus-visible { + outline: 2px solid var(--focus); + outline-offset: 1px; +} + +.row-container { + margin-top: 0.78rem; + display: grid; + gap: 0.72rem; +} + +.row-item { + border: 1px solid var(--border); + border-radius: var(--radius-md); + background: rgba(255, 255, 255, 0.7); + padding: 0.72rem; +} + +.row-item-head { + display: flex; + align-items: center; + justify-content: space-between; + gap: 0.6rem; + margin-bottom: 0.58rem; +} + +.row-item-head[hidden] { + display: none !important; +} + +.row-item-head h4 { + font-family: "IBM Plex Mono", monospace; + font-size: 0.72rem; + color: var(--text-tertiary); + text-transform: none; + letter-spacing: 0.02em; + font-weight: 400; +} + +.row-grid { + display: grid; + grid-template-columns: 1fr; + gap: 0.55rem; +} + +.icon-btn { + border: 1px solid var(--border); + border-radius: 999px; + background: rgba(255, 255, 255, 0.82); + color: var(--danger); + padding: 0.34rem 0.6rem; + cursor: pointer; + font-family: "IBM Plex Mono", monospace; + font-size: 0.65rem; + transition: background-color 0.24s ease; +} + +.icon-btn:hover { + background: rgba(196, 137, 122, 0.09); +} + +.form-actions { + margin-top: 0.82rem; + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + +.helper { + margin-top: 0.76rem; + color: var(--text-secondary); + font-size: 0.86rem; + line-height: 1.6; +} + +.status-box { + margin-top: 0.72rem; + position: relative; + overflow: hidden; + border-radius: 12px; + border: 1px solid var(--border); + background: rgba(255, 255, 255, 0.75); + min-height: 44px; + padding: 0.58rem 0.72rem; + display: flex; + align-items: center; + gap: 0.52rem; + color: var(--text-secondary); + font-family: "IBM Plex Mono", monospace; + font-size: 0.68rem; + letter-spacing: 0.06em; + text-transform: uppercase; +} + +.status-box.running { + color: var(--protein-sage); + border-color: rgba(122, 158, 142, 0.42); + background: linear-gradient( + 120deg, + rgba(122, 158, 142, 0.14), + rgba(122, 158, 142, 0.06) 46%, + rgba(122, 138, 158, 0.09) + ); +} + +.status-box.running::before { + content: ""; + width: 0.75rem; + height: 0.75rem; + flex: 0 0 auto; + border-radius: 999px; + border: 2px solid rgba(122, 158, 142, 0.32); + border-top-color: rgba(122, 158, 142, 0.96); + animation: spin 0.74s linear infinite; +} + +.status-box.ok { + color: var(--ok); + border-color: rgba(61, 108, 85, 0.35); + background: rgba(122, 158, 142, 0.08); +} + +.status-box.error { + color: var(--danger); + border-color: rgba(155, 69, 66, 0.36); + background: rgba(196, 137, 122, 0.09); +} + +.result-cards { + margin-top: 0.72rem; + display: grid; + gap: 0.68rem; +} + +.output-card.is-busy { + position: relative; +} + +.output-card.is-busy::after { + content: ""; + position: absolute; + inset: 0; + border-radius: var(--radius-lg); + pointer-events: none; + background: linear-gradient( + 110deg, + rgba(255, 255, 255, 0) 0%, + rgba(122, 158, 142, 0.08) 35%, + rgba(255, 255, 255, 0) 65% + ); + transform: translateX(-100%); + animation: busy-sweep 1.7s ease-in-out infinite; +} + +.output-card.is-busy .result-cards, +.output-card.is-busy .details-output { + opacity: 0.76; +} + +.empty-result { + border: 1px dashed var(--border); + border-radius: var(--radius-md); + background: rgba(255, 255, 255, 0.65); + padding: 0.92rem; + color: var(--text-secondary); +} + +.result-card { + border: 1px solid var(--border); + border-radius: var(--radius-md); + background: rgba(255, 255, 255, 0.76); + padding: 0.84rem; +} + +.result-card-head { + display: flex; + justify-content: space-between; + align-items: center; + gap: 0.6rem; +} + +.result-card-head h4 { + font-family: "Newsreader", serif; + font-size: 1.08rem; + font-weight: 400; +} + +.result-chip { + border: 1px solid var(--border); + border-radius: 999px; + padding: 0.15rem 0.5rem; + font-family: "IBM Plex Mono", monospace; + font-size: 0.64rem; + color: var(--text-secondary); +} + +.result-main { + margin-top: 0.52rem; + display: flex; + align-items: baseline; + gap: 0.45rem; + flex-wrap: wrap; +} + +.result-main strong { + font-family: "Newsreader", serif; + font-size: 1.82rem; + line-height: 1; + font-weight: 400; +} + +.result-main span { + color: var(--text-secondary); + font-size: 0.86rem; +} + +.metric-grid { + margin-top: 0.62rem; + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 0.45rem; +} + +.metric-grid div { + border: 1px solid var(--border); + border-radius: 10px; + background: rgba(255, 255, 255, 0.74); + padding: 0.45rem; +} + +.metric-grid dt { + font-family: "IBM Plex Mono", monospace; + font-size: 0.62rem; + color: var(--text-tertiary); + letter-spacing: 0.06em; + text-transform: uppercase; +} + +.metric-grid dd { + margin-top: 0.16rem; + font-size: 0.9rem; + color: var(--text-secondary); +} + +.result-meta { + margin-top: 0.6rem; + display: flex; + flex-wrap: wrap; + gap: 0.42rem; +} + +.result-meta span { + font-family: "IBM Plex Mono", monospace; + font-size: 0.62rem; + color: var(--text-secondary); + border: 1px solid var(--border); + border-radius: 999px; + padding: 0.14rem 0.48rem; + background: rgba(255, 255, 255, 0.66); + max-width: 100%; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.details-output { + margin-top: 0.8rem; + border-top: 1px solid var(--border); + padding-top: 0.65rem; +} + +.details-output summary { + cursor: pointer; + font-family: "IBM Plex Mono", monospace; + font-size: 0.66rem; + letter-spacing: 0.08em; + text-transform: uppercase; + color: var(--text-tertiary); +} + +.rows-wrap { + margin-top: 0.72rem; + overflow: auto; +} + +.input-table { + width: 100%; + min-width: 680px; + border-collapse: collapse; +} + +.input-table th, +.input-table td { + text-align: left; + border-bottom: 1px solid var(--border); + padding: 0.42rem; + vertical-align: top; + font-size: 0.86rem; + color: var(--text-secondary); +} + +.input-table th { + font-family: "IBM Plex Mono", monospace; + font-size: 0.65rem; + font-weight: 300; + letter-spacing: 0.1em; + color: var(--text-tertiary); + text-transform: uppercase; +} + +.steps { + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 1rem; +} + +.guide-grid { + margin-top: 0.35rem; +} + +.steps article { + border-radius: 18px; + border: 1px solid var(--border); + background: rgba(255, 255, 255, 0.72); + padding: 1rem; +} + +.steps span { + font-family: "IBM Plex Mono", monospace; + font-size: 0.67rem; + color: var(--text-tertiary); + letter-spacing: 0.1em; +} + +.steps h3 { + margin-top: 0.38rem; + margin-bottom: 0.4rem; + font-family: "Newsreader", serif; + font-size: 1.3rem; + font-weight: 400; +} + +.steps p { + color: var(--text-secondary); + font-size: 0.93rem; + line-height: 1.68; +} + +.guide-note { + margin-top: 0.88rem; + border: 1px solid var(--border); + border-radius: 14px; + background: rgba(255, 255, 255, 0.72); + padding: 0.72rem 0.82rem; + color: var(--text-secondary); + font-size: 0.9rem; +} + +.guide-note strong { + color: var(--text); + font-family: "IBM Plex Mono", monospace; + font-size: 0.82em; + letter-spacing: 0.03em; +} + +.faq-list { + border-top: 1px solid var(--border); +} + +.faq-item { + border-bottom: 1px solid var(--border); + padding-block: 0.86rem; +} + +.faq-item button { + width: 100%; + border: 0; + background: transparent; + text-align: left; + cursor: pointer; + display: flex; + justify-content: space-between; + align-items: center; + gap: 1rem; + color: var(--text); + font-family: "Newsreader", serif; + font-size: 1.24rem; + font-weight: 400; +} + +.faq-item button::after { + content: "+"; + color: var(--text-tertiary); + font-family: "IBM Plex Mono", monospace; + font-size: 0.95rem; +} + +.faq-item button[aria-expanded="true"]::after { + content: "โˆ’"; +} + +.faq-item p { + max-height: 0; + overflow: hidden; + margin: 0; + color: var(--text-secondary); + font-size: 0.95rem; + line-height: 1.67; + transition: max-height 0.24s ease, margin-top 0.24s ease; +} + +.faq-item.open p { + max-height: 12rem; + margin-top: 0.4rem; +} + +.footer { + padding: 1.2rem 0 2.4rem; +} + +.footer-inner { + border-top: 1px solid var(--border); + padding-top: 0.95rem; + display: flex; + justify-content: space-between; + align-items: center; + gap: 1rem; + color: var(--text-tertiary); + font-size: 0.84rem; +} + +.mobile-stick { + position: fixed; + left: 1rem; + right: 1rem; + bottom: 1rem; + z-index: 70; + display: none; + text-align: center; + text-decoration: none; + border-radius: 999px; + padding: 0.82rem 1rem; + background: var(--text); + color: var(--bg); + font-family: "Outfit", sans-serif; + font-size: 0.88rem; + font-weight: 400; + box-shadow: var(--shadow-float); +} + +.reveal { + opacity: 0; + transform: translateY(20px); + transition: opacity 0.48s ease, transform 0.48s ease; +} + +.reveal.show { + opacity: 1; + transform: translateY(0); +} + +@keyframes spin { + to { + transform: rotate(360deg); + } +} + +@keyframes busy-sweep { + 100% { + transform: translateX(100%); + } +} + +@media (max-width: 1120px) { + html { + scroll-snap-type: y proximity; + } + + .hero { + grid-template-columns: 1fr; + min-height: auto; + } + + .hero-visual { + width: min(500px, 100%); + } + + .predict-layout, + .steps, + .hero-stats, + .metric-grid { + grid-template-columns: 1fr; + } + + .input-card { + position: static; + } +} + +@media (max-width: 780px) { + .desktop-nav { + display: none; + } + + .container { + width: min(1280px, calc(100% - 1.5rem)); + } + + .hero { + padding-top: 1.5rem; + } + + .snap-section { + min-height: auto; + scroll-snap-stop: normal; + } + + .input-table { + min-width: 560px; + } + + .mobile-stick { + display: block; + } + + .footer { + padding-bottom: 5rem; + } +} diff --git a/catpred/web/static/catpred.js b/catpred/web/static/catpred.js new file mode 100644 index 0000000..707d620 --- /dev/null +++ b/catpred/web/static/catpred.js @@ -0,0 +1,730 @@ +(function () { + const form = document.getElementById("predictForm"); + const rowContainer = document.getElementById("rowContainer"); + const outputCard = document.querySelector(".output-card"); + + const addRowBtn = document.getElementById("addRowBtn"); + const loadSampleBtn = document.getElementById("loadSampleBtn"); + const runBtn = document.getElementById("runBtn"); + + const statusBox = document.getElementById("statusBox"); + const resultCards = document.getElementById("resultCards"); + const previewTable = document.getElementById("previewTable"); + + const serviceBadge = document.getElementById("serviceBadge"); + const serviceHint = document.getElementById("serviceHint"); + + const presetButtons = document.querySelectorAll(".preset-btn"); + + const sampleRows = [ + { + SMILES: "CCO", + sequence: "ACDEFGHIK", + pdbpath: "seq_001", + }, + { + SMILES: "CCN", + sequence: "LMNPQRSTV", + pdbpath: "seq_002", + }, + ]; + + const supportedParameters = ["kcat", "km", "ki"]; + let availableCheckpointParams = new Set(supportedParameters); + let selectedParameter = "kcat"; + const runningPhaseMessages = [ + "validating input", + "building protein records", + "running model ensemble", + "aggregating uncertainty", + ]; + const defaultRunButtonLabel = runBtn ? runBtn.textContent : "Run prediction"; + let runningStatusInterval = null; + let runStartedAtMs = null; + + function jsonPretty(data) { + return JSON.stringify(data, null, 2); + } + + function escapeHtml(text) { + return String(text) + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/\"/g, """) + .replace(/'/g, "'"); + } + + function formatNumber(value) { + const n = Number(value); + if (!Number.isFinite(n)) { + return "โ€”"; + } + return n.toFixed(1); + } + + function truncateText(value, maxLength) { + const text = String(value || ""); + if (text.length <= maxLength) { + return text; + } + return text.slice(0, maxLength - 1) + "โ€ฆ"; + } + + function setStatus(text, kind) { + if (!statusBox) return; + statusBox.textContent = text; + statusBox.classList.remove("ok", "error", "running"); + if (kind) { + statusBox.classList.add(kind); + } + } + + function formatElapsed(seconds) { + const safeSeconds = Math.max(0, Math.floor(seconds)); + const minutes = Math.floor(safeSeconds / 60); + const remainder = safeSeconds % 60; + return String(minutes).padStart(2, "0") + ":" + String(remainder).padStart(2, "0"); + } + + function getElapsedSeconds() { + if (!runStartedAtMs) return 0; + return Math.floor((Date.now() - runStartedAtMs) / 1000); + } + + function setRunButtonState(isRunning, elapsedLabel) { + if (!runBtn) return; + if (isRunning) { + const suffix = elapsedLabel ? " " + elapsedLabel : ""; + runBtn.textContent = "Running" + suffix; + runBtn.classList.add("is-running"); + runBtn.disabled = true; + runBtn.setAttribute("aria-busy", "true"); + return; + } + + runBtn.textContent = defaultRunButtonLabel || "Run prediction"; + runBtn.classList.remove("is-running"); + runBtn.disabled = false; + runBtn.removeAttribute("aria-busy"); + } + + function setBusyUi(isBusy) { + if (form) { + form.classList.toggle("is-busy", isBusy); + form.setAttribute("aria-busy", isBusy ? "true" : "false"); + } + if (outputCard) { + outputCard.classList.toggle("is-busy", isBusy); + outputCard.setAttribute("aria-busy", isBusy ? "true" : "false"); + } + + if (addRowBtn) { + addRowBtn.disabled = isBusy; + } + if (loadSampleBtn) { + loadSampleBtn.disabled = isBusy; + } + if (presetButtons && presetButtons.length) { + presetButtons.forEach((btn) => { + if (!isBusy && !isParameterAvailable(btn.dataset.param)) { + btn.disabled = true; + return; + } + btn.disabled = isBusy; + }); + } + } + + function startRunningFeedback(payload) { + stopRunningFeedback(); + runStartedAtMs = Date.now(); + setBusyUi(true); + + const parameterLabel = String(payload.parameter || "kcat").toUpperCase(); + const rowCount = Array.isArray(payload.input_rows) ? payload.input_rows.length : 0; + const rowLabel = rowCount === 1 ? "1 row" : String(rowCount) + " rows"; + + const updateRunningStatus = () => { + const elapsed = getElapsedSeconds(); + const elapsedLabel = formatElapsed(elapsed); + const phaseIndex = Math.floor(elapsed / 3) % runningPhaseMessages.length; + const phaseText = runningPhaseMessages[phaseIndex]; + setRunButtonState(true, elapsedLabel); + setStatus( + "Running " + parameterLabel + " on " + rowLabel + " โ€ข " + elapsedLabel + " โ€ข " + phaseText, + "running" + ); + }; + + updateRunningStatus(); + runningStatusInterval = window.setInterval(updateRunningStatus, 1000); + } + + function stopRunningFeedback() { + if (runningStatusInterval) { + window.clearInterval(runningStatusInterval); + runningStatusInterval = null; + } + setBusyUi(false); + setRunButtonState(false); + } + + function setServiceState(text, hint, kind) { + if (serviceBadge) { + serviceBadge.textContent = text; + serviceBadge.classList.remove("ok", "error"); + if (kind) { + serviceBadge.classList.add(kind); + } + } + if (serviceHint) { + serviceHint.textContent = hint || ""; + } + } + + function setActivePreset(paramValue) { + if (!presetButtons || !presetButtons.length) return; + presetButtons.forEach((btn) => { + if (!(btn instanceof HTMLElement)) return; + btn.classList.toggle("active", btn.dataset.param === paramValue); + }); + } + + function getSelectedParameter() { + return String(selectedParameter || "kcat").toLowerCase(); + } + + function firstAvailableParameter() { + for (const param of supportedParameters) { + if (availableCheckpointParams.has(param)) { + return param; + } + } + return null; + } + + function setParameterAvailability(availableCheckpoints) { + if (availableCheckpoints && typeof availableCheckpoints === "object") { + availableCheckpointParams = new Set( + Object.keys(availableCheckpoints) + .map((key) => String(key).toLowerCase()) + .filter((key) => supportedParameters.includes(key)) + ); + } else { + availableCheckpointParams = new Set(supportedParameters); + } + + if (presetButtons && presetButtons.length) { + presetButtons.forEach((btn) => { + const paramValue = String(btn.dataset.param || "").toLowerCase(); + const enabled = availableCheckpointParams.has(paramValue); + btn.disabled = !enabled; + btn.setAttribute("aria-disabled", String(!enabled)); + }); + } + + const fallbackParam = firstAvailableParameter(); + if (fallbackParam && !availableCheckpointParams.has(getSelectedParameter())) { + selectedParameter = fallbackParam; + setActivePreset(fallbackParam); + } + } + + function isParameterAvailable(paramValue) { + return availableCheckpointParams.has(String(paramValue || "").toLowerCase()); + } + + function formatSequenceId(index) { + const safeIndex = Math.max(1, Number(index) || 1); + return "seq_" + String(safeIndex).padStart(3, "0"); + } + + function getNextSequenceId() { + if (!rowContainer) { + return formatSequenceId(1); + } + + let maxSeen = 0; + const idInputs = rowContainer.querySelectorAll('input[name="pdbpath"]'); + idInputs.forEach((input) => { + const raw = String(input.value || "").trim(); + const match = raw.match(/^seq_(\d+)$/i); + if (!match) return; + const parsed = Number(match[1]); + if (Number.isFinite(parsed)) { + maxSeen = Math.max(maxSeen, parsed); + } + }); + + if (maxSeen > 0) { + return formatSequenceId(maxSeen + 1); + } + + const currentRowCount = rowContainer.querySelectorAll(".row-item").length; + return formatSequenceId(currentRowCount + 1); + } + + function rowTemplate(index, values) { + const smiles = values && values.SMILES ? values.SMILES : ""; + const seq = values && values.sequence ? values.sequence : ""; + const pdb = values && values.pdbpath ? values.pdbpath : ""; + + return ( + '
' + + '
' + + "

Entry " + + (index + 1) + + "

" + + '' + + "
" + + '
' + + '' + + '' + + '" + + "
" + + "
" + ); + } + + function renumberRows() { + if (!rowContainer) return; + const items = rowContainer.querySelectorAll(".row-item"); + const shouldShowHeader = items.length > 1; + items.forEach((item, idx) => { + const head = item.querySelector(".row-item-head"); + if (head instanceof HTMLElement) { + head.hidden = !shouldShowHeader; + } + const heading = item.querySelector("h4"); + if (heading) { + heading.textContent = "Entry " + String(idx + 1); + } + }); + } + + function addRow(values) { + if (!rowContainer) return; + const nextValues = values ? { ...values } : { SMILES: "", sequence: "", pdbpath: "" }; + if (!String(nextValues.pdbpath || "").trim()) { + nextValues.pdbpath = getNextSequenceId(); + } + const index = rowContainer.querySelectorAll(".row-item").length; + rowContainer.insertAdjacentHTML("beforeend", rowTemplate(index, nextValues)); + const last = rowContainer.lastElementChild; + if (last) { + last.animate( + [ + { opacity: 0, transform: "translateY(8px)" }, + { opacity: 1, transform: "translateY(0)" }, + ], + { duration: 220, easing: "ease-out" } + ); + } + renumberRows(); + } + + function clearRows() { + if (!rowContainer) return; + rowContainer.innerHTML = ""; + } + + function loadSampleRows() { + clearRows(); + sampleRows.forEach((row) => addRow(row)); + } + + function collectRows() { + if (!rowContainer) return []; + + const rows = []; + const items = rowContainer.querySelectorAll(".row-item"); + + items.forEach((item) => { + const smilesInput = item.querySelector('input[name="SMILES"]'); + const sequenceInput = item.querySelector('textarea[name="sequence"]'); + const pdbpathInput = item.querySelector('input[name="pdbpath"]'); + + if (!smilesInput || !sequenceInput || !pdbpathInput) { + return; + } + + const smiles = smilesInput.value.trim(); + const sequence = sequenceInput.value.trim().toUpperCase(); + const pdbpath = pdbpathInput.value.trim(); + + if (!smiles || !sequence || !pdbpath) { + return; + } + + rows.push({ + SMILES: smiles, + sequence: sequence, + pdbpath: pdbpath, + }); + }); + + return rows; + } + + function validateRows(rows) { + if (!rows.length) { + return "Please add at least one complete input row."; + } + + const mapping = new Map(); + for (let i = 0; i < rows.length; i++) { + const row = rows[i]; + const key = row.pdbpath; + if (mapping.has(key) && mapping.get(key) !== row.sequence) { + return "Each Sequence ID must map to one unique enzyme sequence."; + } + mapping.set(key, row.sequence); + } + + return ""; + } + + function buildPayload(rows) { + const target = getSelectedParameter(); + + return { + parameter: target, + checkpoint_dir: target, + input_rows: rows, + use_gpu: false, + results_dir: "web-app", + fallback_to_local: true, + }; + } + + function parsePrediction(row) { + const keys = Object.keys(row || {}); + const linearKey = keys.find((key) => key.startsWith("Prediction_(")); + const unitMatch = linearKey ? linearKey.match(/^Prediction_\((.*)\)$/) : null; + const unit = unitMatch ? unitMatch[1] : ""; + + return { + linear: linearKey ? row[linearKey] : null, + linearKey: linearKey || "Prediction", + unit: unit, + log10: row.Prediction_log10, + sdTotal: row.SD_total, + sdAleatoric: row.SD_aleatoric, + sdEpistemic: row.SD_epistemic, + }; + } + + function renderResultCards(previewRows, selectedParam) { + if (!resultCards) return; + + if (!previewRows || !previewRows.length) { + resultCards.innerHTML = + '

No preview rows were returned for this run.

'; + return; + } + + const cardsHtml = previewRows + .map((row, idx) => { + const p = parsePrediction(row); + + return ( + '
' + + '
' + + "

Result " + + String(idx + 1) + + "

" + + '' + + escapeHtml(selectedParam.toUpperCase()) + + "" + + "
" + + '
' + + "" + + escapeHtml(formatNumber(p.linear)) + + "" + + "" + + escapeHtml(p.unit || "predicted unit") + + "" + + "
" + + '
' + + "
log10
" + + escapeHtml(formatNumber(p.log10)) + + "
" + + "
Total SD
" + + escapeHtml(formatNumber(p.sdTotal)) + + "
" + + "
Epistemic SD
" + + escapeHtml(formatNumber(p.sdEpistemic)) + + "
" + + "
" + + '
' + + "SMILES: " + + escapeHtml(truncateText(row.SMILES, 24)) + + "" + + "Sequence ID: " + + escapeHtml(row.pdbpath || "โ€”") + + "" + + "
" + + "
" + ); + }) + .join(""); + + resultCards.innerHTML = cardsHtml; + } + + function renderPreviewTable(previewRows) { + if (!previewTable) return; + + if (!previewRows || !previewRows.length) { + previewTable.innerHTML = "No rows to show."; + return; + } + + const keys = Object.keys(previewRows[0]); + const head = + "" + + keys.map((k) => "" + escapeHtml(k) + "").join("") + + ""; + + const bodyRows = previewRows + .map((row) => { + const cells = keys + .map((k) => { + const value = row[k]; + if (value === null || value === undefined) { + return ""; + } + const displayValue = typeof value === "number" ? formatNumber(value) : value; + return "" + escapeHtml(displayValue) + ""; + }) + .join(""); + return "" + cells + ""; + }) + .join(""); + + previewTable.innerHTML = head + "" + bodyRows + ""; + } + + async function fetchReady() { + setServiceState("Checking...", "", ""); + + try { + const response = await fetch("/ready", { + method: "GET", + headers: { Accept: "application/json" }, + }); + + let data; + try { + data = await response.json(); + } catch (_err) { + data = {}; + } + + if (!response.ok) { + setServiceState("Offline", "Service not reachable", "error"); + return; + } + + setParameterAvailability(data && data.api ? data.api.available_checkpoints : null); + const availableParams = Array.from(availableCheckpointParams.values()); + + if (data && data.ready) { + const backend = data.default_backend ? String(data.default_backend) : "default"; + const hint = availableParams.length + ? "Backend: " + backend + " | Checkpoints: " + availableParams.join(", ") + : "Backend: " + backend + " | No local checkpoints found"; + setServiceState("Online", hint, "ok"); + } else { + const hint = availableParams.length + ? "Backend configuration needed" + : "Backend configuration needed | No local checkpoints found"; + setServiceState("Limited", hint, "error"); + } + } catch (_err) { + setServiceState("Offline", "Could not contact service", "error"); + } + } + + async function submitPrediction(event) { + event.preventDefault(); + + const rows = collectRows(); + const rowError = validateRows(rows); + if (rowError) { + setStatus(rowError, "error"); + return; + } + + const selectedParameter = getSelectedParameter(); + if (!isParameterAvailable(selectedParameter)) { + setStatus("No local checkpoint available for " + selectedParameter.toUpperCase() + ".", "error"); + return; + } + + const payload = buildPayload(rows); + startRunningFeedback(payload); + await new Promise((resolve) => window.requestAnimationFrame(resolve)); + + try { + const response = await fetch("/predict", { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, + body: JSON.stringify(payload), + }); + + let data; + try { + data = await response.json(); + } catch (_err) { + data = { detail: "Unexpected response format." }; + } + + if (!response.ok) { + const message = data && data.detail ? String(data.detail) : "Prediction could not be completed."; + renderResultCards([], payload.parameter); + renderPreviewTable([]); + const elapsedLabel = formatElapsed(getElapsedSeconds()); + setStatus(message + " (" + elapsedLabel + ")", "error"); + return; + } + + renderResultCards(data.preview_rows || [], payload.parameter); + renderPreviewTable(data.preview_rows || []); + const elapsedLabel = formatElapsed(getElapsedSeconds()); + setStatus("Prediction complete โ€ข " + elapsedLabel, "ok"); + } catch (_err) { + renderResultCards([], payload.parameter); + renderPreviewTable([]); + const elapsedLabel = formatElapsed(getElapsedSeconds()); + setStatus("Network error while running prediction. (" + elapsedLabel + ")", "error"); + } finally { + stopRunningFeedback(); + runStartedAtMs = null; + } + } + + function setupFaq() { + const faqRoot = document.getElementById("faqList"); + if (!faqRoot) return; + + const items = faqRoot.querySelectorAll(".faq-item"); + items.forEach((item) => { + const button = item.querySelector("button"); + if (!button) return; + + button.addEventListener("click", function () { + const isOpen = item.classList.contains("open"); + items.forEach((row) => { + row.classList.remove("open"); + const rowButton = row.querySelector("button"); + if (rowButton) { + rowButton.setAttribute("aria-expanded", "false"); + } + }); + if (!isOpen) { + item.classList.add("open"); + button.setAttribute("aria-expanded", "true"); + } + }); + }); + } + + function setupReveal() { + const revealItems = document.querySelectorAll(".reveal"); + if (!revealItems.length) return; + + const observer = new IntersectionObserver( + (entries) => { + entries.forEach((entry) => { + if (entry.isIntersecting) { + entry.target.classList.add("show"); + observer.unobserve(entry.target); + } + }); + }, + { + threshold: 0.12, + rootMargin: "0px 0px -36px 0px", + } + ); + + revealItems.forEach((item) => observer.observe(item)); + } + + function setupEvents() { + if (presetButtons && presetButtons.length) { + presetButtons.forEach((btn) => { + btn.addEventListener("click", function () { + if (btn.disabled) return; + const chosenParam = btn.dataset.param || "kcat"; + selectedParameter = String(chosenParam).toLowerCase(); + setActivePreset(chosenParam); + }); + }); + } + + if (addRowBtn) { + addRowBtn.addEventListener("click", function () { + addRow({ SMILES: "", sequence: "", pdbpath: "" }); + }); + } + + if (loadSampleBtn) { + loadSampleBtn.addEventListener("click", function () { + loadSampleRows(); + setStatus("Sample inputs loaded", "ok"); + }); + } + + if (rowContainer) { + rowContainer.addEventListener("click", function (event) { + const target = event.target; + if (!(target instanceof HTMLElement)) return; + if (!target.matches("[data-remove-row]")) return; + + const row = target.closest(".row-item"); + if (!row) return; + + if (rowContainer.querySelectorAll(".row-item").length === 1) { + setStatus("At least one row is required", "error"); + return; + } + + row.remove(); + renumberRows(); + }); + } + + if (form) { + form.addEventListener("submit", submitPrediction); + } + } + + function bootstrap() { + clearRows(); + addRow(sampleRows[0]); + setActivePreset("kcat"); + + renderResultCards([], "kcat"); + renderPreviewTable([]); + + setStatus("Ready when you are.", "ok"); + setupEvents(); + setupFaq(); + setupReveal(); + fetchReady(); + } + + bootstrap(); +})(); diff --git a/catpred/web/static/index.html b/catpred/web/static/index.html new file mode 100644 index 0000000..2ffa804 --- /dev/null +++ b/catpred/web/static/index.html @@ -0,0 +1,214 @@ + + + + + + CatPred | Enzyme Kinetics Prediction + + + + + + + + + +
+ +
+ +
+
+
+

AI for Enzyme Kinetics

+

+ Get kinetic estimates in minutes, + not days. +

+

+ CatPred predicts kcat, Km, and Ki from substrate + SMILES and enzyme sequence inputs, including uncertainty for + confidence-aware decisions. +

+ + + +
    +
  • + 3 parameters + kcat / Km / Ki +
  • +
  • + Batch ready + Run one or multiple rows +
  • +
+
+ + +
+ +
+
+

Prediction Studio

+

Enter inputs and run

+

+ Choose a target, submit one or many rows, and review predictions with uncertainty. +

+ +
+ Service status + Checking... + +
+
+ +
+
+

Inputs

+ +
+ + + +
+ +
+ +
+ + + +
+ +

+ Each row needs: SMILES, sequence, and a Sequence ID. + Sequence IDs auto-fill as seq_001, seq_002, etc. (editable). + Use one Sequence ID per unique sequence. No structure file upload is required. + For CSV/API usage, this field maps to pdbpath. +

+
+ +
+

Results

+
Ready when you are.
+ +
+
+

Run a prediction to see kinetic estimates here.

+
+
+ +
+ View detailed output table +
+
+
+
+
+
+
+ +
+
+
+

Quick Guide

+

Three rules for reliable runs

+
+ +
+
+ 01 +

Use clean inputs

+

+ Each row needs valid SMILES, uppercase amino acid sequence, + and a Sequence ID (auto-filled, internal field: pdbpath). +

+
+
+ 02 +

Keep sequence IDs consistent

+

+ Reuse the same Sequence ID only for the same sequence. Different sequences should + use different IDs. +

+
+
+ 03 +

Read uncertainty first

+

+ Prioritize results with lower uncertainty before selecting candidates for experiments. +

+
+
+

+ Tip: use Load sample in the predictor to test your environment end-to-end before + running larger batches. +

+
+
+
+ + Run prediction + + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..621e861 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +fastapi>=0.95,<0.100 +pydantic>=1.10,<2.0 +pandas>=1.5,<2.3 diff --git a/setup.cfg b/setup.cfg index 3db702f..b61db52 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,4 +60,8 @@ web = uvicorn>=0.22,<1.0 [options.package_data] -catpred = py.typed +catpred = + py.typed + web/static/*.html + web/static/*.css + web/static/*.js diff --git a/vercel.json b/vercel.json new file mode 100644 index 0000000..4582896 --- /dev/null +++ b/vercel.json @@ -0,0 +1,21 @@ +{ + "$schema": "https://openapi.vercel.sh/vercel.json", + "version": 2, + "builds": [ + { + "src": "api/index.py", + "use": "@vercel/python" + } + ], + "functions": { + "api/index.py": { + "maxDuration": 300 + } + }, + "routes": [ + { + "src": "/(.*)", + "dest": "api/index.py" + } + ] +} From cd655e0810b629a1bce4ede1fc9f7d644511ba30 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sat, 28 Feb 2026 23:56:37 -0500 Subject: [PATCH 11/34] docs+deploy: add modal endpoint scaffold for vercel backend --- .env.vercel | 4 ++ README.md | 41 +++++++++++++++ modal_app.py | 141 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 186 insertions(+) create mode 100644 .env.vercel create mode 100644 modal_app.py diff --git a/.env.vercel b/.env.vercel new file mode 100644 index 0000000..7db578e --- /dev/null +++ b/.env.vercel @@ -0,0 +1,4 @@ +CATPRED_DEFAULT_BACKEND=modal +CATPRED_MODAL_ENDPOINT=https:// +CATPRED_MODAL_TOKEN= +CATPRED_MODAL_FALLBACK_TO_LOCAL=0 diff --git a/README.md b/README.md index 00c08c2..1275805 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,47 @@ Notes: - Local checkpoint-based inference is not recommended on Vercel serverless due runtime/dependency limits. - If `CATPRED_MODAL_ENDPOINT` is not configured, the UI still loads but prediction requests will be limited by backend readiness. +#### Deploy a Modal endpoint for Vercel + +This repo includes `modal_app.py`, a Modal `POST` endpoint compatible with CatPred's `/predict` modal backend contract. + +1. Install and authenticate Modal CLI: + +```bash +pip install modal +modal setup +``` + +2. Create/upload checkpoints into a Modal Volume (one-time): + +```bash +modal volume create catpred-checkpoints +modal volume put catpred-checkpoints ./checkpoints/kcat kcat +modal volume put catpred-checkpoints ./checkpoints/km km +modal volume put catpred-checkpoints ./checkpoints/ki ki +``` + +3. (Recommended) create a secret token for endpoint auth: + +```bash +modal secret create catpred-modal-auth CATPRED_MODAL_AUTH_TOKEN="" +``` + +4. Deploy: + +```bash +modal deploy modal_app.py +``` + +After deploy, copy the printed endpoint URL (for function `predict`) and set Vercel variables: + +```bash +CATPRED_DEFAULT_BACKEND=modal +CATPRED_MODAL_ENDPOINT=https:// +CATPRED_MODAL_TOKEN= +CATPRED_MODAL_FALLBACK_TO_LOCAL=0 +``` + ### ๐Ÿงช Fine-Tuning On Custom Data You can fine-tune CatPred on your own regression targets using `train.py`. diff --git a/modal_app.py b/modal_app.py new file mode 100644 index 0000000..2769457 --- /dev/null +++ b/modal_app.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from pathlib import Path +import os +import tempfile +from typing import Any + +from fastapi import Header, HTTPException +import modal +from pydantic import BaseModel, Field + + +image = ( + modal.Image.debian_slim(python_version="3.10") + .pip_install( + "fastapi[standard]>=0.115,<1.0", + "pydantic>=1.10,<2.0", + "pandas>=1.5,<2.3", + "numpy>=1.26,<2.3", + "scikit-learn>=1.3,<1.7", + "scipy>=1.10,<1.16", + "torch>=2.1,<2.7", + "tqdm>=4.66", + "typed-argument-parser>=1.10", + "rdkit-pypi>=2022.9.5", + "descriptastorus>=2.6", + "transformers>=4.47,<5", + "sentencepiece>=0.2.0", + "fair-esm==2.0.0", + "progres==0.2.7", + "rotary-embedding-torch==0.6.5", + "ipdb==0.13.13", + "pandas-flavor>=0.6.0", + ) + .add_local_python_source("catpred") + .add_local_dir("scripts", remote_path="/root/scripts") + .add_local_file("predict.py", remote_path="/root/predict.py") +) + +app = modal.App("catpred-modal-api", image=image) +checkpoints_volume = modal.Volume.from_name("catpred-checkpoints", create_if_missing=True) + + +class PredictPayload(BaseModel): + parameter: str = Field(..., description="One of: kcat, km, ki") + checkpoint_dir: str = Field(..., description="Checkpoint subdirectory inside /checkpoints") + use_gpu: bool = Field(default=False) + input_rows: list[dict[str, Any]] = Field(default_factory=list) + input_filename: str | None = Field(default=None) + + +def _safe_checkpoint_path(raw_checkpoint_dir: str) -> Path: + checkpoint_root = Path("/checkpoints").resolve() + checkpoint_dir = (checkpoint_root / raw_checkpoint_dir).resolve() + try: + checkpoint_dir.relative_to(checkpoint_root) + except ValueError as exc: + raise ValueError("checkpoint_dir must stay inside /checkpoints.") from exc + if not checkpoint_dir.is_dir(): + raise ValueError(f'Checkpoint directory not found: "{checkpoint_dir}"') + return checkpoint_dir + + +@app.function( + timeout=60 * 15, + cpu=4.0, + memory=16384, + volumes={"/checkpoints": checkpoints_volume}, +) +@modal.fastapi_endpoint(method="POST", docs=True) +def predict( + payload: PredictPayload, + authorization: str | None = Header(default=None), +) -> dict[str, Any]: + import pandas as pd + + from catpred.inference.service import run_prediction_pipeline + from catpred.inference.types import PredictionRequest + + expected_token = os.environ.get("CATPRED_MODAL_AUTH_TOKEN") + if expected_token: + provided = "" + if authorization: + lower = authorization.lower() + if lower.startswith("bearer "): + provided = authorization[7:].strip() + else: + provided = authorization.strip() + if provided != expected_token: + raise HTTPException(status_code=401, detail="Unauthorized") + + if not payload.input_rows: + raise HTTPException(status_code=400, detail="input_rows cannot be empty.") + + parameter = payload.parameter.lower() + if parameter not in {"kcat", "km", "ki"}: + raise HTTPException(status_code=400, detail="parameter must be one of: kcat, km, ki.") + + try: + checkpoint_dir = _safe_checkpoint_path(payload.checkpoint_dir) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + safe_name = Path(payload.input_filename or "api_input.csv").name + if not safe_name.endswith(".csv"): + safe_name = f"{safe_name}.csv" + + runtime_dir = Path("/tmp/catpred-modal").resolve() + runtime_dir.mkdir(parents=True, exist_ok=True) + fd, tmp_input_path = tempfile.mkstemp(prefix="modal_input_", suffix=".csv", dir=str(runtime_dir)) + os.close(fd) + + try: + pd.DataFrame(payload.input_rows).to_csv(tmp_input_path, index=False) + + request_obj = PredictionRequest( + parameter=parameter, + input_file=tmp_input_path, + checkpoint_dir=str(checkpoint_dir), + use_gpu=payload.use_gpu, + repo_root="/root", + python_executable="python", + ) + results_dir = str((runtime_dir / "results").resolve()) + output_file = run_prediction_pipeline(request_obj, results_dir=results_dir) + output_df = pd.read_csv(output_file) + + return { + "output_rows": output_df.to_dict(orient="records"), + "output_filename": Path(output_file).name, + "row_count": int(len(output_df)), + "backend": "modal", + } + except HTTPException: + raise + except Exception as exc: + raise HTTPException(status_code=500, detail=f"Modal prediction failed: {exc}") from exc + finally: + input_path = Path(tmp_input_path) + if input_path.exists(): + input_path.unlink() From db8a2bb643018f9556ab379419196ce9b8fadb77 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sun, 1 Mar 2026 00:12:40 -0500 Subject: [PATCH 12/34] fix(modal): make prediction endpoint runtime-compatible and lazy-load train package --- catpred/train/__init__.py | 99 +++++++++++++++++++++------------------ modal_app.py | 11 +++-- 2 files changed, 61 insertions(+), 49 deletions(-) diff --git a/catpred/train/__init__.py b/catpred/train/__init__.py index d8c18d9..4f66fe7 100644 --- a/catpred/train/__init__.py +++ b/catpred/train/__init__.py @@ -1,47 +1,54 @@ -from .metrics import get_metric_func, prc_auc, bce, rmse, bounded_mse, bounded_mae, \ - bounded_rmse, accuracy, f1_metric, mcc_metric, sid_metric, wasserstein_metric -from .loss_functions import get_loss_func, bounded_mse_loss, \ - mcc_class_loss, mcc_multiclass_loss, sid_loss, wasserstein_loss -from .cross_validate import catpred_train, cross_validate, TRAIN_LOGGER_NAME -from .evaluate import evaluate, evaluate_predictions -from .make_predictions import catpred_predict, make_predictions, load_model, set_features, load_data, predict_and_save -from .molecule_fingerprint import catpred_fingerprint, model_fingerprint -from .predict import predict -from .run_training import run_training -from .train import train +from __future__ import annotations -__all__ = [ - 'catpred_train', - 'cross_validate', - 'TRAIN_LOGGER_NAME', - 'evaluate', - 'evaluate_predictions', - 'catpred_predict', - 'catpred_fingerprint', - 'make_predictions', - 'load_model', - 'set_features', - 'load_data', - 'predict_and_save', - 'predict', - 'run_training', - 'train', - 'get_metric_func', - 'prc_auc', - 'bce', - 'rmse', - 'bounded_mse', - 'bounded_mae', - 'bounded_rmse', - 'accuracy', - 'f1_metric', - 'mcc_metric', - 'sid_metric', - 'wasserstein_metric', - 'get_loss_func', - 'bounded_mse_loss', - 'mcc_class_loss', - 'mcc_multiclass_loss', - 'sid_loss', - 'wasserstein_loss' -] +import importlib + +_LAZY_ATTRS = { + "get_metric_func": ("metrics", "get_metric_func"), + "prc_auc": ("metrics", "prc_auc"), + "bce": ("metrics", "bce"), + "rmse": ("metrics", "rmse"), + "bounded_mse": ("metrics", "bounded_mse"), + "bounded_mae": ("metrics", "bounded_mae"), + "bounded_rmse": ("metrics", "bounded_rmse"), + "accuracy": ("metrics", "accuracy"), + "f1_metric": ("metrics", "f1_metric"), + "mcc_metric": ("metrics", "mcc_metric"), + "sid_metric": ("metrics", "sid_metric"), + "wasserstein_metric": ("metrics", "wasserstein_metric"), + "get_loss_func": ("loss_functions", "get_loss_func"), + "bounded_mse_loss": ("loss_functions", "bounded_mse_loss"), + "mcc_class_loss": ("loss_functions", "mcc_class_loss"), + "mcc_multiclass_loss": ("loss_functions", "mcc_multiclass_loss"), + "sid_loss": ("loss_functions", "sid_loss"), + "wasserstein_loss": ("loss_functions", "wasserstein_loss"), + "catpred_train": ("cross_validate", "catpred_train"), + "cross_validate": ("cross_validate", "cross_validate"), + "TRAIN_LOGGER_NAME": ("cross_validate", "TRAIN_LOGGER_NAME"), + "evaluate": ("evaluate", "evaluate"), + "evaluate_predictions": ("evaluate", "evaluate_predictions"), + "catpred_predict": ("make_predictions", "catpred_predict"), + "make_predictions": ("make_predictions", "make_predictions"), + "load_model": ("make_predictions", "load_model"), + "set_features": ("make_predictions", "set_features"), + "load_data": ("make_predictions", "load_data"), + "predict_and_save": ("make_predictions", "predict_and_save"), + "catpred_fingerprint": ("molecule_fingerprint", "catpred_fingerprint"), + "model_fingerprint": ("molecule_fingerprint", "model_fingerprint"), + "predict": ("predict", "predict"), + "run_training": ("run_training", "run_training"), + "train": ("train", "train"), +} + +__all__ = sorted(_LAZY_ATTRS.keys()) + + +def __getattr__(name: str): + target = _LAZY_ATTRS.get(name) + if target is None: + raise AttributeError(f"module 'catpred.train' has no attribute '{name}'") + + module_name, attr_name = target + module = importlib.import_module(f".{module_name}", __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value diff --git a/modal_app.py b/modal_app.py index 2769457..b1bdee0 100644 --- a/modal_app.py +++ b/modal_app.py @@ -3,7 +3,7 @@ from pathlib import Path import os import tempfile -from typing import Any +from typing import Any, Optional from fastapi import Header, HTTPException import modal @@ -12,6 +12,11 @@ image = ( modal.Image.debian_slim(python_version="3.10") + .apt_install( + "libxrender1", + "libxext6", + "libsm6", + ) .pip_install( "fastapi[standard]>=0.115,<1.0", "pydantic>=1.10,<2.0", @@ -46,7 +51,7 @@ class PredictPayload(BaseModel): checkpoint_dir: str = Field(..., description="Checkpoint subdirectory inside /checkpoints") use_gpu: bool = Field(default=False) input_rows: list[dict[str, Any]] = Field(default_factory=list) - input_filename: str | None = Field(default=None) + input_filename: Optional[str] = Field(default=None) def _safe_checkpoint_path(raw_checkpoint_dir: str) -> Path: @@ -70,7 +75,7 @@ def _safe_checkpoint_path(raw_checkpoint_dir: str) -> Path: @modal.fastapi_endpoint(method="POST", docs=True) def predict( payload: PredictPayload, - authorization: str | None = Header(default=None), + authorization: Optional[str] = Header(default=None), ) -> dict[str, Any]: import pandas as pd From da7b8673bc21a33cf6707bf6c35d39d14815c29f Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sun, 1 Mar 2026 00:12:54 -0500 Subject: [PATCH 13/34] chore(vercel): prefill env file with deployed modal endpoint --- .env.vercel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.env.vercel b/.env.vercel index 7db578e..2b705e4 100644 --- a/.env.vercel +++ b/.env.vercel @@ -1,4 +1,4 @@ CATPRED_DEFAULT_BACKEND=modal -CATPRED_MODAL_ENDPOINT=https:// -CATPRED_MODAL_TOKEN= +CATPRED_MODAL_ENDPOINT=https://kaalabhairava2026--catpred-modal-api-predict.modal.run +CATPRED_MODAL_TOKEN= CATPRED_MODAL_FALLBACK_TO_LOCAL=0 From 0c902da55bae4cde248254836b5f005549f1aad2 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sun, 1 Mar 2026 00:18:21 -0500 Subject: [PATCH 14/34] fix(vercel): remove functions/builds conflict --- vercel.json | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vercel.json b/vercel.json index 4582896..c146952 100644 --- a/vercel.json +++ b/vercel.json @@ -7,11 +7,6 @@ "use": "@vercel/python" } ], - "functions": { - "api/index.py": { - "maxDuration": 300 - } - }, "routes": [ { "src": "/(.*)", From a89216cff6bd7a5a5bfb42fc59d4d0a59c87db91 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sun, 1 Mar 2026 00:31:38 -0500 Subject: [PATCH 15/34] ci: add GitHub Actions validation and Modal auto-deploy --- .github/workflows/ci.yml | 42 ++++++++++++++++++++++++++++++ .github/workflows/deploy-modal.yml | 42 ++++++++++++++++++++++++++++++ README.md | 25 ++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/deploy-modal.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1467ea9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,42 @@ +name: CI + +on: + pull_request: + push: + branches: + - main + +jobs: + validate: + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install minimal runtime dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Compile Python sources + run: | + git ls-files '*.py' | xargs -r python -m py_compile + + - name: Smoke test API entrypoints + run: | + python - <<'PY' + from catpred.web.app import create_app + from api.index import app as vercel_app + + api_app = create_app() + assert api_app.title == "CatPred API" + assert vercel_app is not None + print("API smoke checks passed.") + PY diff --git a/.github/workflows/deploy-modal.yml b/.github/workflows/deploy-modal.yml new file mode 100644 index 0000000..88735bb --- /dev/null +++ b/.github/workflows/deploy-modal.yml @@ -0,0 +1,42 @@ +name: Deploy Modal + +on: + workflow_dispatch: + push: + branches: + - main + paths: + - ".github/workflows/deploy-modal.yml" + - "modal_app.py" + - "predict.py" + - "scripts/create_pdbrecords.py" + - "catpred/**" + +jobs: + deploy: + if: ${{ secrets.MODAL_TOKEN_ID != '' && secrets.MODAL_TOKEN_SECRET != '' }} + runs-on: ubuntu-latest + permissions: + contents: read + concurrency: + group: modal-production-deploy + cancel-in-progress: true + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install Modal CLI + run: | + python -m pip install --upgrade pip + pip install "modal>=0.73" + + - name: Deploy modal_app.py + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + run: modal deploy modal_app.py diff --git a/README.md b/README.md index 1275805..080ba25 100644 --- a/README.md +++ b/README.md @@ -248,6 +248,31 @@ CATPRED_MODAL_TOKEN= CATPRED_MODAL_FALLBACK_TO_LOCAL=0 ``` +#### CI/CD (GitHub Actions + Vercel + Modal) + +This repo includes two GitHub Actions workflows: + +- `.github/workflows/ci.yml` + - Runs on every PR and push to `main`. + - Installs minimal API dependencies, compiles all Python files, and smoke-tests API entrypoints. +- `.github/workflows/deploy-modal.yml` + - Runs on push to `main` when backend files change (and manually via `workflow_dispatch`). + - Deploys `modal_app.py` automatically. + +To enable automatic Modal deploys from GitHub Actions, add repository secrets: + +- `MODAL_TOKEN_ID` +- `MODAL_TOKEN_SECRET` + +Create these from Modal: + +1. Go to [https://modal.com/settings/tokens](https://modal.com/settings/tokens). +2. Create a token with deploy permissions for your workspace. +3. Copy token ID and secret into GitHub repo settings: + `Settings -> Secrets and variables -> Actions -> New repository secret`. + +Vercel deployment remains automatic from the connected GitHub branch (`main`). + ### ๐Ÿงช Fine-Tuning On Custom Data You can fine-tune CatPred on your own regression targets using `train.py`. From 75fe7b52e31b070aaa398fa092bb6bbaebe8bba3 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sun, 1 Mar 2026 00:35:57 -0500 Subject: [PATCH 16/34] fix(web): support modal-only deployments and backend-aware fallback --- catpred/web/static/catpred.js | 86 ++++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 11 deletions(-) diff --git a/catpred/web/static/catpred.js b/catpred/web/static/catpred.js index 707d620..bac7128 100644 --- a/catpred/web/static/catpred.js +++ b/catpred/web/static/catpred.js @@ -31,7 +31,14 @@ const supportedParameters = ["kcat", "km", "ki"]; let availableCheckpointParams = new Set(supportedParameters); + let localCheckpointParams = new Set(supportedParameters); let selectedParameter = "kcat"; + const runtimeState = { + defaultBackend: "local", + modalReady: false, + localReady: false, + fallbackToLocalEnabled: false, + }; const runningPhaseMessages = [ "validating input", "building protein records", @@ -204,16 +211,19 @@ return null; } - function setParameterAvailability(availableCheckpoints) { + function parseAvailableCheckpointParams(availableCheckpoints) { if (availableCheckpoints && typeof availableCheckpoints === "object") { - availableCheckpointParams = new Set( + return new Set( Object.keys(availableCheckpoints) .map((key) => String(key).toLowerCase()) .filter((key) => supportedParameters.includes(key)) ); - } else { - availableCheckpointParams = new Set(supportedParameters); } + return new Set(supportedParameters); + } + + function setParameterAvailability(availableCheckpoints) { + availableCheckpointParams = parseAvailableCheckpointParams(availableCheckpoints); if (presetButtons && presetButtons.length) { presetButtons.forEach((btn) => { @@ -235,6 +245,35 @@ return availableCheckpointParams.has(String(paramValue || "").toLowerCase()); } + function chooseRequestBackend() { + if (runtimeState.defaultBackend === "modal" && runtimeState.modalReady) { + return "modal"; + } + if (runtimeState.defaultBackend === "local" && localCheckpointParams.size > 0) { + return "local"; + } + if (runtimeState.modalReady) { + return "modal"; + } + if (runtimeState.localReady) { + return "local"; + } + return runtimeState.defaultBackend || "local"; + } + + function shouldFallbackToLocal(requestBackend, targetParam) { + if (requestBackend !== "modal") { + return false; + } + if (!runtimeState.fallbackToLocalEnabled) { + return false; + } + if (!runtimeState.localReady) { + return false; + } + return localCheckpointParams.has(String(targetParam || "").toLowerCase()); + } + function formatSequenceId(index) { const safeIndex = Math.max(1, Number(index) || 1); return "seq_" + String(safeIndex).padStart(3, "0"); @@ -393,6 +432,7 @@ function buildPayload(rows) { const target = getSelectedParameter(); + const requestBackend = chooseRequestBackend(); return { parameter: target, @@ -400,7 +440,8 @@ input_rows: rows, use_gpu: false, results_dir: "web-app", - fallback_to_local: true, + backend: requestBackend, + fallback_to_local: shouldFallbackToLocal(requestBackend, target), }; } @@ -533,18 +574,41 @@ return; } - setParameterAvailability(data && data.api ? data.api.available_checkpoints : null); + const backends = data && data.backends ? data.backends : {}; + runtimeState.defaultBackend = data && data.default_backend ? String(data.default_backend) : "local"; + runtimeState.modalReady = Boolean(backends.modal && backends.modal.ready); + runtimeState.localReady = Boolean(backends.local && backends.local.ready); + runtimeState.fallbackToLocalEnabled = Boolean(data && data.fallback_to_local_enabled); + + localCheckpointParams = parseAvailableCheckpointParams( + data && data.api ? data.api.available_checkpoints : null + ); + + if (localCheckpointParams.size > 0) { + setParameterAvailability(Object.fromEntries(Array.from(localCheckpointParams).map((key) => [key, key]))); + } else if (runtimeState.modalReady) { + setParameterAvailability({ kcat: "kcat", km: "km", ki: "ki" }); + } else { + setParameterAvailability({}); + } + const availableParams = Array.from(availableCheckpointParams.values()); + const localParams = Array.from(localCheckpointParams.values()); if (data && data.ready) { const backend = data.default_backend ? String(data.default_backend) : "default"; - const hint = availableParams.length - ? "Backend: " + backend + " | Checkpoints: " + availableParams.join(", ") - : "Backend: " + backend + " | No local checkpoints found"; + let hint = ""; + if (localParams.length) { + hint = "Backend: " + backend + " | Checkpoints: " + localParams.join(", "); + } else if (runtimeState.modalReady) { + hint = "Backend: " + backend + " | Remote checkpoints: " + availableParams.join(", "); + } else { + hint = "Backend: " + backend + " | No local checkpoints found"; + } setServiceState("Online", hint, "ok"); } else { - const hint = availableParams.length - ? "Backend configuration needed" + const hint = runtimeState.modalReady + ? "Backend available in limited mode" : "Backend configuration needed | No local checkpoints found"; setServiceState("Limited", hint, "error"); } From d8a095dac47902afcf951dd12de55147a941238d Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sun, 1 Mar 2026 00:42:37 -0500 Subject: [PATCH 17/34] fix(web): timeout long predict calls and cache-bust static assets --- catpred/web/static/catpred.js | 24 ++++++++++++++++++++++-- catpred/web/static/index.html | 4 ++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/catpred/web/static/catpred.js b/catpred/web/static/catpred.js index bac7128..4abe371 100644 --- a/catpred/web/static/catpred.js +++ b/catpred/web/static/catpred.js @@ -15,6 +15,7 @@ const serviceHint = document.getElementById("serviceHint"); const presetButtons = document.querySelectorAll(".preset-btn"); + const predictionTimeoutMs = 120000; const sampleRows = [ { @@ -637,6 +638,11 @@ startRunningFeedback(payload); await new Promise((resolve) => window.requestAnimationFrame(resolve)); + const requestController = new AbortController(); + const requestTimeout = window.setTimeout(() => { + requestController.abort(); + }, predictionTimeoutMs); + try { const response = await fetch("/predict", { method: "POST", @@ -644,6 +650,7 @@ "Content-Type": "application/json", Accept: "application/json", }, + signal: requestController.signal, body: JSON.stringify(payload), }); @@ -667,12 +674,25 @@ renderPreviewTable(data.preview_rows || []); const elapsedLabel = formatElapsed(getElapsedSeconds()); setStatus("Prediction complete โ€ข " + elapsedLabel, "ok"); - } catch (_err) { + } catch (err) { renderResultCards([], payload.parameter); renderPreviewTable([]); const elapsedLabel = formatElapsed(getElapsedSeconds()); - setStatus("Network error while running prediction. (" + elapsedLabel + ")", "error"); + if (err && err.name === "AbortError") { + const timeoutLabel = formatElapsed(Math.floor(predictionTimeoutMs / 1000)); + setStatus( + "Prediction timed out after " + + timeoutLabel + + ". This is often a cold-start delay; retry once or check backend logs. (" + + elapsedLabel + + ")", + "error" + ); + } else { + setStatus("Network error while running prediction. (" + elapsedLabel + ")", "error"); + } } finally { + window.clearTimeout(requestTimeout); stopRunningFeedback(); runStartedAtMs = null; } diff --git a/catpred/web/static/index.html b/catpred/web/static/index.html index 2ffa804..e4374fe 100644 --- a/catpred/web/static/index.html +++ b/catpred/web/static/index.html @@ -14,7 +14,7 @@ href="https://fonts.googleapis.com/css2?family=Newsreader:ital,opsz,wght@0,6..72,300;0,6..72,400;1,6..72,300;1,6..72,400&family=Outfit:wght@200;300;400;500&family=IBM+Plex+Mono:wght@300;400&display=swap" rel="stylesheet" /> - + @@ -209,6 +209,6 @@

Read uncertainty first

Run prediction - + From 39d6bdcf0401e767fa7f3b04fbe472860e805704 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Sun, 1 Mar 2026 15:36:08 -0500 Subject: [PATCH 18/34] feat(web): streamline production landing and improve scrolling --- catpred/web/static/catpred-favicon.svg | 11 ++ catpred/web/static/catpred.css | 146 +++++++---------------- catpred/web/static/index.html | 156 ++++++------------------- setup.cfg | 1 + 4 files changed, 92 insertions(+), 222 deletions(-) create mode 100644 catpred/web/static/catpred-favicon.svg diff --git a/catpred/web/static/catpred-favicon.svg b/catpred/web/static/catpred-favicon.svg new file mode 100644 index 0000000..3c33710 --- /dev/null +++ b/catpred/web/static/catpred-favicon.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/catpred/web/static/catpred.css b/catpred/web/static/catpred.css index 90bb76f..51d546b 100644 --- a/catpred/web/static/catpred.css +++ b/catpred/web/static/catpred.css @@ -34,8 +34,7 @@ html { scroll-behavior: smooth; height: 100%; - scroll-snap-type: y mandatory; - overscroll-behavior-y: contain; + overscroll-behavior-y: auto; } body { @@ -146,17 +145,27 @@ body { } .hero { - min-height: calc(100svh - 78px); + min-height: auto; display: grid; - grid-template-columns: 0.96fr 1.04fr; - gap: clamp(1.3rem, 4vw, 4.2rem); - align-items: center; - padding-block: clamp(2.2rem, 7.4vw, 5.8rem); + grid-template-columns: minmax(0, 0.92fr) minmax(0, 1.08fr); + gap: 1rem clamp(1rem, 3.2vw, 2.2rem); + align-items: start; + padding-block: clamp(1.6rem, 5.2vw, 3.4rem); } .hero-copy { position: relative; z-index: 1; + padding-top: 0.3rem; +} + +.hero-studio { + align-self: start; +} + +.hero > .output-card { + grid-column: 1 / -1; + margin-top: 0.45rem; } .eyebrow { @@ -171,9 +180,9 @@ body { .hero h1 { font-family: "Newsreader", serif; - font-size: clamp(2.4rem, 5.4vw, 4.8rem); + font-size: clamp(1.78rem, 3.3vw, 2.75rem); font-weight: 300; - line-height: 1.05; + line-height: 1.08; letter-spacing: -0.03em; margin-bottom: 1.25rem; } @@ -188,7 +197,7 @@ body { font-size: clamp(0.98rem, 1.9vw, 1.07rem); color: var(--text-secondary); line-height: 1.75; - margin-bottom: 1.8rem; + margin-bottom: 1.1rem; } .lead code, @@ -208,7 +217,7 @@ body { display: flex; flex-wrap: wrap; gap: 0.8rem; - margin-bottom: 1.75rem; + margin-bottom: 1.1rem; } .btn { @@ -293,9 +302,10 @@ body { .hero-stats { list-style: none; display: grid; - grid-template-columns: repeat(3, minmax(0, 1fr)); + grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 0.7rem; - max-width: 720px; + margin-top: 0.86rem; + max-width: 880px; } .hero-stats li { @@ -320,76 +330,12 @@ body { color: var(--text-secondary); } -.hero-visual-wrap { - position: relative; - isolation: isolate; -} - -.hero-visual { - position: relative; - width: min(560px, 100%); - aspect-ratio: 1 / 1; - margin-inline: auto; -} - -.hero-visual::before { - content: ""; - position: absolute; - inset: 10%; - border-radius: 999px; - background: radial-gradient(circle, rgba(122, 158, 142, 0.16) 0%, rgba(196, 137, 122, 0.08) 42%, transparent 74%); - filter: blur(24px); - animation: breathe 9s ease-in-out infinite; -} - -@keyframes breathe { - 0%, - 100% { - transform: scale(1); - opacity: 0.85; - } - - 50% { - transform: scale(1.06); - opacity: 1; - } -} - -.protein-structure { - width: 100%; - height: 100%; - position: relative; - z-index: 1; - animation: proteinFloat 18s ease-in-out infinite; -} - -@keyframes proteinFloat { - 0%, - 100% { - transform: translateY(0) rotate(0deg); - } - - 30% { - transform: translateY(-8px) rotate(0.4deg); - } - - 60% { - transform: translateY(3px) rotate(-0.25deg); - } - - 85% { - transform: translateY(-4px) rotate(0.18deg); - } -} - .section { padding-block: clamp(2.8rem, 7vw, 5.4rem); } .snap-section { - min-height: 100svh; - scroll-snap-align: start; - scroll-snap-stop: always; + min-height: auto; scroll-margin-top: 84px; } @@ -427,6 +373,7 @@ body { .service-strip { margin-top: 0.9rem; display: inline-flex; + flex-wrap: wrap; align-items: center; gap: 0.6rem; border: 1px solid var(--border); @@ -471,13 +418,6 @@ body { font-size: 0.82rem; } -.predict-layout { - display: grid; - grid-template-columns: 0.96fr 1.04fr; - gap: 1rem; - align-items: start; -} - .card { border: 1px solid var(--border); border-radius: var(--radius-lg); @@ -494,8 +434,7 @@ body { } .input-card { - position: sticky; - top: 98px; + position: static; } .preset-row { @@ -1019,14 +958,14 @@ button:focus-visible { } .reveal { - opacity: 0; - transform: translateY(20px); - transition: opacity 0.48s ease, transform 0.48s ease; + opacity: 1; + transform: none; + transition: none; } .reveal.show { opacity: 1; - transform: translateY(0); + transform: none; } @keyframes spin { @@ -1042,28 +981,22 @@ button:focus-visible { } @media (max-width: 1120px) { - html { - scroll-snap-type: y proximity; - } - .hero { grid-template-columns: 1fr; min-height: auto; } - .hero-visual { - width: min(500px, 100%); + .hero > .output-card { + margin-top: 0; } - .predict-layout, .steps, - .hero-stats, .metric-grid { grid-template-columns: 1fr; } - .input-card { - position: static; + .hero-stats { + grid-template-columns: repeat(2, minmax(0, 1fr)); } } @@ -1080,9 +1013,14 @@ button:focus-visible { padding-top: 1.5rem; } - .snap-section { - min-height: auto; - scroll-snap-stop: normal; + .hero-stats { + grid-template-columns: 1fr; + } + + .service-strip { + display: flex; + border-radius: 14px; + padding: 0.5rem 0.62rem; } .input-table { diff --git a/catpred/web/static/index.html b/catpred/web/static/index.html index e4374fe..9e2f04e 100644 --- a/catpred/web/static/index.html +++ b/catpred/web/static/index.html @@ -14,14 +14,15 @@ href="https://fonts.googleapis.com/css2?family=Newsreader:ital,opsz,wght@0,6..72,300;0,6..72,400;1,6..72,300;1,6..72,400&family=Outfit:wght@200;300;400;500&family=IBM+Plex+Mono:wght@300;400&display=swap" rel="stylesheet" /> - + +
-
+
- -
- -
-
-

Prediction Studio

-

Enter inputs and run

-

- Choose a target, submit one or many rows, and review predictions with uncertainty. -

- -
- Service status - Checking... - -
-
- -
+
-

Inputs

+

Prediction Studio

@@ -146,65 +104,27 @@

Inputs

For CSV/API usage, this field maps to pdbpath.

- -
-

Results

-
Ready when you are.
- -
-
-

Run a prediction to see kinetic estimates here.

-
-
- -
- View detailed output table -
-
-
-
-
-
-
-
-
-

Quick Guide

-

Three rules for reliable runs

-
+
+

Results

+
Ready when you are.
-
-
- 01 -

Use clean inputs

-

- Each row needs valid SMILES, uppercase amino acid sequence, - and a Sequence ID (auto-filled, internal field: pdbpath). -

-
-
- 02 -

Keep sequence IDs consistent

-

- Reuse the same Sequence ID only for the same sequence. Different sequences should - use different IDs. -

-
-
- 03 -

Read uncertainty first

-

- Prioritize results with lower uncertainty before selecting candidates for experiments. -

+
+
+

Run a prediction to see kinetic estimates here.

-

- Tip: use Load sample in the predictor to test your environment end-to-end before - running larger batches. -

-
+ +
+ View detailed output table +
+
+
+
+
+
Run prediction diff --git a/setup.cfg b/setup.cfg index b61db52..b2c6b99 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,3 +65,4 @@ catpred = web/static/*.html web/static/*.css web/static/*.js + web/static/*.svg From eec74d56682dc66dfb06bbe738db293281a443e2 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Mon, 9 Mar 2026 22:54:30 -0400 Subject: [PATCH 19/34] refactor(web): reduce landing page to prediction studio --- catpred/web/static/catpred.css | 282 ++++++++------------------------- catpred/web/static/index.html | 94 ++++++----- 2 files changed, 109 insertions(+), 267 deletions(-) diff --git a/catpred/web/static/catpred.css b/catpred/web/static/catpred.css index 51d546b..ac188aa 100644 --- a/catpred/web/static/catpred.css +++ b/catpred/web/static/catpred.css @@ -144,28 +144,39 @@ body { color: var(--text); } -.hero { - min-height: auto; - display: grid; - grid-template-columns: minmax(0, 0.92fr) minmax(0, 1.08fr); - gap: 1rem clamp(1rem, 3.2vw, 2.2rem); - align-items: start; - padding-block: clamp(1.6rem, 5.2vw, 3.4rem); +.studio { + padding-block: clamp(1.2rem, 4vw, 2.4rem); } -.hero-copy { - position: relative; - z-index: 1; - padding-top: 0.3rem; +.studio-head { + display: flex; + align-items: end; + justify-content: space-between; + gap: 1rem; + margin-bottom: 1rem; +} + +.studio-head h1 { + font-family: "Newsreader", serif; + font-size: clamp(1.35rem, 2.2vw, 1.9rem); + font-weight: 400; + line-height: 1.05; + letter-spacing: -0.03em; } -.hero-studio { - align-self: start; +.studio-meta { + display: flex; + align-items: center; + justify-content: flex-end; + gap: 0.75rem; + flex-wrap: wrap; } -.hero > .output-card { - grid-column: 1 / -1; - margin-top: 0.45rem; +.workspace-grid { + display: grid; + grid-template-columns: minmax(0, 0.98fr) minmax(0, 1.02fr); + gap: 1rem; + align-items: start; } .eyebrow { @@ -178,28 +189,6 @@ body { margin-bottom: 1.2rem; } -.hero h1 { - font-family: "Newsreader", serif; - font-size: clamp(1.78rem, 3.3vw, 2.75rem); - font-weight: 300; - line-height: 1.08; - letter-spacing: -0.03em; - margin-bottom: 1.25rem; -} - -.hero h1 em { - font-style: italic; - color: var(--protein-sage); -} - -.lead { - max-width: 60ch; - font-size: clamp(0.98rem, 1.9vw, 1.07rem); - color: var(--text-secondary); - line-height: 1.75; - margin-bottom: 1.1rem; -} - .lead code, .muted code, .helper code, @@ -213,13 +202,6 @@ body { background: rgba(255, 255, 255, 0.66); } -.hero-actions { - display: flex; - flex-wrap: wrap; - gap: 0.8rem; - margin-bottom: 1.1rem; -} - .btn { display: inline-flex; align-items: center; @@ -295,41 +277,6 @@ body { border-color: var(--text); } -.nav-cta { - padding-inline: 1.1rem; -} - -.hero-stats { - list-style: none; - display: grid; - grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); - gap: 0.7rem; - margin-top: 0.86rem; - max-width: 880px; -} - -.hero-stats li { - border: 1px solid var(--border); - background: rgba(255, 255, 255, 0.68); - border-radius: var(--radius-md); - padding: 0.76rem 0.78rem; -} - -.hero-stats strong { - display: block; - font-family: "IBM Plex Mono", monospace; - font-size: 0.94rem; - font-weight: 400; - color: var(--text); -} - -.hero-stats span { - display: block; - margin-top: 0.25rem; - font-size: 0.84rem; - color: var(--text-secondary); -} - .section { padding-block: clamp(2.8rem, 7vw, 5.4rem); } @@ -418,6 +365,34 @@ body { font-size: 0.82rem; } +.resource-links { + display: flex; + flex-wrap: wrap; + gap: 0.55rem; +} + +.resource-link { + display: inline-flex; + align-items: center; + min-height: 2rem; + padding: 0.3rem 0.62rem; + border: 1px solid var(--border); + border-radius: 999px; + background: rgba(255, 255, 255, 0.72); + color: var(--text-secondary); + text-decoration: none; + font-family: "IBM Plex Mono", monospace; + font-size: 0.67rem; + letter-spacing: 0.03em; + transition: border-color 0.2s ease, color 0.2s ease, background-color 0.2s ease; +} + +.resource-link:hover { + color: var(--text); + border-color: rgba(122, 158, 142, 0.32); + background: rgba(122, 158, 142, 0.08); +} + .card { border: 1px solid var(--border); border-radius: var(--radius-lg); @@ -817,127 +792,6 @@ button:focus-visible { text-transform: uppercase; } -.steps { - display: grid; - grid-template-columns: repeat(3, minmax(0, 1fr)); - gap: 1rem; -} - -.guide-grid { - margin-top: 0.35rem; -} - -.steps article { - border-radius: 18px; - border: 1px solid var(--border); - background: rgba(255, 255, 255, 0.72); - padding: 1rem; -} - -.steps span { - font-family: "IBM Plex Mono", monospace; - font-size: 0.67rem; - color: var(--text-tertiary); - letter-spacing: 0.1em; -} - -.steps h3 { - margin-top: 0.38rem; - margin-bottom: 0.4rem; - font-family: "Newsreader", serif; - font-size: 1.3rem; - font-weight: 400; -} - -.steps p { - color: var(--text-secondary); - font-size: 0.93rem; - line-height: 1.68; -} - -.guide-note { - margin-top: 0.88rem; - border: 1px solid var(--border); - border-radius: 14px; - background: rgba(255, 255, 255, 0.72); - padding: 0.72rem 0.82rem; - color: var(--text-secondary); - font-size: 0.9rem; -} - -.guide-note strong { - color: var(--text); - font-family: "IBM Plex Mono", monospace; - font-size: 0.82em; - letter-spacing: 0.03em; -} - -.faq-list { - border-top: 1px solid var(--border); -} - -.faq-item { - border-bottom: 1px solid var(--border); - padding-block: 0.86rem; -} - -.faq-item button { - width: 100%; - border: 0; - background: transparent; - text-align: left; - cursor: pointer; - display: flex; - justify-content: space-between; - align-items: center; - gap: 1rem; - color: var(--text); - font-family: "Newsreader", serif; - font-size: 1.24rem; - font-weight: 400; -} - -.faq-item button::after { - content: "+"; - color: var(--text-tertiary); - font-family: "IBM Plex Mono", monospace; - font-size: 0.95rem; -} - -.faq-item button[aria-expanded="true"]::after { - content: "โˆ’"; -} - -.faq-item p { - max-height: 0; - overflow: hidden; - margin: 0; - color: var(--text-secondary); - font-size: 0.95rem; - line-height: 1.67; - transition: max-height 0.24s ease, margin-top 0.24s ease; -} - -.faq-item.open p { - max-height: 12rem; - margin-top: 0.4rem; -} - -.footer { - padding: 1.2rem 0 2.4rem; -} - -.footer-inner { - border-top: 1px solid var(--border); - padding-top: 0.95rem; - display: flex; - justify-content: space-between; - align-items: center; - gap: 1rem; - color: var(--text-tertiary); - font-size: 0.84rem; -} - .mobile-stick { position: fixed; left: 1rem; @@ -981,23 +835,19 @@ button:focus-visible { } @media (max-width: 1120px) { - .hero { + .studio-head, + .workspace-grid { grid-template-columns: 1fr; - min-height: auto; } - .hero > .output-card { - margin-top: 0; + .studio-head { + display: grid; + align-items: start; } - .steps, .metric-grid { grid-template-columns: 1fr; } - - .hero-stats { - grid-template-columns: repeat(2, minmax(0, 1fr)); - } } @media (max-width: 780px) { @@ -1009,12 +859,12 @@ button:focus-visible { width: min(1280px, calc(100% - 1.5rem)); } - .hero { - padding-top: 1.5rem; + .studio { + padding-top: 1rem; } - .hero-stats { - grid-template-columns: 1fr; + .studio-head h1 { + font-size: 1.5rem; } .service-strip { @@ -1030,8 +880,4 @@ button:focus-visible { .mobile-stick { display: block; } - - .footer { - padding-bottom: 5rem; - } } diff --git a/catpred/web/static/index.html b/catpred/web/static/index.html index 9e2f04e..396d41f 100644 --- a/catpred/web/static/index.html +++ b/catpred/web/static/index.html @@ -15,7 +15,7 @@ rel="stylesheet" /> - + @@ -33,53 +33,49 @@ - - Start prediction
-
-
-

AI for Enzyme Kinetics

-

- Predict kcat, Km, and Ki - for production enzyme workflows. -

-

- Enter substrate SMILES and enzyme sequence, then run a single - prediction or a small batch with uncertainty-aware output. -

- -
- Go to form +
+
+
+

Prediction Studio

+

Prediction Studio

+
Service status Checking...
-
    -
  • - 3 targets - kcat / Km / Ki -
  • -
  • - Batch ready - Add and run multiple rows -
  • -
  • - Uncertainty - Total and epistemic SD reported -
  • -
+ +
-
+

Prediction Studio

@@ -104,25 +100,25 @@

Prediction Studio

For CSV/API usage, this field maps to pdbpath.

-
-
-

Results

-
Ready when you are.
+
+

Results

+
Ready when you are.
-
-
-

Run a prediction to see kinetic estimates here.

-
-
- -
- View detailed output table -
-
+
+
+

Run a prediction to see kinetic estimates here.

+
-
-
+ +
+ View detailed output table +
+
+
+
+
+
From 52476d1f062a54658adeaa592f96a19c4521fdd8 Mon Sep 17 00:00:00 2001 From: theproteinbot Date: Mon, 9 Mar 2026 23:16:02 -0400 Subject: [PATCH 20/34] style(web): highlight publication and repository links --- catpred/web/static/catpred.css | 91 +++++++++++++++++++++++++++------- catpred/web/static/index.html | 24 ++++----- 2 files changed, 85 insertions(+), 30 deletions(-) diff --git a/catpred/web/static/catpred.css b/catpred/web/static/catpred.css index ac188aa..9b12b0c 100644 --- a/catpred/web/static/catpred.css +++ b/catpred/web/static/catpred.css @@ -128,20 +128,33 @@ body { .desktop-nav { display: flex; align-items: center; - gap: 1.65rem; + gap: 0.6rem; } .desktop-nav a { + display: inline-flex; + align-items: center; + min-height: 2.1rem; + padding: 0.35rem 0.7rem; + border: 1px solid var(--border); + border-radius: 999px; + background: rgba(255, 255, 255, 0.72); text-decoration: none; color: var(--text-secondary); - font-size: 0.9rem; - font-weight: 300; - letter-spacing: 0.01em; - transition: color 0.25s ease; + font-family: "IBM Plex Mono", monospace; + font-size: 0.68rem; + font-weight: 400; + letter-spacing: 0.04em; + text-transform: uppercase; + transition: color 0.2s ease, border-color 0.2s ease, background-color 0.2s ease, + transform 0.2s ease; } .desktop-nav a:hover { color: var(--text); + border-color: rgba(122, 158, 142, 0.34); + background: rgba(122, 158, 142, 0.08); + transform: translateY(-1px); } .studio { @@ -164,9 +177,9 @@ body { letter-spacing: -0.03em; } -.studio-meta { +.studio-tools { display: flex; - align-items: center; + align-items: flex-end; justify-content: flex-end; gap: 0.75rem; flex-wrap: wrap; @@ -375,22 +388,44 @@ body { display: inline-flex; align-items: center; min-height: 2rem; - padding: 0.3rem 0.62rem; + padding: 0.34rem 0.68rem; border: 1px solid var(--border); border-radius: 999px; - background: rgba(255, 255, 255, 0.72); - color: var(--text-secondary); + background: rgba(255, 255, 255, 0.88); + color: var(--text); text-decoration: none; font-family: "IBM Plex Mono", monospace; font-size: 0.67rem; - letter-spacing: 0.03em; - transition: border-color 0.2s ease, color 0.2s ease, background-color 0.2s ease; + letter-spacing: 0.04em; + box-shadow: 0 8px 20px rgba(17, 15, 12, 0.05); + transition: border-color 0.2s ease, color 0.2s ease, background-color 0.2s ease, + transform 0.2s ease, box-shadow 0.2s ease; } .resource-link:hover { color: var(--text); - border-color: rgba(122, 158, 142, 0.32); - background: rgba(122, 158, 142, 0.08); + transform: translateY(-1px); + box-shadow: 0 12px 24px rgba(17, 15, 12, 0.08); +} + +.resource-link-paper { + border-color: rgba(196, 137, 122, 0.38); + background: linear-gradient(180deg, rgba(255, 255, 255, 0.94), rgba(196, 137, 122, 0.08)); +} + +.resource-link-paper:hover { + border-color: rgba(196, 137, 122, 0.64); + background: linear-gradient(180deg, rgba(255, 255, 255, 0.98), rgba(196, 137, 122, 0.14)); +} + +.resource-link-code { + border-color: rgba(122, 158, 142, 0.4); + background: linear-gradient(180deg, rgba(255, 255, 255, 0.94), rgba(122, 158, 142, 0.08)); +} + +.resource-link-code:hover { + border-color: rgba(122, 158, 142, 0.66); + background: linear-gradient(180deg, rgba(255, 255, 255, 0.98), rgba(122, 158, 142, 0.15)); } .card { @@ -845,16 +880,17 @@ button:focus-visible { align-items: start; } + .studio-tools { + justify-content: flex-start; + align-items: flex-start; + } + .metric-grid { grid-template-columns: 1fr; } } @media (max-width: 780px) { - .desktop-nav { - display: none; - } - .container { width: min(1280px, calc(100% - 1.5rem)); } @@ -867,12 +903,31 @@ button:focus-visible { font-size: 1.5rem; } + .desktop-nav { + gap: 0.4rem; + } + + .desktop-nav a { + min-height: 1.95rem; + padding-inline: 0.58rem; + font-size: 0.62rem; + } + .service-strip { display: flex; border-radius: 14px; padding: 0.5rem 0.62rem; } + .resource-links { + gap: 0.42rem; + } + + .resource-link { + width: 100%; + justify-content: center; + } + .input-table { min-width: 560px; } diff --git a/catpred/web/static/index.html b/catpred/web/static/index.html index 396d41f..30e3c9a 100644 --- a/catpred/web/static/index.html +++ b/catpred/web/static/index.html @@ -42,21 +42,21 @@
-
-

Prediction Studio

-

Prediction Studio

+
+

Enzyme Kinetics Prediction

+

kcat, Km, and Ki with uncertainty estimates

-
-
- Service status - Checking... - -
+
+
+ Service status + Checking... + +