diff --git a/src/main.py b/src/main.py index fd87e07..295c3c4 100644 --- a/src/main.py +++ b/src/main.py @@ -85,7 +85,7 @@ def run_prism( coords, mass, radius, ps_names, args, output_dir = None ): fl.write("\n") # Obtain patches for all the beads. - patches = get_patches(bead_spread, args.classes, coords, radius) + patches = get_patches(bead_spread, args.classes, coords, radius, cores_) # Annotate the patches for low-med-high precision. annotated_patches = annotate_patches(patches, args.classes, ps_names, coords.shape[1]) high_prec, low_prec = patches[:args.classes], patches[args.classes+1:] diff --git a/src/patch_computer.py b/src/patch_computer.py index ae37224..169733d 100644 --- a/src/patch_computer.py +++ b/src/patch_computer.py @@ -1,6 +1,9 @@ import numpy as np import itertools import jenkspy +import os +from multiprocessing import Pool +DEFAULT_CORES = min(max((os.cpu_count() or 1) - 1, 1), 16) def calc_bead_spread(tup, grid): inds = tup[0] @@ -21,20 +24,38 @@ def to_array(final, ids): full_arr[ids.index(b),ids.index(a)] = final[i] return full_arr -def calc_distance_matrix(args, coords, radius): - pairs = itertools.combinations(args, 2) - mean_dist = [] - for p in pairs: +def worker_calc_distance(batch_pairs): + batch_mean_dist = [] + for p in batch_pairs: surface_distance = np.linalg.norm(coords[:,p[0],:] - coords[:,p[1],:], axis=1) -(radius[p[0]] + radius[p[1]]) - mean_dist.append(np.mean(surface_distance) if np.mean(surface_distance) >= 0 else 0) + batch_mean_dist.append(max(np.mean(surface_distance), 0)) + return batch_mean_dist + +def initialize_worker(coords_, radius_): + global coords, radius + coords = coords_ + radius = radius_ + +def batch_pair_iterator(args, batch_size): + pairs = itertools.combinations(args, 2) + while True: + batch = list(itertools.islice(pairs, batch_size)) + if not batch: + break + yield batch + +def calc_distance_matrix(args, coords, radius, cores=DEFAULT_CORES): + with Pool(cores, initializer=initialize_worker, initargs=(coords, radius)) as pool: + results = pool.map(worker_calc_distance, batch_pair_iterator(args, batch_size=1000)) + mean_dist = [dist for batch in results for dist in batch] return to_array(mean_dist, args) def thresh_to_arg(bead_spread, low_thresh, high_thresh): return [ n for n,i in enumerate(bead_spread) if i >= low_thresh and i <= high_thresh ] -def get_connected_components(arg, coords, radius, thresh=10): +def get_connected_components(arg, coords, radius, thresh=10, cores=DEFAULT_CORES): import networkx as nx - dist = calc_distance_matrix(arg, coords, radius) + dist = calc_distance_matrix(arg, coords, radius, cores=cores) true_pairs = np.argwhere(dist < thresh) l = [] for tp in true_pairs: @@ -46,10 +67,10 @@ def get_connected_components(arg, coords, radius, thresh=10): clusts.append(list(connected_component)) return clusts -def get_patches(bead_spread, classes, coords, radius): +def get_patches(bead_spread, classes, coords, radius, cores=DEFAULT_CORES): breaks = jenkspy.jenks_breaks(bead_spread, n_classes= (classes*2) + 1) arg_patches = [thresh_to_arg(bead_spread, breaks[i-1], breaks[i]) for i in range(1,len(breaks))] - patches = [get_connected_components(arg, coords, radius) for arg in arg_patches] + patches = [get_connected_components(arg, coords, radius, thresh=10, cores=cores) for arg in arg_patches] return patches def annotate_patches(patches, classes, ps_names, num_beads):