Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run_prism( coords, mass, radius, ps_names, args, output_dir = None ):
fl.write("\n")

# Obtain patches for all the beads.
patches = get_patches(bead_spread, args.classes, coords, radius)
patches = get_patches(bead_spread, args.classes, coords, radius, cores_)
# Annotate the patches for low-med-high precision.
annotated_patches = annotate_patches(patches, args.classes, ps_names, coords.shape[1])
high_prec, low_prec = patches[:args.classes], patches[args.classes+1:]
Expand Down
39 changes: 30 additions & 9 deletions src/patch_computer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import itertools
import jenkspy
import os
from multiprocessing import Pool
DEFAULT_CORES = min(max((os.cpu_count() or 1) - 1, 1), 16)

def calc_bead_spread(tup, grid):
inds = tup[0]
Expand All @@ -21,20 +24,38 @@ def to_array(final, ids):
full_arr[ids.index(b),ids.index(a)] = final[i]
return full_arr

def calc_distance_matrix(args, coords, radius):
pairs = itertools.combinations(args, 2)
mean_dist = []
for p in pairs:
def worker_calc_distance(batch_pairs):
batch_mean_dist = []
for p in batch_pairs:
surface_distance = np.linalg.norm(coords[:,p[0],:] - coords[:,p[1],:], axis=1) -(radius[p[0]] + radius[p[1]])
mean_dist.append(np.mean(surface_distance) if np.mean(surface_distance) >= 0 else 0)
batch_mean_dist.append(max(np.mean(surface_distance), 0))
return batch_mean_dist

def initialize_worker(coords_, radius_):
global coords, radius
coords = coords_
radius = radius_

def batch_pair_iterator(args, batch_size):
pairs = itertools.combinations(args, 2)
while True:
batch = list(itertools.islice(pairs, batch_size))
if not batch:
break
yield batch

def calc_distance_matrix(args, coords, radius, cores=DEFAULT_CORES):
with Pool(cores, initializer=initialize_worker, initargs=(coords, radius)) as pool:
results = pool.map(worker_calc_distance, batch_pair_iterator(args, batch_size=1000))
mean_dist = [dist for batch in results for dist in batch]
return to_array(mean_dist, args)

def thresh_to_arg(bead_spread, low_thresh, high_thresh):
return [ n for n,i in enumerate(bead_spread) if i >= low_thresh and i <= high_thresh ]

def get_connected_components(arg, coords, radius, thresh=10):
def get_connected_components(arg, coords, radius, thresh=10, cores=DEFAULT_CORES):
import networkx as nx
dist = calc_distance_matrix(arg, coords, radius)
dist = calc_distance_matrix(arg, coords, radius, cores=cores)
true_pairs = np.argwhere(dist < thresh)
l = []
for tp in true_pairs:
Expand All @@ -46,10 +67,10 @@ def get_connected_components(arg, coords, radius, thresh=10):
clusts.append(list(connected_component))
return clusts

def get_patches(bead_spread, classes, coords, radius):
def get_patches(bead_spread, classes, coords, radius, cores=DEFAULT_CORES):
breaks = jenkspy.jenks_breaks(bead_spread, n_classes= (classes*2) + 1)
arg_patches = [thresh_to_arg(bead_spread, breaks[i-1], breaks[i]) for i in range(1,len(breaks))]
patches = [get_connected_components(arg, coords, radius) for arg in arg_patches]
patches = [get_connected_components(arg, coords, radius, thresh=10, cores=cores) for arg in arg_patches]
return patches

def annotate_patches(patches, classes, ps_names, num_beads):
Expand Down
Loading