diff --git a/beneuro_pose_estimation/cli.py b/beneuro_pose_estimation/cli.py index 3a5880d..8c6d89b 100644 --- a/beneuro_pose_estimation/cli.py +++ b/beneuro_pose_estimation/cli.py @@ -223,15 +223,19 @@ def train( "If not provided, uses default cameras from params.default_cameras." ), ), + custom_labels: bool = typer.Option( + False, + "--custom-labels", "-cl", + help="If set, use custom labels when training the models.", + ), ): """ Train SLEAP models for specified cameras (or all defaults). """ from beneuro_pose_estimation.sleap.sleapTools import train_models - cams = cameras or params.default_cameras - # train_models is the function you already wrote - train_models(cam, custom_labels = custom_labels) + cams = cameras or params.default_cameras + train_models(cameras=cams, custom_labels=custom_labels) # =================================== Updating ==========================================