Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion causationentropy/core/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def discover_network(
bandwidth="silverman",
k_means: int = 5,
n_shuffles: int = 200,
n_jobs=-1,
kd_tree: bool = False,
n_jobs: int = -1,
) -> nx.MultiDiGraph:
r"""
Infer a causal graph via Optimal Causation Entropy (oCSE).
Expand Down Expand Up @@ -96,6 +97,12 @@ def discover_network(
n_shuffles : int, default=200
Number of permutations for statistical significance testing. Higher values
provide more accurate p-value estimates but increase computational cost.
kd_tree : bool, default=False
If True, uses a KD-Tree for nearest-neighbor searches in kNN and
geometric-kNN information estimators. This reduces complexity from
O(N^2) to approximately O(N log N) for large datasets. Results are
mathematically identical to the brute-force path (kd_tree=False).
Has no effect when information='gaussian', 'kde', or 'poisson'.
n_jobs : int, default=-1
Number of parallel jobs for computation. -1 uses all available processors.

Expand Down Expand Up @@ -248,6 +255,7 @@ def discover_network(
metric=metric,
k=k_means,
bandwidth=bandwidth,
kd_tree=kd_tree,
)

# Compute p-value using shuffle test
Expand Down
84 changes: 56 additions & 28 deletions causationentropy/core/information/conditional_mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def kde_conditional_mutual_information(
return I


def knn_conditional_mutual_information(X, Y, Z, metric="minkowski", k=1):
def knn_conditional_mutual_information(X, Y, Z, metric="minkowski", k=1, kd_tree: bool = False):
"""
Estimate conditional mutual information using k-nearest neighbor method.

Expand Down Expand Up @@ -193,27 +193,40 @@ def knn_conditional_mutual_information(X, Y, Z, metric="minkowski", k=1):
if Z is None:
return knn_mutual_information(X, Y, metric=metric, k=k)
else:
JS = np.column_stack((X, Y, Z))
# Find the K-th smallest distance in the joint space
if metric == "minkowski":
D = np.sort(cdist(JS, JS, metric=metric, p=k + 1), axis=1)[:, k]
JS = np.column_stack((X, Y, Z))
XZ = np.column_stack((X, Z))
YZ = np.column_stack((Y, Z))

if kd_tree:
from scipy.spatial import KDTree
tree_JS = KDTree(JS)
dist_JS, _ = tree_JS.query(JS, k=k + 1)
epsilon = dist_JS[:, k]

tree_XZ = KDTree(XZ)
tree_YZ = KDTree(YZ)
tree_Z = KDTree(Z)
nxz = np.array(tree_XZ.query_ball_point(XZ, epsilon, return_length=True)) - 1
nyz = np.array(tree_YZ.query_ball_point(YZ, epsilon, return_length=True)) - 1
nz = np.array(tree_Z.query_ball_point(Z, epsilon, return_length=True)) - 1
else:
D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k]
epsilon = D
# Count neighbors within epsilon in marginal spaces
Dxz = cdist(np.column_stack((X, Z)), np.column_stack((X, Z)), metric=metric)
nxz = np.sum(Dxz < epsilon[:, None], axis=1) - 1
Dyz = cdist(np.column_stack((Y, Z)), np.column_stack((Y, Z)), metric=metric)
nyz = np.sum(Dyz < epsilon[:, None], axis=1) - 1
Dz = cdist(Z, Z, metric=metric)
nz = np.sum(Dz < epsilon[:, None], axis=1) - 1

# VP Estimation formula
if metric == "minkowski":
D = np.sort(cdist(JS, JS, metric=metric, p=k + 1), axis=1)[:, k]
else:
D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k]
epsilon = D
Dxz = cdist(XZ, XZ, metric=metric)
nxz = np.sum(Dxz < epsilon[:, None], axis=1) - 1
Dyz = cdist(YZ, YZ, metric=metric)
nyz = np.sum(Dyz < epsilon[:, None], axis=1) - 1
Dz = cdist(Z, Z, metric=metric)
nz = np.sum(Dz < epsilon[:, None], axis=1) - 1

I = digamma(k) - np.mean(digamma(nxz + 1) + digamma(nyz + 1) - digamma(nz + 1))
return I


def geometric_knn_conditional_mutual_information(X, Y, Z, metric="euclidean", k=1):
def geometric_knn_conditional_mutual_information(X, Y, Z, metric="euclidean", k=1, kd_tree: bool = False):
"""
Estimate conditional mutual information using geometric k-nearest neighbor method.

Expand Down Expand Up @@ -263,15 +276,29 @@ def geometric_knn_conditional_mutual_information(X, Y, Z, metric="euclidean", k=
"""

if Z is None:
return geometric_knn_mutual_information(X, Y)
YZdist = cdist(np.hstack((Y, Z)), np.hstack((Y, Z)), metric=metric)
XZdist = cdist(np.hstack((X, Z)), np.hstack((X, Z)), metric=metric)
XYZdist = cdist(np.hstack((X, Y, Z)), np.hstack((X, Y, Z)), metric=metric)
Zdist = cdist(Z, Z, metric=metric)
HZ = geometric_knn_entropy(Z, Zdist, k)
HXZ = geometric_knn_entropy(np.hstack((X, Z)), XZdist, k)
HYZ = geometric_knn_entropy(np.hstack((Y, Z)), YZdist, k)
HXYZ = geometric_knn_entropy(np.hstack((X, Y, Z)), XYZdist, k)
return geometric_knn_mutual_information(X, Y, metric=metric, k=k, kd_tree=kd_tree)

if kd_tree:
HZ = geometric_knn_entropy(Z, None, k, kd_tree=True, metric=metric)
HXZ = geometric_knn_entropy(np.hstack((X, Z)), None, k, kd_tree=True, metric=metric)
HYZ = geometric_knn_entropy(np.hstack((Y, Z)), None, k, kd_tree=True, metric=metric)
HXYZ = geometric_knn_entropy(np.hstack((X, Y, Z)), None, k, kd_tree=True, metric=metric)
else:
Zdist = cdist(Z, Z, metric=metric)

XZ = np.hstack((X, Z))
YZ = np.hstack((Y, Z))
XYZ = np.hstack((X, Y, Z))

XZdist = cdist(XZ, XZ, metric=metric)
YZdist = cdist(YZ, YZ, metric=metric)
XYZdist = cdist(XYZ, XYZ, metric=metric)

HZ = geometric_knn_entropy(Z, Zdist, k)
HXZ = geometric_knn_entropy(XZ, XZdist, k)
HYZ = geometric_knn_entropy(YZ, YZdist, k)
HXYZ = geometric_knn_entropy(XYZ, XYZdist, k)

cmi = HXZ + HYZ - HXYZ - HZ
return cmi

Expand Down Expand Up @@ -379,6 +406,7 @@ def conditional_mutual_information(
k=6,
bandwidth="silverman",
kernel="gaussian",
kd_tree: bool = False,
):
"""
Compute conditional mutual information using specified estimation method.
Expand Down Expand Up @@ -483,10 +511,10 @@ def conditional_mutual_information(
)

elif method == "knn":
cmi = knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k)
cmi = knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k, kd_tree=kd_tree)

elif method == "geometric_knn":
cmi = geometric_knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k)
cmi = geometric_knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k, kd_tree=kd_tree)

elif method == "poisson":
cmi = poisson_conditional_mutual_information(X, Y, Z)
Expand Down
31 changes: 23 additions & 8 deletions causationentropy/core/information/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def kde_entropy(X, bandwidth="silverman", kernel="gaussian"):
return Hx


def geometric_knn_entropy(X, Xdist, k=1):
def geometric_knn_entropy(X, Xdist, k=1, kd_tree: bool = False, metric: str = "euclidean"):
r"""
Estimate entropy using geometric k-nearest neighbor method.

Expand Down Expand Up @@ -160,21 +160,36 @@ def geometric_knn_entropy(X, Xdist, k=1):
.. [1] Lord, W.M., Sun, J., Bollt, E.M. Geometric k-nearest neighbor estimation of
entropy and mutual information. Chaos 28, 033113 (2018).
"""
if kd_tree:
from scipy.spatial import KDTree
tree = KDTree(X)
distances, indices = tree.query(X, k=k + 1)
_knn_indices = indices[:, 1:]
_knn_dists = distances[:, 1:]
Xdist = None
else:
_knn_indices = None
_knn_dists = None

N, d = X.shape
Xknn = np.zeros((N, k), dtype=int)

for i in range(N):
Xknn[i, :] = np.argsort(Xdist[i, :])[1 : k + 1]
if kd_tree:
Xknn = _knn_indices
else:
for i in range(N):
Xknn[i, :] = np.argsort(Xdist[i, :])[1 : k + 1]
H_X = np.log(N) + np.log(np.pi ** (d / 2) / gamma(1 + d / 2))

# Compute distance-based term with safety checks
log_distances = []
for i in range(N):
dist = l2dist(X[i, :], X[Xknn[i, k - 1], :])
if dist > 1e-12: # Avoid log(0)
if kd_tree:
dist = _knn_dists[i, k - 1]
else:
dist = l2dist(X[i, :], X[Xknn[i, k - 1], :])
if dist > 1e-12:
log_distances.append(np.log(dist))
else:
log_distances.append(-12.0) # log(1e-12) as a reasonable lower bound
log_distances.append(-12.0)

H_X += d / N * np.sum(log_distances)

Expand Down
55 changes: 36 additions & 19 deletions causationentropy/core/information/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def kde_mutual_information(X, Y, bandwidth="silverman", kernel="gaussian"):
return mi


def knn_mutual_information(X, Y, metric="euclidean", k=1):
def knn_mutual_information(X, Y, metric="euclidean", k=1, kd_tree: bool = False):
r"""
Estimate mutual information using k-nearest neighbor (KSG) method.

Expand Down Expand Up @@ -155,19 +155,29 @@ def knn_mutual_information(X, Y, metric="euclidean", k=1):
.. [1] Kraskov, A., Stögbauer, H., Grassberger, P. Estimating mutual information.
Physical Review E 69, 066138 (2004).
"""
# construct the joint space
n = X.shape[0]
JS = np.column_stack((X, Y))

# Find the K^th smallest distance in the joint space
D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k]
epsilon = D

# Count neighbors within epsilon in marginal spaces
Dx = cdist(X, X, metric=metric)
nx = np.sum(Dx < epsilon[:, None], axis=1) - 1
Dy = cdist(Y, Y, metric=metric)
ny = np.sum(Dy < epsilon[:, None], axis=1) - 1
if kd_tree:
from scipy.spatial import KDTree
# Find k-th nearest neighbor distance in joint space
tree_JS = KDTree(JS)
dist_JS, _ = tree_JS.query(JS, k=k + 1)
epsilon = dist_JS[:, k]

tree_X = KDTree(X)
tree_Y = KDTree(Y)
# Count points strictly within epsilon
nx = np.array(tree_X.query_ball_point(X, epsilon, return_length=True)) - 1
ny = np.array(tree_Y.query_ball_point(Y, epsilon, return_length=True)) - 1
else:
# Original brute-force path
D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k]
epsilon = D
Dx = cdist(X, X, metric=metric)
nx = np.sum(Dx < epsilon[:, None], axis=1) - 1
Dy = cdist(Y, Y, metric=metric)
ny = np.sum(Dy < epsilon[:, None], axis=1) - 1

# KSG Estimation formula
I1a = digamma(k)
Expand All @@ -178,7 +188,7 @@ def knn_mutual_information(X, Y, metric="euclidean", k=1):
return mi


def geometric_knn_mutual_information(X, Y, metric="euclidean", k=1):
def geometric_knn_mutual_information(X, Y, metric="euclidean", k=1, kd_tree: bool = False):
"""
Estimate mutual information using geometric k-nearest neighbor method.

Expand Down Expand Up @@ -223,13 +233,20 @@ def geometric_knn_mutual_information(X, Y, metric="euclidean", k=1):
.. [1] Lord, W.M., Sun, J., Bollt, E.M. Geometric k-nearest neighbor estimation of
entropy and mutual information. Chaos 28, 033113 (2018).
"""
Xdist = cdist(X, X, metric=metric)
Ydist = cdist(Y, Y, metric=metric)
XYdist = cdist(np.hstack((X, Y)), np.hstack((X, Y)), metric=metric)

HX = geometric_knn_entropy(X, Xdist, k)
HY = geometric_knn_entropy(Y, Ydist, k)
HXY = geometric_knn_entropy(np.hstack((X, Y)), XYdist, k)
if kd_tree:
HX = geometric_knn_entropy(X, None, k, kd_tree=True, metric=metric)
HY = geometric_knn_entropy(Y, None, k, kd_tree=True, metric=metric)
HXY = geometric_knn_entropy(np.hstack((X, Y)), None, k, kd_tree=True, metric=metric)
else:
Xdist = cdist(X, X, metric=metric)
Ydist = cdist(Y, Y, metric=metric)

XY = np.hstack((X, Y))
XYdist = cdist(XY, XY, metric=metric)

HX = geometric_knn_entropy(X, Xdist, k)
HY = geometric_knn_entropy(Y, Ydist, k)
HXY = geometric_knn_entropy(XY, XYdist, k)

mi = HX + HY - HXY

Expand Down
94 changes: 94 additions & 0 deletions causationentropy/tests/test_kdtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import time

import numpy as np
import pytest

from causationentropy.core.discovery import discover_network
from causationentropy.core.information.conditional_mutual_information import (
conditional_mutual_information,
)
from causationentropy.core.information.mutual_information import (
geometric_knn_mutual_information,
knn_mutual_information,
)


class TestKDTreeCorrectness:
"""kd_tree=True and kd_tree=False must produce numerically identical results."""

def test_knn_mutual_information_correctness(self):
rng = np.random.default_rng(0)
X = rng.standard_normal((150, 2))
Y = rng.standard_normal((150, 1))
mi_bf = knn_mutual_information(X, Y, k=3, kd_tree=False)
mi_kdt = knn_mutual_information(X, Y, k=3, kd_tree=True)
assert abs(mi_bf - mi_kdt) < 1e-9, (
f"knn MI mismatch: brute={mi_bf:.8f}, kd_tree={mi_kdt:.8f}"
)

def test_geometric_knn_mutual_information_correctness(self):
rng = np.random.default_rng(1)
X = rng.standard_normal((100, 2))
Y = rng.standard_normal((100, 1))
mi_bf = geometric_knn_mutual_information(X, Y, k=2, kd_tree=False)
mi_kdt = geometric_knn_mutual_information(X, Y, k=2, kd_tree=True)
assert abs(mi_bf - mi_kdt) < 1e-9

def test_knn_cmi_correctness(self):
rng = np.random.default_rng(2)
X = rng.standard_normal((120, 1))
Y = rng.standard_normal((120, 1))
Z = rng.standard_normal((120, 2))
cmi_bf = conditional_mutual_information(X, Y, Z, method="knn", k=3, kd_tree=False)
cmi_kdt = conditional_mutual_information(X, Y, Z, method="knn", k=3, kd_tree=True)
assert abs(cmi_bf - cmi_kdt) < 1e-9

def test_geometric_knn_cmi_correctness(self):
rng = np.random.default_rng(3)
X = rng.standard_normal((80, 1))
Y = rng.standard_normal((80, 1))
Z = rng.standard_normal((80, 1))
cmi_bf = conditional_mutual_information(X, Y, Z, method="geometric_knn", k=2, kd_tree=False)
cmi_kdt = conditional_mutual_information(X, Y, Z, method="geometric_knn", k=2, kd_tree=True)
assert abs(cmi_bf - cmi_kdt) < 1e-9

def test_discover_network_kd_tree_flag(self):
"""discover_network runs with kd_tree=True and produces a valid graph."""
rng = np.random.default_rng(42)
data = rng.standard_normal((60, 3))
G = discover_network(data, information="knn", max_lag=1, n_shuffles=10, kd_tree=True)
import networkx as nx
assert isinstance(G, nx.MultiDiGraph)
assert G.number_of_nodes() == 3

def test_default_behavior_unchanged(self):
"""kd_tree=False (default) still works exactly as before."""
rng = np.random.default_rng(7)
data = rng.standard_normal((60, 3))
G = discover_network(data, information="knn", max_lag=1, n_shuffles=10)
import networkx as nx
assert isinstance(G, nx.MultiDiGraph)


class TestKDTreeBenchmark:
"""Benchmark to show speedup. Not a strict assertion — print results."""

@pytest.mark.slow
def test_runtime_scaling(self):
for N in [300, 800]:
rng = np.random.default_rng(0)
X = rng.standard_normal((N, 2))
Y = rng.standard_normal((N, 1))
t0 = time.perf_counter()
knn_mutual_information(X, Y, k=5, kd_tree=False)
t_bf = time.perf_counter() - t0

t0 = time.perf_counter()
knn_mutual_information(X, Y, k=5, kd_tree=True)
t_kdt = time.perf_counter() - t0

speedup = t_bf / t_kdt if t_kdt > 0 else float("inf")
print(f"\nN={N}: brute={t_bf:.3f}s kd_tree={t_kdt:.3f}s speedup={speedup:.1f}x")
# At N=800 in ≥2D the KD-Tree should be meaningfully faster
if N >= 800:
assert speedup > 1.5, f"Expected >1.5x speedup at N={N}, got {speedup:.2f}x"