From 5bd2deda979a6c3f315dc1097053f5570fefb684 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 11 Sep 2025 14:42:51 -0700 Subject: [PATCH 1/6] refactoring and modularizing stitching code --- .../acstitch/extract_points.py | 22 +++++ .../stitching_modules/acstitch/stitch.py | 99 +++++++++++++------ 2 files changed, 90 insertions(+), 31 deletions(-) create mode 100644 acpreprocessing/stitching_modules/acstitch/extract_points.py diff --git a/acpreprocessing/stitching_modules/acstitch/extract_points.py b/acpreprocessing/stitching_modules/acstitch/extract_points.py new file mode 100644 index 00000000..5b1f6b25 --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/extract_points.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Sep 11 11:19:46 2025 + +@author: kevint +""" + +import argschema + +class StitchTilesParameters(argschema.ArgSchema): + ptile_path = argschema.fields.Str(required=True) + qtile_path = argschema.fields.Str(required=True) + +class StitchTiles(argschema.ArgSchemaParser): + default_schema = StitchTilesParameters + + def run(self): + pass + +if __name__ == "__main__": + mod = ExtractPointsFromTile() + mod.run() \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 0ba036ed..41f686fd 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -5,6 +5,7 @@ @author: kevint """ import numpy +import argschema from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src @@ -66,41 +67,77 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s return pmlist -def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs): - p_pts,q_pts = get_cc_points_from_sift(p_ds, q_ds, p_siftpts, q_siftpts,n_cc_pts,axis_shift,axis_range) +def run_ccorr(p_ds,q_ds,p_dict,q_dict,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs): + p_pts = get_points_from_tiledict(p_dict) + q_pts = get_points_from_tiledict(q_dict) ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold) return ppm,qpm -def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): - # TODO: handle overly granular bins with potentially 0 sift points returned - if axis_range is None: - axis_range = [[] for i in range(p_siftpts.shape[1])] - if len(axis_range[0]) == 0: - zstarts = numpy.linspace(numpy.min(p_siftpts[:,0]),numpy.max(p_siftpts[:,0]),n_cc_pts+1) - p_pts = numpy.empty((n_cc_pts,3),dtype=int) - q_pts = numpy.empty((n_cc_pts,3),dtype=int) - for i in range(n_cc_pts): - r = numpy.full(p_siftpts.shape[0],True) - for ai,a in enumerate(axis_range): - if len(a)>0: - r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1])) - elif ai == 0: - r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1])) - pr = p_siftpts[r] - qr = q_siftpts[r] - if len(pr) > 0: - imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) - ppt = pr[imax,:] - qpt = qr[imax,:] - else: - ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int) - qpt = ppt + numpy.array(axis_shift) - p_pts[i] = ppt - q_pts[i] = qpt - return p_pts,q_pts +def get_dataset_from_tilepath(tilepath): + pass + + +def get_points_from_tiledict(tiledict): + pass + + +# def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): +# # TODO: handle overly granular bins with potentially 0 sift points returned +# if axis_range is None: +# axis_range = [[] for i in range(p_siftpts.shape[1])] +# if len(axis_range[0]) == 0: +# zstarts = numpy.linspace(numpy.min(p_siftpts[:,0]),numpy.max(p_siftpts[:,0]),n_cc_pts+1) +# p_pts = numpy.empty((n_cc_pts,3),dtype=int) +# q_pts = numpy.empty((n_cc_pts,3),dtype=int) +# for i in range(n_cc_pts): +# r = numpy.full(p_siftpts.shape[0],True) +# for ai,a in enumerate(axis_range): +# if len(a)>0: +# r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1])) +# elif ai == 0: +# r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1])) +# pr = p_siftpts[r] +# qr = q_siftpts[r] +# if len(pr) > 0: +# imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) +# ppt = pr[imax,:] +# qpt = qr[imax,:] +# else: +# ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int) +# qpt = ppt + numpy.array(axis_shift) +# p_pts[i] = ppt +# q_pts[i] = qpt +# return p_pts,q_pts +def run_stitch_method(ptile_path,qtile_path,ptile_cors,qtile_cors,stitch_method,stitch_kwargs): + if stitch_method == "ccorr": + kwargs = {"p_ds":get_dataset_from_tilepath(ptile_path), + "q_ds":get_dataset_from_tilepath(qtile_path), + "p_dict":ptile_cors, + "q_dict":qtile_cors} + run_ccorr(**kwargs,**stitch_kwargs) + else: + pass + + +class StitchTilesParameters(argschema.ArgSchema): + ptile_path = argschema.fields.Str(required=True) + qtile_path = argschema.fields.Str(required=True) + ptile_cors = argschema.fields.Dict(required=True) + qtile_cors = argschema.fields.Dict(required=True) + stitch_method = argschema.fields.Str(required=True) + stitch_kwargs = argschema.fields.Dict(required=True) + + +class StitchTiles(argschema.ArgSchemaParser): + default_schema = StitchTilesParameters -def run_ccorr(**kwargs): - pass \ No newline at end of file + def run(self): + run_stitch_method(**self.args) + + +if __name__ == "__main__": + mod = StitchTiles() + mod.run() \ No newline at end of file From fbd66447924ca4da20c9eca586fff017ed08085a Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 11 Sep 2025 22:06:41 -0700 Subject: [PATCH 2/6] pointmatch file keystone --- .../acstitch/extract_points.py | 8 +- .../stitching_modules/acstitch/stitch.py | 169 +++++++++--------- 2 files changed, 90 insertions(+), 87 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/extract_points.py b/acpreprocessing/stitching_modules/acstitch/extract_points.py index 5b1f6b25..112b04d4 100644 --- a/acpreprocessing/stitching_modules/acstitch/extract_points.py +++ b/acpreprocessing/stitching_modules/acstitch/extract_points.py @@ -7,12 +7,14 @@ import argschema -class StitchTilesParameters(argschema.ArgSchema): + +class ExtractPointsParameters(argschema.ArgSchema): ptile_path = argschema.fields.Str(required=True) qtile_path = argschema.fields.Str(required=True) -class StitchTiles(argschema.ArgSchemaParser): - default_schema = StitchTilesParameters + +class ExtractPointsFromTile(argschema.ArgSchemaParser): + default_schema = ExtractPointsParameters def run(self): pass diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 41f686fd..d23163fb 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -9,77 +9,62 @@ from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src -from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file - - -def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): - p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] - q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] - # sd = SiftDetector(**sift_kwargs) - if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: - sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) - else: - sift_pmlist = None - if sift_pmlist is None: - if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None: - #roilist = stitch_kwargs["roi_list"] - p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) - else: - p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, ij_shift, ns, ds - else: - roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx - p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs) - pmlist = [] - if not p_ptlist is None: - for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): - if not p_pts is None and len(p_pts) > 0: - pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) - else: - pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":None,"q_pts":None}) - return pmlist - - -def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None): - if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: - print("running crosscorrelation with points from " + stitch_kwargs["sift_pointmatch_file"]) - sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) - else: - sift_pmlist = None - p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] - q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] - pmlist = [] - for i in range(len(p_datasets)): - print("computing pointmatches for source pair " + str(i)) - pds = p_datasets[i] - qds = q_datasets[i] - if not sift_pmlist is None: - if i < len(sift_pmlist) and not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0: - ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs) - else: - ppts = None - qpts = None - else: - ppts,qpts = run_ccorr(**ccorr_kwargs) - if not ppts is None and len(ppts) > 0: - pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts}) - else: - pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":None,"q_pts":None}) - return pmlist - - -def run_ccorr(p_ds,q_ds,p_dict,q_dict,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs): - p_pts = get_points_from_tiledict(p_dict) - q_pts = get_points_from_tiledict(q_dict) - ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold) - return ppm,qpm - - -def get_dataset_from_tilepath(tilepath): - pass - - -def get_points_from_tiledict(tiledict): - pass +from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file,save_pointmatch_file + + +# def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): +# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] +# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] +# # sd = SiftDetector(**sift_kwargs) +# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: +# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) +# else: +# sift_pmlist = None +# if sift_pmlist is None: +# if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None: +# #roilist = stitch_kwargs["roi_list"] +# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) +# else: +# p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, ij_shift, ns, ds +# else: +# roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx +# p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs) +# pmlist = [] +# if not p_ptlist is None: +# for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): +# if not p_pts is None and len(p_pts) > 0: +# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) +# else: +# pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":None,"q_pts":None}) +# return pmlist + + +# def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None): +# if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: +# print("running crosscorrelation with points from " + stitch_kwargs["sift_pointmatch_file"]) +# sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) +# else: +# sift_pmlist = None +# p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] +# q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] +# pmlist = [] +# for i in range(len(p_datasets)): +# print("computing pointmatches for source pair " + str(i)) +# pds = p_datasets[i] +# qds = q_datasets[i] +# if not sift_pmlist is None: +# if i < len(sift_pmlist) and not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0: +# ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs) +# else: +# ppts = None +# qpts = None +# else: +# ppts,qpts = run_ccorr(**ccorr_kwargs) +# if not ppts is None and len(ppts) > 0: +# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts}) +# else: +# pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":None,"q_pts":None}) +# return pmlist # def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): @@ -109,24 +94,40 @@ def get_points_from_tiledict(tiledict): # p_pts[i] = ppt # q_pts[i] = qpt # return p_pts,q_pts + + +def run_ccorr(p_ds,q_ds,p_pts,q_pts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs): + ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold) + return ppm,qpm + + +def get_dataset_from_path(tilepath): + return - -def run_stitch_method(ptile_path,qtile_path,ptile_cors,qtile_cors,stitch_method,stitch_kwargs): + +def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,stitch_kwargs): if stitch_method == "ccorr": - kwargs = {"p_ds":get_dataset_from_tilepath(ptile_path), - "q_ds":get_dataset_from_tilepath(qtile_path), - "p_dict":ptile_cors, - "q_dict":qtile_cors} - run_ccorr(**kwargs,**stitch_kwargs) + kwargs = {"p_ds":get_dataset_from_path(p_tilepath), + "q_ds":get_dataset_from_path(q_tilepath), + "p_pts": p_points, + "q_pts": q_points} + p_pts,q_pts = run_ccorr(**kwargs,**stitch_kwargs) else: - pass + return run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method="ccorr",stitch_kwargs=stitch_kwargs) + return {"p_tile":p_tilepath,"q_tile":q_tilepath,"p_pts":p_pts,"q_pts":q_pts} + +def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,stitch_kwargs): + in_pms = read_pointmatch_file(input_file) + args = [in_pms.get(key,None) for key in ["p_tile","q_tile","p_pts","q_pts"]] + if not None in args: + out_pms = run_stitch_method(*args,stitch_method=stitch_method,stitch_kwargs=stitch_kwargs) + save_pointmatch_file(out_pms,output_file) + class StitchTilesParameters(argschema.ArgSchema): - ptile_path = argschema.fields.Str(required=True) - qtile_path = argschema.fields.Str(required=True) - ptile_cors = argschema.fields.Dict(required=True) - qtile_cors = argschema.fields.Dict(required=True) + input_file = argschema.fields.Str(required=True) + output_file = argschema.fields.Str(required=True) stitch_method = argschema.fields.Str(required=True) stitch_kwargs = argschema.fields.Dict(required=True) @@ -135,7 +136,7 @@ class StitchTiles(argschema.ArgSchemaParser): default_schema = StitchTilesParameters def run(self): - run_stitch_method(**self.args) + stitch_tiles_from_pmfile(**self.args) if __name__ == "__main__": From a335ab084bea81aaf03e880b791df61ce78f9e37 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Sat, 11 Oct 2025 14:42:14 -0700 Subject: [PATCH 3/6] added blob-based point extraction methods . . . . --- .../stitching_modules/acstitch/__init__.py | 0 .../acstitch/extract_points.py | 63 +++++++++++++++++-- .../stitching_modules/acstitch/zarrutils.py | 12 +++- 3 files changed, 69 insertions(+), 6 deletions(-) create mode 100644 acpreprocessing/stitching_modules/acstitch/__init__.py diff --git a/acpreprocessing/stitching_modules/acstitch/__init__.py b/acpreprocessing/stitching_modules/acstitch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/acpreprocessing/stitching_modules/acstitch/extract_points.py b/acpreprocessing/stitching_modules/acstitch/extract_points.py index 112b04d4..9bc683a9 100644 --- a/acpreprocessing/stitching_modules/acstitch/extract_points.py +++ b/acpreprocessing/stitching_modules/acstitch/extract_points.py @@ -6,18 +6,73 @@ """ import argschema +from skimage.feature import blob_log,blob_dog,blob_doh +from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_array +DEFAULT_METHOD = blob_dog -class ExtractPointsParameters(argschema.ArgSchema): - ptile_path = argschema.fields.Str(required=True) - qtile_path = argschema.fields.Str(required=True) +def detect_blobs(data,method=None,**kwargs): + # input 3d image data array + # return blobs detected by method + if method == "log": + blob_func = blob_log + elif method == "dog": + blob_func = blob_dog + elif method == "doh": + blob_func = blob_doh + else: + blob_func = DEFAULT_METHOD + + return blob_func(data,**kwargs) +def detect_blobs_roi(zarray,z_range=None,y_range=None,x_range=None,method=None,blob_kwargs=None): + # input zarr array and roi defined by ranges (at miplvl) + # return blobs detected in roi + def get_roi(zarray,z_range,y_range,x_range): + return zarray[0,0,z_range[0]:z_range[1],y_range[0]:y_range[1],x_range[0]:x_range[1]] + + data = get_roi(zarray,z_range,y_range,x_range) + blobs = detect_blobs(data,method,**blob_kwargs) + return blobs + + +def extract_points(tile_path,miplvl,method,roi_list,blob_kwargs): + # input tile path to run method with blob kwargs at mip level + # return + zarray = get_zarr_array(tile_path,miplvl=miplvl) + blobs = [] + for roi in roi_list: + blobs_roi = detect_blobs_roi(zarray,roi["z"],roi["y"],roi["x"],method=method,blob_kwargs=blob_kwargs) + blobs.append(blobs_roi) + return blobs + + +def save_points_to_file(tile_path,output_file,miplvl,method,roi_list,blob_kwargs): + points = extract_points(tile_path,miplvl,method,blob_kwargs) + #save points + + +class BlobDetectionParameters(argschema.schemas.DefaultSchema): + method = argschema.fields.Str(required=False,default=None) + blob_kwargs = argschema.fields.Dict(required=False,default=None) + + +class ExtractPointsParameters(argschema.ArgSchema,BlobDetectionParameters): + tile_path = argschema.fields.Str(required=True) + output_file = argschema.fields.Str(required=True) + mip_lvl = argschema.fields.Int(required=False,default=0) + + class ExtractPointsFromTile(argschema.ArgSchemaParser): default_schema = ExtractPointsParameters def run(self): - pass + extract_points(self.args['tile_path'], + self.args['output_file'], + self.args['mip_lvl'], + self.args['method'], + blob_kwargs=self.args['blob_kwargs']) if __name__ == "__main__": mod = ExtractPointsFromTile() diff --git a/acpreprocessing/stitching_modules/acstitch/zarrutils.py b/acpreprocessing/stitching_modules/acstitch/zarrutils.py index 9af58ece..d01414d1 100644 --- a/acpreprocessing/stitching_modules/acstitch/zarrutils.py +++ b/acpreprocessing/stitching_modules/acstitch/zarrutils.py @@ -2,11 +2,19 @@ import zarr import json -def get_zarr_group(zpath,grpname): +def get_zarr_array(zpath,grpname=None,miplvl=0): + zg = get_zarr_group(zpath,grpname) + return zg[f"{miplvl}"] + + +def get_zarr_group(zpath,grpname=None): # key to working with zarr files # group contains mip datasets and dataset attributes zf = zarr.open(zpath) - return zf[grpname] + if not grpname is None: + return zf[grpname] + else: + return zf def get_group_from_src(srcpath, outpath='zarr://http://bigkahuna.corp.alleninstitute.org/ACdata', # Url for ACdata for NG hosted on BigKahuna From b716fe09e1e96798108ae0727def90ba69deab32 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Wed, 15 Oct 2025 13:18:34 -0700 Subject: [PATCH 4/6] save extracted pointmatch file with offset . . --- .../acstitch/extract_points.py | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/extract_points.py b/acpreprocessing/stitching_modules/acstitch/extract_points.py index 9bc683a9..22a4fba7 100644 --- a/acpreprocessing/stitching_modules/acstitch/extract_points.py +++ b/acpreprocessing/stitching_modules/acstitch/extract_points.py @@ -6,8 +6,10 @@ """ import argschema +import numpy from skimage.feature import blob_log,blob_dog,blob_doh from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_array +from acpreprocessing.stitching_modules.acstitch.io import save_pointmatch_file DEFAULT_METHOD = blob_dog @@ -31,27 +33,47 @@ def detect_blobs_roi(zarray,z_range=None,y_range=None,x_range=None,method=None,b # return blobs detected in roi def get_roi(zarray,z_range,y_range,x_range): return zarray[0,0,z_range[0]:z_range[1],y_range[0]:y_range[1],x_range[0]:x_range[1]] + blob_kwargs = ({} if blob_kwargs is None else blob_kwargs) data = get_roi(zarray,z_range,y_range,x_range) blobs = detect_blobs(data,method,**blob_kwargs) - return blobs + values = data[blobs[:,0].astype(int),blobs[:,1].astype(int),blobs[:,2].astype(int)] + return blobs,values -def extract_points(tile_path,miplvl,method,roi_list,blob_kwargs): +def extract_points(tile_path,miplvl,method,roi_list,blob_kwargs=None): # input tile path to run method with blob kwargs at mip level # return zarray = get_zarr_array(tile_path,miplvl=miplvl) blobs = [] + values = [] for roi in roi_list: - blobs_roi = detect_blobs_roi(zarray,roi["z"],roi["y"],roi["x"],method=method,blob_kwargs=blob_kwargs) + blobs_roi,values_roi = detect_blobs_roi(zarray,roi["z"],roi["y"],roi["x"],method=method,blob_kwargs=blob_kwargs) blobs.append(blobs_roi) - return blobs + values.append(values_roi) + return blobs,values -def save_points_to_file(tile_path,output_file,miplvl,method,roi_list,blob_kwargs): - points = extract_points(tile_path,miplvl,method,blob_kwargs) - #save points - +def save_extracted_pointmatch_file(src_tile,trg_tile,output_file,miplvl,method,roi_list,blob_kwargs=None,n_points=1,offset=None): + src_points_rois,values_rois = extract_points(src_tile,miplvl,method,roi_list,blob_kwargs) + src_points = [] + for points,values in zip(src_points_rois,values_rois): + if len(points) > 0: + if len(points) < n_points: + n = len(points) + else: + n = n_points + ind = numpy.argpartition(values, -n)[-n:] + src_points.append(points[ind,:3]*(2**miplvl)) + if src_points: + src_points = numpy.concatenate(src_points) + if not offset is None: + trg_points = src_points + numpy.asarray(offset) + else: + trg_points = src_points + pmdict = {"p_tile":src_tile,"q_tile":trg_tile,"p_pts":src_points,"q_pts":trg_points} + save_pointmatch_file(pmdict,output_file) + class BlobDetectionParameters(argschema.schemas.DefaultSchema): method = argschema.fields.Str(required=False,default=None) From 6af136320f72a7c2917d03c08b446b8c86784bb3 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 16 Oct 2025 15:39:36 -0700 Subject: [PATCH 5/6] offsets with zarr metadata . . . . . . . . --- .../acstitch/extract_points.py | 81 +++++++++++++------ .../stitching_modules/acstitch/stitch.py | 4 +- .../stitching_modules/acstitch/zarrutils.py | 23 +++++- 3 files changed, 80 insertions(+), 28 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/extract_points.py b/acpreprocessing/stitching_modules/acstitch/extract_points.py index 22a4fba7..69f2dad2 100644 --- a/acpreprocessing/stitching_modules/acstitch/extract_points.py +++ b/acpreprocessing/stitching_modules/acstitch/extract_points.py @@ -7,8 +7,9 @@ import argschema import numpy +import json from skimage.feature import blob_log,blob_dog,blob_doh -from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_array +from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_group,get_zarr_array,ZarrV3Metadata from acpreprocessing.stitching_modules.acstitch.io import save_pointmatch_file DEFAULT_METHOD = blob_dog @@ -37,25 +38,34 @@ def get_roi(zarray,z_range,y_range,x_range): data = get_roi(zarray,z_range,y_range,x_range) blobs = detect_blobs(data,method,**blob_kwargs) - values = data[blobs[:,0].astype(int),blobs[:,1].astype(int),blobs[:,2].astype(int)] - return blobs,values + if len(blobs)>0: + values = data[blobs[:,0].astype(int),blobs[:,1].astype(int),blobs[:,2].astype(int)] + blobs[:,:3] = blobs[:,:3] + numpy.array([[z_range[0],y_range[0],x_range[0]]]) + return blobs,values + print("no blobs detected!!") + return None,None -def extract_points(tile_path,miplvl,method,roi_list,blob_kwargs=None): +def extract_points(zarray,roi_list,method,blob_kwargs=None): # input tile path to run method with blob kwargs at mip level - # return - zarray = get_zarr_array(tile_path,miplvl=miplvl) + # return blobs = [] values = [] for roi in roi_list: blobs_roi,values_roi = detect_blobs_roi(zarray,roi["z"],roi["y"],roi["x"],method=method,blob_kwargs=blob_kwargs) - blobs.append(blobs_roi) - values.append(values_roi) + if not blobs_roi is None: + blobs.append(blobs_roi) + values.append(values_roi) return blobs,values -def save_extracted_pointmatch_file(src_tile,trg_tile,output_file,miplvl,method,roi_list,blob_kwargs=None,n_points=1,offset=None): - src_points_rois,values_rois = extract_points(src_tile,miplvl,method,roi_list,blob_kwargs) +def save_extracted_pointmatch_file(src_tile,trg_tile,output_file,miplvl,roi_list,method,blob_kwargs=None,n_points=1,offset=None): + if offset is None: + p_md = ZarrV3Metadata(zgroup=get_zarr_group(src_tile)) + q_md = ZarrV3Metadata(zgroup=get_zarr_group(trg_tile)) + offset = numpy.divide(q_md.get_coordinate_translation(),q_md.mip_voxel_dims(miplvl))-numpy.divide(p_md.get_coordinate_translation(),p_md.mip_voxel_dims(miplvl)) + zarray = get_zarr_array(src_tile,miplvl=miplvl) + src_points_rois,values_rois = extract_points(zarray,roi_list,method,blob_kwargs) src_points = [] for points,values in zip(src_points_rois,values_rois): if len(points) > 0: @@ -64,15 +74,26 @@ def save_extracted_pointmatch_file(src_tile,trg_tile,output_file,miplvl,method,r else: n = n_points ind = numpy.argpartition(values, -n)[-n:] - src_points.append(points[ind,:3]*(2**miplvl)) + src_points.append(points[ind,:3]) if src_points: src_points = numpy.concatenate(src_points) - if not offset is None: - trg_points = src_points + numpy.asarray(offset) + trg_points = src_points - numpy.asarray(offset) + src_points *= 2**miplvl + trg_points *= 2**miplvl + else: + trg_points = [] + tspec = [{"p_tile":src_tile,"q_tile":trg_tile,"p_pts":src_points.astype(int),"q_pts":trg_points.astype(int)}] + save_pointmatch_file(tspec,output_file) + + +def extract_points_from_tiles(src_tile,trg_tile,output_file,miplvl,roi_file,method,blob_kwargs=None,points_kwargs=None): + points_kwargs = ({} if points_kwargs is None else points_kwargs) + if roi_file is None: + roi_list = [] else: - trg_points = src_points - pmdict = {"p_tile":src_tile,"q_tile":trg_tile,"p_pts":src_points,"q_pts":trg_points} - save_pointmatch_file(pmdict,output_file) + with open(roi_file,'r') as f: + roi_list = json.load(f) + save_extracted_pointmatch_file(src_tile,trg_tile,output_file,miplvl,roi_list,method,blob_kwargs,**points_kwargs) class BlobDetectionParameters(argschema.schemas.DefaultSchema): @@ -80,22 +101,32 @@ class BlobDetectionParameters(argschema.schemas.DefaultSchema): blob_kwargs = argschema.fields.Dict(required=False,default=None) -class ExtractPointsParameters(argschema.ArgSchema,BlobDetectionParameters): - tile_path = argschema.fields.Str(required=True) +class PointExtractionParameters(argschema.schemas.DefaultSchema): + n_points = argschema.fields.Int(required=False,default=1) + + +class ExtractPointsParameters(argschema.ArgSchema, + BlobDetectionParameters, + PointExtractionParameters): + p_tile = argschema.fields.Str(required=True) + q_tile = argschema.fields.Str(required=True) output_file = argschema.fields.Str(required=True) mip_lvl = argschema.fields.Int(required=False,default=0) + roi_file = argschema.fields.Str(required=False,default=None) -class ExtractPointsFromTile(argschema.ArgSchemaParser): +class ExtractPointsFromTilePair(argschema.ArgSchemaParser): default_schema = ExtractPointsParameters def run(self): - extract_points(self.args['tile_path'], - self.args['output_file'], - self.args['mip_lvl'], - self.args['method'], - blob_kwargs=self.args['blob_kwargs']) + extract_points_from_tiles(self.args['p_tile'], + self.args['q_tile'], + self.args['output_file'], + self.args['mip_lvl'], + self.args['roi_file'], + self.args['method'], + blob_kwargs=self.args['blob_kwargs']) if __name__ == "__main__": - mod = ExtractPointsFromTile() + mod = ExtractPointsFromTilePair() mod.run() \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index d23163fb..7a4b135e 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -6,9 +6,9 @@ """ import numpy import argschema -from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments +#from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences -from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src +#from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file,save_pointmatch_file diff --git a/acpreprocessing/stitching_modules/acstitch/zarrutils.py b/acpreprocessing/stitching_modules/acstitch/zarrutils.py index d01414d1..cae5c4e2 100644 --- a/acpreprocessing/stitching_modules/acstitch/zarrutils.py +++ b/acpreprocessing/stitching_modules/acstitch/zarrutils.py @@ -1,6 +1,7 @@ import pathlib import zarr import json +import numpy def get_zarr_array(zpath,grpname=None,miplvl=0): zg = get_zarr_group(zpath,grpname) @@ -16,6 +17,7 @@ def get_zarr_group(zpath,grpname=None): else: return zf + def get_group_from_src(srcpath, outpath='zarr://http://bigkahuna.corp.alleninstitute.org/ACdata', # Url for ACdata for NG hosted on BigKahuna inpath = 'J:'): @@ -28,9 +30,28 @@ def get_group_from_src(srcpath, print(str(p) + " not found!") return None + def get_src_from_json(sourcejson,plane,tile): with open(sourcejson,'r') as f: js = json.load(f) srcList = js[plane]['sources'] ind = [s.split("_")[-1] for s in srcList].index(tile) - return srcList[ind] \ No newline at end of file + return srcList[ind] + + +class ZarrV3Metadata: + + def __init__(self,zgroup=None): + self.zgroup = zgroup + if not self.zgroup is None: + self.attrs = self.zgroup.attrs + else: + self.attrs = None + + def mip_voxel_dims(self,miplvl=0): + dims = self.attrs["multiscales"][0]["datasets"][miplvl]["coordinateTransformations"][0]["scale"][2:] + return numpy.array(dims) + + def get_coordinate_translation(self): + trans = self.attrs["multiscales"][0]["coordinateTransformations"][0]["translation"][2:] + return numpy.array(trans) \ No newline at end of file From c14026736815e734771ff46c6e6a225facf8f9ac Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Fri, 17 Oct 2025 11:22:40 -0700 Subject: [PATCH 6/6] updated stitching code . . . --- .../stitching_modules/acstitch/stitch.py | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 7a4b135e..d6999738 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -8,7 +8,7 @@ import argschema #from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences -#from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src +from acpreprocessing.stitching_modules.acstitch.zarrutils import get_zarr_array from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file,save_pointmatch_file @@ -96,40 +96,47 @@ # return p_pts,q_pts -def run_ccorr(p_ds,q_ds,p_pts,q_pts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs): +def run_ccorr(p_ds,q_ds,p_pts,q_pts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8): ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold) return ppm,qpm -def get_dataset_from_path(tilepath): - return +def get_dataset_from_path(tilepath,miplvl=0): + return get_zarr_array(tilepath,miplvl=miplvl) -def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,stitch_kwargs): +def run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method,miplvl=0,stitch_kwargs=None): + stitch_kwargs = ({} if stitch_kwargs is None else stitch_kwargs) if stitch_method == "ccorr": - kwargs = {"p_ds":get_dataset_from_path(p_tilepath), - "q_ds":get_dataset_from_path(q_tilepath), - "p_pts": p_points, - "q_pts": q_points} + kwargs = {"p_ds":get_dataset_from_path(p_tilepath,miplvl), + "q_ds":get_dataset_from_path(q_tilepath,miplvl), + "p_pts": p_points/(2**miplvl), + "q_pts": q_points/(2**miplvl)} p_pts,q_pts = run_ccorr(**kwargs,**stitch_kwargs) else: return run_stitch_method(p_tilepath,q_tilepath,p_points,q_points,stitch_method="ccorr",stitch_kwargs=stitch_kwargs) - return {"p_tile":p_tilepath,"q_tile":q_tilepath,"p_pts":p_pts,"q_pts":q_pts} + return {"p_tile":p_tilepath,"q_tile":q_tilepath,"p_pts":p_pts*(2**miplvl),"q_pts":q_pts*(2**miplvl)} -def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,stitch_kwargs): +def stitch_tiles_from_pmfile(input_file,output_file,stitch_method,miplvl=0,stitch_kwargs=None): in_pms = read_pointmatch_file(input_file) - args = [in_pms.get(key,None) for key in ["p_tile","q_tile","p_pts","q_pts"]] - if not None in args: - out_pms = run_stitch_method(*args,stitch_method=stitch_method,stitch_kwargs=stitch_kwargs) - save_pointmatch_file(out_pms,output_file) + out_pms = [] + for tspec in in_pms: + args = [tspec.get(key) for key in ["p_tile","q_tile","p_pts","q_pts"]] + pmdict = run_stitch_method(*args,stitch_method=stitch_method,miplvl=miplvl,stitch_kwargs=stitch_kwargs) + # else: + # print("WARNING: tspec contains None") + # pmdict = None + out_pms.append(pmdict) + save_pointmatch_file(out_pms,output_file) class StitchTilesParameters(argschema.ArgSchema): input_file = argschema.fields.Str(required=True) output_file = argschema.fields.Str(required=True) stitch_method = argschema.fields.Str(required=True) - stitch_kwargs = argschema.fields.Dict(required=True) + mip_lvl = argschema.fields.Int(required=False,default=0) + stitch_kwargs = argschema.fields.Dict(required=False,default=None) class StitchTiles(argschema.ArgSchemaParser):