Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# axonal_connectomics
Repository for tools developed for axonal connectomics

# Level of support
# Stitching modules
## Deskew and zarr conversion
acpreprocessing.stitching_modules.convert_to_n5.tiff_to_ngff
- Sequentially reads image arrays from a tiff stack series for pixel-wise deskew (optional), for computing a downsampling pyramid to a user-defined depth, and for writing out the data volume into a next-generation file format (zarr v3).

We are not currently supporting this code, but simply releasing it to the community AS IS but are not able to provide any guarantees of support. The community is welcome to submit issues and pull requests, but you should not expect an active response.
## Tile stitching
acpreprocessing.stitching_modules.acstitch.stitch
- Generate point correspondences between tiles from template matching or SIFT features at user-defined resolution level (mip).


##Stitching requirements:
## Stitching requirements:
Set these env variables:
```
export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk-amd64
Expand Down
115 changes: 25 additions & 90 deletions acpreprocessing/stitching_modules/acstitch/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,124 +6,58 @@
"""
import numpy
import argschema
#from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments
from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments
from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences
from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_array
from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file,save_pointmatch_file


# def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None):
# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist]
# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist]
# # sd = SiftDetector(**sift_kwargs)
# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]:
# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"])
# else:
# sift_pmlist = None
# if sift_pmlist is None:
# if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None:
# #roilist = stitch_kwargs["roi_list"]
# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs)
# else:
# p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, ij_shift, ns, ds
# else:
# roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx
# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs)
# pmlist = []
# if not p_ptlist is None:
# for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist):
# if not p_pts is None and len(p_pts) > 0:
# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts})
# else:
# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":None,"q_pts":None})
# return pmlist


# def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None):
# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]:
# print("running crosscorrelation with points from " + stitch_kwargs["sift_pointmatch_file"])
# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"])
# else:
# sift_pmlist = None
# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist]
# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist]
# pmlist = []
# for i in range(len(p_datasets)):
# print("computing pointmatches for source pair " + str(i))
# pds = p_datasets[i]
# qds = q_datasets[i]
# if not sift_pmlist is None:
# if i < len(sift_pmlist) and not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0:
# ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs)
# else:
# ppts = None
# qpts = None
# else:
# ppts,qpts = run_ccorr(**ccorr_kwargs)
# if not ppts is None and len(ppts) > 0:
# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts})
# else:
# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":None,"q_pts":None})
# return pmlist


# def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None):
# # TODO: handle overly granular bins with potentially 0 sift points returned
# if axis_range is None:
# axis_range = [[] for i in range(p_siftpts.shape[1])]
# if len(axis_range[0]) == 0:
# zstarts = numpy.linspace(numpy.min(p_siftpts[:,0]),numpy.max(p_siftpts[:,0]),n_cc_pts+1)
# p_pts = numpy.empty((n_cc_pts,3),dtype=int)
# q_pts = numpy.empty((n_cc_pts,3),dtype=int)
# for i in range(n_cc_pts):
# r = numpy.full(p_siftpts.shape[0],True)
# for ai,a in enumerate(axis_range):
# if len(a)>0:
# r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1]))
# elif ai == 0:
# r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1]))
# pr = p_siftpts[r]
# qr = q_siftpts[r]
# if len(pr) > 0:
# imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]])
# ppt = pr[imax,:]
# qpt = qr[imax,:]
# else:
# ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int)
# qpt = ppt + numpy.array(axis_shift)
# p_pts[i] = ppt
# q_pts[i] = qpt
# return p_pts,q_pts


def run_ccorr(p_ds,q_ds,p_pts,q_pts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8):
ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold)
return ppm,qpm

def run_sift(p_ds,q_ds,miplvl=0,sift_kwargs=None,stitch_kwargs=None):
if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]:
sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"])
else:
sift_pmlist = None
if sift_pmlist is None:
if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None:
p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_ds,q_ds,**stitch_kwargs)
else:
p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_ds,q_ds,**stitch_kwargs)
else:
roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs)
p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_ds,q_ds,roilist,**stitch_kwargs)
return p_ptlist,q_ptlist

def get_dataset_from_path(tilepath,miplvl=0):
return get_zarr_array(tilepath,miplvl=miplvl)


def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,miplvl=0,stitch_kwargs=None):
def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,miplvl=0,sift_kwargs=None,stitch_kwargs=None):
stitch_kwargs = ({} if stitch_kwargs is None else stitch_kwargs)
if stitch_method == "ccorr":
kwargs = {"p_ds":get_dataset_from_path(p_tilepath,miplvl),
"q_ds":get_dataset_from_path(q_tilepath,miplvl),
"p_pts": p_points/(2**miplvl),
"q_pts": q_points/(2**miplvl)}
p_pts,q_pts = run_ccorr(**kwargs,**stitch_kwargs)
elif stitch_method == "sift":
kwargs = {"p_ds":get_dataset_from_path(p_tilepath,miplvl),
"q_ds":get_dataset_from_path(q_tilepath,miplvl),
"miplvl":miplvl}
p_pts,q_pts = run_sift(**kwargs,sift_kwargs=sift_kwargs,stitch_kwargs=stitch_kwargs)
else:
return run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method="ccorr",stitch_kwargs=stitch_kwargs)
return {"p_tile":p_tilepath,"q_tile":q_tilepath,"p_pts":p_pts*(2**miplvl),"q_pts":q_pts*(2**miplvl)}


def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,miplvl=0,stitch_kwargs=None):
def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,miplvl=0,sift_kwargs=None,stitch_kwargs=None):
in_pms = read_pointmatch_file(input_file)
out_pms = []
for tspec in in_pms:
args = [tspec.get(key) for key in ["p_tile","q_tile","p_pts","q_pts"]]
pmdict = run_stitch_method(*args,stitch_method=stitch_method,miplvl=miplvl,stitch_kwargs=stitch_kwargs)
pmdict = run_stitch_method(*args,stitch_method=stitch_method,miplvl=miplvl,sift_kwargs=sift_kwargs,stitch_kwargs=stitch_kwargs)
# else:
# print("WARNING: tspec contains None")
# pmdict = None
Expand All @@ -135,7 +69,8 @@ class StitchTilesParameters(argschema.ArgSchema):
input_file = argschema.fields.Str(required=True)
output_file = argschema.fields.Str(required=True)
stitch_method = argschema.fields.Str(required=True)
mip_lvl = argschema.fields.Int(required=False,default=0)
miplvl = argschema.fields.Int(required=False,default=0)
sift_kwargs = argschema.fields.Dict(required=False,default=None)
stitch_kwargs = argschema.fields.Dict(required=False,default=None)


Expand Down
Loading