diff --git a/README.md b/README.md index 8f01e77..be38183 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,16 @@ # axonal_connectomics Repository for tools developed for axonal connectomics -# Level of support +# Stitching modules +## Deskew and zarr conversion +acpreprocessing.stitching_modules.convert_to_n5.tiff_to_ngff +- Sequentially reads image arrays from a tiff stack series for pixel-wise deskew (optional), for computing a downsampling pyramid to a user-defined depth, and for writing out the data volume into a next-generation file format (zarr v3). -We are not currently supporting this code, but simply releasing it to the community AS IS but are not able to provide any guarantees of support. The community is welcome to submit issues and pull requests, but you should not expect an active response. +## Tile stitching +acpreprocessing.stitching_modules.acstitch.stitch +- Generate point correspondences between tiles from template matching or SIFT features at user-defined resolution level (mip). - -##Stitching requirements: +## Stitching requirements: Set these env variables: ``` export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk-amd64 diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index d699973..52e8fd4 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -6,106 +6,35 @@ """ import numpy import argschema -#from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments +from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_array from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file,save_pointmatch_file -# def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): -# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] -# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] -# # sd = SiftDetector(**sift_kwargs) -# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: -# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) -# else: -# sift_pmlist = None -# if sift_pmlist is None: -# if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None: -# #roilist = stitch_kwargs["roi_list"] -# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) -# else: -# p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, ij_shift, ns, ds -# else: -# roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx -# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs) -# pmlist = [] -# if not p_ptlist is None: -# for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): -# if not p_pts is None and len(p_pts) > 0: -# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) -# else: -# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":None,"q_pts":None}) -# return pmlist - - -# def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None): -# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: -# print("running crosscorrelation with points from " + stitch_kwargs["sift_pointmatch_file"]) -# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) -# else: -# sift_pmlist = None -# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] -# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] -# pmlist = [] -# for i in range(len(p_datasets)): -# print("computing pointmatches for source pair " + str(i)) -# pds = p_datasets[i] -# qds = q_datasets[i] -# if not sift_pmlist is None: -# if i < len(sift_pmlist) and not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0: -# ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs) -# else: -# ppts = None -# qpts = None -# else: -# ppts,qpts = run_ccorr(**ccorr_kwargs) -# if not ppts is None and len(ppts) > 0: -# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts}) -# else: -# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":None,"q_pts":None}) -# return pmlist - - -# def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): -# # TODO: handle overly granular bins with potentially 0 sift points returned -# if axis_range is None: -# axis_range = [[] for i in range(p_siftpts.shape[1])] -# if len(axis_range[0]) == 0: -# zstarts = numpy.linspace(numpy.min(p_siftpts[:,0]),numpy.max(p_siftpts[:,0]),n_cc_pts+1) -# p_pts = numpy.empty((n_cc_pts,3),dtype=int) -# q_pts = numpy.empty((n_cc_pts,3),dtype=int) -# for i in range(n_cc_pts): -# r = numpy.full(p_siftpts.shape[0],True) -# for ai,a in enumerate(axis_range): -# if len(a)>0: -# r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1])) -# elif ai == 0: -# r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1])) -# pr = p_siftpts[r] -# qr = q_siftpts[r] -# if len(pr) > 0: -# imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) -# ppt = pr[imax,:] -# qpt = qr[imax,:] -# else: -# ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int) -# qpt = ppt + numpy.array(axis_shift) -# p_pts[i] = ppt -# q_pts[i] = qpt -# return p_pts,q_pts - - def run_ccorr(p_ds,q_ds,p_pts,q_pts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8): ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold) return ppm,qpm +def run_sift(p_ds,q_ds,miplvl=0,sift_kwargs=None,stitch_kwargs=None): + if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: + sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) + else: + sift_pmlist = None + if sift_pmlist is None: + if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None: + p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_ds,q_ds,**stitch_kwargs) + else: + p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_ds,q_ds,**stitch_kwargs) + else: + roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) + p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_ds,q_ds,roilist,**stitch_kwargs) + return p_ptlist,q_ptlist def get_dataset_from_path(tilepath,miplvl=0): return get_zarr_array(tilepath,miplvl=miplvl) - -def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,miplvl=0,stitch_kwargs=None): +def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,miplvl=0,sift_kwargs=None,stitch_kwargs=None): stitch_kwargs = ({} if stitch_kwargs is None else stitch_kwargs) if stitch_method == "ccorr": kwargs = {"p_ds":get_dataset_from_path(p_tilepath,miplvl), @@ -113,17 +42,22 @@ def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,mipl "p_pts": p_points/(2**miplvl), "q_pts": q_points/(2**miplvl)} p_pts,q_pts = run_ccorr(**kwargs,**stitch_kwargs) + elif stitch_method == "sift": + kwargs = {"p_ds":get_dataset_from_path(p_tilepath,miplvl), + "q_ds":get_dataset_from_path(q_tilepath,miplvl), + "miplvl":miplvl} + p_pts,q_pts = run_sift(**kwargs,sift_kwargs=sift_kwargs,stitch_kwargs=stitch_kwargs) else: return run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method="ccorr",stitch_kwargs=stitch_kwargs) return {"p_tile":p_tilepath,"q_tile":q_tilepath,"p_pts":p_pts*(2**miplvl),"q_pts":q_pts*(2**miplvl)} -def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,miplvl=0,stitch_kwargs=None): +def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,miplvl=0,sift_kwargs=None,stitch_kwargs=None): in_pms = read_pointmatch_file(input_file) out_pms = [] for tspec in in_pms: args = [tspec.get(key) for key in ["p_tile","q_tile","p_pts","q_pts"]] - pmdict = run_stitch_method(*args,stitch_method=stitch_method,miplvl=miplvl,stitch_kwargs=stitch_kwargs) + pmdict = run_stitch_method(*args,stitch_method=stitch_method,miplvl=miplvl,sift_kwargs=sift_kwargs,stitch_kwargs=stitch_kwargs) # else: # print("WARNING: tspec contains None") # pmdict = None @@ -135,7 +69,8 @@ class StitchTilesParameters(argschema.ArgSchema): input_file = argschema.fields.Str(required=True) output_file = argschema.fields.Str(required=True) stitch_method = argschema.fields.Str(required=True) - mip_lvl = argschema.fields.Int(required=False,default=0) + miplvl = argschema.fields.Int(required=False,default=0) + sift_kwargs = argschema.fields.Dict(required=False,default=None) stitch_kwargs = argschema.fields.Dict(required=False,default=None)