Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
132 changes: 132 additions & 0 deletions acpreprocessing/stitching_modules/acstitch/extract_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 11 11:19:46 2025

@author: kevint
"""

import argschema
import numpy
import json
from skimage.feature import blob_log,blob_dog,blob_doh
from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_group,get_zarr_array,ZarrV3Metadata
from acpreprocessing.stitching_modules.acstitch.io import save_pointmatch_file

DEFAULT_METHOD = blob_dog

def detect_blobs(data,method=None,**kwargs):
# input 3d image data array
# return blobs detected by method
if method == "log":
blob_func = blob_log
elif method == "dog":
blob_func = blob_dog
elif method == "doh":
blob_func = blob_doh
else:
blob_func = DEFAULT_METHOD

return blob_func(data,**kwargs)


def detect_blobs_roi(zarray,z_range=None,y_range=None,x_range=None,method=None,blob_kwargs=None):
# input zarr array and roi defined by ranges (at miplvl)
# return blobs detected in roi
def get_roi(zarray,z_range,y_range,x_range):
return zarray[0,0,z_range[0]:z_range[1],y_range[0]:y_range[1],x_range[0]:x_range[1]]
blob_kwargs = ({} if blob_kwargs is None else blob_kwargs)

data = get_roi(zarray,z_range,y_range,x_range)
blobs = detect_blobs(data,method,**blob_kwargs)
if len(blobs)>0:
values = data[blobs[:,0].astype(int),blobs[:,1].astype(int),blobs[:,2].astype(int)]
blobs[:,:3] = blobs[:,:3] + numpy.array([[z_range[0],y_range[0],x_range[0]]])
return blobs,values
print("no blobs detected!!")
return None,None


def extract_points(zarray,roi_list,method,blob_kwargs=None):
# input tile path to run method with blob kwargs at mip level
# return
blobs = []
values = []
for roi in roi_list:
blobs_roi,values_roi = detect_blobs_roi(zarray,roi["z"],roi["y"],roi["x"],method=method,blob_kwargs=blob_kwargs)
if not blobs_roi is None:
blobs.append(blobs_roi)
values.append(values_roi)
return blobs,values


def save_extracted_pointmatch_file(src_tile,trg_tile,output_file,miplvl,roi_list,method,blob_kwargs=None,n_points=1,offset=None):
if offset is None:
p_md = ZarrV3Metadata(zgroup=get_zarr_group(src_tile))
q_md = ZarrV3Metadata(zgroup=get_zarr_group(trg_tile))
offset = numpy.divide(q_md.get_coordinate_translation(),q_md.mip_voxel_dims(miplvl))-numpy.divide(p_md.get_coordinate_translation(),p_md.mip_voxel_dims(miplvl))
zarray = get_zarr_array(src_tile,miplvl=miplvl)
src_points_rois,values_rois = extract_points(zarray,roi_list,method,blob_kwargs)
src_points = []
for points,values in zip(src_points_rois,values_rois):
if len(points) > 0:
if len(points) < n_points:
n = len(points)
else:
n = n_points
ind = numpy.argpartition(values, -n)[-n:]
src_points.append(points[ind,:3])
if src_points:
src_points = numpy.concatenate(src_points)
trg_points = src_points - numpy.asarray(offset)
src_points *= 2**miplvl
trg_points *= 2**miplvl
else:
trg_points = []
tspec = [{"p_tile":src_tile,"q_tile":trg_tile,"p_pts":src_points.astype(int),"q_pts":trg_points.astype(int)}]
save_pointmatch_file(tspec,output_file)


def extract_points_from_tiles(src_tile,trg_tile,output_file,miplvl,roi_file,method,blob_kwargs=None,points_kwargs=None):
points_kwargs = ({} if points_kwargs is None else points_kwargs)
if roi_file is None:
roi_list = []
else:
with open(roi_file,'r') as f:
roi_list = json.load(f)
save_extracted_pointmatch_file(src_tile,trg_tile,output_file,miplvl,roi_list,method,blob_kwargs,**points_kwargs)


class BlobDetectionParameters(argschema.schemas.DefaultSchema):
method = argschema.fields.Str(required=False,default=None)
blob_kwargs = argschema.fields.Dict(required=False,default=None)


class PointExtractionParameters(argschema.schemas.DefaultSchema):
n_points = argschema.fields.Int(required=False,default=1)


class ExtractPointsParameters(argschema.ArgSchema,
BlobDetectionParameters,
PointExtractionParameters):
p_tile = argschema.fields.Str(required=True)
q_tile = argschema.fields.Str(required=True)
output_file = argschema.fields.Str(required=True)
mip_lvl = argschema.fields.Int(required=False,default=0)
roi_file = argschema.fields.Str(required=False,default=None)


class ExtractPointsFromTilePair(argschema.ArgSchemaParser):
default_schema = ExtractPointsParameters

def run(self):
extract_points_from_tiles(self.args['p_tile'],
self.args['q_tile'],
self.args['output_file'],
self.args['mip_lvl'],
self.args['roi_file'],
self.args['method'],
blob_kwargs=self.args['blob_kwargs'])

if __name__ == "__main__":
mod = ExtractPointsFromTilePair()
mod.run()
221 changes: 133 additions & 88 deletions acpreprocessing/stitching_modules/acstitch/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,102 +5,147 @@
@author: kevint
"""
import numpy
from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments
import argschema
#from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments
from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences
from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src
from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file
from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_array
from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file,save_pointmatch_file


def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None):
p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist]
q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist]
# sd = SiftDetector(**sift_kwargs)
if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]:
sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"])
else:
sift_pmlist = None
if sift_pmlist is None:
if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None:
#roilist = stitch_kwargs["roi_list"]
p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs)
else:
p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, ij_shift, ns, ds
else:
roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx
p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs)
pmlist = []
if not p_ptlist is None:
for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist):
if not p_pts is None and len(p_pts) > 0:
pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts})
else:
pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":None,"q_pts":None})
return pmlist


def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None):
if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]:
print("running crosscorrelation with points from " + stitch_kwargs["sift_pointmatch_file"])
sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"])
else:
sift_pmlist = None
p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist]
q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist]
pmlist = []
for i in range(len(p_datasets)):
print("computing pointmatches for source pair " + str(i))
pds = p_datasets[i]
qds = q_datasets[i]
if not sift_pmlist is None:
if i < len(sift_pmlist) and not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0:
ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs)
else:
ppts = None
qpts = None
else:
ppts,qpts = run_ccorr(**ccorr_kwargs)
if not ppts is None and len(ppts) > 0:
pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts})
else:
pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":None,"q_pts":None})
return pmlist

# def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None):
# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist]
# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist]
# # sd = SiftDetector(**sift_kwargs)
# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]:
# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"])
# else:
# sift_pmlist = None
# if sift_pmlist is None:
# if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None:
# #roilist = stitch_kwargs["roi_list"]
# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs)
# else:
# p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, ij_shift, ns, ds
# else:
# roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx
# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs)
# pmlist = []
# if not p_ptlist is None:
# for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist):
# if not p_pts is None and len(p_pts) > 0:
# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts})
# else:
# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":None,"q_pts":None})
# return pmlist


# def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None):
# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]:
# print("running crosscorrelation with points from " + stitch_kwargs["sift_pointmatch_file"])
# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"])
# else:
# sift_pmlist = None
# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist]
# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist]
# pmlist = []
# for i in range(len(p_datasets)):
# print("computing pointmatches for source pair " + str(i))
# pds = p_datasets[i]
# qds = q_datasets[i]
# if not sift_pmlist is None:
# if i < len(sift_pmlist) and not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0:
# ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs)
# else:
# ppts = None
# qpts = None
# else:
# ppts,qpts = run_ccorr(**ccorr_kwargs)
# if not ppts is None and len(ppts) > 0:
# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts})
# else:
# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":None,"q_pts":None})
# return pmlist


# def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None):
# # TODO: handle overly granular bins with potentially 0 sift points returned
# if axis_range is None:
# axis_range = [[] for i in range(p_siftpts.shape[1])]
# if len(axis_range[0]) == 0:
# zstarts = numpy.linspace(numpy.min(p_siftpts[:,0]),numpy.max(p_siftpts[:,0]),n_cc_pts+1)
# p_pts = numpy.empty((n_cc_pts,3),dtype=int)
# q_pts = numpy.empty((n_cc_pts,3),dtype=int)
# for i in range(n_cc_pts):
# r = numpy.full(p_siftpts.shape[0],True)
# for ai,a in enumerate(axis_range):
# if len(a)>0:
# r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1]))
# elif ai == 0:
# r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1]))
# pr = p_siftpts[r]
# qr = q_siftpts[r]
# if len(pr) > 0:
# imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]])
# ppt = pr[imax,:]
# qpt = qr[imax,:]
# else:
# ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int)
# qpt = ppt + numpy.array(axis_shift)
# p_pts[i] = ppt
# q_pts[i] = qpt
# return p_pts,q_pts


def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs):
p_pts,q_pts = get_cc_points_from_sift(p_ds, q_ds, p_siftpts, q_siftpts,n_cc_pts,axis_shift,axis_range)
def run_ccorr(p_ds,q_ds,p_pts,q_pts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8):
ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold)
return ppm,qpm


def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None):
# TODO: handle overly granular bins with potentially 0 sift points returned
if axis_range is None:
axis_range = [[] for i in range(p_siftpts.shape[1])]
if len(axis_range[0]) == 0:
zstarts = numpy.linspace(numpy.min(p_siftpts[:,0]),numpy.max(p_siftpts[:,0]),n_cc_pts+1)
p_pts = numpy.empty((n_cc_pts,3),dtype=int)
q_pts = numpy.empty((n_cc_pts,3),dtype=int)
for i in range(n_cc_pts):
r = numpy.full(p_siftpts.shape[0],True)
for ai,a in enumerate(axis_range):
if len(a)>0:
r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1]))
elif ai == 0:
r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1]))
pr = p_siftpts[r]
qr = q_siftpts[r]
if len(pr) > 0:
imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]])
ppt = pr[imax,:]
qpt = qr[imax,:]
else:
ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int)
qpt = ppt + numpy.array(axis_shift)
p_pts[i] = ppt
q_pts[i] = qpt
return p_pts,q_pts
def get_dataset_from_path(tilepath,miplvl=0):
return get_zarr_array(tilepath,miplvl=miplvl)


def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,miplvl=0,stitch_kwargs=None):
stitch_kwargs = ({} if stitch_kwargs is None else stitch_kwargs)
if stitch_method == "ccorr":
kwargs = {"p_ds":get_dataset_from_path(p_tilepath,miplvl),
"q_ds":get_dataset_from_path(q_tilepath,miplvl),
"p_pts": p_points/(2**miplvl),
"q_pts": q_points/(2**miplvl)}
p_pts,q_pts = run_ccorr(**kwargs,**stitch_kwargs)
else:
return run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method="ccorr",stitch_kwargs=stitch_kwargs)
return {"p_tile":p_tilepath,"q_tile":q_tilepath,"p_pts":p_pts*(2**miplvl),"q_pts":q_pts*(2**miplvl)}


def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,miplvl=0,stitch_kwargs=None):
in_pms = read_pointmatch_file(input_file)
out_pms = []
for tspec in in_pms:
args = [tspec.get(key) for key in ["p_tile","q_tile","p_pts","q_pts"]]
pmdict = run_stitch_method(*args,stitch_method=stitch_method,miplvl=miplvl,stitch_kwargs=stitch_kwargs)
# else:
# print("WARNING: tspec contains None")
# pmdict = None
out_pms.append(pmdict)
save_pointmatch_file(out_pms,output_file)


class StitchTilesParameters(argschema.ArgSchema):
input_file = argschema.fields.Str(required=True)
output_file = argschema.fields.Str(required=True)
stitch_method = argschema.fields.Str(required=True)
mip_lvl = argschema.fields.Int(required=False,default=0)
stitch_kwargs = argschema.fields.Dict(required=False,default=None)


class StitchTiles(argschema.ArgSchemaParser):
default_schema = StitchTilesParameters

def run_ccorr(**kwargs):
pass
def run(self):
stitch_tiles_from_pmfile(**self.args)


if __name__ == "__main__":
mod = StitchTiles()
mod.run()
Loading
Loading