Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 59 additions & 192 deletions scripts/run_nwb_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
import argparse
import json
import logging
import os
import sys
import traceback
from datetime import datetime
from pathlib import Path

logging.basicConfig(
Expand All @@ -52,6 +50,7 @@
# Helpers
# ──────────────────────────────────────────────────────────────────────────────


def _connect_dj() -> None:
"""Ensure DataJoint is connected, trying conf_file_finding first."""
try:
Expand All @@ -66,146 +65,23 @@ def _connect_dj() -> None:
dj.conn()


def _find_kilosort_output(probe_dir: Path) -> Path | None:
"""
Return the highest-numbered job's Kilosort output directory under probe_dir.

Expected layout:
<kilosort_dir>/<probe>_imec<N>/job_id_<N>/kilosort<k>_output/
"""
job_dirs = sorted(
[p for p in probe_dir.glob("job_id_*") if p.name.split("_")[-1].isdigit()],
key=lambda p: int(p.name.split("_")[-1]),
)
if not job_dirs:
return None
kilosort_outputs = list(job_dirs[-1].glob("kilosort*_output"))
return kilosort_outputs[0] if kilosort_outputs else None


def _build_source_data(
job: dict,
export_params: dict,
virmen_file: Path | None,
kilosort_dir: Path | None,
) -> dict:
"""
Translate the DataJoint job record + export_params into a
`source_data` dict accepted by TowersNWBConverter.

Returns
-------
dict
Ready to pass as ``source_data=`` to TowersNWBConverter.

Raises
------
FileNotFoundError
If a required data file cannot be located.
"""
source_data: dict = {}

# ── Behavior (always required) ────────────────────────────────────────────
if virmen_file is None:
raise FileNotFoundError(
"No --virmen-file provided. Cannot locate the behavioral .mat file. "
"Re-run with --virmen-file /path/to/session.mat"
)
if not Path(virmen_file).exists():
raise FileNotFoundError(f"Virmen file not found: {virmen_file}")
source_data["VirmenData"] = {"file_path": str(virmen_file)}

# ── Ephys ─────────────────────────────────────────────────────────────────
if export_params.get("include_ephys") and kilosort_dir is not None:
kilosort_dir = Path(kilosort_dir)
probe_dirs = sorted(kilosort_dir.glob("*_imec*"))
if not probe_dirs:
log.warning("--kilosort-dir given but no *_imec* subdirectories found – skipping ephys.")
else:
for probe_dir in probe_dirs:
probe_idx = "".join(filter(str.isdigit, probe_dir.name.split("imec")[-1]))
interface_name = f"KilosortProbe{probe_idx}" if probe_idx else "Kilosort"
ks_output = _find_kilosort_output(probe_dir)
if ks_output is None:
log.warning(f"No Kilosort output found under {probe_dir} – skipping.")
continue
source_data[interface_name] = {"folder_path": str(ks_output)}
log.info(f" {interface_name}: {ks_output}")
elif export_params.get("include_ephys"):
log.warning(
"include_ephys=True but no --kilosort-dir provided. "
"Re-run with --kilosort-dir to include ephys data."
)

return source_data


def _query_metadata(session_key: dict) -> dict:
"""
Pull experimenter, subject sex/DoB and sync timestamps from DataJoint.

Returns a dict with keys: experimenter, subject_sex, subject_dob,
sync_timestamps (may be None if not found).
"""
import datajoint as dj

result: dict = {
"experimenter": [],
"subject_sex": "U",
"subject_dob": None,
"sync_timestamps": None,
}

try:
subject = dj.create_virtual_module("subject", dj.config["custom"]["database.prefix"] + "subject")
lab = dj.create_virtual_module("lab", dj.config["custom"]["database.prefix"] + "lab")

subject_fullname = session_key["subject_fullname"]
sub_info = (subject.Subject() * lab.User() & f"subject_fullname = '{subject_fullname}'").fetch1()

# Owner name
owner_full = sub_info.get("full_name", sub_info.get("user_id", ""))
if " " in owner_full:
parts = owner_full.rsplit(" ", 1)
result["experimenter"].append(f"{parts[-1]}, {parts[0]}")
else:
result["experimenter"].append(owner_full)

# Sex
sex_map = {"Male": "M", "Female": "F", "Unknown": "U", "m": "M", "f": "F"}
result["subject_sex"] = sex_map.get(str(sub_info.get("sex", "U")), "U")

# Date of birth
dob = sub_info.get("dob")
if dob is not None:
result["subject_dob"] = datetime.combine(dob, datetime.min.time()) if hasattr(dob, "year") else dob

except Exception as exc:
log.warning(f"Could not query all metadata from DB: {exc}")

# Sync timestamps (optional BehaviorSync table)
try:
nwb_prod = dj.create_virtual_module(
"nwb_production", dj.config["custom"]["database.prefix"] + "nwb_production"
)
sync_rows = (nwb_prod.BehaviorSync & session_key).fetch("sync_timestamps", as_dict=True)
if sync_rows:
import numpy as np

result["sync_timestamps"] = np.array(sync_rows[0]["sync_timestamps"])
except Exception:
pass # BehaviorSync is optional

return result

# Shared conversion logic lives in u19_pipeline.nwb_export.conversion so the CLI
# and the cronjob handler share one code path. Re-exported under the historical
# private names to keep this module's internal references working.
from u19_pipeline.nwb_export.conversion import ( # noqa: E402
run_conversion_to_file as _run_conversion_to_file,
)

# ──────────────────────────────────────────────────────────────────────────────
# Status helpers
# ──────────────────────────────────────────────────────────────────────────────

def _transition(nwb_production, job_key: dict, new_status_id: int, dry_run: bool) -> None:
from u19_pipeline.nwb_production import update_job_status # type: ignore

def _transition(
nwb_production, job_key: dict, new_status_id: int, dry_run: bool
) -> None:
from u19_pipeline.nwb_export_enums import NwbExportStatusEnum # type: ignore
from u19_pipeline.nwb_production import update_job_status # type: ignore

new_status = NwbExportStatusEnum(new_status_id)
log.info(f" → {new_status.name}")
Expand All @@ -214,9 +90,9 @@ def _transition(nwb_production, job_key: dict, new_status_id: int, dry_run: bool


def _fail(nwb_production, job_key: dict, exc: Exception, dry_run: bool) -> None:
from u19_pipeline.nwb_production import update_job_status # type: ignore
from u19_pipeline.nwb_export_enums import NwbExportStatusEnum # type: ignore
from u19_pipeline.nwb_export.error_capture import capture_exception # type: ignore
from u19_pipeline.nwb_export_enums import NwbExportStatusEnum # type: ignore
from u19_pipeline.nwb_production import update_job_status # type: ignore

tb = capture_exception(exc)
log.error(f"Job FAILED: {tb['error_message']}")
Expand All @@ -233,15 +109,15 @@ def _fail(nwb_production, job_key: dict, exc: Exception, dry_run: bool) -> None:
# Main
# ──────────────────────────────────────────────────────────────────────────────


def run(
job_id: int,
virmen_file: Path | None,
kilosort_dir: Path | None,
dry_run: bool,
) -> None:
import datajoint as dj
from u19_pipeline.nwb_export_enums import NwbExportStatusEnum # type: ignore
from u19_pipeline import nwb_production # type: ignore
from u19_pipeline.nwb_export_enums import NwbExportStatusEnum # type: ignore

job_key = {"nwb_job_id": job_id}

Expand Down Expand Up @@ -275,10 +151,13 @@ def run(
except json.JSONDecodeError:
# Streamlit stored it as a Python repr string in early versions
import ast

try:
export_params = ast.literal_eval(raw_params)
except Exception:
log.warning(f"Could not parse export_parameters: {raw_params!r}; proceeding with empty params")
log.warning(
f"Could not parse export_parameters: {raw_params!r}; proceeding with empty params"
)

session_key = {
"subject_fullname": job["subject_fullname"],
Expand All @@ -288,15 +167,21 @@ def run(

log.info(f" export_params: {export_params}")
if dry_run:
log.info("[DRY RUN] Would transition: QUEUED → DATA_VALIDATION → PROCESSING → VALIDATION → COMPLETED")
log.info(
"[DRY RUN] Would transition: QUEUED → DATA_VALIDATION → PROCESSING → VALIDATION → COMPLETED"
)
log.info("[DRY RUN] Exiting without any DB writes or file operations.")
return

# ── 2. DATA_VALIDATION ────────────────────────────────────────────────────
_transition(nwb_production, job_key, int(NwbExportStatusEnum.DATA_VALIDATION), dry_run)
_transition(
nwb_production, job_key, int(NwbExportStatusEnum.DATA_VALIDATION), dry_run
)

try:
from u19_pipeline.nwb_production_utils import validate_behavior_data_exists # type: ignore
from u19_pipeline.nwb_production_utils import (
validate_behavior_data_exists, # type: ignore
)

ok, msg = validate_behavior_data_exists(session_key)
if not ok:
Expand All @@ -310,59 +195,36 @@ def run(
_transition(nwb_production, job_key, int(NwbExportStatusEnum.PROCESSING), dry_run)

try:
from tank_lab_to_nwb.convert_towers_task.towersnwbconverter import TowersNWBConverter # type: ignore
from tank_lab_to_nwb.convert_towers_task.towersnwbconverter import (
TowersNWBConverter, # type: ignore
)
except ImportError:
log.error(
"tank-lab-to-nwb is not importable. Install it with:\n"
" pip install -e /path/to/tank-lab-to-nwb-clean\n"
"or add it to PYTHONPATH."
)
_fail(nwb_production, job_key, ImportError("tank_lab_to_nwb not installed"), dry_run)
_fail(
nwb_production,
job_key,
ImportError("tank_lab_to_nwb not installed"),
dry_run,
)
sys.exit(1)

output_path = job["output_filepath"]
try:
source_data = _build_source_data(job, export_params, virmen_file, kilosort_dir)
log.info(f" source_data keys: {list(source_data.keys())}")

metadata = _query_metadata(session_key)

converter = TowersNWBConverter(
source_data=source_data,
sync_timestamps=metadata["sync_timestamps"],
)

raw_metadata = converter.get_metadata()

# Inject DB-sourced metadata
raw_metadata["NWBFile"]["session_description"] = (
f"U19 pipeline export – {session_key['subject_fullname']} "
f"{session_key['session_date']}"
)
if metadata["experimenter"]:
raw_metadata["NWBFile"]["experimenter"] = metadata["experimenter"]

if "Subject" not in raw_metadata:
raw_metadata["Subject"] = {}
if metadata["subject_sex"]:
raw_metadata["Subject"]["sex"] = metadata["subject_sex"]
if metadata["subject_dob"] is not None:
raw_metadata["Subject"]["date_of_birth"] = metadata["subject_dob"]

output_path = job["output_filepath"]
Path(output_path).parent.mkdir(parents=True, exist_ok=True)

log.info(f" Writing NWB to: {output_path}")
converter.run_conversion(
nwbfile_path=output_path,
metadata=raw_metadata,
overwrite=True,
size_gb = _run_conversion_to_file(
job=job,
export_params=export_params,
session_key=session_key,
virmen_file=virmen_file,
kilosort_dir=kilosort_dir,
output_path=output_path,
)
log.info(" ✓ Conversion complete")

# Record actual file size
size_gb = Path(output_path).stat().st_size / (1024**3)
nwb_production.NwbExportJob.update1({**job_key, "actual_file_size_gb": size_gb})
log.info(f" ✓ File size: {size_gb:.3f} GB")

except Exception as exc:
log.error(traceback.format_exc())
Expand Down Expand Up @@ -392,13 +254,15 @@ def run(
# CLI entry point
# ──────────────────────────────────────────────────────────────────────────────


def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument(
"--job-id", "-j",
"--job-id",
"-j",
type=int,
default=None,
help=(
Expand Down Expand Up @@ -434,19 +298,22 @@ def _parse_args() -> argparse.Namespace:

def _get_pending_job_ids() -> list[int]:
"""Return nwb_job_ids for all jobs not in a terminal state (COMPLETED / FAILED)."""
import datajoint as dj
from u19_pipeline import acquisition # noqa: ensure FK context
from u19_pipeline import nwb_production # type: ignore

from u19_pipeline import (
nwb_production, # type: ignore
)
from u19_pipeline.nwb_export_enums import NwbExportStatusEnum

terminal_ids = [
int(NwbExportStatusEnum.COMPLETED),
int(NwbExportStatusEnum.FAILED),
]
restriction = " AND ".join(f"status_id != {s}" for s in terminal_ids)
job_ids = (nwb_production.NwbExportJob & restriction).fetch(
"nwb_job_id", order_by="submission_timestamp ASC"
).tolist()
job_ids = (
(nwb_production.NwbExportJob & restriction)
.fetch("nwb_job_id", order_by="submission_timestamp ASC")
.tolist()
)
return job_ids


Expand All @@ -469,7 +336,7 @@ def _get_pending_job_ids() -> list[int]:
else:
log.info(f"Found {len(job_ids)} pending job(s): {job_ids}")
for job_id in job_ids:
log.info(f"\n{'='*60}\nProcessing job #{job_id}\n{'='*60}")
log.info(f"\n{'=' * 60}\nProcessing job #{job_id}\n{'=' * 60}")
try:
run(
job_id=job_id,
Expand Down
Loading