Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 64 additions & 36 deletions u19_pipeline/automatic_job/nwb_export_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
the pipeline stages: data validation, NWB conversion, and validation.
"""

import ast
import json
import time
import traceback
from datetime import datetime
from pathlib import Path

import u19_pipeline.automatic_job.params_config as config
import u19_pipeline.utils.slack_utils as slack_utils
from u19_pipeline import nwb_production, recording
from u19_pipeline import acquisition, nwb_production, recording
from u19_pipeline.imaging_pipeline import imaging_element
from u19_pipeline.nwb_production_utils import (
validate_behavior_data_exists,
Expand All @@ -21,6 +23,26 @@
)


def _parse_number_list(raw) -> list:
"""
Parse a probe_numbers / fov_numbers value into a list of ints.

NwbExportModality stores these as a JSON-array string (e.g. "[0, 1, 2]"),
but the value may already be a list/None. Returns an empty list for NULL.
"""
if raw is None:
return []
if isinstance(raw, (list, tuple)):
return list(raw)
try:
return list(json.loads(raw))
except (ValueError, TypeError):
try:
return list(ast.literal_eval(raw))
except (ValueError, SyntaxError, TypeError):
return []


class NwbExportHandler:
"""Handler for NWB export job processing pipeline."""

Expand All @@ -33,12 +55,12 @@ def pipeline_handler_main():
NWB export jobs through their pipeline stages.
"""
# Get active jobs (status < COMPLETED and not FAILED)
active_jobs = (nwb_production.NwbExportJob & "status_nwb_id >= 0 AND status_nwb_id < 3").fetch(as_dict=True)
active_jobs = (nwb_production.NwbExportJob & "status_id >= 0 AND status_id < 3").fetch(as_dict=True)

print(f"Processing {len(active_jobs)} active NWB export jobs...")

for job in active_jobs:
current_status = job["status_nwb_id"]
current_status = job["status_id"]

try:
# Dispatch to appropriate handler based on current status
Expand Down Expand Up @@ -112,37 +134,43 @@ def process_data_validation(job: dict) -> tuple[bool, dict]:
try:
print(f"Validating data for job {job['nwb_job_id']}...")

# Check behavior data
if nwb_production.NwbExportJob.BehaviorExport & {"nwb_job_id": job["nwb_job_id"]}:
session_key = (nwb_production.NwbExportJob.BehaviorExport & {"nwb_job_id": job["nwb_job_id"]}).fetch1(
"KEY"
)

valid, error_msg = validate_behavior_data_exists(session_key)
if not valid:
raise ValueError(f"Behavior validation failed: {error_msg}")

# Check ephys data
if nwb_production.NwbExportJob.EphysExport & {"nwb_job_id": job["nwb_job_id"]}:
ephys_record = (nwb_production.NwbExportJob.EphysExport & {"nwb_job_id": job["nwb_job_id"]}).fetch1()
recording_key = {k: ephys_record[k] for k in recording.Recording.primary_key if k in ephys_record}
probe_numbers = list(ephys_record["probe_numbers"])

valid, error_msg = validate_ephys_data_exists(recording_key, probe_numbers)
if not valid:
raise ValueError(f"Ephys validation failed: {error_msg}")

# Check imaging data
if nwb_production.NwbExportJob.ImagingExport & {"nwb_job_id": job["nwb_job_id"]}:
imaging_record = (
nwb_production.NwbExportJob.ImagingExport & {"nwb_job_id": job["nwb_job_id"]}
).fetch1()
scan_key = {k: imaging_record[k] for k in imaging_element.Scan.primary_key if k in imaging_record}
fov_numbers = list(imaging_record["fov_numbers"])
# The session that this job exports is the acquisition.Session referenced
# by the NwbExportJob primary key. Derive its key from the job record.
session_key = {
k: job[k] for k in acquisition.Session.primary_key if k in job
}

valid, error_msg = validate_imaging_data_exists(scan_key, fov_numbers)
if not valid:
raise ValueError(f"Imaging validation failed: {error_msg}")
# Modalities to export are recorded in NwbExportModality (one row per
# modality_name). Branch on modality_name to run the right validation.
modalities = (
nwb_production.NwbExportModality & {"nwb_job_id": job["nwb_job_id"]}
).fetch(as_dict=True)

for modality in modalities:
modality_name = modality["modality_name"]

if modality_name == "behavior":
valid, error_msg = validate_behavior_data_exists(session_key)
if not valid:
raise ValueError(f"Behavior validation failed: {error_msg}")

elif modality_name == "ephys":
recording_key = {
k: job[k] for k in recording.Recording.primary_key if k in job
}
probe_numbers = _parse_number_list(modality.get("probe_numbers"))
valid, error_msg = validate_ephys_data_exists(recording_key, probe_numbers)
if not valid:
raise ValueError(f"Ephys validation failed: {error_msg}")

elif modality_name == "imaging":
scan_key = {
k: job[k] for k in imaging_element.Scan.primary_key if k in job
}
fov_numbers = _parse_number_list(modality.get("fov_numbers"))
valid, error_msg = validate_imaging_data_exists(scan_key, fov_numbers)
if not valid:
raise ValueError(f"Imaging validation failed: {error_msg}")

print(f"Data validation passed for job {job['nwb_job_id']}")
return True, error_info
Expand Down Expand Up @@ -295,7 +323,7 @@ def update_status_pipeline(job_key: dict, old_status: int, new_status: int, erro
print(f"Updating job {job_key['nwb_job_id']}: status {old_status} -> {new_status}")

# Update job status
(nwb_production.NwbExportJob & job_key).update1({"status_nwb_id": new_status})
(nwb_production.NwbExportJob & job_key).update1({"status_id": new_status})

# Set completion timestamp if completed
if new_status == 3: # COMPLETED
Expand All @@ -304,8 +332,8 @@ def update_status_pipeline(job_key: dict, old_status: int, new_status: int, erro
# Log status change
log_entry = {
**job_key,
"status_nwb_id_old": old_status,
"status_nwb_id_new": new_status,
"status_old": old_status,
"status_new": new_status,
"status_timestamp": datetime.now(),
}

Expand Down
8 changes: 6 additions & 2 deletions u19_pipeline/nwb_production.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,12 @@ def submit_nwb_export_job(

NwbExportJob.insert1(job_record)

# Get auto-generated job ID
job_id = (NwbExportJob & session_key).fetch1("nwb_job_id")
# Get auto-generated job ID. A session can have multiple jobs, so filtering
# by session_key alone is not unique. Filter by the (session_key, job_name)
# we just inserted and take the most recent nwb_job_id deterministically.
job_id = (NwbExportJob & {**session_key, "job_name": job_name}).fetch(
"nwb_job_id", order_by="nwb_job_id DESC", limit=1
)[0]

# Add modality associations
for modality_name, modality_type, numbers in modalities:
Expand Down
60 changes: 42 additions & 18 deletions u19_pipeline/nwb_production_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,38 +69,62 @@ def estimate_imaging_size_gb(scan_key: dict, fov_numbers: list) -> float:
return size_gb


def _parse_number_list(raw) -> list:
"""Parse a probe_numbers / fov_numbers value (JSON-array string) into a list."""
import ast
import json

if raw is None:
return []
if isinstance(raw, (list, tuple)):
return list(raw)
try:
return list(json.loads(raw))
except (ValueError, TypeError):
try:
return list(ast.literal_eval(raw))
except (ValueError, SyntaxError, TypeError):
return []


def estimate_total_size(nwb_job_key: dict) -> float:
"""
Calculate total estimated size for a job.

Queries modality part tables and sums estimates.
Queries the NwbExportModality association table for the job and sums the
per-modality estimates. The session/recording/scan keys are derived from the
NwbExportJob record (which carries the acquisition.Session primary key).

Args:
nwb_job_key: Dictionary with nwb_job_id

Returns:
Total estimated size in GB
"""
from u19_pipeline import nwb_production
from u19_pipeline import acquisition, nwb_production, recording
from u19_pipeline.imaging_pipeline import imaging_element

total_gb = 0.0

# Check behavior
if nwb_production.NwbExportJob.BehaviorExport & nwb_job_key:
session_key = (nwb_production.NwbExportJob.BehaviorExport & nwb_job_key).fetch1("KEY")
total_gb += estimate_behavior_size_gb(session_key)

# Check ephys
if nwb_production.NwbExportJob.EphysExport & nwb_job_key:
recording_key, probe_numbers = (nwb_production.NwbExportJob.EphysExport & nwb_job_key).fetch1(
"KEY", "probe_numbers"
)
total_gb += estimate_ephys_size_gb(recording_key, probe_numbers)

# Check imaging
if nwb_production.NwbExportJob.ImagingExport & nwb_job_key:
scan_key, fov_numbers = (nwb_production.NwbExportJob.ImagingExport & nwb_job_key).fetch1("KEY", "fov_numbers")
total_gb += estimate_imaging_size_gb(scan_key, fov_numbers)
job = (nwb_production.NwbExportJob & nwb_job_key).fetch1()
modalities = (nwb_production.NwbExportModality & nwb_job_key).fetch(as_dict=True)

for modality in modalities:
modality_name = modality["modality_name"]

if modality_name == "behavior":
session_key = {k: job[k] for k in acquisition.Session.primary_key if k in job}
total_gb += estimate_behavior_size_gb(session_key)

elif modality_name == "ephys":
recording_key = {k: job[k] for k in recording.Recording.primary_key if k in job}
probe_numbers = _parse_number_list(modality.get("probe_numbers"))
total_gb += estimate_ephys_size_gb(recording_key, probe_numbers)

elif modality_name == "imaging":
scan_key = {k: job[k] for k in imaging_element.Scan.primary_key if k in job}
fov_numbers = _parse_number_list(modality.get("fov_numbers"))
total_gb += estimate_imaging_size_gb(scan_key, fov_numbers)

return total_gb

Expand Down