Skip to content

models module

Syed Ibrahim Omer edited this page Apr 12, 2026 · 5 revisions

Models Module

src/models/ — ML architectures, training, inference, and TensorRT engine wrappers.

Files

File Purpose Key Classes/Functions
surge_model.py Core ML architectures MigrationLSTM, MigrationTransformer, SurgeJointLoss, build_sequential_tensors()
train_and_evaluate.py Training pipeline Full train loop: RF → Transformer → LSTM → Ensemble → Evaluate
inference.py Production prediction MigrationSurgeEnsemble — loads all models, horizon-aware weighting
surge_metrics.py Surge evaluation detect_surge(), evaluate_surge_performance()
embedding.py Jina v5 embedding runner Batch embed articles, append to Parquets, LZ4 compression
jinav5_engine.py Jina TensorRT wrapper JinaV5EmbeddingTrtModel — CUDA memory, dynamic shapes
flant5_engine.py Flan-T5 TensorRT wrapper TensorRTFlanT5Engine — tokenize, beam search, generate
led_engine.py LED TensorRT wrapper allenai/led-base-16384 — long-context labeling
utils.py Jina ONNX export Export Jina-v5-nano to ONNX with mean pooling

Model Architectures (surge_model.py)

MigrationLSTM

Input: (B, 6, 3) + country_id → LSTM(2 layers, H=64) + Embedding(8) → FC → (B, 6)

MigrationTransformer

Input: (B, 6, 3) → Linear(64) + PosEnc → TransformerEncoder(2L, 4H) → FC → (B, 6)

SurgeJointLoss

L = 0.6 × Huber(ŷ, y) + 0.4 × BCE(ŷ - 1.5σ, surge_flag)

Training Flow (train_and_evaluate.py)

1. Load train_panel.parquet → temporal split (≤2022 / 2023+)
2. Train 6 × cuML Random Forest (one per lead month)
3. Train MigrationTransformer (MSE, 20 epochs)
4. Train MigrationLSTM (SurgeJointLoss, 25 epochs)
5. Ensemble with horizon-aware weights
6. Evaluate all 4 on OOT test set
7. Save artifacts to src/models/trained_models/

Trained Model Artifacts

src/models/trained_models/
├── lstm.pth, transformer.pth
├── rf_lead_{1..6}.joblib
├── scaler_x.joblib, scaler_y.joblib
└── country_map.json

TensorRT Engines

src/models/tensor-rt/
├── flan-t5-large/int8_wo_cpu/1-gpu/
└── (jina, led engine paths configured externally)

See Also

Wiki navigation

Quick Start

  • Project Overview — Goals, research questions, methodology, and team
  • Glossary — Key terms used throughout this wiki

Data Sources

Raw inputs that feed the prediction system.

Page Description
Visa Data US Department of State visa issuance statistics (108 monthly PDFs)
Encounter Data CBP Southwest border encounter statistics (FY2019–2026)
Google News 170K+ news articles across 15 countries × 8 topics
Google Trends Monthly search-interest time series (15 countries × 8 keywords)
Exchange Rates IMF Real Effective Exchange Rate for 6 countries

Pipeline

The end-to-end flow from raw data to production forecasts.

Page Description
Data Collection Ingestion layer: async scraping, bounded concurrency, retry logic
Data Processing PDF parsing, JSON→Parquet, encounter merging
NLP Enrichment Embedding → Clustering → Labeling → Sentiment
Panel Construction Feature engineering: 18 lag features, 6 lead targets
Training Pipeline Out-of-time train/test split, 4 architectures
Inference Pipeline Horizon-aware ensemble, production prediction flow

Models

Machine learning architectures and their roles in the ensemble.

Page Description
Random Forest cuML GPU Random Forest — best at short horizons (Lead 1–2)
LSTM MigrationLSTM — country-aware with SurgeJointLoss
Transformer MigrationTransformer — best at long horizons (Lead 5–6)
Horizon-Aware Ensemble Dynamic weighting: RF→short, Transformer→long
SurgeJointLoss Dual-objective loss: Huber + BCE for crisis detection
Jina v5 Embeddings TensorRT INT8 news article embeddings (768-dim)
Flan-T5 Summarization TensorRT INT8 cluster labeling engine

Analysis Methods

Statistical techniques driving the lead-lag and surge analysis.

Page Description
Lead-Lag Analysis Pearson correlation at 0–6 month offsets
Surge Detection Quantile-based and σ-threshold spike identification
Sentiment Analysis Rule-based lexicon scoring for migration-relevant news
Event Clustering HDBSCAN GPU clustering + LED label generation
Cross-Correlation Analysis CCF analysis, VAR benchmarking, ADF stationarity tests
Multiple Comparison Correction Benjamini-Hochberg FDR for 58 significant signals

Key Findings

What the system discovered about migration predictability.

Page Description
Event-Visa Findings News events as leading indicators (r=0.617 at 3-month lag)
Exchange Rate Findings Exchange rate signals (DR r=0.498 at 2-month lag)
Model Performance Ensemble results: F1=0.96 at Lead 1, F1=0.86 at Lead 6

Source Modules

Reference documentation for every src/ subpackage and key files.

Page Description
Main Entry Point src/main.py CLI: bootstrap, collect-live, sync-data
Collection Module src/collection/* — visa, encounter, news, trends, HF sync
Processing Module src/processing/* — parse, merge, build_panel, summarize
Analysis Module src/analysis/* — events, exchange_rate, trends_analysis, plots
Models Module src/models/* — surge_model, train_and_evaluate, inference
News Scraper Deep dive: batch decoding, checkpoint recovery, throttling
PDF Parser Deep dive: PyMuPDF table extraction, VISA_MAP normalization
TensorRT Engines Deep dive: Jina-v5, Flan-T5, LED TensorRT engines
Build Panel Detail Deep dive: lag/lead construction, forward-fill strategies
HF Sync Deep dive: bidirectional Hugging Face Hub sync

Infrastructure

Compute, reproducibility, and operational details.

Page Description
GPU Acceleration TensorRT INT8, cuML, CUDA streams, NVML profiling
Reproducibility HF bootstrap, run.sh pipeline, dependency checking

Clone this wiki locally