From f5e69b2aa1d39b4aeebb9a6c9ff63a69eb55c8c9 Mon Sep 17 00:00:00 2001 From: omgol411 Date: Mon, 13 Apr 2026 22:00:37 +0530 Subject: [PATCH 1/4] parallelize distance calculation during `get_patches` --- src/patch_computer.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/patch_computer.py b/src/patch_computer.py index ae37224..01793ae 100644 --- a/src/patch_computer.py +++ b/src/patch_computer.py @@ -1,6 +1,8 @@ import numpy as np import itertools import jenkspy +import os +from multiprocessing import Pool def calc_bead_spread(tup, grid): inds = tup[0] @@ -21,12 +23,28 @@ def to_array(final, ids): full_arr[ids.index(b),ids.index(a)] = final[i] return full_arr -def calc_distance_matrix(args, coords, radius): - pairs = itertools.combinations(args, 2) - mean_dist = [] - for p in pairs: +def worker_calc_distance(batch_pairs): + batch_mean_dist = [] + for p in batch_pairs: surface_distance = np.linalg.norm(coords[:,p[0],:] - coords[:,p[1],:], axis=1) -(radius[p[0]] + radius[p[1]]) - mean_dist.append(np.mean(surface_distance) if np.mean(surface_distance) >= 0 else 0) + batch_mean_dist.append(max(np.mean(surface_distance), 0)) + return batch_mean_dist + +def initialize_worker(coords_, radius_): + global coords, radius + coords = coords_ + radius = radius_ + +def calc_distance_matrix(args, coords, radius, cores=min(max(os.cpu_count() - 1, 1), 16)): + # pairs = itertools.combinations(args, 2) + # mean_dist = [] + # for p in pairs: + # surface_distance = np.linalg.norm(coords[:,p[0],:] - coords[:,p[1],:], axis=1) -(radius[p[0]] + radius[p[1]]) + # mean_dist.append(np.mean(surface_distance) if np.mean(surface_distance) >= 0 else 0) + batches = np.array_split(list(itertools.combinations(args, 2)), cores) + with Pool(cores, initializer=initialize_worker, initargs=(coords, radius)) as pool: + results = pool.map(worker_calc_distance, batches) + mean_dist = [dist for batch in results for dist in batch] return to_array(mean_dist, args) def thresh_to_arg(bead_spread, low_thresh, high_thresh): From 77953e443325e6a87822dc0ecee2fe606155ebf6 Mon Sep 17 00:00:00 2001 From: omgol411 Date: Mon, 13 Apr 2026 22:08:38 +0530 Subject: [PATCH 2/4] expose `cores` argument in `get_patches` --- src/main.py | 2 +- src/patch_computer.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main.py b/src/main.py index fd87e07..295c3c4 100644 --- a/src/main.py +++ b/src/main.py @@ -85,7 +85,7 @@ def run_prism( coords, mass, radius, ps_names, args, output_dir = None ): fl.write("\n") # Obtain patches for all the beads. - patches = get_patches(bead_spread, args.classes, coords, radius) + patches = get_patches(bead_spread, args.classes, coords, radius, cores_) # Annotate the patches for low-med-high precision. annotated_patches = annotate_patches(patches, args.classes, ps_names, coords.shape[1]) high_prec, low_prec = patches[:args.classes], patches[args.classes+1:] diff --git a/src/patch_computer.py b/src/patch_computer.py index 01793ae..4440226 100644 --- a/src/patch_computer.py +++ b/src/patch_computer.py @@ -50,9 +50,9 @@ def calc_distance_matrix(args, coords, radius, cores=min(max(os.cpu_count() - 1, def thresh_to_arg(bead_spread, low_thresh, high_thresh): return [ n for n,i in enumerate(bead_spread) if i >= low_thresh and i <= high_thresh ] -def get_connected_components(arg, coords, radius, thresh=10): +def get_connected_components(arg, coords, radius, thresh=10, cores=min(max(os.cpu_count() - 1, 1), 16)): import networkx as nx - dist = calc_distance_matrix(arg, coords, radius) + dist = calc_distance_matrix(arg, coords, radius, cores=cores) true_pairs = np.argwhere(dist < thresh) l = [] for tp in true_pairs: @@ -64,10 +64,10 @@ def get_connected_components(arg, coords, radius, thresh=10): clusts.append(list(connected_component)) return clusts -def get_patches(bead_spread, classes, coords, radius): +def get_patches(bead_spread, classes, coords, radius, cores=min(max(os.cpu_count() - 1, 1), 16)): breaks = jenkspy.jenks_breaks(bead_spread, n_classes= (classes*2) + 1) arg_patches = [thresh_to_arg(bead_spread, breaks[i-1], breaks[i]) for i in range(1,len(breaks))] - patches = [get_connected_components(arg, coords, radius) for arg in arg_patches] + patches = [get_connected_components(arg, coords, radius, thresh=10, cores=cores) for arg in arg_patches] return patches def annotate_patches(patches, classes, ps_names, num_beads): From a5991825f8d045be96baba9a11a85cffd4f68046 Mon Sep 17 00:00:00 2001 From: omgol411 Date: Mon, 13 Apr 2026 22:53:35 +0530 Subject: [PATCH 3/4] use iterator for batches of bead pairs - from coderabbit suggestion --- src/patch_computer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/patch_computer.py b/src/patch_computer.py index 4440226..2ccb220 100644 --- a/src/patch_computer.py +++ b/src/patch_computer.py @@ -35,15 +35,17 @@ def initialize_worker(coords_, radius_): coords = coords_ radius = radius_ +def batch_pair_iterator(args, batch_size): + pairs = itertools.combinations(args, 2) + while True: + batch = list(itertools.islice(pairs, batch_size)) + if not batch: + break + yield batch + def calc_distance_matrix(args, coords, radius, cores=min(max(os.cpu_count() - 1, 1), 16)): - # pairs = itertools.combinations(args, 2) - # mean_dist = [] - # for p in pairs: - # surface_distance = np.linalg.norm(coords[:,p[0],:] - coords[:,p[1],:], axis=1) -(radius[p[0]] + radius[p[1]]) - # mean_dist.append(np.mean(surface_distance) if np.mean(surface_distance) >= 0 else 0) - batches = np.array_split(list(itertools.combinations(args, 2)), cores) with Pool(cores, initializer=initialize_worker, initargs=(coords, radius)) as pool: - results = pool.map(worker_calc_distance, batches) + results = pool.map(worker_calc_distance, batch_pair_iterator(args, batch_size=1000)) mean_dist = [dist for batch in results for dist in batch] return to_array(mean_dist, args) From 500420a074208b32eb30792b6a865215459d2f0c Mon Sep 17 00:00:00 2001 From: omgol411 Date: Mon, 13 Apr 2026 22:55:30 +0530 Subject: [PATCH 4/4] specify default cores in one place and reuse --- src/patch_computer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/patch_computer.py b/src/patch_computer.py index 2ccb220..169733d 100644 --- a/src/patch_computer.py +++ b/src/patch_computer.py @@ -3,6 +3,7 @@ import jenkspy import os from multiprocessing import Pool +DEFAULT_CORES = min(max((os.cpu_count() or 1) - 1, 1), 16) def calc_bead_spread(tup, grid): inds = tup[0] @@ -43,7 +44,7 @@ def batch_pair_iterator(args, batch_size): break yield batch -def calc_distance_matrix(args, coords, radius, cores=min(max(os.cpu_count() - 1, 1), 16)): +def calc_distance_matrix(args, coords, radius, cores=DEFAULT_CORES): with Pool(cores, initializer=initialize_worker, initargs=(coords, radius)) as pool: results = pool.map(worker_calc_distance, batch_pair_iterator(args, batch_size=1000)) mean_dist = [dist for batch in results for dist in batch] @@ -52,7 +53,7 @@ def calc_distance_matrix(args, coords, radius, cores=min(max(os.cpu_count() - 1, def thresh_to_arg(bead_spread, low_thresh, high_thresh): return [ n for n,i in enumerate(bead_spread) if i >= low_thresh and i <= high_thresh ] -def get_connected_components(arg, coords, radius, thresh=10, cores=min(max(os.cpu_count() - 1, 1), 16)): +def get_connected_components(arg, coords, radius, thresh=10, cores=DEFAULT_CORES): import networkx as nx dist = calc_distance_matrix(arg, coords, radius, cores=cores) true_pairs = np.argwhere(dist < thresh) @@ -66,7 +67,7 @@ def get_connected_components(arg, coords, radius, thresh=10, cores=min(max(os.cp clusts.append(list(connected_component)) return clusts -def get_patches(bead_spread, classes, coords, radius, cores=min(max(os.cpu_count() - 1, 1), 16)): +def get_patches(bead_spread, classes, coords, radius, cores=DEFAULT_CORES): breaks = jenkspy.jenks_breaks(bead_spread, n_classes= (classes*2) + 1) arg_patches = [thresh_to_arg(bead_spread, breaks[i-1], breaks[i]) for i in range(1,len(breaks))] patches = [get_connected_components(arg, coords, radius, thresh=10, cores=cores) for arg in arg_patches]