Skip to content

Optimizations for wide histogram building#12158

Open
siqi-he wants to merge 12 commits intodmlc:masterfrom
siqi-he:tiling
Open

Optimizations for wide histogram building#12158
siqi-he wants to merge 12 commits intodmlc:masterfrom
siqi-he:tiling

Conversation

@siqi-he
Copy link
Copy Markdown
Contributor

@siqi-he siqi-he commented Apr 13, 2026

Optimizations for wide histogram building using column block tiling

Motivation

For wide datasets (e.g. >500 features), the per-thread histogram buffer in CPU hist tree building exceeds L2 cache. Each row scatters gradient updates across the full buffer, causing heavy cache misses. The existing ColsWiseBuildHistKernel mitigates this for dense data by iterating column-by-column, but suffers from poor gradient-pair reuse (reloads gpair for every column). There is no mitigation for the sparse (any-missing) row-wise path.

Changes

Add column-block tiling with a thread-local local buffer to both histogram kernels. Instead of scattering into the full histogram, each thread accumulates into a small buffer covering ~32 columns worth of bins (~128 KB, fits in L2), then flushes to the full histogram. This localizes writes and amortizes gradient-pair loads across multiple columns per row.

Benchmark methodology

All benchmarks use tree_method='hist', max_depth=8, max_bin=256, 100 rounds, 3 repeats (average of runs 2-3). CPU pinning via taskset. The benchmarks were run using aws ec2 c6i.32xl instance, using all physical cores (i.e. nthread=64).

Datasets include real (Epsilon, Bosch, Santander) and synthetic (HIGGS with PolynomialFeatures expansion, various sparsity levels via injected NaN at fixed seed). Sparse datasets force the row-wise kernel path (IsDense()=false). Predictions were verified to be identical (measured by np.allclose) between master (b2f15e6) and tiling branch across all datasets at 100 rounds.

Results

Dataset Features Hist size Sparsity Baseline Tiled Speedup
epsilon-dense 2000 7.8 MB 0% 38.75s 36.42s 1.06x
epsilon-sparse0.1% 2000 7.8 MB 0.1% 138.74s 46.94s 2.95x
higgs3-sparse0.1% 4494 17.6 MB 0.1% 74.58s 41.25s 1.81x
higgs3-sparse1% 4494 17.6 MB 1% 77.58s 43.48s 1.78x
higgs3-1kf-sparse0.1% 1000 3.9 MB 0.1% 12.41s 10.03s 1.24x
higgs3-1kf-sparse1% 1000 3.9 MB 1% 11.03s 9.49s 1.16x
higgs3-1kf-sparse10% 1000 3.9 MB 10% 12.43s 10.01s 1.24x
higgs3-1kf-sparse49.99% 1000 3.9 MB 49.99% 9.99s 9.23s 1.08x
higgs3-500f-dense 500 2.0 MB 0% 3.50s 3.38s 1.04x
santander 200 0.8 MB 0% 1.65s 1.74s 0.95x
bosch 968 3.8 MB 81% 5.17s 5.16s 1.00x
full benchmark script
#!/usr/bin/env python
"""Benchmark script for tiling PR.
Reproduces the results table in the PR description.

Usage:
    taskset -c 0-63 python bench_pr.py <label> [nthreads]

Datasets required:
    - HIGGS.csv (Kaggle Higgs)
    - data/epsilon.bz2 (LIBSVM format, from https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/)
    - data/santander_train.csv (Kaggle Santander)
    - train_numeric.csv (Kaggle Bosch)
"""
import datetime
import gc
import os
import sys

import numpy as np
import xgboost as xgb

LABEL = sys.argv[1] if len(sys.argv) > 1 else "unknown"
N_THREADS = int(sys.argv[2]) if len(sys.argv) > 2 else os.cpu_count() // 2
NUM_ROUND = 100
N_REPEATS = 3

PARAM = {
    "objective": "binary:logistic",
    "eta": 0.1,
    "max_depth": 8,
    "nthread": N_THREADS,
    "tree_method": "hist",
    "max_bin": 256,
}


def log(msg):
    print(msg, flush=True)


def inject_sparse(X, fraction, seed=123):
    rng = np.random.RandomState(seed)
    mask = rng.random(X.shape) < fraction
    X_sp = X.copy()
    X_sp[mask] = np.nan
    return X_sp


def bench(name, dm, n_features):
    hist_mb = n_features * 256 * 16 / (1024 * 1024)
    log(f"\n  {name}: hist={hist_mb:.1f}MB")
    xgb.train(PARAM, dm, 5)
    times = []
    for r in range(N_REPEATS):
        st = datetime.datetime.now()
        xgb.train(PARAM, dm, NUM_ROUND)
        elapsed = (datetime.datetime.now() - st).total_seconds()
        times.append(elapsed)
        log(f"    Run {r + 1}: {elapsed:.2f}s")
    avg = np.mean(times[1:])
    std = np.std(times[1:])
    log(f"    >> {name}: {avg:.2f}s +/-{std:.2f}")


log(f"{'=' * 70}")
log(f"Benchmark: {LABEL}")
log(f"Threads: {N_THREADS}, Rounds: {NUM_ROUND}, Repeats: {N_REPEATS}")
log(f"{'=' * 70}")

# ---------- Epsilon ----------
log("\nLoading Epsilon...")
from sklearn.datasets import load_svmlight_file

X_eps, y_eps = load_svmlight_file("data/epsilon.bz2", n_features=2000)
X_eps = X_eps.toarray().astype(np.float32)
y_eps = ((y_eps + 1) / 2).astype(np.float32)

bench("epsilon-dense", xgb.DMatrix(X_eps, label=y_eps), 2000)
bench(
    "epsilon-sparse0.1%",
    xgb.DMatrix(inject_sparse(X_eps, 0.001), label=y_eps, missing=np.nan),
    2000,
)
del X_eps, y_eps
gc.collect()

# ---------- HIGGS poly-3 ----------
log("\nLoading HIGGS poly-3...")
import pandas as pd
from sklearn.preprocessing import PolynomialFeatures

df = pd.read_csv("HIGGS.csv", header=None, nrows=100_000)
data_raw = df.values[:, 1:].astype(np.float32)
label_h = df.values[:, 0].astype(np.float32)
poly = PolynomialFeatures(degree=3, interaction_only=False, include_bias=False)
X_h3 = poly.fit_transform(data_raw).astype(np.float32)
del df, data_raw, poly
gc.collect()

# 4494 features
bench(
    "higgs3-sparse0.1%",
    xgb.DMatrix(inject_sparse(X_h3, 0.001), label=label_h, missing=np.nan),
    4494,
)
bench(
    "higgs3-sparse1%",
    xgb.DMatrix(inject_sparse(X_h3, 0.01), label=label_h, missing=np.nan),
    4494,
)

# 1000 features
X_1k = X_h3[:, :1000].copy()
del X_h3
gc.collect()

bench(
    "higgs3-1kf-sparse0.1%",
    xgb.DMatrix(inject_sparse(X_1k, 0.001), label=label_h, missing=np.nan),
    1000,
)
bench(
    "higgs3-1kf-sparse1%",
    xgb.DMatrix(inject_sparse(X_1k, 0.01), label=label_h, missing=np.nan),
    1000,
)
bench(
    "higgs3-1kf-sparse10%",
    xgb.DMatrix(inject_sparse(X_1k, 0.10), label=label_h, missing=np.nan),
    1000,
)
bench(
    "higgs3-1kf-sparse49.99%",
    xgb.DMatrix(inject_sparse(X_1k, 0.4999), label=label_h, missing=np.nan),
    1000,
)

# 500 features dense
X_500 = X_1k[:, :500].copy()
del X_1k
gc.collect()
bench("higgs3-500f-dense", xgb.DMatrix(X_500, label=label_h), 500)
del X_500, label_h
gc.collect()

# ---------- Santander ----------
log("\nLoading Santander...")
df_s = pd.read_csv("data/santander_train.csv")
y_s = df_s["target"].values.astype(np.float32)
X_s = df_s.drop(columns=["ID_code", "target"]).values.astype(np.float32)
bench("santander", xgb.DMatrix(X_s, label=y_s), 200)
del X_s, y_s, df_s
gc.collect()

# ---------- Bosch ----------
log("\nLoading Bosch...")
df_b = pd.read_csv("train_numeric.csv")
y_b = df_b["Response"].values.astype(np.float32)
X_b = df_b.drop(columns=["Id", "Response"]).values.astype(np.float32)
bench("bosch", xgb.DMatrix(X_b, label=y_b, missing=np.nan), 968)
del X_b, y_b, df_b
gc.collect()

log(f"\n{'=' * 70}")
log("Done")

Comment on lines +502 to +503
constexpr double kMinDensityForTiling = 0.5;
bool bin_sorted = !BuildingManager::kAnyMissing || gmat.RowsSortedByBin();
Copy link
Copy Markdown
Contributor Author

@siqi-he siqi-he Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local buffer is flushed entirely every column block. When the data is very sparse, only a few bins are actually hit. Therefore doing a full sweep would actually slow things down. The 0.5 threshold is a rough heuristic. The idea is that denser data tend to benefit more from tiling.

For the tiled kernel to work, entries within a row need to be in ascending bin order. It seems that this is the case for standard SparsePage but not guaranteed for CSRArrayAdapter as it accepts user-provided CSR data where column indices may not be sorted. This guard is thus added to avoid silent failures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant