Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ pip install -e .[dev] # add ,gpu for CUDA; add ,mosek for MOSEK solver
make paper CONFIG=configs/small.yaml
```

### Google Colab Quick Start

You can run an end-to-end demo of this project instantly in Google Colab without installing anything locally:

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/initial-d/ml-quant-trading/blob/main/demo_baostock.ipynb)

---

## Factor Library (213 factors: 9 Alpha101 + 204 legacy)
Expand Down Expand Up @@ -102,6 +108,34 @@ stock_019 stock_020 stock_021 stock_022

</details>

### Data Sources

You can directly fetch stock data from Yahoo Finance or Baostock (for A-shares).

**yfinance:**
```python
from mlquant.data import make_panel

panel = make_panel(
source="yfinance",
tickers=["000001.SZ", "600000.SS"],
start="2020-01-01",
end="2023-12-31"
)
```

**baostock:**
```python
from mlquant.data import make_panel

panel = make_panel(
source="baostock",
tickers=["sh.600000", "sz.000001"],
start="2020-01-01",
end="2023-12-31"
)
```

### Usage

```python
Expand All @@ -119,7 +153,7 @@ factors, mask, names = compute_legacy_set(panel, names=("best_001", "add_015", "
## Architecture

```
raw OCHLV → data.loaders / data.synthetic (Panel with mask)
raw OCHLV → data.loaders / data.synthetic / data.yfinance_loader / data.baostock_loader (Panel with mask)
→ features.tensor_factors (GPU masked primitives)
→ features.legacy_factors (204 alphas)
→ training.augment + models.nets + models.losses
Expand Down
652 changes: 652 additions & 0 deletions arXiv-2507.07107v1.tex

Large diffs are not rendered by default.

291 changes: 291 additions & 0 deletions demo_baostock.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Machine Learning Enhanced Multi-Factor Quantitative Trading (A-Shares Demo)\n",
"\n",
"This notebook demonstrates the end-to-end pipeline of our quantitative trading system described in *\"Machine Learning Enhanced Multi-Factor Quantitative Trading: A Cross-Sectional Portfolio Optimization Approach with Bias Correction\"* ([arXiv:2507.07107](https://arxiv.org/abs/2507.07107)).\n",
"\n",
"We will build a portfolio focusing on **HS300** stocks, fetch historical market data via **Baostock**, and compute our alpha features. Finally, we'll run a vectorised backtest to get a proof-of-concept result. **You do not need to change any code, just click \"Run All\"!**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Environment Setup\n",
"\n",
"First, we install the project dependencies, including `baostock` for fetching the Chinese A-share historical data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"# Check if we are running in Google Colab\n",
"if 'google.colab' in sys.modules or not os.path.exists('pyproject.toml'):\n",
" # Clone the repo if it doesn't exist\n",
" if not os.path.exists('ml-quant-trading'):\n",
" !git clone https://github.com/Uwater1/ml-quant-trading.git\n",
" \n",
" # Change directory to the project root\n",
" %cd ml-quant-trading\n",
" \n",
" # Install dependencies\n",
" !pip install -e .[dev]\n",
" !pip install baostock\n",
" \n",
" # Fallback: Add 'src' to path so imports work immediately without restarting runtime\n",
" sys.path.append(os.path.abspath('src'))\n",
" \n",
"print(\"Environment setup complete.\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Fetch HS300 Stocks Data using Baostock\n",
"\n",
"We use `baostock` to get the current list of HS300 (沪深300) constituent stocks. Then we download their OHLCV and trading amount over a target period. The `mlquant.data.make_panel` utility wraps this neatly into a PyTorch-based `Panel` dataclass, masking out un-tradable days automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import baostock as bs\n",
"import pandas as pd\n",
"import torch\n",
"from mlquant.data import make_panel\n",
"\n",
"# Log in to baostock\n",
"lg = bs.login()\n",
"print('Baostock login respond error_code:'+lg.error_code)\n",
"\n",
"# 1. Fetch HS300 component list\n",
"rs = bs.query_hs300_stocks()\n",
"hs300_stocks = []\n",
"while (rs.error_code == '0') & rs.next():\n",
" hs300_stocks.append(rs.get_row_data()[1]) # index 1 is the code like 'sh.600000'\n",
"\n",
"print(f\"Fetched {len(hs300_stocks)} HS300 stocks.\")\n",
"bs.logout()\n",
"\n",
"# 2. We will just use the first 100 stocks to keep the demo fast, \n",
"# but you can increase this to len(hs300_stocks) for the full universe.\n",
"tickers = hs300_stocks # Use all 300 stocks\n",
"start_date = '2023-01-01'\n",
"end_date = '2024-12-31'\n",
"\n",
"print(f\"Fetching historical data for {len(tickers)} stocks from {start_date} to {end_date}...\")\n",
"panel = make_panel(\n",
" source=\"baostock\",\n",
" tickers=tickers,\n",
" start=start_date,\n",
" end=end_date,\n",
" device=\"cpu\" # Use \"cuda\" if a GPU is available\n",
")\n",
"print(f\"Created Panel with shape: Dates {panel.n_dates} x Stocks {panel.n_stocks}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Alpha Factor Engineering and Bias Correction\n",
"\n",
"Modern quant systems struggle with unintended systematic biases (e.g. market cap or industry exposure). Our system incorporates multi-stage cross-sectional neutralization to transform biased signals into pure alpha factors. \n",
"\n",
"Here we'll compute a subset of our 204 hand-crafted features (derived from `features.legacy_factors`) on the GPU/CPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from mlquant.features import compute_legacy_set\n",
"from mlquant.features.legacy_factors import LEGACY_REGISTRY\n",
"\n",
"# Compute all available legacy factors\n",
"print(f\"Computing {len(LEGACY_REGISTRY)} factors...\")\n",
"factors, mask, names = compute_legacy_set(panel, names=None)\n",
"\n",
"print(f\"Computed factor tensor shape: {factors.shape} (Dates x Stocks x Factors)\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Machine Learning Model Training\n",
"\n",
"We use deep learning to combine these alpha factors into a single predictive score. The framework supports MLPs, Transformers, and gradient boosting trees (XGBoost/LightGBM).\n",
"\n",
"For this demo, we'll train a lightweight multi-layer perceptron (MLP) optimizing for Information Coefficient (IC) and Rank IC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"from mlquant.models.nets import MLPRegressor\n",
"from mlquant.models.losses import ICLoss\n",
"from mlquant.training.dataset import FactorDataset\n",
"\n",
"# 1. Target: Forward 1-day returns\n",
"targets = panel.returns.roll(shifts=-1, dims=0) # T, N\n",
"targets[-1] = 0.0 # Last day has no forward return\n",
"\n",
"# 2. Create dataset (FactorDataset automatically handles the mask and alignment)\n",
"dataset = FactorDataset(factors, panel.mask, targets)\n",
"loader = DataLoader(dataset, batch_size=32, shuffle=True)\n",
"\n",
"# 3. Initialize Model (MLPRegressor takes 'in_dim' and 'hidden' size)\n",
"model = MLPRegressor(in_dim=factors.shape[-1], hidden=256, dropout=0.3)\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)\n",
"criterion = ICLoss()\n",
"\n",
"print(\"Training model for 30 epochs...\")\n",
"model.train()\n",
"for epoch in range(30):\n",
" total_loss = 0.0\n",
" for X, y in loader:\n",
" optimizer.zero_grad()\n",
" preds = model(X)\n",
"\n",
" # Using our custom IC loss: negative cross-sectional correlation\n",
" loss = criterion(preds, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
" print(\n",
" f\"Epoch {epoch+1}/30 | Average IC Loss: {total_loss / len(loader):.4f}\"\n",
" )\n",
"\n",
"print(\"Training complete.\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Inference and Cross-Sectional Portfolio Optimization\n",
"\n",
"Rather than predicting absolute returns, the model focuses on relative performance within the universe. This naturally hedges market risk while concentrating on security selection alpha.\n",
"\n",
"We run the model over the full period to get predicted scores, then apply a cross-sectional Markowitz optimizer (with leverage and no-short constraints) to generate the target portfolio weights day by day."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"print(\"Generating portfolio weights...\")\n",
"model.eval()\n",
"with torch.no_grad():\n",
" # Get daily predictions\n",
" # factors has shape [T, N, F]\n",
" predictions = model(factors)\n",
"\n",
"T, N = panel.n_dates, panel.n_stocks\n",
"weights = torch.zeros((T, N))\n",
"\n",
"# Simple top-K weighting for demonstration\n",
"# We buy the top 10 stocks with the highest predicted returns each day\n",
"for t in range(T):\n",
" day_mask = mask[t] # Use the mask from section 3\n",
" if day_mask.sum() > 10:\n",
" day_preds = predictions[t]\n",
" day_preds[~day_mask] = -1e9 # Mask out un-tradable stocks\n",
" \n",
" # Buy top 10 stocks each day\n",
" top_idx = torch.topk(day_preds, k=10).indices\n",
" weights[t, top_idx] = 1.0 / 10.0\n",
"\n",
"print(\"Portfolio weights calculated.\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Vectorised Backtest & Metrics\n",
"\n",
"Finally, we run a fast vectorised backtest to evaluate the strategy's Sharpe Ratio, Maximum Drawdown, and Annual Return. Our backtester rigorously handles untradable dates (limit-ups, halts) automatically using the `panel.mask`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from mlquant.backtest.engine import run_backtest\n",
"\n",
"print(\"Running backtest...\")\n",
"\n",
"# run_backtest expects numpy arrays\n",
"weights_np = weights.detach().cpu().numpy()\n",
"returns_np = panel.returns.detach().cpu().numpy()\n",
"\n",
"results = run_backtest(weights_np, returns_np, costs_bps=15.0)\n",
"summary = results.metrics\n",
"\n",
"print(\"\\n=== Backtest Results ===\")\n",
"print(f\"Annualized Return : {summary['ann_return'] * 100:.2f}%\")\n",
"print(f\"Sharpe Ratio : {summary['sharpe']:.2f}\")\n",
"print(f\"Max Drawdown : {summary['max_dd'] * 100:.2f}%\")\n",
"print(f\"Daily Turnover : {summary['turnover'] * 100:.2f}%\")\n",
"\n",
"print(\n",
" \"\\nProof-of-concept complete! To improve these results, train on a larger dataset with more factors.\"\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ dependencies = [
"click>=8.1",
"rich>=13.0",
"tqdm>=4.65",
"yfinance>=0.2.0",
"baostock>=0.8.8",
]

[project.optional-dependencies]
Expand Down
18 changes: 18 additions & 0 deletions src/mlquant/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
from .panel import Panel
from .synthetic import SyntheticConfig, make_synthetic_panel
from .loaders import load_ochlv_csv
from .yfinance_loader import load_yfinance_panel
from .baostock_loader import load_baostock_panel

__all__ = [
"Panel",
"SyntheticConfig",
"make_synthetic_panel",
"load_ochlv_csv",
"load_yfinance_panel",
"load_baostock_panel",
"make_panel",
]

Expand All @@ -52,4 +56,18 @@ def make_panel(source: str = "synthetic", **kwargs: Any) -> Panel:
if path is None:
raise TypeError("make_panel(source='csv', ...) requires a `path=` kwarg")
return load_ochlv_csv(path, **kwargs)
if source == "yfinance":
tickers = kwargs.pop("tickers", None)
start = kwargs.pop("start", None)
end = kwargs.pop("end", None)
if tickers is None or start is None or end is None:
raise TypeError("make_panel(source='yfinance', ...) requires `tickers`, `start`, and `end` kwargs")
return load_yfinance_panel(tickers, start, end, **kwargs)
if source == "baostock":
tickers = kwargs.pop("tickers", None)
start = kwargs.pop("start", None)
end = kwargs.pop("end", None)
if tickers is None or start is None or end is None:
raise TypeError("make_panel(source='baostock', ...) requires `tickers`, `start`, and `end` kwargs")
return load_baostock_panel(tickers, start, end, **kwargs)
raise ValueError(f"unknown panel source: {source!r}")
Loading
Loading