From a5c8b35739017e6221a4f5498f0f512526a82578 Mon Sep 17 00:00:00 2001 From: iPopovS Date: Thu, 4 Jun 2026 16:19:47 +0300 Subject: [PATCH 1/5] Add KDTree option to geometric CMI estimator --- .../conditional_mutual_information.py | 84 ++++++++++++------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/causationentropy/core/information/conditional_mutual_information.py b/causationentropy/core/information/conditional_mutual_information.py index eea7f22..13bece4 100644 --- a/causationentropy/core/information/conditional_mutual_information.py +++ b/causationentropy/core/information/conditional_mutual_information.py @@ -139,7 +139,7 @@ def kde_conditional_mutual_information( return I -def knn_conditional_mutual_information(X, Y, Z, metric="minkowski", k=1): +def knn_conditional_mutual_information(X, Y, Z, metric="minkowski", k=1, kd_tree: bool = False): """ Estimate conditional mutual information using k-nearest neighbor method. @@ -193,27 +193,40 @@ def knn_conditional_mutual_information(X, Y, Z, metric="minkowski", k=1): if Z is None: return knn_mutual_information(X, Y, metric=metric, k=k) else: - JS = np.column_stack((X, Y, Z)) - # Find the K-th smallest distance in the joint space - if metric == "minkowski": - D = np.sort(cdist(JS, JS, metric=metric, p=k + 1), axis=1)[:, k] + JS = np.column_stack((X, Y, Z)) + XZ = np.column_stack((X, Z)) + YZ = np.column_stack((Y, Z)) + + if kd_tree: + from scipy.spatial import KDTree + tree_JS = KDTree(JS) + dist_JS, _ = tree_JS.query(JS, k=k + 1) + epsilon = dist_JS[:, k] + + tree_XZ = KDTree(XZ) + tree_YZ = KDTree(YZ) + tree_Z = KDTree(Z) + nxz = np.array(tree_XZ.query_ball_point(XZ, epsilon, return_length=True)) - 1 + nyz = np.array(tree_YZ.query_ball_point(YZ, epsilon, return_length=True)) - 1 + nz = np.array(tree_Z.query_ball_point(Z, epsilon, return_length=True)) - 1 else: - D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k] - epsilon = D - # Count neighbors within epsilon in marginal spaces - Dxz = cdist(np.column_stack((X, Z)), np.column_stack((X, Z)), metric=metric) - nxz = np.sum(Dxz < epsilon[:, None], axis=1) - 1 - Dyz = cdist(np.column_stack((Y, Z)), np.column_stack((Y, Z)), metric=metric) - nyz = np.sum(Dyz < epsilon[:, None], axis=1) - 1 - Dz = cdist(Z, Z, metric=metric) - nz = np.sum(Dz < epsilon[:, None], axis=1) - 1 - - # VP Estimation formula + if metric == "minkowski": + D = np.sort(cdist(JS, JS, metric=metric, p=k + 1), axis=1)[:, k] + else: + D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k] + epsilon = D + Dxz = cdist(XZ, XZ, metric=metric) + nxz = np.sum(Dxz < epsilon[:, None], axis=1) - 1 + Dyz = cdist(YZ, YZ, metric=metric) + nyz = np.sum(Dyz < epsilon[:, None], axis=1) - 1 + Dz = cdist(Z, Z, metric=metric) + nz = np.sum(Dz < epsilon[:, None], axis=1) - 1 + I = digamma(k) - np.mean(digamma(nxz + 1) + digamma(nyz + 1) - digamma(nz + 1)) return I -def geometric_knn_conditional_mutual_information(X, Y, Z, metric="euclidean", k=1): +def geometric_knn_conditional_mutual_information(X, Y, Z, metric="euclidean", k=1, kd_tree: bool = False): """ Estimate conditional mutual information using geometric k-nearest neighbor method. @@ -263,15 +276,29 @@ def geometric_knn_conditional_mutual_information(X, Y, Z, metric="euclidean", k= """ if Z is None: - return geometric_knn_mutual_information(X, Y) - YZdist = cdist(np.hstack((Y, Z)), np.hstack((Y, Z)), metric=metric) - XZdist = cdist(np.hstack((X, Z)), np.hstack((X, Z)), metric=metric) - XYZdist = cdist(np.hstack((X, Y, Z)), np.hstack((X, Y, Z)), metric=metric) - Zdist = cdist(Z, Z, metric=metric) - HZ = geometric_knn_entropy(Z, Zdist, k) - HXZ = geometric_knn_entropy(np.hstack((X, Z)), XZdist, k) - HYZ = geometric_knn_entropy(np.hstack((Y, Z)), YZdist, k) - HXYZ = geometric_knn_entropy(np.hstack((X, Y, Z)), XYZdist, k) + return geometric_knn_mutual_information(X, Y, metric=metric, k=k, kd_tree=kd_tree) + + if kd_tree: + HZ = geometric_knn_entropy(Z, None, k, kd_tree=True, metric=metric) + HXZ = geometric_knn_entropy(np.hstack((X, Z)), None, k, kd_tree=True, metric=metric) + HYZ = geometric_knn_entropy(np.hstack((Y, Z)), None, k, kd_tree=True, metric=metric) + HXYZ = geometric_knn_entropy(np.hstack((X, Y, Z)), None, k, kd_tree=True, metric=metric) + else: + Zdist = cdist(Z, Z, metric=metric) + + XZ = np.hstack((X, Z)) + YZ = np.hstack((Y, Z)) + XYZ = np.hstack((X, Y, Z)) + + XZdist = cdist(XZ, XZ, metric=metric) + YZdist = cdist(YZ, YZ, metric=metric) + XYZdist = cdist(XYZ, XYZ, metric=metric) + + HZ = geometric_knn_entropy(Z, Zdist, k) + HXZ = geometric_knn_entropy(XZ, XZdist, k) + HYZ = geometric_knn_entropy(YZ, YZdist, k) + HXYZ = geometric_knn_entropy(XYZ, XYZdist, k) + cmi = HXZ + HYZ - HXYZ - HZ return cmi @@ -379,6 +406,7 @@ def conditional_mutual_information( k=6, bandwidth="silverman", kernel="gaussian", + kd_tree: bool = False, ): """ Compute conditional mutual information using specified estimation method. @@ -483,10 +511,10 @@ def conditional_mutual_information( ) elif method == "knn": - cmi = knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k) + cmi = knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k, kd_tree=kd_tree) elif method == "geometric_knn": - cmi = geometric_knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k) + cmi = geometric_knn_conditional_mutual_information(X, Y, Z, metric=metric, k=k, kd_tree=kd_tree) elif method == "poisson": cmi = poisson_conditional_mutual_information(X, Y, Z) From a9b98fdc738c0cede6671e09cea9ba6d478ccfb0 Mon Sep 17 00:00:00 2001 From: iPopovS Date: Thu, 4 Jun 2026 16:21:00 +0300 Subject: [PATCH 2/5] Add kd_tree parameter to network discovery API --- causationentropy/core/discovery.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/causationentropy/core/discovery.py b/causationentropy/core/discovery.py index ae5d442..f081bb6 100644 --- a/causationentropy/core/discovery.py +++ b/causationentropy/core/discovery.py @@ -27,7 +27,8 @@ def discover_network( bandwidth="silverman", k_means: int = 5, n_shuffles: int = 200, - n_jobs=-1, + kd_tree: bool = False, + n_jobs: int = -1, ) -> nx.MultiDiGraph: r""" Infer a causal graph via Optimal Causation Entropy (oCSE). @@ -96,6 +97,12 @@ def discover_network( n_shuffles : int, default=200 Number of permutations for statistical significance testing. Higher values provide more accurate p-value estimates but increase computational cost. + kd_tree : bool, default=False + If True, uses a KD-Tree for nearest-neighbor searches in kNN and + geometric-kNN information estimators. This reduces complexity from + O(N^2) to approximately O(N log N) for large datasets. Results are + mathematically identical to the brute-force path (kd_tree=False). + Has no effect when information='gaussian', 'kde', or 'poisson'. n_jobs : int, default=-1 Number of parallel jobs for computation. -1 uses all available processors. @@ -248,6 +255,7 @@ def discover_network( metric=metric, k=k_means, bandwidth=bandwidth, + kd_tree=kd_tree, ) # Compute p-value using shuffle test From 5688d3e3345a60acb422473d3f7c55219ed96de5 Mon Sep 17 00:00:00 2001 From: iPopovS Date: Thu, 4 Jun 2026 16:22:05 +0300 Subject: [PATCH 3/5] Add tests for KDTree-based kNN estimators --- causationentropy/tests/test_kdtree.py | 94 +++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 causationentropy/tests/test_kdtree.py diff --git a/causationentropy/tests/test_kdtree.py b/causationentropy/tests/test_kdtree.py new file mode 100644 index 0000000..171f444 --- /dev/null +++ b/causationentropy/tests/test_kdtree.py @@ -0,0 +1,94 @@ +import time + +import numpy as np +import pytest + +from causationentropy.core.discovery import discover_network +from causationentropy.core.information.conditional_mutual_information import ( + conditional_mutual_information, +) +from causationentropy.core.information.mutual_information import ( + geometric_knn_mutual_information, + knn_mutual_information, +) + + +class TestKDTreeCorrectness: + """kd_tree=True and kd_tree=False must produce numerically identical results.""" + + def test_knn_mutual_information_correctness(self): + rng = np.random.default_rng(0) + X = rng.standard_normal((150, 2)) + Y = rng.standard_normal((150, 1)) + mi_bf = knn_mutual_information(X, Y, k=3, kd_tree=False) + mi_kdt = knn_mutual_information(X, Y, k=3, kd_tree=True) + assert abs(mi_bf - mi_kdt) < 1e-9, ( + f"knn MI mismatch: brute={mi_bf:.8f}, kd_tree={mi_kdt:.8f}" + ) + + def test_geometric_knn_mutual_information_correctness(self): + rng = np.random.default_rng(1) + X = rng.standard_normal((100, 2)) + Y = rng.standard_normal((100, 1)) + mi_bf = geometric_knn_mutual_information(X, Y, k=2, kd_tree=False) + mi_kdt = geometric_knn_mutual_information(X, Y, k=2, kd_tree=True) + assert abs(mi_bf - mi_kdt) < 1e-9 + + def test_knn_cmi_correctness(self): + rng = np.random.default_rng(2) + X = rng.standard_normal((120, 1)) + Y = rng.standard_normal((120, 1)) + Z = rng.standard_normal((120, 2)) + cmi_bf = conditional_mutual_information(X, Y, Z, method="knn", k=3, kd_tree=False) + cmi_kdt = conditional_mutual_information(X, Y, Z, method="knn", k=3, kd_tree=True) + assert abs(cmi_bf - cmi_kdt) < 1e-9 + + def test_geometric_knn_cmi_correctness(self): + rng = np.random.default_rng(3) + X = rng.standard_normal((80, 1)) + Y = rng.standard_normal((80, 1)) + Z = rng.standard_normal((80, 1)) + cmi_bf = conditional_mutual_information(X, Y, Z, method="geometric_knn", k=2, kd_tree=False) + cmi_kdt = conditional_mutual_information(X, Y, Z, method="geometric_knn", k=2, kd_tree=True) + assert abs(cmi_bf - cmi_kdt) < 1e-9 + + def test_discover_network_kd_tree_flag(self): + """discover_network runs with kd_tree=True and produces a valid graph.""" + rng = np.random.default_rng(42) + data = rng.standard_normal((60, 3)) + G = discover_network(data, information="knn", max_lag=1, n_shuffles=10, kd_tree=True) + import networkx as nx + assert isinstance(G, nx.MultiDiGraph) + assert G.number_of_nodes() == 3 + + def test_default_behavior_unchanged(self): + """kd_tree=False (default) still works exactly as before.""" + rng = np.random.default_rng(7) + data = rng.standard_normal((60, 3)) + G = discover_network(data, information="knn", max_lag=1, n_shuffles=10) + import networkx as nx + assert isinstance(G, nx.MultiDiGraph) + + +class TestKDTreeBenchmark: + """Benchmark to show speedup. Not a strict assertion — print results.""" + + @pytest.mark.slow + def test_runtime_scaling(self): + for N in [300, 800]: + rng = np.random.default_rng(0) + X = rng.standard_normal((N, 2)) + Y = rng.standard_normal((N, 1)) + t0 = time.perf_counter() + knn_mutual_information(X, Y, k=5, kd_tree=False) + t_bf = time.perf_counter() - t0 + + t0 = time.perf_counter() + knn_mutual_information(X, Y, k=5, kd_tree=True) + t_kdt = time.perf_counter() - t0 + + speedup = t_bf / t_kdt if t_kdt > 0 else float("inf") + print(f"\nN={N}: brute={t_bf:.3f}s kd_tree={t_kdt:.3f}s speedup={speedup:.1f}x") + # At N=800 in ≥2D the KD-Tree should be meaningfully faster + if N >= 800: + assert speedup > 1.5, f"Expected >1.5x speedup at N={N}, got {speedup:.2f}x" From 877c098007ee68501de7f3c2f663d34ca6a89081 Mon Sep 17 00:00:00 2001 From: iPopovS Date: Thu, 4 Jun 2026 16:22:45 +0300 Subject: [PATCH 4/5] Use KDTree for geometric kNN entropy estimation --- causationentropy/core/information/entropy.py | 31 +++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/causationentropy/core/information/entropy.py b/causationentropy/core/information/entropy.py index 483ea67..71a68c4 100644 --- a/causationentropy/core/information/entropy.py +++ b/causationentropy/core/information/entropy.py @@ -112,7 +112,7 @@ def kde_entropy(X, bandwidth="silverman", kernel="gaussian"): return Hx -def geometric_knn_entropy(X, Xdist, k=1): +def geometric_knn_entropy(X, Xdist, k=1, kd_tree: bool = False, metric: str = "euclidean"): r""" Estimate entropy using geometric k-nearest neighbor method. @@ -160,21 +160,36 @@ def geometric_knn_entropy(X, Xdist, k=1): .. [1] Lord, W.M., Sun, J., Bollt, E.M. Geometric k-nearest neighbor estimation of entropy and mutual information. Chaos 28, 033113 (2018). """ + if kd_tree: + from scipy.spatial import KDTree + tree = KDTree(X) + distances, indices = tree.query(X, k=k + 1) + _knn_indices = indices[:, 1:] + _knn_dists = distances[:, 1:] + Xdist = None + else: + _knn_indices = None + _knn_dists = None + N, d = X.shape Xknn = np.zeros((N, k), dtype=int) - - for i in range(N): - Xknn[i, :] = np.argsort(Xdist[i, :])[1 : k + 1] + if kd_tree: + Xknn = _knn_indices + else: + for i in range(N): + Xknn[i, :] = np.argsort(Xdist[i, :])[1 : k + 1] H_X = np.log(N) + np.log(np.pi ** (d / 2) / gamma(1 + d / 2)) - # Compute distance-based term with safety checks log_distances = [] for i in range(N): - dist = l2dist(X[i, :], X[Xknn[i, k - 1], :]) - if dist > 1e-12: # Avoid log(0) + if kd_tree: + dist = _knn_dists[i, k - 1] + else: + dist = l2dist(X[i, :], X[Xknn[i, k - 1], :]) + if dist > 1e-12: log_distances.append(np.log(dist)) else: - log_distances.append(-12.0) # log(1e-12) as a reasonable lower bound + log_distances.append(-12.0) H_X += d / N * np.sum(log_distances) From 4a61a462c625aaec5af042edc293ac78624f0ae7 Mon Sep 17 00:00:00 2001 From: iPopovS Date: Thu, 4 Jun 2026 16:23:24 +0300 Subject: [PATCH 5/5] Add KDTree option to geometric MI estimator --- .../core/information/mutual_information.py | 55 ++++++++++++------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/causationentropy/core/information/mutual_information.py b/causationentropy/core/information/mutual_information.py index 3131c85..62c0743 100644 --- a/causationentropy/core/information/mutual_information.py +++ b/causationentropy/core/information/mutual_information.py @@ -106,7 +106,7 @@ def kde_mutual_information(X, Y, bandwidth="silverman", kernel="gaussian"): return mi -def knn_mutual_information(X, Y, metric="euclidean", k=1): +def knn_mutual_information(X, Y, metric="euclidean", k=1, kd_tree: bool = False): r""" Estimate mutual information using k-nearest neighbor (KSG) method. @@ -155,19 +155,29 @@ def knn_mutual_information(X, Y, metric="euclidean", k=1): .. [1] Kraskov, A., Stögbauer, H., Grassberger, P. Estimating mutual information. Physical Review E 69, 066138 (2004). """ - # construct the joint space n = X.shape[0] JS = np.column_stack((X, Y)) - # Find the K^th smallest distance in the joint space - D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k] - epsilon = D - - # Count neighbors within epsilon in marginal spaces - Dx = cdist(X, X, metric=metric) - nx = np.sum(Dx < epsilon[:, None], axis=1) - 1 - Dy = cdist(Y, Y, metric=metric) - ny = np.sum(Dy < epsilon[:, None], axis=1) - 1 + if kd_tree: + from scipy.spatial import KDTree + # Find k-th nearest neighbor distance in joint space + tree_JS = KDTree(JS) + dist_JS, _ = tree_JS.query(JS, k=k + 1) + epsilon = dist_JS[:, k] + + tree_X = KDTree(X) + tree_Y = KDTree(Y) + # Count points strictly within epsilon + nx = np.array(tree_X.query_ball_point(X, epsilon, return_length=True)) - 1 + ny = np.array(tree_Y.query_ball_point(Y, epsilon, return_length=True)) - 1 + else: + # Original brute-force path + D = np.sort(cdist(JS, JS, metric=metric), axis=1)[:, k] + epsilon = D + Dx = cdist(X, X, metric=metric) + nx = np.sum(Dx < epsilon[:, None], axis=1) - 1 + Dy = cdist(Y, Y, metric=metric) + ny = np.sum(Dy < epsilon[:, None], axis=1) - 1 # KSG Estimation formula I1a = digamma(k) @@ -178,7 +188,7 @@ def knn_mutual_information(X, Y, metric="euclidean", k=1): return mi -def geometric_knn_mutual_information(X, Y, metric="euclidean", k=1): +def geometric_knn_mutual_information(X, Y, metric="euclidean", k=1, kd_tree: bool = False): """ Estimate mutual information using geometric k-nearest neighbor method. @@ -223,13 +233,20 @@ def geometric_knn_mutual_information(X, Y, metric="euclidean", k=1): .. [1] Lord, W.M., Sun, J., Bollt, E.M. Geometric k-nearest neighbor estimation of entropy and mutual information. Chaos 28, 033113 (2018). """ - Xdist = cdist(X, X, metric=metric) - Ydist = cdist(Y, Y, metric=metric) - XYdist = cdist(np.hstack((X, Y)), np.hstack((X, Y)), metric=metric) - - HX = geometric_knn_entropy(X, Xdist, k) - HY = geometric_knn_entropy(Y, Ydist, k) - HXY = geometric_knn_entropy(np.hstack((X, Y)), XYdist, k) + if kd_tree: + HX = geometric_knn_entropy(X, None, k, kd_tree=True, metric=metric) + HY = geometric_knn_entropy(Y, None, k, kd_tree=True, metric=metric) + HXY = geometric_knn_entropy(np.hstack((X, Y)), None, k, kd_tree=True, metric=metric) + else: + Xdist = cdist(X, X, metric=metric) + Ydist = cdist(Y, Y, metric=metric) + + XY = np.hstack((X, Y)) + XYdist = cdist(XY, XY, metric=metric) + + HX = geometric_knn_entropy(X, Xdist, k) + HY = geometric_knn_entropy(Y, Ydist, k) + HXY = geometric_knn_entropy(XY, XYdist, k) mi = HX + HY - HXY