From 21a630eb1888b2ed13f34b8ed20e3e0fca3048c0 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Fri, 19 Jun 2026 14:59:12 +0200 Subject: [PATCH 01/16] Refactor data_service to separate GetDataSamples and GetMetadata into two functions; and so adapt UI (#213) --- .github/workflows/ci.yml | 19 -- .github/workflows/release.yml | 9 +- pyproject.toml | 3 +- .../docker_in_docker/Dockerfile | 3 +- .../siblings_self_contained/Dockerfile | 3 +- weightslab/proto/experiment_service.proto | 28 ++ weightslab/proto/experiment_service_pb2.py | 220 +++++++------- .../proto/experiment_service_pb2_grpc.py | 47 +++ .../tests/backend/test_compare_dataloaders.py | 2 +- .../tests/gRPC/test_grpc_user_actions.py | 43 ++- .../services/test_trainer_services_unit.py | 11 +- weightslab/trainer/services/data_service.py | 281 ++++++++++++------ .../trainer/services/experiment_service.py | 99 ++++++ weightslab/trainer/trainer_services.py | 4 + 14 files changed, 535 insertions(+), 237 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00c53b8f..0f6ce350 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -223,25 +223,6 @@ jobs: # A per-test timeout guards against any regression that hangs a test. python -m pytest ./weightslab/tests -v --timeout=300 - # TODO (GP): WL CI do not find WS CI for now; token or visibility problem ?? - # - name: Trigger WeightsStudio CI - # env: - # WS_TOKEN: ${{ secrets.WEIGHTS_STUDIO_API_TOKEN }} - # run: | - # if [ -z "${WS_TOKEN}" ]; then - # echo "WEIGHTS_STUDIO_API_TOKEN not set; skipping WeightsStudio trigger." - # exit 0 - # fi - - # # Trigger the ws-ci workflow in the weights_studio repository on main. - # curl -fSs -X POST "https://api.github.com/repos/GrayboxTech/weights_studio/actions/workflows/ws-ci.yml/dispatches" \ - # -H "Authorization: Bearer ${WS_TOKEN}" \ - # -H "Accept: application/vnd.github+json" \ - # -H "Content-Type: application/json" \ - # -d '{"ref":"main"}' - - # echo "WeightsStudio workflow dispatch requested successfully." - build-and-publish-dev: # Only publish to TestPyPI when pushing to main (not on PRs or dev branch pushes). if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d4edd332..2c391367 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,7 +47,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install . --extra-index-url https://download.pytorch.org/whl/cpu + # Install the test extra so pytest, graphviz, torchmetrics, + # pytorch-lightning and tensorboard are available (several test modules + # import pytest / use pytest fixtures and cannot run under bare unittest). + python -m pip install '.[utest]' --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install pytest-timeout - name: Run tests run: | @@ -303,7 +307,8 @@ jobs: build-and-publish-main: name: Build & Publish Main (PyPI) - needs: [detect-target] + needs: [detect-target,test] + # needs: [detect-target] runs-on: ubuntu-latest if: ${{ needs.detect-target.outputs.is_main == 'true' }} permissions: diff --git a/pyproject.toml b/pyproject.toml index 962e5b7f..d9d72aad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,8 @@ dependencies = [ # Imaging "Pillow>=10,<12", - "opencv-python>=4.8,<5", + + # Modeling "onnx>=1.15,<=1.20", # Utility used in examples and progress reporting diff --git a/weightslab/examples/Docker_training/docker_in_docker/Dockerfile b/weightslab/examples/Docker_training/docker_in_docker/Dockerfile index 722a0b32..cab3af85 100644 --- a/weightslab/examples/Docker_training/docker_in_docker/Dockerfile +++ b/weightslab/examples/Docker_training/docker_in_docker/Dockerfile @@ -17,10 +17,9 @@ FROM python:3.11-slim # --- System deps ------------------------------------------------------------- # - docker engine (dockerd + CLI + compose plugin + containerd): installed via # the official convenience script. We need the *daemon* here (DinD). -# - libgl1/libglib2.0-0: runtime libs for opencv-python (a weightslab dep). # - curl/ca-certificates/git: fetch the docker installer + optional dev install. RUN apt-get update && apt-get install -y --no-install-recommends \ - curl sudo ca-certificates git libgl1 libglib2.0-0 \ + curl sudo ca-certificates git \ && curl -fsSL https://get.docker.com | sh \ && rm -rf /var/lib/apt/lists/* diff --git a/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile b/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile index 15d43c6b..b66e2162 100644 --- a/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile +++ b/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile @@ -16,9 +16,8 @@ FROM python:3.11-slim # --- System deps ------------------------------------------------------------- # docker CLI + compose plugin ONLY (no daemon — we use the host's daemon). -# libgl1/libglib2.0-0: runtime libs for opencv-python (a weightslab dep). RUN apt-get update && apt-get install -y --no-install-recommends \ - curl sudo ca-certificates gnupg libgl1 libglib2.0-0 \ + curl sudo ca-certificates gnupg \ && install -m 0755 -d /etc/apt/keyrings \ && curl -fsSL https://download.docker.com/linux/debian/gpg \ | gpg --dearmor -o /etc/apt/keyrings/docker.gpg \ diff --git a/weightslab/proto/experiment_service.proto b/weightslab/proto/experiment_service.proto index 73de39d7..fefb5e47 100644 --- a/weightslab/proto/experiment_service.proto +++ b/weightslab/proto/experiment_service.proto @@ -17,6 +17,11 @@ service ExperimentService { // Data Service (for weights_studio UI) rpc ApplyDataQuery (DataQueryRequest) returns (DataQueryResponse); rpc GetDataSamples (DataSamplesRequest) returns (DataSamplesResponse); + // Metadata-only retrieval (dataframe columns). Returns every metadata column + // name for the WHOLE dataset, the current grid slice's per-sample metadata, and + // the open modal sample's metadata. Separated from GetDataSamples, which now + // returns only image / label / prediction data. + rpc GetMetaData (GetMetaDataRequest) returns (GetMetaDataResponse); // Raw point cloud of one sample (task_type "detection_pointcloud"), server-streamed // in binary chunks for the interactive 3D viewer. rpc GetPointCloud (PointCloudRequest) returns (stream PointCloudChunk); @@ -162,6 +167,13 @@ message PlotNoteOperation { string note = 4; } +// Manual "save now" trigger: force a checkpoint of the current model weights +// (and, when requested, the architecture) regardless of pending-change tracking. +message SaveCheckpointOperation { + bool save_architecture = 1; // force re-dump architecture even if a file already exists + bool save_optimizer = 2; // also persist optimizer state +} + message TrainerCommand { bool get_hyper_parameters = 4; bool get_interactive_layers = 5; @@ -174,6 +186,7 @@ message TrainerCommand { optional DenySamplesOperation remove_from_denylist_operation = 11; optional DenySamplesOperation remove_eval_from_denylist_operation = 12; optional PlotNoteOperation plot_note_operation = 13; + optional SaveCheckpointOperation save_checkpoint_operation = 14; } message HyperParameterDesc { @@ -382,6 +395,21 @@ message DataSamplesResponse { repeated DataRecord data_records = 3; } +// --- Metadata retrieval (separated from GetDataSamples) --- +message GetMetaDataRequest { + int32 start_index = 1; // grid slice start (current view order) + int32 records_cnt = 2; // grid slice size + string modal_sample_id = 3; // optional: sample_id of the open modal ("" = none) +} + +message GetMetaDataResponse { + bool success = 1; + string message = 2; + repeated string all_metadata_names = 3; // every metadata column for the WHOLE dataset + repeated DataRecord grid_records = 4; // per-sample metadata for the requested slice + DataRecord modal_record = 5; // metadata for the open modal sample (if found) +} + // --- Point cloud transfer (task_type "detection_pointcloud") --- message PointCloudRequest { string sample_id = 1; diff --git a/weightslab/proto/experiment_service_pb2.py b/weightslab/proto/experiment_service_pb2.py index 9ea162d1..1560bb5b 100644 --- a/weightslab/proto/experiment_service_pb2.py +++ b/weightslab/proto/experiment_service_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"\xdc\x06\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\x84\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"W\n\x12GetMetaDataRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12\x17\n\x0fmodal_sample_id\x18\x03 \x01(\t\"\x99\x01\n\x13GetMetaDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1a\n\x12\x61ll_metadata_names\x18\x03 \x03(\t\x12!\n\x0cgrid_records\x18\x04 \x03(\x0b\x32\x0b.DataRecord\x12!\n\x0cmodal_record\x18\x05 \x01(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xbe\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x38\n\x0bGetMetaData\x12\x13.GetMetaDataRequest\x1a\x14.GetMetaDataResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -37,16 +37,16 @@ _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_options = b'8\001' _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._loaded_options = None _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_options = b'8\001' - _globals['_WEIGHTOPERATIONTYPE']._serialized_start=8812 - _globals['_WEIGHTOPERATIONTYPE']._serialized_end=8912 - _globals['_ZEROFYPREDICATE']._serialized_start=8914 - _globals['_ZEROFYPREDICATE']._serialized_end=9025 - _globals['_AGENTINTENTTYPE']._serialized_start=9027 - _globals['_AGENTINTENTTYPE']._serialized_end=9104 - _globals['_SAMPLEEDITTYPE']._serialized_start=9106 - _globals['_SAMPLEEDITTYPE']._serialized_end=9179 - _globals['_AGENTPROVIDERTYPE']._serialized_start=9181 - _globals['_AGENTPROVIDERTYPE']._serialized_end=9225 + _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9231 + _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9331 + _globals['_ZEROFYPREDICATE']._serialized_start=9333 + _globals['_ZEROFYPREDICATE']._serialized_end=9444 + _globals['_AGENTINTENTTYPE']._serialized_start=9446 + _globals['_AGENTINTENTTYPE']._serialized_end=9523 + _globals['_SAMPLEEDITTYPE']._serialized_start=9525 + _globals['_SAMPLEEDITTYPE']._serialized_end=9598 + _globals['_AGENTPROVIDERTYPE']._serialized_start=9600 + _globals['_AGENTPROVIDERTYPE']._serialized_end=9644 _globals['_GETLATESTLOGGERDATAREQUEST']._serialized_start=46 _globals['_GETLATESTLOGGERDATAREQUEST']._serialized_end=183 _globals['_LOGGERDATAPOINT']._serialized_start=186 @@ -81,100 +81,106 @@ _globals['_LOADCHECKPOINTOPERATION']._serialized_end=2367 _globals['_PLOTNOTEOPERATION']._serialized_start=2369 _globals['_PLOTNOTEOPERATION']._serialized_end=2467 - _globals['_TRAINERCOMMAND']._serialized_start=2470 - _globals['_TRAINERCOMMAND']._serialized_end=3330 - _globals['_HYPERPARAMETERDESC']._serialized_start=3333 - _globals['_HYPERPARAMETERDESC']._serialized_end=3490 - _globals['_NEURONSTATISTICS']._serialized_start=3493 - _globals['_NEURONSTATISTICS']._serialized_end=3863 - _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_start=3722 - _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_end=3771 - _globals['_LAYERREPRESENTATION']._serialized_start=3866 - _globals['_LAYERREPRESENTATION']._serialized_end=4234 - _globals['_ACTIVATIONREQUEST']._serialized_start=4236 - _globals['_ACTIVATIONREQUEST']._serialized_end=4308 - _globals['_ACTIVATIONMAP']._serialized_start=4310 - _globals['_ACTIVATIONMAP']._serialized_end=4382 - _globals['_ACTIVATIONRESPONSE']._serialized_start=4384 - _globals['_ACTIVATIONRESPONSE']._serialized_end=4484 - _globals['_TASKFIELD']._serialized_start=4487 - _globals['_TASKFIELD']._serialized_end=4634 - _globals['_RECORDMETADATA']._serialized_start=4637 - _globals['_RECORDMETADATA']._serialized_end=4969 - _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_start=4916 - _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_end=4969 - _globals['_SAMPLESTATISTICS']._serialized_start=4972 - _globals['_SAMPLESTATISTICS']._serialized_end=5119 - _globals['_COMMANDRESPONSE']._serialized_start=5122 - _globals['_COMMANDRESPONSE']._serialized_end=5352 - _globals['_SAMPLEREQUEST']._serialized_start=5354 - _globals['_SAMPLEREQUEST']._serialized_end=5439 - _globals['_SAMPLEREQUESTRESPONSE']._serialized_start=5442 - _globals['_SAMPLEREQUESTRESPONSE']._serialized_end=5743 - _globals['_BATCHSAMPLEREQUEST']._serialized_start=5746 - _globals['_BATCHSAMPLEREQUEST']._serialized_end=5892 - _globals['_BATCHSAMPLERESPONSE']._serialized_start=5894 - _globals['_BATCHSAMPLERESPONSE']._serialized_end=5956 - _globals['_WEIGHTSREQUEST']._serialized_start=5958 - _globals['_WEIGHTSREQUEST']._serialized_end=6004 - _globals['_WEIGHTSRESPONSE']._serialized_start=6007 - _globals['_WEIGHTSRESPONSE']._serialized_end=6292 - _globals['_DATAQUERYREQUEST']._serialized_start=6294 - _globals['_DATAQUERYREQUEST']._serialized_end=6376 - _globals['_CATEGORICALTAGDEF']._serialized_start=6378 - _globals['_CATEGORICALTAGDEF']._serialized_end=6431 - _globals['_DATAQUERYRESPONSE']._serialized_start=6434 - _globals['_DATAQUERYRESPONSE']._serialized_end=6731 - _globals['_DATASAMPLESREQUEST']._serialized_start=6734 - _globals['_DATASAMPLESREQUEST']._serialized_end=6928 - _globals['_DATASTAT']._serialized_start=6930 - _globals['_DATASTAT']._serialized_end=7039 - _globals['_DATARECORD']._serialized_start=7041 - _globals['_DATARECORD']._serialized_end=7103 - _globals['_DATASAMPLESRESPONSE']._serialized_start=7105 - _globals['_DATASAMPLESRESPONSE']._serialized_end=7195 - _globals['_POINTCLOUDREQUEST']._serialized_start=7197 - _globals['_POINTCLOUDREQUEST']._serialized_end=7271 - _globals['_POINTCLOUDCHUNK']._serialized_start=7274 - _globals['_POINTCLOUDCHUNK']._serialized_end=7465 - _globals['_DATAEDITSREQUEST']._serialized_start=7468 - _globals['_DATAEDITSREQUEST']._serialized_end=7688 - _globals['_DATAEDITSRESPONSE']._serialized_start=7690 - _globals['_DATAEDITSRESPONSE']._serialized_end=7743 - _globals['_DATASPLITSRESPONSE']._serialized_start=7745 - _globals['_DATASPLITSRESPONSE']._serialized_end=7803 - _globals['_AGENTHEALTHRESPONSE']._serialized_start=7805 - _globals['_AGENTHEALTHRESPONSE']._serialized_end=7862 - _globals['_INITIALIZEAGENTREQUEST']._serialized_start=7864 - _globals['_INITIALIZEAGENTREQUEST']._serialized_end=7958 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=7960 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8019 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8021 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8061 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8063 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8123 - _globals['_GETAGENTMODELSREQUEST']._serialized_start=8125 - _globals['_GETAGENTMODELSREQUEST']._serialized_end=8148 - _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8150 - _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8224 - _globals['_RESETAGENTRESPONSE']._serialized_start=8226 - _globals['_RESETAGENTRESPONSE']._serialized_end=8280 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8282 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8333 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8335 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8396 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8398 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8480 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8482 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8543 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8545 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8573 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8576 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=8705 - _globals['_CANCELEVALUATIONREQUEST']._serialized_start=8707 - _globals['_CANCELEVALUATIONREQUEST']._serialized_end=8748 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=8750 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=8810 - _globals['_EXPERIMENTSERVICE']._serialized_start=9228 - _globals['_EXPERIMENTSERVICE']._serialized_end=10512 + _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2469 + _globals['_SAVECHECKPOINTOPERATION']._serialized_end=2545 + _globals['_TRAINERCOMMAND']._serialized_start=2548 + _globals['_TRAINERCOMMAND']._serialized_end=3504 + _globals['_HYPERPARAMETERDESC']._serialized_start=3507 + _globals['_HYPERPARAMETERDESC']._serialized_end=3664 + _globals['_NEURONSTATISTICS']._serialized_start=3667 + _globals['_NEURONSTATISTICS']._serialized_end=4037 + _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_start=3896 + _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_end=3945 + _globals['_LAYERREPRESENTATION']._serialized_start=4040 + _globals['_LAYERREPRESENTATION']._serialized_end=4408 + _globals['_ACTIVATIONREQUEST']._serialized_start=4410 + _globals['_ACTIVATIONREQUEST']._serialized_end=4482 + _globals['_ACTIVATIONMAP']._serialized_start=4484 + _globals['_ACTIVATIONMAP']._serialized_end=4556 + _globals['_ACTIVATIONRESPONSE']._serialized_start=4558 + _globals['_ACTIVATIONRESPONSE']._serialized_end=4658 + _globals['_TASKFIELD']._serialized_start=4661 + _globals['_TASKFIELD']._serialized_end=4808 + _globals['_RECORDMETADATA']._serialized_start=4811 + _globals['_RECORDMETADATA']._serialized_end=5143 + _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_start=5090 + _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_end=5143 + _globals['_SAMPLESTATISTICS']._serialized_start=5146 + _globals['_SAMPLESTATISTICS']._serialized_end=5293 + _globals['_COMMANDRESPONSE']._serialized_start=5296 + _globals['_COMMANDRESPONSE']._serialized_end=5526 + _globals['_SAMPLEREQUEST']._serialized_start=5528 + _globals['_SAMPLEREQUEST']._serialized_end=5613 + _globals['_SAMPLEREQUESTRESPONSE']._serialized_start=5616 + _globals['_SAMPLEREQUESTRESPONSE']._serialized_end=5917 + _globals['_BATCHSAMPLEREQUEST']._serialized_start=5920 + _globals['_BATCHSAMPLEREQUEST']._serialized_end=6066 + _globals['_BATCHSAMPLERESPONSE']._serialized_start=6068 + _globals['_BATCHSAMPLERESPONSE']._serialized_end=6130 + _globals['_WEIGHTSREQUEST']._serialized_start=6132 + _globals['_WEIGHTSREQUEST']._serialized_end=6178 + _globals['_WEIGHTSRESPONSE']._serialized_start=6181 + _globals['_WEIGHTSRESPONSE']._serialized_end=6466 + _globals['_DATAQUERYREQUEST']._serialized_start=6468 + _globals['_DATAQUERYREQUEST']._serialized_end=6550 + _globals['_CATEGORICALTAGDEF']._serialized_start=6552 + _globals['_CATEGORICALTAGDEF']._serialized_end=6605 + _globals['_DATAQUERYRESPONSE']._serialized_start=6608 + _globals['_DATAQUERYRESPONSE']._serialized_end=6905 + _globals['_DATASAMPLESREQUEST']._serialized_start=6908 + _globals['_DATASAMPLESREQUEST']._serialized_end=7102 + _globals['_DATASTAT']._serialized_start=7104 + _globals['_DATASTAT']._serialized_end=7213 + _globals['_DATARECORD']._serialized_start=7215 + _globals['_DATARECORD']._serialized_end=7277 + _globals['_DATASAMPLESRESPONSE']._serialized_start=7279 + _globals['_DATASAMPLESRESPONSE']._serialized_end=7369 + _globals['_GETMETADATAREQUEST']._serialized_start=7371 + _globals['_GETMETADATAREQUEST']._serialized_end=7458 + _globals['_GETMETADATARESPONSE']._serialized_start=7461 + _globals['_GETMETADATARESPONSE']._serialized_end=7614 + _globals['_POINTCLOUDREQUEST']._serialized_start=7616 + _globals['_POINTCLOUDREQUEST']._serialized_end=7690 + _globals['_POINTCLOUDCHUNK']._serialized_start=7693 + _globals['_POINTCLOUDCHUNK']._serialized_end=7884 + _globals['_DATAEDITSREQUEST']._serialized_start=7887 + _globals['_DATAEDITSREQUEST']._serialized_end=8107 + _globals['_DATAEDITSRESPONSE']._serialized_start=8109 + _globals['_DATAEDITSRESPONSE']._serialized_end=8162 + _globals['_DATASPLITSRESPONSE']._serialized_start=8164 + _globals['_DATASPLITSRESPONSE']._serialized_end=8222 + _globals['_AGENTHEALTHRESPONSE']._serialized_start=8224 + _globals['_AGENTHEALTHRESPONSE']._serialized_end=8281 + _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8283 + _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8377 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8379 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8438 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8440 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8480 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8482 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8542 + _globals['_GETAGENTMODELSREQUEST']._serialized_start=8544 + _globals['_GETAGENTMODELSREQUEST']._serialized_end=8567 + _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8569 + _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8643 + _globals['_RESETAGENTRESPONSE']._serialized_start=8645 + _globals['_RESETAGENTRESPONSE']._serialized_end=8699 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8701 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8752 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8754 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8815 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8817 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8899 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8901 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8962 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8964 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8992 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8995 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=9124 + _globals['_CANCELEVALUATIONREQUEST']._serialized_start=9126 + _globals['_CANCELEVALUATIONREQUEST']._serialized_end=9167 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=9169 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=9229 + _globals['_EXPERIMENTSERVICE']._serialized_start=9647 + _globals['_EXPERIMENTSERVICE']._serialized_end=10989 # @@protoc_insertion_point(module_scope) diff --git a/weightslab/proto/experiment_service_pb2_grpc.py b/weightslab/proto/experiment_service_pb2_grpc.py index 004c4d93..51f19e87 100644 --- a/weightslab/proto/experiment_service_pb2_grpc.py +++ b/weightslab/proto/experiment_service_pb2_grpc.py @@ -74,6 +74,11 @@ def __init__(self, channel): request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesRequest.SerializeToString, response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesResponse.FromString, _registered_method=True) + self.GetMetaData = channel.unary_unary( + '/ExperimentService/GetMetaData', + request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, + response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.FromString, + _registered_method=True) self.GetPointCloud = channel.unary_stream( '/ExperimentService/GetPointCloud', request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.PointCloudRequest.SerializeToString, @@ -188,6 +193,16 @@ def GetDataSamples(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetMetaData(self, request, context): + """Metadata-only retrieval (dataframe columns). Returns every metadata column + name for the WHOLE dataset, the current grid slice's per-sample metadata, and + the open modal sample's metadata. Separated from GetDataSamples, which now + returns only image / label / prediction data. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def GetPointCloud(self, request, context): """Raw point cloud of one sample (task_type "detection_pointcloud"), server-streamed in binary chunks for the interactive 3D viewer. @@ -307,6 +322,11 @@ def add_ExperimentServiceServicer_to_server(servicer, server): request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesRequest.FromString, response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesResponse.SerializeToString, ), + 'GetMetaData': grpc.unary_unary_rpc_method_handler( + servicer.GetMetaData, + request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.FromString, + response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.SerializeToString, + ), 'GetPointCloud': grpc.unary_stream_rpc_method_handler( servicer.GetPointCloud, request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.PointCloudRequest.FromString, @@ -594,6 +614,33 @@ def GetDataSamples(request, metadata, _registered_method=True) + @staticmethod + def GetMetaData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/ExperimentService/GetMetaData', + weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, + weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def GetPointCloud(request, target, diff --git a/weightslab/tests/backend/test_compare_dataloaders.py b/weightslab/tests/backend/test_compare_dataloaders.py index 0b4d137b..245b413c 100644 --- a/weightslab/tests/backend/test_compare_dataloaders.py +++ b/weightslab/tests/backend/test_compare_dataloaders.py @@ -8,7 +8,7 @@ import unittest # On Windows, DataLoader workers use spawn: each worker re-imports the heavy -# weightslab package (torch + cv2 + onnx + langchain + cert/banner setup), so a +# weightslab package (torch + onnx + langchain + cert/banner setup), so a # multi-worker loader takes far longer than any sane test timeout. These tests # are meaningful on Linux/CI (cheap fork workers); skip the num_workers>0 cases # on Windows. Single-worker correctness still runs everywhere. diff --git a/weightslab/tests/gRPC/test_grpc_user_actions.py b/weightslab/tests/gRPC/test_grpc_user_actions.py index 05bfc639..f9b655c2 100644 --- a/weightslab/tests/gRPC/test_grpc_user_actions.py +++ b/weightslab/tests/gRPC/test_grpc_user_actions.py @@ -285,6 +285,9 @@ def _make_real_data_service(self): ds._compute_natural_sort = False ds._is_filtered = False ds._last_internals_update_time = 0 + # Thread pool used by GetDataSamples' per-sample path (GetMetaData uses the + # vectorized path and doesn't need it, but GetDataSamples does). + ds._data_executor = ThreadPoolExecutor(max_workers=2) ds._agent = MagicMock() ds._agent.is_ollama_available.return_value = True ds.audit_logger = MagicMock() @@ -426,9 +429,9 @@ def test_grpc_apply_data_query_direct_filter_reduces_view(self): kept = [idx[1] for idx in data_service._all_datasets_df.index.tolist()] self.assertEqual(sorted(kept), ["2", "3"]) - def test_grpc_get_data_samples_returns_scalar_stats(self): - """GetDataSamples must stream per-sample records with requested scalar stats - (no raw images needed for this path).""" + def test_grpc_get_data_samples_excludes_metadata(self): + """GetDataSamples returns image / label / prediction data only — metadata + columns (e.g. 'loss') are now served exclusively by GetMetaData.""" data_service, _ = self._make_real_data_service() servicer = self._make_servicer_with_real_data_service(data_service) @@ -444,10 +447,42 @@ def test_grpc_get_data_samples_returns_scalar_stats(self): self.assertEqual(len(response.data_records), 3) returned_ids = {r.sample_id for r in response.data_records} self.assertEqual(returned_ids, {"1", "2", "3"}) - # The requested 'loss' stat is present on each record. for rec in response.data_records: + names = {s.name for s in rec.data_stats} + # The 'loss' metadata stat must NOT leak through GetDataSamples anymore. + self.assertNotIn("loss", names) + # Rendering flags (origin/task_type/discarded) still travel with image data. + self.assertIn("origin", names) + self.assertIn("discarded", names) + + def test_grpc_get_metadata_returns_names_records_and_modal(self): + """GetMetaData returns whole-dataset column names, per-sample grid metadata + for the slice, and the open modal sample's metadata.""" + data_service, _ = self._make_real_data_service() + servicer = self._make_servicer_with_real_data_service(data_service) + + request = pb2.GetMetaDataRequest( + start_index=0, + records_cnt=10, + modal_sample_id="2", + ) + response = servicer.GetMetaData(request, _MockContext()) + + self.assertTrue(response.success) + # All metadata column names for the whole dataset include 'loss'. + self.assertIn("loss", list(response.all_metadata_names)) + # Grid records cover the slice and carry the 'loss' metadata stat. + self.assertEqual(len(response.grid_records), 3) + returned_ids = {r.sample_id for r in response.grid_records} + self.assertEqual(returned_ids, {"1", "2", "3"}) + for rec in response.grid_records: names = {s.name for s in rec.data_stats} self.assertIn("loss", names) + # Modal record resolves the requested sample_id with its metadata. + self.assertTrue(response.HasField("modal_record")) + self.assertEqual(response.modal_record.sample_id, "2") + modal_names = {s.name for s in response.modal_record.data_stats} + self.assertIn("loss", modal_names) class TestGRPCLoggerOutputIntegration(_TimeoutMixin, unittest.TestCase): diff --git a/weightslab/tests/trainer/services/test_trainer_services_unit.py b/weightslab/tests/trainer/services/test_trainer_services_unit.py index d1454521..40914b1d 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_unit.py +++ b/weightslab/tests/trainer/services/test_trainer_services_unit.py @@ -365,15 +365,10 @@ def test_metadata_only_response_uses_dataframe_columns(self): "quality": [0.1, 0.9], } ) - request = type( - "Req", - (), - { - "stats_to_retrieve": ["quality"], - }, - )() - response = service._build_metadata_only_response(df_slice, request) + # _build_metadata_only_response now takes an explicit requested_cols list + # (it is the building block reused by the GetMetaData RPC). + response = service._build_metadata_only_response(df_slice, ["quality"]) self.assertTrue(response.success) self.assertEqual(len(response.data_records), 2) diff --git a/weightslab/trainer/services/data_service.py b/weightslab/trainer/services/data_service.py index b12495c3..13ba5a17 100755 --- a/weightslab/trainer/services/data_service.py +++ b/weightslab/trainer/services/data_service.py @@ -901,26 +901,23 @@ def _compute_natural_sort_stats(self): logger.info(f"[DataService] Starting natural sort stats computation with weights: {SORT_WEIGHTS}") - try: - import cv2 - except ImportError: - logger.warning("[DataService] OpenCV not found. Skipping natural sort computation.") - return "OpenCV not installed" - if self._all_datasets_df is None or self._all_datasets_df.empty: return "No data to process" - # Helper: Calculate Shannon Entropy (Complexity) + # Helper: Calculate Shannon Entropy (Complexity). + # 256-bin histogram of the 8-bit grayscale image via numpy (no OpenCV). def calc_entropy(img_gray): try: - # Calculate histogram (256 bins for 8-bit) - hist = cv2.calcHist([img_gray], [0], None, [256], [0, 256]) - # Normalize histogram to get probabilities - p = hist.ravel() / hist.sum() - # Filter out zero probabilities to avoid log(0) + gray_u8 = np.clip(np.rint(img_gray), 0, 255).astype(np.uint8) + counts = np.bincount(gray_u8.ravel(), minlength=256).astype(np.float64) + total = counts.sum() + if total <= 0: + return 0.0 + # Normalize to probabilities, dropping zeros to avoid log(0) + p = counts / total p = p[p > 0] # Shannon Entropy in bits - return -np.sum(p * np.log2(p)) + return float(-np.sum(p * np.log2(p))) except Exception: return 0.0 @@ -961,27 +958,28 @@ def process_sample(args): # Convert to numpy (RGB) img_np = np.array(pil_img) - # Brightness (mean pixel intensity) - # If RGB, convert to Gray, else just mean - if img_np.ndim == 3: - # OpenCV expects BGR usually, but PIL gives RGB. - # cvtColor RGB2GRAY is correct. - try: - gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) - except Exception: - gray = img_np + # Brightness (mean pixel intensity). For color images, reduce to + # luma using the ITU-R 601-2 transform — the same weights PIL's + # "L" mode uses — with plain numpy. + if img_np.ndim == 3 and img_np.shape[2] >= 3: + luma = np.array([0.299, 0.587, 0.114], dtype=np.float32) + gray = img_np[..., :3].astype(np.float32) @ luma + elif img_np.ndim == 3: + gray = img_np[..., 0].astype(np.float32) else: - gray = img_np + gray = img_np.astype(np.float32) - brightness = np.mean(gray) + brightness = float(np.mean(gray)) entropy = calc_entropy(gray) - # HSV Stats - if img_np.ndim == 3: + # HSV stats (hue/saturation). Computed with Pillow's "HSV" mode + # (H, S, V each in 0-255) to avoid an OpenCV dependency. + if img_np.ndim == 3 and img_np.shape[2] >= 3: try: - hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV) - hue = np.mean(hsv[:, :, 0]) - saturation = np.mean(hsv[:, :, 1]) + rgb_img = pil_img if pil_img.mode == "RGB" else pil_img.convert("RGB") + hsv = np.asarray(rgb_img.convert("HSV")) + hue = float(np.mean(hsv[:, :, 0])) + saturation = float(np.mean(hsv[:, :, 1])) except Exception: hue = 0.0 saturation = 0.0 @@ -997,8 +995,8 @@ def process_sample(args): # Entropy: 0-8 bits typical for 8-bit image norm_entropy = min(max(entropy / 8.0, 0.0), 1.0) - # Hue: 0-179 in OpenCV - norm_hue = min(max(hue / 179.0, 0.0), 1.0) + # Hue: 0-255 in Pillow's HSV space + norm_hue = min(max(hue / 255.0, 0.0), 1.0) score = ( SORT_WEIGHTS.get("brightness", 0) * norm_brightness + @@ -1121,7 +1119,7 @@ def _process_sample_row(self, args): try: origin = row.get(SampleStatsEx.ORIGIN.value, 'unknown') sample_id = row.get(SampleStatsEx.SAMPLE_ID.value, 0) - # logger.debug(f"Processing sample_id={sample_id} from origin={origin} with request: {request}") + logger.debug(f"Processing sample_id={sample_id} from origin={origin} with request: {request}") # ===== Timing accumulators ===== t_image_load = 0.0 @@ -1204,58 +1202,14 @@ def _process_sample_row(self, args): except Exception: pass - # ====== Step 5a: Process stats ====== - stats_to_retrieve = list(request.stats_to_retrieve) - - # These columns are handled explicitly later in the pipeline - exclude_cols = { - SampleStatsEx.SAMPLE_ID.value, - SampleStatsEx.ORIGIN.value, - SampleStatsEx.TARGET.value if not skip_label_for_request else None, - SampleStatsEx.PREDICTION.value, - SampleStatsEx.TASK_TYPE.value, - '_instance_signals', # Special handling for multi-instance signals - 'annotation_id', # Internal multi-index tracking - } - - if not stats_to_retrieve: - stats_to_retrieve = [col for col in df_columns if col not in exclude_cols] - - # Optimized bulk processing of stats - for stat_name in stats_to_retrieve: - # Never re-process core fields generically (prevents duplicates/bad db state overwriting calculated state) - if stat_name in exclude_cols: - continue - - value = row.get(stat_name) - - # Skip prediction raw array - if (isinstance(value, np.ndarray) and value.ndim > 1) or (isinstance(value, (list, tuple, np.ndarray)) and len(value) == 0): - continue - elif isinstance(value, float): - value = round(value, 7) - elif isinstance(value, bool): - value = int(value) - - # Check if it s a tag column here and handle it as a string stat with the tag name as value - value_string = str(value) - if stat_name.startswith(f"{SampleStatsEx.TAG.value}"): - tag_name = stat_name[len(f"{SampleStatsEx.TAG.value}:"):] # Remove "tags_" prefix to get tag name - data_stats.append( - create_data_stat( - f"{SampleStatsEx.TAG.value}:{tag_name}", - "string", - shape=[1], - value_string=value_string, - thumbnail=b"" - ) - ) - else: - data_stats.append( - create_data_stat(stat_name, "string", shape=[1], value_string=value_string[:512], thumbnail=b"") - ) + # ====== Step 5a: Metadata stats — moved to GetMetaData ====== + # Generic dataframe metadata columns (signals, tags, custom fields, etc.) + # are no longer returned by GetDataSamples; the dedicated GetMetaData RPC + # serves them. GetDataSamples returns only the rendering flags + # origin / task_type / discarded (needed for the split border, overlay + # mode and gray-out) plus image / label / prediction data below. - # ====== Step 6: Add origin and task_type stats ====== + # ====== Step 6: Add origin, task_type and discarded rendering flags ====== data_stats.append( create_data_stat( "origin", 'string', shape=[1], value_string=origin, thumbnail=b"" @@ -1266,6 +1220,18 @@ def _process_sample_row(self, args): "task_type", 'string', shape=[1], value_string=str(task_type), thumbnail=b"" ) ) + # 'discarded' drives the grayed-out cell rendering, so it rides with the + # image data as "1"/"0" (not treated as analytical metadata). This keeps + # the gray-out reliable on every grid (re)fetch / scroll. + try: + _discarded_str = "1" if bool(row.get(SampleStatsEx.DISCARDED.value)) else "0" + except Exception: + _discarded_str = "0" + data_stats.append( + create_data_stat( + SampleStatsEx.DISCARDED.value, 'string', shape=[1], value_string=_discarded_str, thumbnail=b"" + ) + ) target_mask_stat_index = None pred_mask_stat_index = None @@ -2716,12 +2682,13 @@ def _is_metadata_only_request(self, request) -> bool: except Exception: return False - def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): - """Build DataSamplesResponse from dataframe columns only (no dataset/image traversal). + def _build_metadata_only_response(self, df_slice: pd.DataFrame, requested_cols=None): + """Build a DataSamplesResponse of metadata DataRecords from dataframe columns only. - This is a single-job vectorized path: the entire df_slice is processed - at once using pandas operations rather than dispatching per-sample_id - work to the thread pool. + No dataset/image traversal: the entire df_slice is processed at once using + vectorized pandas operations rather than dispatching per-sample_id work to + the thread pool. ``requested_cols`` restricts the columns; when None/empty + all columns are returned except heavy per-sample blobs. Used by GetMetaData. """ if df_slice is None or df_slice.empty: return pb2.DataSamplesResponse( @@ -2730,7 +2697,7 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): data_records=[], ) - requested_cols = list(getattr(request, 'stats_to_retrieve', []) or []) + requested_cols = list(requested_cols or []) # NOTE: ORIGIN is intentionally NOT excluded. The histogram (and any caller # that needs per-sample split coloring) requests 'origin' explicitly and # relies on this fast vectorized path to return it — without this, the client @@ -2792,6 +2759,11 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): series = df_slice[col] if series.dtype.kind == 'f': str_vals = series.round(7).astype(str).str[:512].tolist() + elif series.dtype.kind == 'b': + # Booleans (e.g. 'discarded') → "1"/"0" so the UI's boolean/discarded + # handling — which expects the legacy per-sample "1"/"0" form — keeps + # working now that metadata is served exclusively by GetMetaData. + str_vals = series.astype(int).astype(str).tolist() else: str_vals = series.astype(str).str[:512].tolist() # NaN → None @@ -2841,6 +2813,132 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): data_records=data_records, ) + def _get_all_metadata_column_names(self) -> list: + """Return every metadata column name available across the WHOLE dataset. + + Excludes heavy per-sample blob columns (pred/target) and internal + bookkeeping columns, matching the column set _build_metadata_only_response + emits. Order follows the dataframe columns (de-duplicated) so the UI column + picker stays stable across refreshes. + """ + try: + df = self._all_datasets_df + if df is None or df.empty: + return [] + _HEAVY_BLOB_COLS = {"pred", "prediction", "prediction_raw", "target"} + _INTERNAL_COLS = { + SampleStatsEx.SAMPLE_ID.value, + SampleStatsEx.TASK_TYPE.value, + "annotation_id", + "_instance_signals", + } + seen, names = set(), [] + for col in df.columns: + name = str(col) + if col in _HEAVY_BLOB_COLS or col in _INTERNAL_COLS: + continue + if name in seen: + continue + seen.add(name) + names.append(name) + # Include index-level names too (e.g. 'origin' in the (origin, sample_id) + # multi-index), excluding internal levels like sample_id/annotation_id. + if isinstance(df.index, pd.MultiIndex): + index_names = [n for n in (df.index.names or []) if n] + elif df.index.name: + index_names = [df.index.name] + else: + index_names = [] + for n in index_names: + name = str(n) + if n in _HEAVY_BLOB_COLS or n in _INTERNAL_COLS or name in seen: + continue + seen.add(name) + names.append(name) + return names + except Exception as e: + logger.warning("Error enumerating metadata column names: %s", e) + return [] + + def GetMetaData(self, request, context): + """Metadata-only retrieval, separated from GetDataSamples. + + Returns: + - all_metadata_names: every metadata column for the WHOLE dataset + - grid_records: per-sample metadata for the requested grid slice + - modal_record: metadata for the open modal sample (by sample_id), if any + """ + try: + # Read the current view directly (kept fresh by the same mechanisms + # GetDataSamples relies on); no forced refresh on the 15s metadata poll. + all_names = self._get_all_metadata_column_names() + df = self._all_datasets_df + + if df is None or df.empty: + return pb2.GetMetaDataResponse( + success=False, + message="Internal dataframe is empty or not initialized.", + all_metadata_names=all_names, + grid_records=[], + ) + + # ---- Grid slice metadata (current view order) ---- + grid_records = [] + start = max(0, int(getattr(request, "start_index", 0))) + count = int(getattr(request, "records_cnt", 0)) + if count > 0: + try: + df_slice = safe_reset_index(df.iloc[start:start + count]) + except IndexError: + df_slice = None + if df_slice is not None and not df_slice.empty: + df_slice, _ = self._merge_multi_instance_signals(df_slice) + grid_resp = self._build_metadata_only_response(df_slice) + if grid_resp.success: + grid_records = list(grid_resp.data_records) + + # ---- Modal sample metadata (optional, by sample_id) ---- + modal_record = None + modal_id = str(getattr(request, "modal_sample_id", "") or "").strip() + if modal_id: + try: + sid_col = SampleStatsEx.SAMPLE_ID.value + matches = None + if sid_col in df.columns: + matches = df[df[sid_col].astype(str) == modal_id] + elif isinstance(df.index, pd.MultiIndex) and sid_col in (df.index.names or []): + # sample_id is a multi-index level (origin, sample_id). + level_vals = df.index.get_level_values(sid_col).astype(str) + matches = df[level_vals == modal_id] + else: + matches = df[df.index.astype(str) == modal_id] + if matches is not None and not matches.empty: + modal_df = safe_reset_index(matches.iloc[[0]]) + modal_df, _ = self._merge_multi_instance_signals(modal_df) + modal_resp = self._build_metadata_only_response(modal_df) + if modal_resp.success and modal_resp.data_records: + modal_record = modal_resp.data_records[0] + except Exception as e: + logger.warning("GetMetaData modal lookup failed for %s: %s", modal_id, e) + + resp = pb2.GetMetaDataResponse( + success=True, + message=f"Retrieved {len(grid_records)} metadata records, {len(all_names)} columns", + all_metadata_names=all_names, + grid_records=grid_records, + ) + if modal_record is not None: + resp.modal_record.CopyFrom(modal_record) + return resp + except Exception as e: + logger.error("Error in GetMetaData: %s", str(e), exc_info=True) + return pb2.GetMetaDataResponse( + success=False, + message=f"Failed to retrieve metadata: {str(e)}", + all_metadata_names=[], + grid_records=[], + ) + def _merge_multi_instance_signals(self, df_slice): """Merge per-instance signals into dictionaries for multi-index dataframes. @@ -2988,8 +3086,9 @@ def _process_get_data_samples(self, request, context): logger.debug( "Retrieving samples from %s to %s", request.start_index, request.start_index + request.records_cnt) - if self._is_metadata_only_request(request): - return self._build_metadata_only_response(df_slice, request) + # NOTE: metadata-only requests are no longer served here. GetDataSamples + # returns image / label / prediction data only; metadata columns are + # served by the dedicated GetMetaData RPC. # ---- Preview-cache tolerant path ------------------------------- # For preview-tier requests, serve what is available from cache diff --git a/weightslab/trainer/services/experiment_service.py b/weightslab/trainer/services/experiment_service.py index 436997b5..50398dd4 100644 --- a/weightslab/trainer/services/experiment_service.py +++ b/weightslab/trainer/services/experiment_service.py @@ -664,6 +664,105 @@ def ExperimentCommand(self, request, context): ), ) + if request.HasField("save_checkpoint_operation"): + op = request.save_checkpoint_operation + + checkpoint_manager = components.get("checkpoint_manager") if isinstance(components, dict) else None + if checkpoint_manager is None: + try: + checkpoint_manager = ledgers.get_checkpoint_manager() + except Exception: + checkpoint_manager = None + if checkpoint_manager is None: + return pb2.CommandResponse(success=False, message="Checkpoint manager not initialized") + + # Resolve the live model. The ledger may hand back a proxy, so unwrap it + # the same way checkpoint_manager._save_changes does before dumping. + model = components.get("model") if isinstance(components, dict) else None + if model is None: + try: + model = ledgers.get_model() + except Exception: + model = None + if model is not None and callable(getattr(model, "get", None)): + try: + inner = model.get() + if inner is not None: + model = inner + except Exception: + pass + if model is None: + return pb2.CommandResponse(success=False, message="No model available to checkpoint") + + if getattr(checkpoint_manager, "current_exp_hash", None) is None: + return pb2.CommandResponse( + success=False, + message="No experiment hash set yet; run at least one training step before saving.", + ) + + # Snapshot model state under the training lock so we never persist a + # half-applied optimizer step (mirrors the load-checkpoint hygiene). + if not try_acquire_rlock(): + logger.error( + "[ExperimentCommand] save_checkpoint: weightslab_rlock timed out after %.0fs", + _GRPC_LOCK_TIMEOUT_S, + ) + return pb2.CommandResponse( + success=False, + message=f"Training busy: lock not acquired within {_GRPC_LOCK_TIMEOUT_S:.0f}s. Try again.", + ) + + ckpt_path = None + try: + # save_model_architecture is a no-op when the .pkl already exists, + # so delete it first when a forced architecture re-dump is requested. + if op.save_architecture: + try: + h = checkpoint_manager.current_exp_hash[8:-8] + arch_file = checkpoint_manager.models_dir / h / f"{h}_architecture.pkl" + if arch_file.exists(): + arch_file.unlink() + except Exception as e: + logger.warning(f"Could not remove existing architecture file for forced re-dump: {e}") + checkpoint_manager.save_model_architecture(model) + + ckpt_path = checkpoint_manager.save_model_checkpoint( + model=model, + save_optimizer=bool(op.save_optimizer), + save_model_checkpoint=True, + force_dump_pending=True, + ) + except Exception as e: + logger.error(f"Error during manual checkpoint save: {e}") + self._log_audit( + "checkpoint_save", + "failed", + {"save_architecture": bool(op.save_architecture), "save_optimizer": bool(op.save_optimizer)}, + error=str(e), + ) + return pb2.CommandResponse(success=False, message=f"Failed to save checkpoint: {e}") + finally: + weightslab_rlock.release() + + self._log_audit( + "checkpoint_save", + "success" if ckpt_path is not None else "failed", + { + "experiment_hash": checkpoint_manager.current_exp_hash, + "save_architecture": bool(op.save_architecture), + "save_optimizer": bool(op.save_optimizer), + "checkpoint_file": str(ckpt_path) if ckpt_path else None, + }, + ) + return pb2.CommandResponse( + success=ckpt_path is not None, + message=( + f"Saved model weights{' and architecture' if op.save_architecture else ''} for {checkpoint_manager.current_exp_hash}" + if ckpt_path is not None + else "Checkpoint save produced no weights file (weight dumping may be disabled in config)." + ), + ) + if request.HasField("hyper_parameter_change"): hyper_parameters = request.hyper_parameter_change.hyper_parameters diff --git a/weightslab/trainer/trainer_services.py b/weightslab/trainer/trainer_services.py index dfe1ad42..de0f1ece 100644 --- a/weightslab/trainer/trainer_services.py +++ b/weightslab/trainer/trainer_services.py @@ -334,6 +334,10 @@ def GetDataSamples(self, request, context): logger.debug(f"\nExperimentServiceServicer.GetDataSamples({request})") return self._exp_service.data_service.GetDataSamples(request, context) + def GetMetaData(self, request, context): + logger.debug(f"\nExperimentServiceServicer.GetMetaData({request})") + return self._exp_service.data_service.GetMetaData(request, context) + def GetPointCloud(self, request, context): logger.debug(f"\nExperimentServiceServicer.GetPointCloud({request})") # Server-streaming RPC: delegate the generator directly. From 9c4b79870465e4b73c5884756630ad8140122da2 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Fri, 19 Jun 2026 15:00:21 +0200 Subject: [PATCH 02/16] Refactor value proxy (#212) * Fix release note generation from PR commits only * Add __getitem__ to ValueProxy for parameters wrap --- .github/workflows/release.yml | 59 +++++++++++++++++++++--- weightslab/backend/ledgers.py | 17 +++++++ weightslab/tests/backend/test_ledgers.py | 23 +++++++++ 3 files changed, 93 insertions(+), 6 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2c391367..ae6155af 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,10 +47,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - # Install the test extra so pytest, graphviz, torchmetrics, + +# Install the test extra so pytest, graphviz, torchmetrics, # pytorch-lightning and tensorboard are available (several test modules # import pytest / use pytest fixtures and cannot run under bare unittest). - python -m pip install '.[utest]' --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install .[utest] --extra-index-url https://download.pytorch.org/whl/cpu python -m pip install pytest-timeout - name: Run tests @@ -230,7 +231,7 @@ jobs: export PRS_JSON PRS_JSON=$(gh pr list --state merged --base "dev" --limit 100 \ - --json number,title,mergedAt,url,author,body \ + --json number,title,mergedAt,url,author,body,commits \ 2>/dev/null || echo "[]") python3 << 'PYEOF' @@ -253,6 +254,29 @@ jobs: line += f"\n\n > {desc}" return line filtered = [pr for pr in prs_data if pr.get("mergedAt", "") > prev_date] + # "What's Changed precisely": list each merged PR's own developer commits, + # NOT the squashed/merge/chore commits that land on the release branch + # (those are redundant with the PR list above and not useful). + _seen_commits = set() + _commit_blocks = [] + for pr in filtered: + _commit_lines = [] + for c in (pr.get("commits") or []): + oid = (c.get("oid") or "")[:7] + head = (c.get("messageHeadline") or "").strip() + if not head or head.startswith("Merge "): + continue + key = (oid, head) + if key in _seen_commits: + continue + _seen_commits.add(key) + _commit_lines.append(f" - `{oid}` {head}") + if _commit_lines: + _commit_blocks.append( + f"**[#{pr['number']}]({pr['url']}) {pr['title']}**\n" + "\n".join(_commit_lines) + ) + # Fall back to the raw git log only when no PR commits are available. + commits_section = "\n\n".join(_commit_blocks) if _commit_blocks else commits if filtered: prs_lines = "\n".join(_pr_entry(pr) for pr in filtered) seen, clabels = set(), [] @@ -280,7 +304,7 @@ jobs: "Happy Training!\n\n" "---\n\n" "### What's Changed precisely:\n\n" - f"{commits}\n\n" + f"{commits_section}\n\n" "---\n\n" "### Thank you!\n\n" f"{contributors}\n" @@ -463,7 +487,7 @@ jobs: export PRS_JSON PRS_JSON=$(gh pr list --state merged --base "main" --limit 100 \ - --json number,title,mergedAt,url,author,body \ + --json number,title,mergedAt,url,author,body,commits \ 2>/dev/null || echo "[]") python3 << 'PYEOF' @@ -486,6 +510,29 @@ jobs: line += f"\n\n > {desc}" return line filtered = [pr for pr in prs_data if pr.get("mergedAt", "") > prev_date] + # "What's Changed precisely": list each merged PR's own developer commits, + # NOT the squashed/merge/chore commits that land on the release branch + # (those are redundant with the PR list above and not useful). + _seen_commits = set() + _commit_blocks = [] + for pr in filtered: + _commit_lines = [] + for c in (pr.get("commits") or []): + oid = (c.get("oid") or "")[:7] + head = (c.get("messageHeadline") or "").strip() + if not head or head.startswith("Merge "): + continue + key = (oid, head) + if key in _seen_commits: + continue + _seen_commits.add(key) + _commit_lines.append(f" - `{oid}` {head}") + if _commit_lines: + _commit_blocks.append( + f"**[#{pr['number']}]({pr['url']}) {pr['title']}**\n" + "\n".join(_commit_lines) + ) + # Fall back to the raw git log only when no PR commits are available. + commits_section = "\n\n".join(_commit_blocks) if _commit_blocks else commits if filtered: prs_lines = "\n".join(_pr_entry(pr) for pr in filtered) seen, clabels = set(), [] @@ -513,7 +560,7 @@ jobs: "Happy Training!\n\n" "---\n\n" "### What's Changed precisely:\n\n" - f"{commits}\n\n" + f"{commits_section}\n\n" "---\n\n" "### Thank you!\n\n" f"{contributors}\n" diff --git a/weightslab/backend/ledgers.py b/weightslab/backend/ledgers.py index eb188dde..a1b4e1ee 100644 --- a/weightslab/backend/ledgers.py +++ b/weightslab/backend/ledgers.py @@ -231,6 +231,23 @@ def __contains__(self, item: Any) -> bool: except TypeError: return False + def __getitem__(self, key: Any) -> Any: + """Support subscript access: ``proxy[key]``. + + Equivalent to ``.get(key)`` — reaches into the resolved value with + ``[]`` and wraps nested dicts in a fresh _ValueProxy so chaining + (e.g. ``proxy['dataset']['batch_size']``) keeps resolving live. + Missing keys raise ``KeyError``, matching standard subscript access. + """ + key = self._unwrap(key) + v = self._resolve() + if v is None: + raise KeyError(key) + value = v[key] + if isinstance(value, dict): + return Proxy._ValueProxy(Proxy(v), key) + return value + def __int__(self) -> int: return int(self._resolve()) diff --git a/weightslab/tests/backend/test_ledgers.py b/weightslab/tests/backend/test_ledgers.py index 07f679d7..232dbb3e 100644 --- a/weightslab/tests/backend/test_ledgers.py +++ b/weightslab/tests/backend/test_ledgers.py @@ -197,6 +197,29 @@ def test_proxy_get_key_default_mode_returns_live_proxy(self): hp_handle["lr"] = 0.02 self.assertEqual(lr.get(), 0.02) + def test_value_proxy_subscript_access(self): + """ValueProxy subscript [key] is equivalent to .get(key) and chains.""" + hp_handle = GLOBAL_LEDGER.get_hyperparams() + GLOBAL_LEDGER.register_hyperparams( + params={"dataset": {"batch_size": 32, "splits": {"train": 0.8}}} + ) + + dataset = hp_handle.get("dataset") + # [key] matches .get(key) for the resolved mapping. + self.assertEqual(dataset["batch_size"], dataset.get("batch_size")) + self.assertEqual(dataset["batch_size"], 32) + + # Nested dicts are wrapped in a live proxy so chaining keeps resolving. + self.assertEqual(dataset["splits"]["train"], 0.8) + + # Reads stay fresh against the underlying mapping. + hp_handle["dataset"] = {"batch_size": 64, "splits": {"train": 0.9}} + self.assertEqual(dataset["batch_size"], 64) + + # Missing keys raise KeyError, matching standard subscript semantics. + with self.assertRaises(KeyError): + dataset["missing"] + def test_proxy_pickles_and_restores(self): proxy = Proxy({"flag": True, "count": 3}) From 19944266be9cdff037b11e3419d4745413ec2b51 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Fri, 19 Jun 2026 15:00:39 +0200 Subject: [PATCH 03/16] docs: document BB_THUMB_RENDER / BB_MODAL_RENDER env vars [force ci] (#211) Document the Weights Studio per-image bounding-box render caps (GT and PRED capped independently) in the configuration env-var reference and the weights_studio deployment .env examples. Co-authored-by: Claude Opus 4.8 (1M context) --- docs/configuration.rst | 42 +++++++++++++++++++++++++++++++++++++++++ docs/weights_studio.rst | 8 +++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/docs/configuration.rst b/docs/configuration.rst index fd23ee41..de866df6 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -547,3 +547,45 @@ These variables are injected into the browser bundle at build / dev time. * - ``VITE_WS_MODAL_CACHE_MAX_MB`` - ``64`` - Maximum memory (MB) for the full-resolution modal image cache. + + +Bounding-box render limits +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Detection samples can carry many bounding boxes per image (dense scenes, +high-recall predictions). Drawing them all slows rendering and turns the +overlay into noise, so the number of boxes drawn per image is capped. The cap +is applied **separately** to ground-truth (GT) and predictions (PRED) — a value +of ``10`` allows up to 10 GT boxes *and* 10 PRED boxes per image. Boxes beyond +the cap are simply not drawn (predictions are typically score-ordered, so the +most confident ones are kept). + +These are set on the Weights Studio frontend container (for example in +``../weights_studio/docker/docker-compose.yml``) and injected into the page at +startup by the nginx entrypoint — changing them needs no rebuild, just a +container restart. For a local ``vite`` dev server, use the ``VITE_`` fallbacks +shown below. Values are clamped to a hard ceiling of ``10000``. + +.. list-table:: + :header-rows: 1 + :widths: 30 12 58 + + * - Variable + - Default + - Description + * - ``BB_THUMB_RENDER`` + - ``10`` + - Maximum bounding boxes drawn per image in the grid **thumbnails**, per + overlay (up to N ground-truth and N predictions). Dev-server fallback: + ``VITE_BB_THUMB_RENDER``. + * - ``BB_MODAL_RENDER`` + - ``100`` + - Maximum bounding boxes drawn per image in the **modal** detail view, per + overlay (up to N ground-truth and N predictions). A ``?`` button in the + top-right of the modal image surfaces the active limit on hover. + Dev-server fallback: ``VITE_BB_MODAL_RENDER``. + +.. note:: + + These caps only affect *rendering* — no sample data is dropped. They apply to + detection bounding-box overlays; segmentation masks are unaffected. diff --git a/docs/weights_studio.rst b/docs/weights_studio.rst index 79cba1e0..1d6036dc 100644 --- a/docs/weights_studio.rst +++ b/docs/weights_studio.rst @@ -71,6 +71,8 @@ Default values in ``../weights_studio/docker/.env``: - ``VITE_PORT=5173`` - ``VITE_HISTOGRAM_MAX_BINS=512`` +- ``BB_THUMB_RENDER=10`` (max bounding boxes drawn per thumbnail image, per overlay) +- ``BB_MODAL_RENDER=100`` (max bounding boxes drawn per modal image, per overlay) - ``WS_SERVER_HOST=localhost`` - ``WS_SERVER_PORT=8080`` - ``WS_SERVER_PROTOCOL=https`` @@ -313,6 +315,8 @@ consistently with your deployed endpoints: WS_SERVER_HOST=studio.your-domain.com WS_SERVER_PORT=443 VITE_HISTOGRAM_MAX_BINS=512 + BB_THUMB_RENDER=10 + BB_MODAL_RENDER=100 # envoy / backend internal wiring ENVOY_PORT=8080 @@ -369,7 +373,9 @@ Use this pattern for a simple single-VM production-like deployment. WS_SERVER_PROTOCOL=https WS_SERVER_HOST=studio.your-domain.com WS_SERVER_PORT=443 - VITE_HISTOGRAM_MAX_BINS=512 + VITE_HISTOGRAM_MAX_BINS=512 + BB_THUMB_RENDER=10 + BB_MODAL_RENDER=100 ENVOY_PORT=8080 ENVOY_ADMIN_PORT=9901 From 0be0e94f6a09a7ae2f98560a922f27b5a70e652e Mon Sep 17 00:00:00 2001 From: Guillaume Date: Fri, 19 Jun 2026 15:02:25 +0200 Subject: [PATCH 04/16] Manual weights dump (#210) * Add save btn and fix UI text issue of the current version * add details of export fct to readme --- README.md | 23 +++ weightslab/__init__.py | 2 + weightslab/proto/experiment_service.proto | 5 + weightslab/proto/experiment_service_pb2.py | 59 ++++++++ .../services/test_trainer_services_unit.py | 101 +++++++++++++ .../trainer/services/experiment_service.py | 133 ++++++++++++++++++ 6 files changed, 323 insertions(+) diff --git a/README.md b/README.md index 3481c80e..c9c0890a 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,29 @@ def main(): total_loss += loss.item() + # Write the history of these samples every x steps + if model.get_age() % 100 == 0: + print(f'Dump signals history and dataframe at age {model.get_age()}') + wl.write_history( + # path=None, # Use root_log_dir by default, filename generated from parameters md5 hash + type_of_history="all", + graph_name=[ + 'train/clsf_instance', + 'val/clsf_instance' + ], + # experiment_hash=None, Default is 'last', i.e., current experiment hash + sample_id=['11', '29', '28', '27', '22'], + instance_id=[1, 2, 3] + ) + + # Dump the sample dataframe: all signals plus the loss_shape categorical tag, + wl.write_dataframe( + columns=["signals", "tag:loss_shape"], + format='csv' + # sample_id=['0', '28'] + # instance_id=[1, 2], + ) + avg_loss = total_loss / len(dataloader) print(f"Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}") diff --git a/weightslab/__init__.py b/weightslab/__init__.py index 07262ba3..af668cc1 100644 --- a/weightslab/__init__.py +++ b/weightslab/__init__.py @@ -102,8 +102,10 @@ "query_signal_history", "query_sample_history", "query_instance_history", + "write_history", "write_dataframe", + "pointcloud_thumbnail", "pointcloud_boxes", diff --git a/weightslab/proto/experiment_service.proto b/weightslab/proto/experiment_service.proto index fefb5e47..7c4cffd5 100644 --- a/weightslab/proto/experiment_service.proto +++ b/weightslab/proto/experiment_service.proto @@ -160,6 +160,11 @@ message LoadCheckpointOperation { int32 checkpoint_id = 1; } +message SaveCheckpointOperation { + bool save_architecture = 1; // force re-dump the model architecture even if a file already exists + bool save_optimizer = 2; // also persist optimizer state alongside the weights +} + message PlotNoteOperation { string metric_name = 1; string experiment_hash = 2; diff --git a/weightslab/proto/experiment_service_pb2.py b/weightslab/proto/experiment_service_pb2.py index 1560bb5b..ce6f6b14 100644 --- a/weightslab/proto/experiment_service_pb2.py +++ b/weightslab/proto/experiment_service_pb2.py @@ -24,6 +24,7 @@ +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\x84\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"W\n\x12GetMetaDataRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12\x17\n\x0fmodal_sample_id\x18\x03 \x01(\t\"\x99\x01\n\x13GetMetaDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1a\n\x12\x61ll_metadata_names\x18\x03 \x03(\t\x12!\n\x0cgrid_records\x18\x04 \x03(\x0b\x32\x0b.DataRecord\x12!\n\x0cmodal_record\x18\x05 \x01(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xbe\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x38\n\x0bGetMetaData\x12\x13.GetMetaDataRequest\x1a\x14.GetMetaDataResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') _globals = globals() @@ -37,6 +38,16 @@ _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_options = b'8\001' _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._loaded_options = None _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_options = b'8\001' + _globals['_WEIGHTOPERATIONTYPE']._serialized_start=8986 + _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9086 + _globals['_ZEROFYPREDICATE']._serialized_start=9088 + _globals['_ZEROFYPREDICATE']._serialized_end=9199 + _globals['_AGENTINTENTTYPE']._serialized_start=9201 + _globals['_AGENTINTENTTYPE']._serialized_end=9278 + _globals['_SAMPLEEDITTYPE']._serialized_start=9280 + _globals['_SAMPLEEDITTYPE']._serialized_end=9353 + _globals['_AGENTPROVIDERTYPE']._serialized_start=9355 + _globals['_AGENTPROVIDERTYPE']._serialized_end=9399 _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9231 _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9331 _globals['_ZEROFYPREDICATE']._serialized_start=9333 @@ -79,6 +90,10 @@ _globals['_DENYSAMPLESOPERATION']._serialized_end=2317 _globals['_LOADCHECKPOINTOPERATION']._serialized_start=2319 _globals['_LOADCHECKPOINTOPERATION']._serialized_end=2367 + _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2369 + _globals['_SAVECHECKPOINTOPERATION']._serialized_end=2445 + _globals['_PLOTNOTEOPERATION']._serialized_start=2447 + _globals['_PLOTNOTEOPERATION']._serialized_end=2545 _globals['_PLOTNOTEOPERATION']._serialized_start=2369 _globals['_PLOTNOTEOPERATION']._serialized_end=2467 _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2469 @@ -135,6 +150,50 @@ _globals['_DATARECORD']._serialized_end=7277 _globals['_DATASAMPLESRESPONSE']._serialized_start=7279 _globals['_DATASAMPLESRESPONSE']._serialized_end=7369 + _globals['_POINTCLOUDREQUEST']._serialized_start=7371 + _globals['_POINTCLOUDREQUEST']._serialized_end=7445 + _globals['_POINTCLOUDCHUNK']._serialized_start=7448 + _globals['_POINTCLOUDCHUNK']._serialized_end=7639 + _globals['_DATAEDITSREQUEST']._serialized_start=7642 + _globals['_DATAEDITSREQUEST']._serialized_end=7862 + _globals['_DATAEDITSRESPONSE']._serialized_start=7864 + _globals['_DATAEDITSRESPONSE']._serialized_end=7917 + _globals['_DATASPLITSRESPONSE']._serialized_start=7919 + _globals['_DATASPLITSRESPONSE']._serialized_end=7977 + _globals['_AGENTHEALTHRESPONSE']._serialized_start=7979 + _globals['_AGENTHEALTHRESPONSE']._serialized_end=8036 + _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8038 + _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8132 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8134 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8193 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8195 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8235 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8237 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8297 + _globals['_GETAGENTMODELSREQUEST']._serialized_start=8299 + _globals['_GETAGENTMODELSREQUEST']._serialized_end=8322 + _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8324 + _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8398 + _globals['_RESETAGENTRESPONSE']._serialized_start=8400 + _globals['_RESETAGENTRESPONSE']._serialized_end=8454 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8456 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8507 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8509 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8570 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8572 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8654 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8656 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8717 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8719 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8747 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8750 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=8879 + _globals['_CANCELEVALUATIONREQUEST']._serialized_start=8881 + _globals['_CANCELEVALUATIONREQUEST']._serialized_end=8922 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=8924 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=8984 + _globals['_EXPERIMENTSERVICE']._serialized_start=9402 + _globals['_EXPERIMENTSERVICE']._serialized_end=10686 _globals['_GETMETADATAREQUEST']._serialized_start=7371 _globals['_GETMETADATAREQUEST']._serialized_end=7458 _globals['_GETMETADATARESPONSE']._serialized_start=7461 diff --git a/weightslab/tests/trainer/services/test_trainer_services_unit.py b/weightslab/tests/trainer/services/test_trainer_services_unit.py index 40914b1d..01c965c9 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_unit.py +++ b/weightslab/tests/trainer/services/test_trainer_services_unit.py @@ -148,6 +148,107 @@ def test_restore_checkpoint_weights_step_mode(self): _, kwargs = checkpoint_manager.load_state.call_args self.assertEqual(kwargs.get("target_step"), 5) + def _make_save_service(self, components): + ctx = _DummyCtx(components=components) + with patch("weightslab.trainer.services.experiment_service.DataService"): + return ExperimentService(ctx) + + def test_save_checkpoint_with_optimizer_and_architecture(self): + """Manual save with both optimizer + architecture: pauses, then dumps all three.""" + trainer = MagicMock() + model = MagicMock() + checkpoint_manager = MagicMock() + checkpoint_manager.save_model_checkpoint.return_value = "/tmp/exp/weights_step_000010.pt" + checkpoint_manager.save_model_architecture.return_value = "/tmp/exp/arch.pkl" + hp = {"is_training": True} + + service = self._make_save_service({ + "trainer": trainer, + "model": model, + "checkpoint_manager": checkpoint_manager, + "hyperparams": hp, + }) + + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation( + save_architecture=True, save_optimizer=True + ) + ) + response = service.ExperimentCommand(request, None) + + self.assertTrue(response.success) + # Training is paused BEFORE the dump, and is_training is cleared. + trainer.pause.assert_called_once() + self.assertFalse(hp["is_training"]) + # Weights dumped with optimizer; architecture dumped too. + checkpoint_manager.save_model_checkpoint.assert_called_once() + _, kwargs = checkpoint_manager.save_model_checkpoint.call_args + self.assertTrue(kwargs.get("save_optimizer")) + self.assertIs(kwargs.get("model"), model) + checkpoint_manager.save_model_architecture.assert_called_once_with(model) + self.assertIn("optimizer", response.message) + self.assertIn("architecture", response.message) + + def test_save_checkpoint_weights_only(self): + """Manual save without optimizer/architecture: only weights are dumped.""" + trainer = MagicMock() + model = MagicMock() + checkpoint_manager = MagicMock() + checkpoint_manager.save_model_checkpoint.return_value = "/tmp/exp/weights_step_000010.pt" + hp = {"is_training": True} + + service = self._make_save_service({ + "trainer": trainer, + "model": model, + "checkpoint_manager": checkpoint_manager, + "hyperparams": hp, + }) + + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation( + save_architecture=False, save_optimizer=False + ) + ) + response = service.ExperimentCommand(request, None) + + self.assertTrue(response.success) + trainer.pause.assert_called_once() + checkpoint_manager.save_model_checkpoint.assert_called_once() + _, kwargs = checkpoint_manager.save_model_checkpoint.call_args + self.assertFalse(kwargs.get("save_optimizer")) + # No architecture dump requested. + checkpoint_manager.save_model_architecture.assert_not_called() + self.assertNotIn("optimizer", response.message) + self.assertNotIn("architecture", response.message) + + def test_save_checkpoint_no_model_registered(self): + """No registered model: fails clearly and does NOT pause or dump.""" + trainer = MagicMock() + checkpoint_manager = MagicMock() + + service = self._make_save_service({ + "trainer": trainer, + "model": None, + "checkpoint_manager": checkpoint_manager, + "hyperparams": {}, + }) + + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation( + save_architecture=True, save_optimizer=True + ) + ) + # The ledger fallback must also report no model. + with patch("weightslab.trainer.services.experiment_service.ledgers.get_model", return_value=None): + response = service.ExperimentCommand(request, None) + + self.assertFalse(response.success) + self.assertIn("No model registered", response.message) + # Nothing was dumped, and a running experiment is left untouched. + trainer.pause.assert_not_called() + checkpoint_manager.save_model_checkpoint.assert_not_called() + checkpoint_manager.save_model_architecture.assert_not_called() + class TestModelServiceUnit(unittest.TestCase): def test_get_weights_success(self): diff --git a/weightslab/trainer/services/experiment_service.py b/weightslab/trainer/services/experiment_service.py index 50398dd4..4d32323b 100644 --- a/weightslab/trainer/services/experiment_service.py +++ b/weightslab/trainer/services/experiment_service.py @@ -596,6 +596,134 @@ def _get_live_hyper_parameter_descs(self, components): ) return hyper_parameter_descs + def _handle_save_checkpoint(self, save_op, components, context): + """Pause training and force-dump the current model weights. + + Implements the manual "Save weights" button. Training is paused *first* + (so the weights are captured at a consistent point, between training + steps) and only then are the latest weights written to a checkpoint — + optionally with the optimizer state and/or a fresh architecture dump, + per the ``SaveCheckpointOperation`` flags. + + If no model is registered there is nothing to dump: we return a clear + failure *without* disrupting the run (training is left untouched). + + Returns a ``pb2.CommandResponse``. + """ + save_optimizer = bool(getattr(save_op, "save_optimizer", False)) + save_architecture = bool(getattr(save_op, "save_architecture", False)) + audit_details = { + "save_optimizer": save_optimizer, + "save_architecture": save_architecture, + } + + # 1) Resolve the model first. Without a registered model there is nothing + # to dump — fail early so we don't needlessly pause a running experiment. + model = components.get("model") if components else None + if model is None: + model = ledgers.get_model() + if model is None: + msg = ( + "No model registered — nothing to dump. Register one with " + "watch_or_edit(model, flag='model')." + ) + logger.warning("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=msg) + return pb2.CommandResponse(success=False, message=msg) + + # 2) Resolve the checkpoint manager (component cache, then ledger fallback). + checkpoint_manager = components.get("checkpoint_manager") if components else None + if checkpoint_manager is None: + checkpoint_manager = ledgers.get_checkpoint_manager() + if checkpoint_manager is None: + msg = "Checkpoint manager not initialized" + logger.warning("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=msg) + return pb2.CommandResponse(success=False, message=msg) + + # 3) Pause training before dumping. Acquire the global lock so any + # in-flight training step has finished; clearing is_training keeps the + # loop parked at its next pause point, so the subsequent save reads a + # consistent model state. + if not try_acquire_rlock(): + logger.error( + "[SaveCheckpoint] weightslab_rlock timed out after %.0fs", + _GRPC_LOCK_TIMEOUT_S, + ) + if context is not None: + context.abort( + grpc.StatusCode.RESOURCE_EXHAUSTED, + f"Training lock not acquired within {_GRPC_LOCK_TIMEOUT_S:.0f}s", + ) + return pb2.CommandResponse(success=False, message="Lock timeout") + try: + trainer = components.get("trainer") if components else None + if trainer is not None: + logger.info("[SaveCheckpoint] Pausing training before weights dump...") + trainer.pause() + hp = components.get("hyperparams") if components else None + if hp is not None: + try: + hp["is_training"] = False + except Exception: + logger.debug("[SaveCheckpoint] Could not set is_training=False", exc_info=True) + finally: + weightslab_rlock.release() + + # 4) Ensure an experiment hash exists so save_model_checkpoint has a + # target directory (a brand-new experiment may not have one yet). + try: + if getattr(checkpoint_manager, "current_exp_hash", None) is None: + if hasattr(checkpoint_manager, "get_current_experiment_hash"): + checkpoint_manager.get_current_experiment_hash() + if ( + getattr(checkpoint_manager, "current_exp_hash", None) is None + and hasattr(checkpoint_manager, "update_experiment_hash") + ): + checkpoint_manager.update_experiment_hash(first_time=True) + except Exception: + logger.debug("[SaveCheckpoint] Could not ensure experiment hash", exc_info=True) + + # 5) Dump the weights (force_dump_pending flushes any pending changes). + try: + checkpoint_path = checkpoint_manager.save_model_checkpoint( + model=model, + save_optimizer=save_optimizer, + force_dump_pending=True, + ) + except Exception as e: + msg = f"Failed to save model weights: {e}" + logger.error("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=str(e)) + return pb2.CommandResponse(success=False, message=msg) + + if checkpoint_path is None: + msg = "Failed to save model weights (no checkpoint produced)." + logger.warning("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=msg) + return pb2.CommandResponse(success=False, message=msg) + + # 6) Optionally dump the architecture as well. + arch_saved = False + if save_architecture: + try: + arch_path = checkpoint_manager.save_model_architecture(model) + arch_saved = arch_path is not None + except Exception: + logger.warning("[SaveCheckpoint] Could not save model architecture", exc_info=True) + + saved = ["weights"] + if save_optimizer: + saved.append("optimizer") + if arch_saved: + saved.append("architecture") + msg = f"Saved {', '.join(saved)} (training paused)." + logger.info("[SaveCheckpoint] %s -> %s", msg, checkpoint_path) + audit_details["path"] = str(checkpoint_path) + audit_details["architecture_saved"] = arch_saved + self._log_audit("checkpoint_save", "success", audit_details) + return pb2.CommandResponse(success=True, message=msg) + # Training & hyperparameter commands # ------------------------------------------------------------------------- def ExperimentCommand(self, request, context): @@ -603,6 +731,11 @@ def ExperimentCommand(self, request, context): components = self._ctx.components # Write requests + if request.HasField("save_checkpoint_operation"): + return self._handle_save_checkpoint( + request.save_checkpoint_operation, components, context + ) + if request.HasField("plot_note_operation"): note_op = request.plot_note_operation metric_name = str(note_op.metric_name or "") From 28ed7eb90ff0717d872a4e90c7bc21adac3f4e3e Mon Sep 17 00:00:00 2001 From: Guillaume Date: Fri, 19 Jun 2026 15:02:39 +0200 Subject: [PATCH 05/16] DuckDB integratiion (#209) --- CLAUDE.md | 144 ++ pyproject.toml | 1 + weightslab/backend/logger.py | 1312 +++++++++-------- weightslab/src.py | 10 +- .../backend/test_instance_signal_logger.py | 65 +- weightslab/tests/backend/test_logger_core.py | 116 +- .../tests/gRPC/test_grpc_user_actions.py | 18 +- .../services/test_trainer_services_unit.py | 19 +- 8 files changed, 951 insertions(+), 734 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..9a1a56ba --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,144 @@ +# WeightsLab Workspace — Project Knowledge + +Self-contained knowledge file for the WeightsLab workspace. One file, organized by topic so you can jump to the section you need. **Jump to:** + +| If you need… | Go to section | +|---|---| +| The 3 repos and how they relate | [1. Workspace layout](#1-workspace-layout) | +| How backend & frontend talk at runtime | [2. Runtime integration](#2-runtime-integration) | +| Where backend Python code lives | [3. weightslab backend module map](#3-weightslab-backend-module-map) | +| Where frontend TS code / tests live | [4. weights_studio frontend module map](#4-weights_studio-frontend-module-map) | +| How a user training script plugs in | [5. Integration API (the usecase pattern)](#5-integration-api-the-usecase-pattern) | +| Testing rules, data/H5/tags features | [6. Topic notes](#6-topic-notes) | + +> Paths/line claims are point-in-time — verify against current code before asserting as fact. + +--- + +## 1. Workspace layout + +Three sibling repos under `c:\Users\GuillaumePELLUET\Documents\Codes\`: + +- **weightslab** (`weightslab/`) — Python backend/core. ML training, data processing, gRPC API, the published `pip install weightslab` package. Python pkg root: `weightslab/weightslab/`. +- **weights_studio** (`weights_studio/`) — TypeScript/Vite web UI. Consumes the backend over grpc-web. **All Playwright/E2E user-simulation tests live here, not in weightslab.** +- **weightslab_kitchen** (`weightslab_kitchen/`) — private examples/reference, minimal docs. + +They must be checked out **side-by-side**: weights_studio's proto codegen reads into the weightslab directory (see §2). + +--- + +## 2. Runtime integration + +### Shared contract — one proto, two sides +- Source of truth: `weightslab/weightslab/proto/experiment_service.proto` defines `service ExperimentService` (~20 RPCs). +- **Backend** implements it in `weightslab/weightslab/trainer/services/experiment_service.py` (the gRPC servicer), delegating to `model_service.py`, `data_service.py`, `agent_service.py`. +- **Frontend** consumes it via generated client `weights_studio/src/experiment_service.client.ts` (+ `experiment_service.ts`), produced by `npm run generate-proto:data` → + `protoc --ts_out src/ --proto_path ../weightslab/weightslab/proto experiment_service.proto`. + +### Wire path (browser → training process) +``` +weights_studio (browser, GrpcWebFetchTransport, src/main.ts) + → http(s) :8080 Envoy proxy (grpc-web ↔ grpc transcoding) + → cluster grpc-backend :__GRPC_BACKEND_PORT__ (Python gRPC servicer) + → in-process training loop (watched model/optimizer/data/loss) +``` +Browsers can't speak raw gRPC, so Envoy translates. Frontend default server port **8080** (Envoy listener); admin **9901**. Backend gRPC port is templated (`__GRPC_BACKEND_PORT__`, substituted at deploy). `main.ts` supports path-based deploys (`//api`, `/demo//api`) and loopback/TLS host normalization. + +### RPC groups +- **Training control:** `ExperimentCommand` (pause/resume/…), `GetLatestLoggerData` (metric streaming). +- **Weights/arch:** `ManipulateWeights`, `GetWeights`, `GetActivations`. +- **Data:** `GetSamples`, `ApplyDataQuery`, `GetDataSamples`, `EditDataSample`, `GetDataSplits`. +- **Agent (LLM):** `CheckAgentHealth`, `InitializeAgent`, `ChangeAgentModel`, `GetAgentModels`, `ResetAgent`. +- **Checkpoint/eval:** `RestoreCheckpoint`, `TriggerEvaluation`, `GetEvaluationStatus`, `CancelEvaluation`. + +### Deployment +- `weightslab ui launch` → `weightslab/weightslab/ui_docker_bridge.py` brings up the bundled Docker stack (`weightslab/weightslab/ui/docker/docker-compose.yml` + `envoy.yaml`) with TLS via `weightslab/security/CertAuthManager`. This is how the published package serves the studio UI. +- weights_studio also ships its own dev/prod Docker + Envoy under `weights_studio/docker/` and `weights_studio/envoy/`. `npm run dev` = Vite on :5173. + +### Proto-change checklist (keep all three in sync) +1. Edit `.proto`. 2. Regenerate backend `*_pb2*.py`. 3. Run `npm run generate-proto:data` in weights_studio. + +--- + +## 3. weightslab backend module map + +Package root `weightslab/weightslab/`. Public API re-exported from `__init__.py` (← `src.py`). Used as `import weightslab as wl`. + +Layers (top depends on lower): +- **`src.py`** — facade implementing public verbs: `watch_or_edit`, `serve`, `keep_serving`, `save_signals`, `save_instance_signals`, `tag_samples`, `register_categorical_tag`, `discard_samples`, `query_signal_history` / `query_sample_history` / `query_instance_history`, `get_current_experiment_hash`, etc. +- **`trainer/`** — orchestration. `trainer_services.py`, `trainer_tools.py`, `experiment_context.py`. + - `services/experiment_service.py` — the gRPC servicer implementing `ExperimentService`. + - `services/{model_service,data_service,agent_service}.py` — per-domain delegates. + - `services/agent/` — LLM agent (configured by repo-root `agent_config.yaml`: `ollama` local / `openrouter` remote). + - `services/instance_merger.py` — multi-instance (detection/seg) handling. +- **`components/`** — cross-cutting runtime machinery. + - `global_monitoring.py` — `guard_training_context` / `guard_testing_context`, pause controller, the global rlock used by the servicer (training + serving run in one process, different threads). + - `evaluation_controller.py` (`eval_controller`), `checkpoint_manager.py`, `tracking.py`, `experiment_hash.py`, `parallel_primitives.py`. +- **`models/`** — `model_with_ops.py` (watched/op-able model wrapper), `monkey_patcher.py`. +- **`data/`** — dataframe + storage backbone. `dataframe_manager.py`, `data_samples_with_ops.py`, `sample_stats.py` (`SampleStatsEx`); storage `h5_dataframe_store.py`, `h5_array_store.py`, `array_proxy.py`. +- **`backend/`** — primitives. `ledgers.py` (`GLOBAL_LEDGER`, hyperparameter registry: `get_hyperparams`/`set_hyperparam`/`Proxy`), `model_interface.py`, `optimizer_interface.py`, `dataloader_interface.py`, `audit_logger.py`, `logger.py`, `cli.py` (optional localhost TCP REPL). +- **`proto/`** — `.proto` + generated `*_pb2*.py` (shared with weights_studio). +- **`baseline_models/`** — ready nets (e.g. `baseline_models.pytorch.models.FashionCNN`). +- **`ui/`** — bundled Docker/Envoy/nginx assets. **`security/`** — `CertAuthManager`. **`examples/`** — see §5. + +**Key fan-in points:** `ledgers.GLOBAL_LEDGER` is the hub (`watch_or_edit` registers objects there; the servicer reads/mutates through it). `components/global_monitoring` locks coordinate the training thread with gRPC calls. + +--- + +## 4. weights_studio frontend module map + +Vite + TypeScript. Entry `index.html` → `src/main.ts`. + +- **`main.ts`** — bootstrap: infers server host/port (default :8080), builds `GrpcWebFetchTransport`, wires panels, handles path-based deploy + TLS host normalization. +- **`experiment_service.client.ts` + `experiment_service.ts`** — generated gRPC-web client/types (also under `src/proto/`). **Do not hand-edit;** regenerate via `generate-proto:data`. +- **`left_panel/`** (`leftPanel.ts`, `panelResizer.ts` — controls, class/tag prefs), **`main_area/`** (board resizers), **`plots/`** (Chart.js + zoom), **`grid_data/`** (sample grid/table), **`agent/agentPanel.ts`** (LLM agent UI), **`ui/`/`utils/`/`helpers.ts`/`ContextMenu.ts`/`darkMode.ts`/`resilience.ts`** (shared UI + reconnection), **`test/`** (vitest). + +### Build / proto scripts (package.json) +- `generate-proto:data` reads the sibling weightslab repo (must be side-by-side). +- `npm run dev` (Vite, `VITE_HOST` 0.0.0.0 / `VITE_PORT` 5173), `build`, `preview`. + +### Tests (see §6 for placement rule) +- Unit: `npm run test` (vitest). +- Managed realtime (spins backend, via `scripts/run-managed-playwright.mjs`): `test:realtime:cls`, `test:realtime:seg`. +- Real-usecase E2E: `test:e2e:detection` (`tests/playwright/real_usecases/user_detection_yolo.spec.ts`), `test:e2e:segmentation` (`...user_segmentation_bdd.spec.ts`). +- `test:all` = unit + realtime cls/seg + e2e. + +--- + +## 5. Integration API (the usecase pattern) + +How a user's own PyTorch script plugs in so weights_studio can inspect/edit it live. Examples: `weightslab/weightslab/examples/{PyTorch,PyTorch_Lightning}//` — each is `main.py` + `config.yaml`. Usecases: `ws-classification`, `ws-segmentation`, `ws-face_recognition-triplet_loss`, `ws-vad` (+ Lightning classification). + +### Pattern — `import weightslab as wl` +Wrap each training object with `wl.watch_or_edit(obj, flag=...)`; the returned tracked proxy is registered in the ledger so the gRPC service can read stats / apply edits at runtime: +- `flag="hyperparameters"` — HP dict (required flag for trainer-services/UI visibility). +- `flag="model"` — wraps `nn.Module` (`device=…`); enables weight inspection + arch ops + `.get_age()`. +- `flag="optimizer"`. +- `flag="data"` — wraps a `Dataset` into a tracked loader: `loader_name`, `batch_size`, `shuffle`, `is_training`, `preload_labels`, `enable_h5_persistence`, … +- `flag="loss"` — wraps a `reduction="none"` criterion (`signal_name`, `log=True`); called `(preds_raw, targets, batch_ids=ids, preds=preds)` so per-sample loss maps to sample ids. +- `flag="metric"` — wraps a torchmetrics metric. + +### Dataset contract +`Dataset.__getitem__` returns **`(image, idx, label)`** — the sample id is threaded through training so per-sample signals attribute back to the sample. + +### Loop conventions +- `with guard_training_context:` (train step) / `with guard_testing_context:` (eval) — drives pause/resume + train/test stat separation. +- `wl.save_signals(preds_raw=, targets=, batch_ids=ids, signals={...}, preds=)` for extra per-sample signals. +- Use `model.get_age()` (steps actually trained, survives checkpoint reloads), not the raw loop counter. + +### Serving lifecycle +- `wl.serve(serving_grpc=…, serving_cli=…)` starts background serving threads **in the same process** as training. +- End the script with `wl.keep_serving()` to keep serving threads alive after the loop. +- Config from sibling `config.yaml`: `root_log_dir`, `device`, `training_steps_to_do`, `eval_full_to_train_steps_ratio`, `data.*_loader`, `optimizer.lr`, `enable_h5_persistence`, `serving_grpc`, … + +--- + +## 6. Topic notes + +- **Playwright test placement:** E2E/user-simulation tests belong in **weights_studio** (UI simulation), not weightslab. The Python backend now starts in **parallel** with Docker deployment (not sequentially) in the managed test runner. +- **Multi-instance dataframe:** MultiIndex `(sample_id, annotation_id)` supports per-instance data for detection/segmentation. +- **H5 storage:** `H5DataFrameStore` preserves the `(sample_id, annotation_id)` multi-index through write/read. `tag:xxx` columns are auto-optimized to categorical dtype (~90% memory savings). +- **Categorical tags:** planned support for multi-value tags with predefined categories; boolean tags unchanged. +- **Detection class colors:** class preferences from the left panel apply to detection bbox rendering. +- **Audit logger:** json/csv output configurable via `AUDIT_LOG_FORMAT` env var. +- **Docker-in-Docker:** envoy template mounting / file access fixed for GitHub Actions runner DinD environments. diff --git a/pyproject.toml b/pyproject.toml index d9d72aad..71aef233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "numpy>=1.25.2,<2.0; python_version < '3.13'", "numpy>=2.1,<3; python_version >= '3.13'", "pandas>=2.2.3,<3", + "duckdb>=1.1,<2", # signal/sample/instance history store "PyYAML>=6.0.3,<7", "dill>=0.3.8,<0.5", "zstandard>=0.22,<1", diff --git a/weightslab/backend/logger.py b/weightslab/backend/logger.py index c4a780c9..2104eff9 100644 --- a/weightslab/backend/logger.py +++ b/weightslab/backend/logger.py @@ -1,70 +1,89 @@ -import torch as th +"""DuckDB-backed signal history logger. + +``LoggerQueue`` is a thin interface that maps the logger's public methods onto +a DuckDB database holding three history tables: + +* ``signals`` — aggregated training-curve points (one row per averaged + step entry / evaluation marker). +* ``per_sample`` — per-sample signal values ``(sample_id, step, value)``. +* ``per_instance`` — per-instance values ``(sample_id, annotation_id, step, value)`` + for detection / segmentation. + +Design notes +------------ +* **Hot path is RAM, reads hit DuckDB.** ``add_scalars`` / + ``add_instance_scalars`` only append to in-memory staging lists (O(1), no SQL). + Rows are bulk-inserted into DuckDB lazily — right before any query, snapshot, + delete or update — via a single vectorized ``INSERT ... SELECT``. This keeps + per-step logging cheap while letting DuckDB do the heavy aggregation + (``GROUP BY step`` over millions of rows) in native code — exactly what + break-by-slices needs. +* **Transient runtime state stays in Python.** The live-streaming pending queue, + the per-step aggregation buffer and the evaluation accumulator are small and + short-lived, so they remain plain Python structures. +* **Persistence.** ``db_path`` defaults to ``":memory:"``. Pass a file path to + back the history with an on-disk DuckDB file. Either way ``save_snapshot`` / + ``load_snapshot`` round-trip the full history as a plain dict, so the + checkpoint manager's snapshotting is unchanged. +* **Thread-safety.** A single DuckDB connection is guarded by an ``RLock``; + staging appends and flushes take the same lock. +""" + +import json +import threading import time -from array import array as _array -from copy import deepcopy - -from weightslab.backend.ledgers import get_logger, register_logger, get_checkpoint_manager - - -def _make_per_sample_buf(): - """Compact storage for per-sample signals: three typed C arrays. - Uses array.array instead of a list of dicts to reduce memory by ~20-40x: - - list of dicts: ~400-600 bytes/entry (Python dict overhead + 6 string keys) - - compact arrays: 12 bytes/entry (int32 + int32 + float32) +import duckdb +import pandas as pd +import torch as th - Fields: - sample_ids: list of str - dataset sample index - steps: signed int32 - global training step - values: float32 - signal value at that step for that sample - """ - return { - "sample_ids": [], # str - "steps": _array('i'), # int32, 4 bytes each - "values": _array('f'), # float32, 4 bytes each - } +from weightslab.backend.ledgers import get_logger, register_logger, get_checkpoint_manager -def _make_per_instance_buf(): - """Compact storage for per-instance signals: four typed C arrays. +# Column order for each table's staging buffer / bulk insert. +_SIGNAL_COLS = [ + "metric_name", "experiment_hash", "step", "metric_value", "timestamp", + "audit_mode", "is_evaluation_marker", "split_name", "evaluation_tags", + "point_note", "seq", +] +_SAMPLE_COLS = ["metric_name", "experiment_hash", "sample_id", "step", "value", "seq"] +_INSTANCE_COLS = [ + "metric_name", "experiment_hash", "sample_id", "annotation_id", "step", "value", "seq", +] - Fields: - sample_ids: list of str - dataset sample index - annotation_ids: signed int32 - instance index within sample (1-based) - steps: signed int32 - global training step - values: float32 - signal value at that step for that instance - """ - return { - "sample_ids": [], # str - "annotation_ids": _array('i'), # int32, 4 bytes each - "steps": _array('i'), # int32, 4 bytes each - "values": _array('f'), # float32, 4 bytes each - } +# Auto-flush staged rows to DuckDB once the combined staging buffers exceed this +# many rows, to bound memory during long runs that never read history. +_STAGE_FLUSH_THRESHOLD = 50_000 class LoggerQueue: - def __init__(self, register: bool = True) -> None: + def __init__(self, register: bool = True, db_path: str = ":memory:") -> None: self.graph_names = set() self._current_step_buffer = {} self._last_step = None - self._signal_history = {} # Keep all signals in memory for persistence - self._signal_history_per_sample = {} # Keep all signals per sample in memory for persistence - self._signal_history_per_instance = {} # Keep all signals per instance in memory for persistence - # Reverse indices: O(1) lookup by sample_id or (sample_id, annotation_id) - # Structure: {graph_name: {exp_hash: {sample_id: [row_indices]}}} - self._sample_index = {} - # Structure: {graph_name: {exp_hash: {(sample_id, annotation_id): [row_indices]}}} - self._instance_index = {} - self._pending_queue = [] # Queue for new signals waiting to be sent to WeightsStudio + + # Live-streaming queue of new points waiting to be sent to WeightsStudio. + self._pending_queue = [] self._buffered_step = None - # Evaluation mode state + # Evaluation mode state (transient). self._eval_mode_active: bool = False self._eval_mode_hash: str = "" self._eval_mode_split: str = "" self._eval_mode_tags: list[str] = [] self._eval_accum: dict = {} # {graph_name: [sum, count]} + # DuckDB connection + write-staging buffers. + self._lock = threading.RLock() + self._db_path = db_path + self._conn = duckdb.connect(database=db_path) + self._stage_signals: list = [] + self._stage_sample: list = [] + self._stage_instance: list = [] + self._seq = 0 + self._ensure_tables() + self._restore_runtime_state_from_db() + lg = None if register: try: @@ -76,27 +95,155 @@ def __init__(self, register: bool = True) -> None: # Init checkpoint manager for experiment hash retrieval (if available) self.chkpt_manager = get_checkpoint_manager() + # ------------------------------------------------------------------ + # DuckDB plumbing + # ------------------------------------------------------------------ + def _ensure_tables(self) -> None: + with self._lock: + self._conn.execute( + """ + CREATE TABLE IF NOT EXISTS signals ( + metric_name VARCHAR, + experiment_hash VARCHAR, + step INTEGER, + metric_value DOUBLE, + timestamp BIGINT, + audit_mode BOOLEAN, + is_evaluation_marker BOOLEAN, + split_name VARCHAR, + evaluation_tags VARCHAR, + point_note VARCHAR, + seq BIGINT + ); + CREATE TABLE IF NOT EXISTS per_sample ( + metric_name VARCHAR, + experiment_hash VARCHAR, + sample_id VARCHAR, + step INTEGER, + value REAL, + seq BIGINT + ); + CREATE TABLE IF NOT EXISTS per_instance ( + metric_name VARCHAR, + experiment_hash VARCHAR, + sample_id VARCHAR, + annotation_id INTEGER, + step INTEGER, + value REAL, + seq BIGINT + ); + """ + ) + + def _restore_runtime_state_from_db(self) -> None: + """Repopulate seq counter and graph names from an existing (file) DB.""" + with self._lock: + max_seq = self._conn.execute( + """ + SELECT max(m) FROM ( + SELECT max(seq) AS m FROM signals + UNION ALL SELECT max(seq) FROM per_sample + UNION ALL SELECT max(seq) FROM per_instance + ) + """ + ).fetchone()[0] + self._seq = (int(max_seq) + 1) if max_seq is not None else 0 + + for tbl in ("signals", "per_sample", "per_instance"): + for (name,) in self._conn.execute( + f"SELECT DISTINCT metric_name FROM {tbl}" + ).fetchall(): + if name is not None: + self.graph_names.add(name) + + def _next_seq(self) -> int: + s = self._seq + self._seq += 1 + return s + + def _maybe_autoflush(self) -> None: + if (len(self._stage_signals) + len(self._stage_sample) + + len(self._stage_instance)) >= _STAGE_FLUSH_THRESHOLD: + self._flush_stage() + + def _flush_stage(self) -> None: + """Bulk-insert all staged rows into DuckDB and clear the buffers.""" + with self._lock: + if self._stage_signals: + df = pd.DataFrame(self._stage_signals, columns=_SIGNAL_COLS) + self._conn.register("_stg_sig", df) + self._conn.execute("INSERT INTO signals SELECT * FROM _stg_sig") + self._conn.unregister("_stg_sig") + self._stage_signals = [] + if self._stage_sample: + df = pd.DataFrame(self._stage_sample, columns=_SAMPLE_COLS) + self._conn.register("_stg_ps", df) + self._conn.execute("INSERT INTO per_sample SELECT * FROM _stg_ps") + self._conn.unregister("_stg_ps") + self._stage_sample = [] + if self._stage_instance: + df = pd.DataFrame(self._stage_instance, columns=_INSTANCE_COLS) + self._conn.register("_stg_pi", df) + self._conn.execute("INSERT INTO per_instance SELECT * FROM _stg_pi") + self._conn.unregister("_stg_pi") + self._stage_instance = [] + + def _stage_signal_row(self, graph_name, exp_hash, step, metric_value, timestamp, + audit_mode, is_marker, split_name, eval_tags, point_note): + self._stage_signals.append(( + graph_name, exp_hash, int(step), float(metric_value), int(timestamp), + bool(audit_mode), bool(is_marker), split_name or "", + json.dumps(list(eval_tags or [])), point_note or "", self._next_seq(), + )) + self._maybe_autoflush() + + def _stage_sample_row(self, graph_name, exp_hash, sample_id, step, value): + self._stage_sample.append(( + graph_name, exp_hash, str(sample_id), int(step), float(value), self._next_seq(), + )) + self._maybe_autoflush() + + def _stage_instance_row(self, graph_name, exp_hash, sample_id, annotation_id, step, value): + self._stage_instance.append(( + graph_name, exp_hash, str(sample_id), int(annotation_id), int(step), + float(value), self._next_seq(), + )) + self._maybe_autoflush() + + @staticmethod + def _hash_filter(exp_hash, params, table_alias=""): + """Append an experiment-hash WHERE fragment. ``None`` means 'all hashes'.""" + if exp_hash is None: + return "" + params.append(exp_hash) + col = f"{table_alias}experiment_hash" if table_alias else "experiment_hash" + return f" AND {col} = ?" + def __len__(self): - """Return logger length.""" - len_history = 0 - for k in self._signal_history: - for exp_hash in self._signal_history[k]: - l = len(self._signal_history[k][exp_hash]) - len_history = max(len_history, l) - return len_history - - # Clear history method (can be called by WeightsLabCallback at the start of a new experiment to reset state, - # while preserving graph names which are derived from signals and may be needed for future signals after clearing history) + """Max number of distinct steps recorded for any (metric, hash) curve.""" + with self._lock: + self._flush_stage() + row = self._conn.execute( + """ + SELECT max(cnt) FROM ( + SELECT count(DISTINCT step) AS cnt + FROM signals GROUP BY metric_name, experiment_hash + ) + """ + ).fetchone() + return int(row[0]) if row and row[0] is not None else 0 + def clear_signal_histories(self): - """Clear signal histories.""" - # Note: We do not clear graph names here as they are derived from signals and may be needed for future signals after clearing history. - self._signal_history.clear() - self._signal_history_per_sample.clear() - self._signal_history_per_instance.clear() - self._sample_index.clear() - self._instance_index.clear() - self._current_step_buffer.clear() - self._buffered_step = None + """Clear all signal histories (keeps graph names and runtime buffers reset).""" + with self._lock: + self._stage_signals = [] + self._stage_sample = [] + self._stage_instance = [] + self._conn.execute("DELETE FROM signals") + self._conn.execute("DELETE FROM per_sample") + self._conn.execute("DELETE FROM per_instance") + self._current_step_buffer.clear() + self._buffered_step = None def _to_float(self, value): if isinstance(value, th.Tensor): @@ -111,7 +258,6 @@ def _get_audit_mode(self): 2. Check hyperparams auditor_mode (fallback for legacy/hyperparams-based control) """ try: - # First priority: check registered model interface from weightslab.backend.ledgers import get_model model = get_model() if model is not None and hasattr(model, 'audit_mode'): @@ -120,7 +266,6 @@ def _get_audit_mode(self): pass try: - # Fallback: check hyperparams auditor_mode from weightslab.backend.ledgers import get_hyperparams hp = get_hyperparams() if hp is not None: @@ -129,27 +274,33 @@ def _get_audit_mode(self): pass return False - def _append_history_entry(self, graph_name, exp_hash, global_step, metric_value, audit_mode=None): + def _append_history_entry(self, graph_name, exp_hash, global_step, metric_value, + audit_mode=None, is_marker=False, split_name="", + evaluation_tags=None): + """Stage a signals row and return the live-queue entry dict.""" if audit_mode is None: audit_mode = self._get_audit_mode() + timestamp = int(time.time()) signal_entry = { "model_age": global_step, "metric_name": graph_name, "metric_value": metric_value, "experiment_hash": exp_hash, - "timestamp": int(time.time()), + "timestamp": timestamp, "audit_mode": audit_mode, } - - if graph_name not in self._signal_history: - self._signal_history[graph_name] = {} - if exp_hash not in self._signal_history[graph_name]: - self._signal_history[graph_name][exp_hash] = {} - if global_step not in self._signal_history[graph_name][exp_hash]: - self._signal_history[graph_name][exp_hash][global_step] = [] - - self._signal_history[graph_name][exp_hash][global_step].append(signal_entry) + if is_marker: + signal_entry["is_evaluation_marker"] = True + signal_entry["split_name"] = split_name + signal_entry["evaluation_tags"] = list(evaluation_tags or []) + + with self._lock: + self._stage_signal_row( + graph_name, exp_hash, global_step, metric_value, timestamp, + bool(audit_mode), bool(is_marker), split_name, + list(evaluation_tags or []), "", + ) return signal_entry def _flush_current_step_buffer(self, add_to_queue: bool): @@ -179,21 +330,27 @@ def _flush_current_step_buffer(self, add_to_queue: bool): def get_next_evaluation_count(self, base_hash: str) -> int: """Return the next unused evaluation index for *base_hash*. - Scans the current signal history for keys of the form + Scans recorded experiment hashes for keys of the form ``_`` and returns max(found) + 1 (or 1 if none). """ prefix = base_hash + "_" max_count = 0 - for gname in self._signal_history: - for hash_key in self._signal_history[gname]: - if isinstance(hash_key, str) and hash_key.startswith(prefix): - suffix = hash_key[len(prefix):] - try: - count = int(suffix) - if count > max_count: - max_count = count - except ValueError: - pass + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT DISTINCT experiment_hash FROM signals " + "WHERE experiment_hash LIKE ?", + [prefix + "%"], + ).fetchall() + for (hash_key,) in rows: + if isinstance(hash_key, str) and hash_key.startswith(prefix): + suffix = hash_key[len(prefix):] + try: + count = int(suffix) + if count > max_count: + max_count = count + except ValueError: + pass return max_count + 1 def start_evaluation_mode(self, split_name: str, eval_hash: str, evaluation_tags=None) -> None: @@ -205,10 +362,6 @@ def start_evaluation_mode(self, split_name: str, eval_hash: str, evaluation_tags Per-sample history *is* still updated (for Break-By-Slice on eval results), using *eval_hash* as the experiment key. - - Args: - split_name: Human-readable split name (e.g. ``"train_loader"``). - eval_hash: Modified experiment hash (e.g. ``"abc123_1"``). """ self._flush_current_step_buffer(add_to_queue=True) self._eval_mode_active = True @@ -225,9 +378,6 @@ def stop_evaluation_mode(self, model_age: int) -> dict: history under *eval_hash* and into the pending queue, then resets evaluation-mode state. - Args: - model_age: Current model age (training step) at time of evaluation. - Returns: Dict mapping graph_name → averaged value for all signals seen. """ @@ -248,26 +398,16 @@ def stop_evaluation_mode(self, model_age: int) -> dict: results[graph_name] = avg self.graph_names.add(graph_name) - # Store in signal history under eval_hash - if graph_name not in self._signal_history: - self._signal_history[graph_name] = {} - if eval_hash not in self._signal_history[graph_name]: - self._signal_history[graph_name][eval_hash] = {} - if model_age not in self._signal_history[graph_name][eval_hash]: - self._signal_history[graph_name][eval_hash][model_age] = [] - - entry = { - "model_age": model_age, - "metric_name": graph_name, - "metric_value": avg, - "experiment_hash": eval_hash, - "timestamp": int(time.time()), - "is_evaluation_marker": True, - "split_name": split_name, - "evaluation_tags": evaluation_tags, - "audit_mode": audit_mode, - } - self._signal_history[graph_name][eval_hash][model_age].append(entry) + entry = self._append_history_entry( + graph_name=graph_name, + exp_hash=eval_hash, + global_step=model_age, + metric_value=avg, + audit_mode=audit_mode, + is_marker=True, + split_name=split_name, + evaluation_tags=evaluation_tags, + ) self._pending_queue.append(entry) self._eval_accum = {} @@ -277,12 +417,7 @@ def stop_evaluation_mode(self, model_age: int) -> dict: return results def abort_evaluation_mode(self) -> None: - """Abort evaluation mode and drop all in-progress evaluation data. - - This is used when an evaluation is canceled or timed out. - It clears the accumulation buffer and removes any per-sample history - that may have been written under the in-flight evaluation hash. - """ + """Abort evaluation mode and drop all in-progress evaluation data.""" if not self._eval_mode_active: return @@ -304,19 +439,10 @@ def remove_evaluation_hash(self, eval_hash: str) -> None: if not eval_hash: return - # Remove any marker/history entries tied to the evaluation hash. - for graph_name in list(self._signal_history.keys()): - try: - self._signal_history[graph_name].pop(eval_hash, None) - except Exception: - pass - - # Remove per-sample traces recorded under the same hash. - for graph_name in list(self._signal_history_per_sample.keys()): - try: - self._signal_history_per_sample[graph_name].pop(eval_hash, None) - except Exception: - pass + with self._lock: + self._flush_stage() + self._conn.execute("DELETE FROM signals WHERE experiment_hash = ?", [eval_hash]) + self._conn.execute("DELETE FROM per_sample WHERE experiment_hash = ?", [eval_hash]) # Drop queued points that reference this hash. self._pending_queue = [ @@ -335,109 +461,124 @@ def add_scalars(self, graph_name, signal, global_step, signal_per_sample, aggreg - Evaluation mode active: accumulate into internal buffer; per-sample history still gets written under the eval hash for Break-By-Slice support. """ - self.graph_names.add(graph_name) - self._last_step = global_step - - # ---------------------------------------------------------------- - # Evaluation-mode interception - # ---------------------------------------------------------------- - if self._eval_mode_active: - # Collect scalar values to accumulate - values: list = [] - if aggregate_by_step and signal_per_sample and isinstance(signal_per_sample, dict): - values = [self._to_float(v) for v in signal_per_sample.values()] - elif signal and isinstance(signal, dict): - values = [self._to_float(v) for _, v in signal.items()] - - if values: - if graph_name not in self._eval_accum: - self._eval_accum[graph_name] = [0.0, 0] - self._eval_accum[graph_name][0] += sum(values) - self._eval_accum[graph_name][1] += len(values) - - # Still store per-sample signals under eval_hash (for Break-By-Slice) - if signal_per_sample and isinstance(signal_per_sample, dict): - eval_hash = self._eval_mode_hash - if graph_name not in self._signal_history_per_sample: - self._signal_history_per_sample[graph_name] = {} - if eval_hash not in self._signal_history_per_sample[graph_name]: - self._signal_history_per_sample[graph_name][eval_hash] = _make_per_sample_buf() - buf = self._signal_history_per_sample[graph_name][eval_hash] + with self._lock: + self.graph_names.add(graph_name) + self._last_step = global_step + + # ------------------------------------------------------------ + # Evaluation-mode interception + # ------------------------------------------------------------ + if self._eval_mode_active: + values: list = [] + if aggregate_by_step and signal_per_sample and isinstance(signal_per_sample, dict): + values = [self._to_float(v) for v in signal_per_sample.values()] + elif signal and isinstance(signal, dict): + values = [self._to_float(v) for _, v in signal.items()] + + if values: + if graph_name not in self._eval_accum: + self._eval_accum[graph_name] = [0.0, 0] + self._eval_accum[graph_name][0] += sum(values) + self._eval_accum[graph_name][1] += len(values) + + # Still store per-sample signals under eval_hash (for Break-By-Slice) + if signal_per_sample and isinstance(signal_per_sample, dict): + eval_hash = self._eval_mode_hash + step_i = int(global_step) + for sid, value in signal_per_sample.items(): + self._stage_sample_row(graph_name, eval_hash, sid, step_i, self._to_float(value)) + + return # Do NOT add to normal history during evaluation mode + # ------------------------------------------------------------ + + exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager else None + + if self._buffered_step is not None and global_step != self._buffered_step: + self._flush_current_step_buffer(add_to_queue=True) + + if not aggregate_by_step and self._current_step_buffer: + self._flush_current_step_buffer(add_to_queue=True) + + # Update per-sample signal history + if isinstance(signal_per_sample, dict) and len(signal_per_sample): step_i = int(global_step) - idx_map = self._sample_index.setdefault(graph_name, {}).setdefault(eval_hash, {}) for sid, value in signal_per_sample.items(): - row = len(buf["sample_ids"]) - buf["sample_ids"].append(sid) - buf["steps"].append(step_i) - buf["values"].append(self._to_float(value)) - idx_map.setdefault(str(sid), []).append(row) - - return # Do NOT add to normal history during evaluation mode - # ---------------------------------------------------------------- - - exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager else None + self._stage_sample_row(graph_name, exp_hash, sid, step_i, self._to_float(value)) - if self._buffered_step is not None and global_step != self._buffered_step: - self._flush_current_step_buffer(add_to_queue=True) + metric_values = [] + if isinstance(signal_per_sample, dict) and aggregate_by_step and len(signal_per_sample): + for value in signal_per_sample.values(): + metric_values.append(self._to_float(value)) + else: + for _, line_value in signal.items(): + metric_values.append(self._to_float(line_value)) + + if aggregate_by_step: + if metric_values: + self._buffered_step = global_step + buffer_key = (global_step, graph_name, exp_hash) + if buffer_key not in self._current_step_buffer: + self._current_step_buffer[buffer_key] = {"sum": 0.0, "count": 0} + self._current_step_buffer[buffer_key]["sum"] += sum(metric_values) + self._current_step_buffer[buffer_key]["count"] += len(metric_values) + return + + # Update averaged signal history immediately. Only emit when we have at + # least one valid metric value (signals carrying only per-sample data are + # stored separately in per_sample). + signal_entry = None + if len(metric_values) > 0: + signal_entry = self._append_history_entry( + graph_name=graph_name, + exp_hash=exp_hash, + global_step=global_step, + metric_value=sum(metric_values) / len(metric_values) if len(metric_values) > 1 else metric_values[0], + ) + + if signal_entry is not None: + self._pending_queue.append(signal_entry) - if not aggregate_by_step and self._current_step_buffer: - self._flush_current_step_buffer(add_to_queue=True) + def ingest_per_sample(self, graph_name: str, exp_hash, triples) -> None: + """Insert per-sample ``(sample_id, step, value)`` triples, de-duplicating + on ``(sample_id, step)`` within ``(graph_name, exp_hash)``. - # Update per-sample signal history with compact array storage - if isinstance(signal_per_sample, dict) and len(signal_per_sample): - if graph_name not in self._signal_history_per_sample: - self._signal_history_per_sample[graph_name] = {} - if exp_hash not in self._signal_history_per_sample[graph_name]: - self._signal_history_per_sample[graph_name][exp_hash] = _make_per_sample_buf() + Unlike ``add_scalars`` (which always appends), this is idempotent on the + ``(sample_id, step)`` key: the first value wins and later duplicates are + ignored. Useful for back-filling / importing history without creating + repeated points. - buf = self._signal_history_per_sample[graph_name][exp_hash] - step_i = int(global_step) - idx_map = self._sample_index.setdefault(graph_name, {}).setdefault(exp_hash, {}) - for sid, value in signal_per_sample.items(): - row = len(buf["sample_ids"]) - buf["sample_ids"].append(sid) - buf["steps"].append(step_i) - buf["values"].append(self._to_float(value)) - idx_map.setdefault(str(sid), []).append(row) - - metric_values = [] - if isinstance(signal_per_sample, dict) and aggregate_by_step and len(signal_per_sample): - for value in signal_per_sample.values(): - metric_values.append(self._to_float(value)) - else: - for _, line_value in signal.items(): - metric_values.append(self._to_float(line_value)) - - if aggregate_by_step: - if metric_values: - self._buffered_step = global_step - buffer_key = (global_step, graph_name, exp_hash) - if buffer_key not in self._current_step_buffer: - self._current_step_buffer[buffer_key] = {"sum": 0.0, "count": 0} - self._current_step_buffer[buffer_key]["sum"] += sum(metric_values) - self._current_step_buffer[buffer_key]["count"] += len(metric_values) + Args: + graph_name: Signal name. + exp_hash: Experiment hash (``None`` allowed). + triples: Iterable of ``(sample_id, step, value)``. + """ + triples = list(triples) + if not triples: return - # Update averaged signal history immediately - signal_entry = None + with self._lock: + self.graph_names.add(graph_name) + self._flush_stage() - # Only add to history if we have at least one valid metric value (otherwise we may end up with empty/invalid entries from signals that only contain per-sample values, which are stored separately in _signal_history_per_sample) - if len(metric_values) > 0: - signal_entry = self._append_history_entry( - graph_name=graph_name, - exp_hash=exp_hash, - global_step=global_step, - metric_value=sum(metric_values) / len(metric_values) if len(metric_values) > 1 else metric_values[0], - ) + # Existing (sample_id, step) keys for this (graph, hash). + params = [graph_name] + sql = "SELECT sample_id, step FROM per_sample WHERE metric_name = ?" + sql += self._hash_filter(exp_hash, params) + seen = {(str(s), int(t)) for s, t in self._conn.execute(sql, params).fetchall()} - # Add signal to pending queue for live incremental update to WeightsStudio - if signal_entry is not None: - self._pending_queue.append(signal_entry) + for sid, step, value in triples: + key = (str(sid), int(step)) + if key in seen: + continue + seen.add(key) + self._stage_sample_row(graph_name, exp_hash, sid, step, self._to_float(value)) - # Print methods for debugging/inspection of logger state + # ------------------------------------------------------------------ + # Print helpers (debug) + # ------------------------------------------------------------------ def print_history(self): - """Print all items in history.""" - for metric_name, experiments in self._signal_history.items(): + history = self.get_signal_history() + for metric_name, experiments in history.items(): print(f"Metric: {metric_name}") for exp_hash, steps in experiments.items(): print(f" Experiment Hash: {exp_hash}") @@ -445,133 +586,145 @@ def print_history(self): print(f" Step: {step}") for signal in signals: print(f" Signal: {signal}") - return self._signal_history + return history def print_history_per_sample(self): - """Print all items in per-sample history.""" - for metric_name, exps in self._signal_history_per_sample.items(): + history = self.get_signal_history_per_sample() + for metric_name, exps in history.items(): print(f"Metric: {metric_name}") - for exp_hash, buf in exps.items(): + for exp_hash, entries in exps.items(): print(f" Experiment Hash: {exp_hash}") - for sid, step, val in zip(buf["sample_ids"], buf["steps"], buf["values"]): - print(f" Sample ID: {sid}, Step: {step}, Value: {val}") - return self._signal_history_per_sample + for e in entries: + print(f" Sample ID: {e['sample_id']}, Step: {e['model_age']}, Value: {e['metric_value']}") + return history def print_buffer(self): - """Print current step buffer contents.""" print(f"Current step: {self._last_step}") print(f"Buffered metrics: {self._current_step_buffer}") return self._current_step_buffer - # Accessor methods for retrieving logger state (e.g. for checkpoint saving or programmatic access) + # ------------------------------------------------------------------ + # Accessors + # ------------------------------------------------------------------ def get_graph_names(self): - """ - Get list of all graph names encountered in signals. - Returns: - List of graph names. - """ + """Get list of all graph names encountered in signals.""" return list(self.graph_names) + def list_sample_signal_names(self) -> list: + """Distinct signal names that have per-sample history.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute("SELECT DISTINCT metric_name FROM per_sample").fetchall() + return [r[0] for r in rows] + + def list_instance_signal_names(self) -> list: + """Distinct signal names that have per-instance history.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute("SELECT DISTINCT metric_name FROM per_instance").fetchall() + return [r[0] for r in rows] + def get_signal_history(self): - """Retrieve all accumulated signals from memory.""" - # self._flush_current_step_buffer(add_to_queue=False) # History should already be up to date since we flush on step change and on add_scalars when not aggregating by step, but we can flush here as well to be safe before retrieving history for checkpoint saving - return deepcopy(self._signal_history) + """Reconstruct aggregated history as ``{metric: {hash: {step: [entry, ...]}}}``.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + """ + SELECT metric_name, experiment_hash, step, metric_value, timestamp, + audit_mode, is_evaluation_marker, split_name, evaluation_tags, point_note + FROM signals ORDER BY seq + """ + ).fetchall() + + result: dict = {} + for (metric, h, step, val, ts, audit, marker, split, tags, note) in rows: + entry = { + "model_age": step, + "metric_name": metric, + "metric_value": val, + "experiment_hash": h, + "timestamp": int(ts) if ts is not None else 0, + "audit_mode": bool(audit), + "is_evaluation_marker": bool(marker), + "split_name": split or "", + "evaluation_tags": json.loads(tags) if tags else [], + } + if note: + entry["point_note"] = note + result.setdefault(metric, {}).setdefault(h, {}).setdefault(step, []).append(entry) + return result def get_current_signaL_history(self, graph_name: str, meta: bool = False): - """Get current history for a specific signal.""" - if graph_name not in self._signal_history: + """Get current-hash aggregated history for a specific signal.""" + if graph_name not in self.graph_names: return {} - # Get Current Hash exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager else None - # Process history + with self._lock: + self._flush_stage() + params = [graph_name] + sql = "SELECT step, metric_value FROM signals WHERE metric_name = ?" + sql += self._hash_filter(exp_hash, params) + sql += " ORDER BY seq" + rows = self._conn.execute(sql, params).fetchall() + if meta: - return self._signal_history.get(graph_name, {}).get(exp_hash, {}) - else: - history = self._signal_history.get(graph_name, {}).get(exp_hash, {}) - result = [] - for _, entries in history.items(): - for entry in entries: - result.append({ - "model_age": entry.get("model_age"), - "metric_value": entry.get("metric_value"), - }) - return result + steps: dict = {} + for step, val in rows: + steps.setdefault(step, []).append({ + "model_age": step, "metric_value": val, + }) + return steps + + return [{"model_age": step, "metric_value": val} for step, val in rows] def get_signal_history_per_sample(self): - """Reconstruct per-sample history as list-of-dicts from compact array storage.""" - result = {} - for graph_name, exps in self._signal_history_per_sample.items(): - result[graph_name] = {} - for exp_hash, buf in exps.items(): - entries = [] - for sid, step, val in zip(buf["sample_ids"], buf["steps"], buf["values"]): - entries.append({ - "sample_id": sid, - "model_age": step, - "metric_name": graph_name, - "metric_value": float(val), - "experiment_hash": exp_hash, - }) - result[graph_name][exp_hash] = entries + """Per-sample history as ``{metric: {hash: [entry, ...]}}``.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT metric_name, experiment_hash, sample_id, step, value " + "FROM per_sample ORDER BY seq" + ).fetchall() + + result: dict = {} + for (metric, h, sid, step, val) in rows: + result.setdefault(metric, {}).setdefault(h, []).append({ + "sample_id": sid, + "model_age": step, + "metric_name": metric, + "metric_value": float(val), + "experiment_hash": h, + }) return result def get_current_signaL_history_per_sample(self, graph_name: str, sample_ids: list = None, exp_hash: str = None): - """Get current history for a specific signal.""" - if graph_name not in self._signal_history: + """Get current-hash per-sample history for a specific signal.""" + if graph_name not in self.graph_names: return {} - # Get Current Hash exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager and exp_hash is None else exp_hash - - # Return history for the specified graph name, filtered by sample IDs and experiment hash if provided. If meta=True, returns raw history dict; otherwise returns list of (sample_id, step, value) tuples. - result = self.query_per_sample( - graph_name, - sample_ids=sample_ids, - exp_hash=exp_hash - ) - return result + return self.query_per_sample(graph_name, sample_ids=sample_ids, exp_hash=exp_hash) def query_per_sample(self, graph_name: str, sample_ids=None, exp_hash=None): - """Efficiently query per-sample history for specific sample IDs. + """Query per-sample history. - Returns a dict mapping sample_id → list of {model_age, signal_value} dicts, - filtered by sample_ids and optionally by experiment hash. - Much faster than get_signal_history_per_sample() for targeted queries - (e.g., "show me only samples with label 8"). - - Args: - graph_name: Signal name (e.g., "loss", "accuracy"). - sample_ids: Collection of sample IDs to filter by. If None, returns all. - exp_hash: Specific experiment hash to query. If None, queries all hashes. - - Returns: - List of (sample_id, step, value, experiment_hash) tuples. + Returns a list of ``(sample_id, step, value, experiment_hash)`` tuples, + filtered by *sample_ids* and optionally *exp_hash* (``None`` = all hashes). """ - if graph_name not in self._signal_history_per_sample: - return [] - - exps = self._signal_history_per_sample[graph_name] - hashes = [exp_hash] if exp_hash is not None else list(exps.keys()) - # Stored ids are ints; callers pass str (df index is str-normalized) — compare as str. - sid_set = {str(s) for s in sample_ids} if sample_ids is not None else None - - results = [] - for h in hashes: - buf = exps.get(h) - if buf is None: - continue - if sid_set is None: - for sid, step, val in zip(buf["sample_ids"], buf["steps"], buf["values"]): - results.append((sid, step, float(val), h)) - else: - idx_map = self._sample_index.get(graph_name, {}).get(h, {}) - for sid in sid_set: - for row in idx_map.get(sid, []): - results.append((sid, buf["steps"][row], float(buf["values"][row]), h)) - - return results + with self._lock: + self._flush_stage() + params = [graph_name] + sql = "SELECT sample_id, step, value, experiment_hash FROM per_sample WHERE metric_name = ?" + sql += self._hash_filter(exp_hash, params) + if sample_ids is not None: + sql += " AND sample_id IN (SELECT UNNEST(?))" + params.append([str(s) for s in sample_ids]) + sql += " ORDER BY seq" + rows = self._conn.execute(sql, params).fetchall() + + return [(sid, int(step), float(val), h) for (sid, step, val, h) in rows] def query_per_instance( self, @@ -585,52 +738,24 @@ def query_per_instance( Returns a list of ``(sample_id, annotation_id, step, value, exp_hash)`` tuples. Any of *sample_id*, *annotation_id*, *exp_hash* may be ``None`` to return all values along that dimension. - - Args: - graph_name: Signal name (e.g. ``"confidence"``). - sample_id: Filter to a single sample. ``None`` returns all samples. - annotation_id: Filter to a single instance (1-based). ``None`` = all. - exp_hash: Filter to one experiment hash. ``None`` = all. """ - if graph_name not in self._signal_history_per_instance: - return [] - - exps = self._signal_history_per_instance[graph_name] - hashes = [exp_hash] if exp_hash is not None else list(exps.keys()) - sid_filter = str(sample_id) if sample_id is not None else None - aid_filter = int(annotation_id) if annotation_id is not None else None - - results = [] - for h in hashes: - buf = exps.get(h) - if buf is None: - continue - if sid_filter is None and aid_filter is None: - # No filter: full scan - for sid, aid, step, val in zip( - buf["sample_ids"], buf["annotation_ids"], buf["steps"], buf["values"] - ): - results.append((str(sid), int(aid), int(step), float(val), h)) - elif sid_filter is not None and aid_filter is not None: - # Both filters: O(1) index lookup - idx_map = self._instance_index.get(graph_name, {}).get(h, {}) - for row in idx_map.get((sid_filter, aid_filter), []): - results.append((sid_filter, aid_filter, int(buf["steps"][row]), float(buf["values"][row]), h)) - elif sid_filter is not None: - # Sample filter only: collect all annotation_ids for this sample - idx_map = self._instance_index.get(graph_name, {}).get(h, {}) - for (sid_k, aid_k), rows in idx_map.items(): - if sid_k == sid_filter: - for row in rows: - results.append((sid_filter, aid_k, int(buf["steps"][row]), float(buf["values"][row]), h)) - else: - # annotation_id filter only: scan index keys - idx_map = self._instance_index.get(graph_name, {}).get(h, {}) - for (sid_k, aid_k), rows in idx_map.items(): - if aid_k == aid_filter: - for row in rows: - results.append((sid_k, aid_filter, int(buf["steps"][row]), float(buf["values"][row]), h)) - return results + with self._lock: + self._flush_stage() + params = [graph_name] + sql = ("SELECT sample_id, annotation_id, step, value, experiment_hash " + "FROM per_instance WHERE metric_name = ?") + sql += self._hash_filter(exp_hash, params) + if sample_id is not None: + sql += " AND sample_id = ?" + params.append(str(sample_id)) + if annotation_id is not None: + sql += " AND annotation_id = ?" + params.append(int(annotation_id)) + sql += " ORDER BY seq" + rows = self._conn.execute(sql, params).fetchall() + + return [(str(sid), int(aid), int(step), float(val), h) + for (sid, aid, step, val, h) in rows] def aggregate_per_sample_by_step( self, @@ -640,58 +765,29 @@ def aggregate_per_sample_by_step( ) -> dict: """Return mean signal value per step, aggregated over matching samples. - Uses numpy vectorized operations instead of a Python loop — ~100× faster - than iterating ``query_per_sample`` results for large sample counts. - - Args: - graph_name: Signal name. - sample_ids: Samples to include. ``None`` = all samples. - exp_hash: Filter to one experiment hash. ``None`` = all hashes. + DuckDB performs the ``GROUP BY step`` average natively, which scales to + millions of rows far better than a Python loop — this is the path used + by break-by-slices. Returns: - ``{exp_hash: [(step, mean_value), ...]}`` — one sorted series per hash. + ``{exp_hash: [(step, mean_value), ...]}`` — one step-sorted series + per hash. """ - import numpy as _np - - if graph_name not in self._signal_history_per_sample: - return {} - - exps = self._signal_history_per_sample[graph_name] - hashes = [exp_hash] if exp_hash is not None else list(exps.keys()) - sid_set = {str(s) for s in sample_ids} if sample_ids is not None else None - - result = {} - for h in hashes: - buf = exps.get(h) - if buf is None: - continue - - # Convert typed C arrays to numpy with zero-copy (frombuffer gives a read-only view) - steps_np = _np.frombuffer(buf["steps"], dtype=_np.int32).copy() - values_np = _np.frombuffer(buf["values"], dtype=_np.float32).copy() - - if sid_set is not None: - idx_map = self._sample_index.get(graph_name, {}).get(h, {}) - rows = [] - for sid in sid_set: - rows.extend(idx_map.get(sid, [])) - if not rows: - continue - row_idx = _np.array(rows, dtype=_np.intp) - steps_np = steps_np[row_idx] - values_np = values_np[row_idx] - - if len(steps_np) == 0: - continue - - # Vectorized group-by step → mean - unique_steps, inverse = _np.unique(steps_np, return_inverse=True) - sums = _np.bincount(inverse, weights=values_np.astype(_np.float64)) - counts = _np.bincount(inverse) - means = sums / counts - - result[h] = list(zip(unique_steps.tolist(), means.tolist())) - + with self._lock: + self._flush_stage() + params = [graph_name] + sql = ("SELECT experiment_hash, step, avg(value) AS mean_value " + "FROM per_sample WHERE metric_name = ?") + sql += self._hash_filter(exp_hash, params) + if sample_ids is not None: + sql += " AND sample_id IN (SELECT UNNEST(?))" + params.append([str(s) for s in sample_ids]) + sql += " GROUP BY experiment_hash, step ORDER BY experiment_hash, step" + rows = self._conn.execute(sql, params).fetchall() + + result: dict = {} + for (h, step, mean_val) in rows: + result.setdefault(h, []).append((int(step), float(mean_val))) return result def add_instance_scalars( @@ -703,11 +799,10 @@ def add_instance_scalars( global_step: int, exp_hash: str | None = None, ) -> None: - """Record per-instance scalar values in compact storage. + """Record per-instance scalar values. - Call this from ``save_instance_signals`` once per scalar signal per - batch. Each element of *sample_ids*, *annotation_ids*, *values* - corresponds to one detection / segmentation instance. + Each element of *sample_ids*, *annotation_ids*, *values* corresponds to + one detection / segmentation instance. Args: graph_name: Signal name (e.g. ``"confidence"``). @@ -724,76 +819,63 @@ def add_instance_scalars( else None ) - if graph_name not in self._signal_history_per_instance: - self._signal_history_per_instance[graph_name] = {} - if exp_hash not in self._signal_history_per_instance[graph_name]: - self._signal_history_per_instance[graph_name][exp_hash] = _make_per_instance_buf() - - buf = self._signal_history_per_instance[graph_name][exp_hash] - step_i = int(global_step) - idx_map = self._instance_index.setdefault(graph_name, {}).setdefault(exp_hash, {}) try: import numpy as _np vals = _np.asarray(values, dtype=_np.float32).ravel() except Exception: vals = [float(v) for v in values] - for sid, aid, val in zip(sample_ids, annotation_ids, vals): - row = len(buf["sample_ids"]) - sid_s, aid_i = str(sid), int(aid) - buf["sample_ids"].append(sid_s) - buf["annotation_ids"].append(aid_i) - buf["steps"].append(step_i) - buf["values"].append(float(val)) - idx_map.setdefault((sid_s, aid_i), []).append(row) + with self._lock: + step_i = int(global_step) + for sid, aid, val in zip(sample_ids, annotation_ids, vals): + self._stage_instance_row(graph_name, exp_hash, sid, aid, step_i, float(val)) def get_signal_history_per_instance(self) -> dict: - """Reconstruct per-instance history as list-of-dicts from compact array storage.""" - result = {} - for graph_name, exps in self._signal_history_per_instance.items(): - result[graph_name] = {} - for exp_hash, buf in exps.items(): - entries = [] - for sid, aid, step, val in zip( - buf["sample_ids"], buf["annotation_ids"], buf["steps"], buf["values"] - ): - entries.append({ - "sample_id": str(sid), - "annotation_id": int(aid), - "model_age": int(step), - "metric_name": graph_name, - "metric_value": float(val), - "experiment_hash": exp_hash, - }) - result[graph_name][exp_hash] = entries + """Per-instance history as ``{metric: {hash: [entry, ...]}}``.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT metric_name, experiment_hash, sample_id, annotation_id, step, value " + "FROM per_instance ORDER BY seq" + ).fetchall() + + result: dict = {} + for (metric, h, sid, aid, step, val) in rows: + result.setdefault(metric, {}).setdefault(h, []).append({ + "sample_id": str(sid), + "annotation_id": int(aid), + "model_age": int(step), + "metric_name": metric, + "metric_value": float(val), + "experiment_hash": h, + }) return result def save_snapshot(self) -> dict: - """Build a serializable snapshot of the logger state.""" + """Build a serializable snapshot of the logger state (compact format).""" self._flush_current_step_buffer(add_to_queue=False) - # Compact serialization: store parallel lists instead of list-of-dicts - per_sample_compact = {} - for graph_name, exps in self._signal_history_per_sample.items(): + per_sample_compact: dict = {} + for graph_name, exps in self.get_signal_history_per_sample().items(): per_sample_compact[graph_name] = {} - for exp_hash, buf in exps.items(): + for exp_hash, entries in exps.items(): per_sample_compact[graph_name][exp_hash] = { "_compact": True, - "sample_ids": list(buf["sample_ids"]), - "steps": list(buf["steps"]), - "values": list(buf["values"]), + "sample_ids": [e["sample_id"] for e in entries], + "steps": [e["model_age"] for e in entries], + "values": [e["metric_value"] for e in entries], } - per_instance_compact = {} - for graph_name, exps in self._signal_history_per_instance.items(): + per_instance_compact: dict = {} + for graph_name, exps in self.get_signal_history_per_instance().items(): per_instance_compact[graph_name] = {} - for exp_hash, buf in exps.items(): + for exp_hash, entries in exps.items(): per_instance_compact[graph_name][exp_hash] = { - "_compact": True, - "sample_ids": list(buf["sample_ids"]), - "annotation_ids": list(buf["annotation_ids"]), - "steps": list(buf["steps"]), - "values": list(buf["values"]), + "_compact": True, + "sample_ids": [e["sample_id"] for e in entries], + "annotation_ids": [e["annotation_id"] for e in entries], + "steps": [e["model_age"] for e in entries], + "values": [e["metric_value"] for e in entries], } return { @@ -803,32 +885,34 @@ def save_snapshot(self) -> dict: "signal_history_per_instance": per_instance_compact, } - # ------------------------------------------------------------------ - # Convenience: list all evaluation-marker hashes in history - # ------------------------------------------------------------------ def get_evaluation_marker_hashes(self) -> list: - """Return all experiment hashes that correspond to evaluation markers.""" + """Return all experiment hashes of the form ``_`` in history.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT DISTINCT experiment_hash FROM signals WHERE experiment_hash IS NOT NULL" + ).fetchall() + hashes = set() - for gname in self._signal_history: - for hash_key in self._signal_history[gname]: - if isinstance(hash_key, str) and "_" in hash_key: - # Check that the suffix is a pure integer - suffix = hash_key.rsplit("_", 1)[-1] - try: - int(suffix) - hashes.add(hash_key) - except ValueError: - pass + for (hash_key,) in rows: + if isinstance(hash_key, str) and "_" in hash_key: + suffix = hash_key.rsplit("_", 1)[-1] + try: + int(suffix) + hashes.add(hash_key) + except ValueError: + pass return sorted(hashes) def get_and_clear_queue(self): """Get pending queue and clear it (for incremental updates to WeightsStudio).""" - queue_copy = list(self._pending_queue) - self._pending_queue.clear() + with self._lock: + queue_copy = list(self._pending_queue) + self._pending_queue.clear() return queue_copy def set_point_note(self, metric_name: str, experiment_hash: str, model_age: int, note: str) -> bool: - """Attach or clear a note for a specific signal point identified by metric/hash/step.""" + """Attach or clear a note for a signal point identified by metric/hash/step.""" metric_name = str(metric_name or "") experiment_hash = str(experiment_hash or "") if not metric_name or not experiment_hash: @@ -836,55 +920,64 @@ def set_point_note(self, metric_name: str, experiment_hash: str, model_age: int, normalized_step = int(model_age) cleaned_note = str(note or "").strip() - updated = False - - entries = ( - self._signal_history.get(metric_name, {}) - .get(experiment_hash, {}) - .get(normalized_step, []) - ) - for entry in entries: - if not isinstance(entry, dict): - continue - if cleaned_note: - entry["point_note"] = cleaned_note - else: - entry.pop("point_note", None) - updated = True - for entry in self._pending_queue: - if not isinstance(entry, dict): - continue - if str(entry.get("metric_name", "")) != metric_name: - continue - if str(entry.get("experiment_hash", "")) != experiment_hash: - continue - try: - if int(entry.get("model_age", -1)) != normalized_step: + with self._lock: + self._flush_stage() + matched = self._conn.execute( + "SELECT count(*) FROM signals " + "WHERE metric_name = ? AND experiment_hash = ? AND step = ?", + [metric_name, experiment_hash, normalized_step], + ).fetchone()[0] + if matched: + self._conn.execute( + "UPDATE signals SET point_note = ? " + "WHERE metric_name = ? AND experiment_hash = ? AND step = ?", + [cleaned_note, metric_name, experiment_hash, normalized_step], + ) + + for entry in self._pending_queue: + if not isinstance(entry, dict): continue - except Exception: - continue - if cleaned_note: - entry["point_note"] = cleaned_note - else: - entry.pop("point_note", None) + if str(entry.get("metric_name", "")) != metric_name: + continue + if str(entry.get("experiment_hash", "")) != experiment_hash: + continue + try: + if int(entry.get("model_age", -1)) != normalized_step: + continue + except Exception: + continue + if cleaned_note: + entry["point_note"] = cleaned_note + else: + entry.pop("point_note", None) - return updated + return bool(matched) - # Logger saving/loading methods for checkpoint persistence (used in WeightsLabCallback) + # ------------------------------------------------------------------ + # Snapshot loading (checkpoint persistence) + # ------------------------------------------------------------------ def load_signal_history(self, signals): - """Load signal history into memory (supports legacy and nested formats).""" + """Load aggregated signal history (supports legacy list and nested dict).""" if not signals: return - def _append_signal_entry(metric_name, exp_hash, step, signal_entry): - if metric_name not in self._signal_history: - self._signal_history[metric_name] = {} - if exp_hash not in self._signal_history[metric_name]: - self._signal_history[metric_name][exp_hash] = {} - if step not in self._signal_history[metric_name][exp_hash]: - self._signal_history[metric_name][exp_hash][step] = [] - self._signal_history[metric_name][exp_hash][step].append(signal_entry) + def _stage_entry(metric_name, exp_hash, step, entry): + try: + step_i = int(step) + except (TypeError, ValueError): + return + with self._lock: + self._stage_signal_row( + metric_name, exp_hash, step_i, + float(entry.get("metric_value", 0.0)), + int(entry.get("timestamp", int(time.time()))), + bool(entry.get("audit_mode", False)), + bool(entry.get("is_evaluation_marker", False)), + entry.get("split_name", ""), + entry.get("evaluation_tags", []), + entry.get("point_note", "") or "", + ) if isinstance(signals, dict): for metric_name, experiments in signals.items(): @@ -895,23 +988,10 @@ def _append_signal_entry(metric_name, exp_hash, step, signal_entry): if not isinstance(steps, dict): continue for step_key, entries in steps.items(): - step = step_key - if isinstance(step_key, str): - try: - step = int(step_key) - except Exception: - step = step_key - entries_list = entries if isinstance(entries, list) else [entries] for entry in entries_list: - if not isinstance(entry, dict): - continue - signal_entry = dict(entry) - signal_entry.setdefault("metric_name", metric_name) - signal_entry.setdefault("model_age", step) - signal_entry.setdefault("experiment_hash", exp_hash) - signal_entry.setdefault("timestamp", int(time.time())) - _append_signal_entry(metric_name, exp_hash, step, signal_entry) + if isinstance(entry, dict): + _stage_entry(metric_name, exp_hash, step_key, entry) return if isinstance(signals, list): @@ -921,138 +1001,100 @@ def _append_signal_entry(metric_name, exp_hash, step, signal_entry): metric_name = signal.get("metric_name") if not metric_name: continue - exp_hash = signal.get("experiment_hash") - step = signal.get("model_age") - signal_entry = dict(signal) - signal_entry.setdefault("metric_name", metric_name) - signal_entry.setdefault("model_age", step) - signal_entry.setdefault("experiment_hash", exp_hash) - signal_entry.setdefault("timestamp", int(time.time())) self.graph_names.add(metric_name) - _append_signal_entry(metric_name, exp_hash, step, signal_entry) + _stage_entry( + metric_name, + signal.get("experiment_hash"), + signal.get("model_age", 0), + signal, + ) def load_signal_history_per_sample(self, signals_per_sample): - """Load per-sample history into compact array storage. + """Load per-sample history. Handles three formats: - - New compact: {graph_name: {exp_hash: {"_compact": True, "sample_ids": [...], "steps": [...], "values": [...]}}} - - Legacy list: {graph_name: {exp_hash: [{sample_id, model_age, metric_value, ...}, ...]}} - - Legacy dict: {graph_name: {sample_id_as_key: {model_age, metric_value, ...}}} → stored under None key + - Compact: {graph: {hash: {"_compact": True, "sample_ids": [...], "steps": [...], "values": [...]}}} + - Legacy list: {graph: {hash: [{sample_id, model_age, metric_value, ...}, ...]}} + - Legacy dict: {graph: {sample_id_as_key: {model_age, metric_value, ...}}} → stored under None hash """ if not signals_per_sample: return for metric_name, samples_by_exp in signals_per_sample.items(): self.graph_names.add(metric_name) - if metric_name not in self._signal_history_per_sample: - self._signal_history_per_sample[metric_name] = {} - if not isinstance(samples_by_exp, dict): continue for exp_hash, entries in samples_by_exp.items(): - # --- New compact format --- + # --- Compact format --- if isinstance(entries, dict) and entries.get("_compact"): - if exp_hash not in self._signal_history_per_sample[metric_name]: - self._signal_history_per_sample[metric_name][exp_hash] = _make_per_sample_buf() - buf = self._signal_history_per_sample[metric_name][exp_hash] - ids = entries.get("sample_ids", []) + ids = entries.get("sample_ids", []) steps = entries.get("steps", []) - vals = entries.get("values", []) - idx_map = self._sample_index.setdefault(metric_name, {}).setdefault(exp_hash, {}) - for s, t, v in zip(ids, steps, vals): - try: - row = len(buf["sample_ids"]) - sid_s = str(s) - buf["sample_ids"].append(sid_s) - buf["steps"].append(int(t)) - buf["values"].append(float(v)) - idx_map.setdefault(sid_s, []).append(row) - except (TypeError, ValueError): - pass + vals = entries.get("values", []) + with self._lock: + for s, t, v in zip(ids, steps, vals): + try: + self._stage_sample_row(metric_name, exp_hash, s, int(t), float(v)) + except (TypeError, ValueError): + pass - # --- Legacy list-of-dicts format --- + # --- Legacy list-of-dicts --- elif isinstance(entries, list): - if exp_hash not in self._signal_history_per_sample[metric_name]: - self._signal_history_per_sample[metric_name][exp_hash] = _make_per_sample_buf() - buf = self._signal_history_per_sample[metric_name][exp_hash] - idx_map = self._sample_index.setdefault(metric_name, {}).setdefault(exp_hash, {}) - for entry in entries: - if not isinstance(entry, dict): - continue + with self._lock: + for entry in entries: + if not isinstance(entry, dict): + continue + try: + self._stage_sample_row( + metric_name, exp_hash, + entry.get("sample_id", -1), + int(entry.get("model_age", 0)), + float(entry.get("metric_value", 0.0)), + ) + except (TypeError, ValueError): + pass + + # --- Legacy single-dict (exp_hash key was actually the sample_id) --- + elif isinstance(entries, dict): + sid = str(exp_hash) if isinstance(exp_hash, (int, float)) else str(-1) + with self._lock: try: - row = len(buf["sample_ids"]) - sid_s = str(entry.get("sample_id", -1)) - buf["sample_ids"].append(sid_s) - buf["steps"].append(int(entry.get("model_age", 0))) - buf["values"].append(float(entry.get("metric_value", 0.0))) - idx_map.setdefault(sid_s, []).append(row) + self._stage_sample_row( + metric_name, None, sid, + int(entries.get("model_age", 0)), + float(entries.get("metric_value", 0.0)), + ) except (TypeError, ValueError): pass - # --- Legacy single-dict format (exp_hash key was actually the sample_id) --- - elif isinstance(entries, dict): - null_key = None - if null_key not in self._signal_history_per_sample[metric_name]: - self._signal_history_per_sample[metric_name][null_key] = _make_per_sample_buf() - buf = self._signal_history_per_sample[metric_name][null_key] - idx_map = self._sample_index.setdefault(metric_name, {}).setdefault(null_key, {}) - try: - row = len(buf["sample_ids"]) - sid = str(exp_hash) if isinstance(exp_hash, (int, float)) else str(-1) - buf["sample_ids"].append(sid) - buf["steps"].append(int(entries.get("model_age", 0))) - buf["values"].append(float(entries.get("metric_value", 0.0))) - idx_map.setdefault(sid, []).append(row) - except (TypeError, ValueError): - pass - def load_signal_history_per_instance(self, signals_per_instance: dict) -> None: """Load per-instance history from a compact snapshot dict.""" if not signals_per_instance: return for metric_name, exps in signals_per_instance.items(): self.graph_names.add(metric_name) - if metric_name not in self._signal_history_per_instance: - self._signal_history_per_instance[metric_name] = {} if not isinstance(exps, dict): continue for exp_hash, entries in exps.items(): if not (isinstance(entries, dict) and entries.get("_compact")): continue - if exp_hash not in self._signal_history_per_instance[metric_name]: - self._signal_history_per_instance[metric_name][exp_hash] = _make_per_instance_buf() - buf = self._signal_history_per_instance[metric_name][exp_hash] - ids = entries.get("sample_ids", []) + ids = entries.get("sample_ids", []) aids = entries.get("annotation_ids", []) steps = entries.get("steps", []) - vals = entries.get("values", []) - idx_map = self._instance_index.setdefault(metric_name, {}).setdefault(exp_hash, {}) - for s, a, t, v in zip(ids, aids, steps, vals): - try: - row = len(buf["sample_ids"]) - sid_s, aid_i = str(s), int(a) - buf["sample_ids"].append(sid_s) - buf["annotation_ids"].append(aid_i) - buf["steps"].append(int(t)) - buf["values"].append(float(v)) - idx_map.setdefault((sid_s, aid_i), []).append(row) - except (TypeError, ValueError): - pass + vals = entries.get("values", []) + with self._lock: + for s, a, t, v in zip(ids, aids, steps, vals): + try: + self._stage_instance_row(metric_name, exp_hash, s, int(a), int(t), float(v)) + except (TypeError, ValueError): + pass def load_snapshot(self, snapshot: dict): """Restore logger state from a snapshot dict.""" if not snapshot: return - graph_names = snapshot.get("graph_names", []) - self.graph_names.update(graph_names) - - signals = snapshot.get("signal_history", []) - self.load_signal_history(signals) - - signals_per_sample = snapshot.get("signal_history_per_sample", {}) - self.load_signal_history_per_sample(signals_per_sample) - - signals_per_instance = snapshot.get("signal_history_per_instance", {}) - self.load_signal_history_per_instance(signals_per_instance) + self.graph_names.update(snapshot.get("graph_names", [])) + self.load_signal_history(snapshot.get("signal_history", [])) + self.load_signal_history_per_sample(snapshot.get("signal_history_per_sample", {})) + self.load_signal_history_per_instance(snapshot.get("signal_history_per_instance", {})) diff --git a/weightslab/src.py b/weightslab/src.py index d3600fc6..dde692d1 100644 --- a/weightslab/src.py +++ b/weightslab/src.py @@ -3123,7 +3123,7 @@ def query_sample_history( names = ( [signal_name] if signal_name - else list(_lg._signal_history_per_sample.keys()) + else _lg.list_sample_signal_names() ) results = [] for name in names: @@ -3156,7 +3156,7 @@ def query_instance_history( names = ( [signal_name] if signal_name - else list(_lg._signal_history_per_instance.keys()) + else _lg.list_instance_signal_names() ) results = [] for name in names: @@ -3352,7 +3352,7 @@ def write_history( instance_rows: list = [] if write_global: - for gn, hashes in _lg._signal_history.items(): + for gn, hashes in _lg.get_signal_history().items(): if _gn_filter is not None and gn not in _gn_filter: continue for h, steps in hashes.items(): @@ -3378,7 +3378,7 @@ def write_history( graphs_s = ( list(_gn_filter) if _gn_filter is not None - else list(_lg._signal_history_per_sample.keys()) + else _lg.list_sample_signal_names() ) for gn in graphs_s: for sid, step, val, h in _lg.query_per_sample( @@ -3400,7 +3400,7 @@ def write_history( graphs_i = ( list(_gn_filter) if _gn_filter is not None - else list(_lg._signal_history_per_instance.keys()) + else _lg.list_instance_signal_names() ) # query_per_instance filters by a single (sample_id, annotation_id); iterate when multiple given _sid_iter = _sid_filter if _sid_filter is not None else [None] diff --git a/weightslab/tests/backend/test_instance_signal_logger.py b/weightslab/tests/backend/test_instance_signal_logger.py index 60e6613c..9d4b8dcc 100644 --- a/weightslab/tests/backend/test_instance_signal_logger.py +++ b/weightslab/tests/backend/test_instance_signal_logger.py @@ -341,8 +341,9 @@ def test_sample_index_built_on_add(self): lg = _fresh_logger() lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"img0": 0.5, "img1": 0.3}, aggregate_by_step=False) - idx = lg._sample_index.get("loss", {}) - self.assertTrue(any("img0" in h_idx for h_idx in idx.values())) + rows = lg.query_per_sample("loss", sample_ids=["img0"]) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], "img0") def test_sample_index_points_to_correct_rows(self): lg = _fresh_logger() @@ -350,28 +351,24 @@ def test_sample_index_points_to_correct_rows(self): signal_per_sample={"img0": 0.5}, aggregate_by_step=False) lg.add_scalars("loss", {"loss": 0.3}, 2, signal_per_sample={"img0": 0.3}, aggregate_by_step=False) - # img0 appears twice — both rows should be indexed - h = list(lg._sample_index["loss"].keys())[0] - rows = lg._sample_index["loss"][h]["img0"] + # img0 appears twice — both rows should be returned, at steps 1 and 2 + rows = lg.query_per_sample("loss", sample_ids=["img0"]) self.assertEqual(len(rows), 2) - buf = list(lg._signal_history_per_sample["loss"].values())[0] - self.assertEqual(buf["steps"][rows[0]], 1) - self.assertEqual(buf["steps"][rows[1]], 2) + self.assertEqual({r[1] for r in rows}, {1, 2}) def test_instance_index_built_on_add(self): lg = _fresh_logger() lg.add_instance_scalars("iou", ["s0", "s0", "s1"], [1, 2, 1], [0.9, 0.8, 0.7], 5, "h1") - idx = lg._instance_index["iou"]["h1"] - self.assertIn(("s0", 1), idx) - self.assertIn(("s0", 2), idx) - self.assertIn(("s1", 1), idx) + keys = {(r[0], r[1]) for r in lg.query_per_instance("iou")} + self.assertIn(("s0", 1), keys) + self.assertIn(("s0", 2), keys) + self.assertIn(("s1", 1), keys) def test_instance_index_points_to_correct_values(self): lg = _fresh_logger() lg.add_instance_scalars("iou", ["s0", "s0"], [1, 1], [0.9, 0.8], 5, "h1") - # Same (s0, 1) at two different steps → two rows - idx = lg._instance_index["iou"]["h1"] - rows = idx[("s0", 1)] + # Same (s0, 1) recorded twice → two rows returned + rows = lg.query_per_instance("iou", sample_id="s0", annotation_id=1) self.assertEqual(len(rows), 2) def test_sample_index_rebuilt_after_snapshot_load(self): @@ -381,8 +378,8 @@ def test_sample_index_rebuilt_after_snapshot_load(self): snap = lg.save_snapshot() lg2 = _fresh_logger() lg2.load_snapshot(snap) - self.assertIn("img0", list(lg2._sample_index.get("loss", {}).values())[0]) - self.assertIn("img1", list(lg2._sample_index.get("loss", {}).values())[0]) + self.assertEqual(len(lg2.query_per_sample("loss", sample_ids=["img0"])), 1) + self.assertEqual(len(lg2.query_per_sample("loss", sample_ids=["img1"])), 1) def test_instance_index_rebuilt_after_snapshot_load(self): lg = _fresh_logger() @@ -390,9 +387,9 @@ def test_instance_index_rebuilt_after_snapshot_load(self): snap = lg.save_snapshot() lg2 = _fresh_logger() lg2.load_snapshot(snap) - idx = lg2._instance_index.get("iou", {}).get("h1", {}) - self.assertIn(("s0", 1), idx) - self.assertIn(("s1", 2), idx) + keys = {(r[0], r[1]) for r in lg2.query_per_instance("iou", exp_hash="h1")} + self.assertIn(("s0", 1), keys) + self.assertIn(("s1", 2), keys) def test_clear_signal_histories_also_clears_indices(self): lg = _fresh_logger() @@ -400,8 +397,8 @@ def test_clear_signal_histories_also_clears_indices(self): signal_per_sample={"img0": 0.5}, aggregate_by_step=False) lg.add_instance_scalars("iou", ["s0"], [1], [0.8], 1, "h1") lg.clear_signal_histories() - self.assertEqual(lg._sample_index, {}) - self.assertEqual(lg._instance_index, {}) + self.assertEqual(lg.query_per_sample("loss"), []) + self.assertEqual(lg.query_per_instance("iou"), []) def test_query_uses_index_not_full_scan(self): """query_per_sample with filter returns correct results via index path.""" @@ -422,10 +419,10 @@ def test_eval_mode_also_updates_sample_index(self): lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"imgA": 0.5, "imgB": 0.3}, aggregate_by_step=False) lg._eval_mode_active = False - idx = lg._sample_index.get("loss", {}).get("eval_h1", {}) - self.assertIn("imgA", idx) - self.assertIn("imgB", idx) - # Query must find them + # Per-sample data was written under the eval hash and is queryable + under_eval = {r[0] for r in lg.query_per_sample("loss", exp_hash="eval_h1")} + self.assertIn("imgA", under_eval) + self.assertIn("imgB", under_eval) rows = lg.query_per_sample("loss", sample_ids=["imgA"]) self.assertEqual(len(rows), 1) @@ -445,9 +442,9 @@ def test_legacy_list_of_dicts_snapshot_rebuilds_index(self): }, } lg.load_snapshot(legacy_snap) - idx = lg._sample_index.get("loss", {}).get("h1", {}) - self.assertIn("img0", idx) - self.assertIn("img1", idx) + under_h1 = {r[0] for r in lg.query_per_sample("loss", exp_hash="h1")} + self.assertIn("img0", under_h1) + self.assertIn("img1", under_h1) rows = lg.query_per_sample("loss", sample_ids=["img0"]) self.assertEqual(len(rows), 1) @@ -456,14 +453,8 @@ def test_multi_exp_hash_filter(self): lg = _fresh_logger() lg.add_scalars("loss", {"loss": 0.1}, 1, signal_per_sample={"s0": 0.1}, aggregate_by_step=False) - # Manually inject a second hash entry to simulate two runs - from array import array as _array - lg._signal_history_per_sample["loss"]["h2"] = { - "sample_ids": ["s0"], - "steps": _array('i', [1]), - "values": _array('f', [0.9]), - } - lg._sample_index.setdefault("loss", {}).setdefault("h2", {})["s0"] = [0] + # Add a second run's data under hash "h2" + lg.ingest_per_sample("loss", "h2", [("s0", 1, 0.9)]) rows_h2 = lg.query_per_sample("loss", sample_ids=["s0"], exp_hash="h2") self.assertEqual(len(rows_h2), 1) self.assertAlmostEqual(rows_h2[0][2], 0.9, places=4) diff --git a/weightslab/tests/backend/test_logger_core.py b/weightslab/tests/backend/test_logger_core.py index 867f6d37..9443e571 100644 --- a/weightslab/tests/backend/test_logger_core.py +++ b/weightslab/tests/backend/test_logger_core.py @@ -39,6 +39,17 @@ def _add(lg, sig, sid, step, val, aggregate_by_step=False): aggregate_by_step=aggregate_by_step) +def _seed_eval_hash(lg, exp_hash, sig="loss", val=0.5, step=1): + """Write an aggregated marker under *exp_hash* via the evaluation lifecycle. + + Replaces white-box seeding of the (now DuckDB-backed) history dict. + """ + lg.start_evaluation_mode("val", exp_hash) + lg.add_scalars(sig, {sig: val}, step, + signal_per_sample=None, aggregate_by_step=False) + lg.stop_evaluation_mode(model_age=step) + + # --------------------------------------------------------------------------- # 1. __len__ # --------------------------------------------------------------------------- @@ -82,22 +93,24 @@ def test_no_existing_evals_returns_1(self): def test_existing_h1_1_returns_2(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {}} + _seed_eval_hash(lg, "h1_1") self.assertEqual(lg.get_next_evaluation_count("h1"), 2) def test_existing_h1_1_and_h1_3_returns_4(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {}, "h1_3": {}} + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1_3") self.assertEqual(lg.get_next_evaluation_count("h1"), 4) def test_non_int_suffix_is_ignored(self): lg = _lg() - lg._signal_history["loss"] = {"h1_abc": {}, "h1_1": {}} + _seed_eval_hash(lg, "h1_abc") + _seed_eval_hash(lg, "h1_1") self.assertEqual(lg.get_next_evaluation_count("h1"), 2) def test_different_base_hash_not_counted(self): lg = _lg() - lg._signal_history["loss"] = {"h2_5": {}} + _seed_eval_hash(lg, "h2_5") self.assertEqual(lg.get_next_evaluation_count("h1"), 1) @@ -126,7 +139,8 @@ def test_add_scalars_during_eval_goes_to_accum_not_history(self): lg.add_scalars("loss", {"loss": 0.4}, 10, signal_per_sample=None, aggregate_by_step=False) self.assertIn("loss", lg._eval_accum) - self.assertNotIn("loss", lg._signal_history) + # Nothing written to the aggregated history during evaluation mode + self.assertEqual(lg.get_signal_history(), {}) def test_add_scalars_during_eval_accumulates_values(self): lg = _lg() @@ -147,15 +161,15 @@ def test_stop_computes_mean_and_writes_history(self): results = lg.stop_evaluation_mode(model_age=10) self.assertIn("loss", results) self.assertAlmostEqual(results["loss"], 0.5, places=5) - # Written into _signal_history under eval_hash - self.assertIn("h1_1", lg._signal_history.get("loss", {})) + # Written into history under eval_hash + self.assertIn("h1_1", lg.get_signal_history().get("loss", {})) def test_stop_emits_is_evaluation_marker(self): lg = _lg() lg.start_evaluation_mode("val", "h1_1") lg.add_scalars("loss", {"loss": 0.5}, 10, signal_per_sample=None, aggregate_by_step=False) lg.stop_evaluation_mode(model_age=10) - entries = lg._signal_history["loss"]["h1_1"][10] + entries = lg.get_signal_history()["loss"]["h1_1"][10] self.assertTrue(entries[0].get("is_evaluation_marker")) def test_stop_adds_to_pending_queue(self): @@ -190,7 +204,7 @@ def test_stop_stores_split_name_and_tags(self): lg.start_evaluation_mode("val", "h1_1", evaluation_tags=["hard", "easy"]) lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample=None, aggregate_by_step=False) lg.stop_evaluation_mode(model_age=1) - entry = lg._signal_history["loss"]["h1_1"][1][0] + entry = lg.get_signal_history()["loss"]["h1_1"][1][0] self.assertEqual(entry["split_name"], "val") self.assertEqual(entry["evaluation_tags"], ["hard", "easy"]) @@ -230,7 +244,7 @@ def test_abort_removes_per_sample_written_during_eval(self): signal_per_sample={"img0": 0.5}, aggregate_by_step=True) lg.abort_evaluation_mode() # Per-sample history under "h1_1" should be gone - self.assertNotIn("h1_1", lg._signal_history_per_sample.get("loss", {})) + self.assertEqual(lg.query_per_sample("loss", exp_hash="h1_1"), []) def test_abort_removes_queue_entries_for_eval_hash(self): lg = _lg() @@ -253,21 +267,24 @@ class TestRemoveEvaluationHash(unittest.TestCase): def test_removes_from_signal_history(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {1: []}, "h1": {1: []}} + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1") lg.remove_evaluation_hash("h1_1") - self.assertNotIn("h1_1", lg._signal_history["loss"]) - self.assertIn("h1", lg._signal_history["loss"]) + hist = lg.get_signal_history()["loss"] + self.assertNotIn("h1_1", hist) + self.assertIn("h1", hist) def test_removes_from_per_sample_history(self): lg = _lg() _add(lg, "loss", "s0", 1, 0.5) - # manually inject an eval hash entry - from array import array as _array - lg._signal_history_per_sample["loss"]["h1_1"] = { - "sample_ids": ["s0"], "steps": _array('i', [1]), "values": _array('f', [0.5]) - } + # write per-sample data under an eval hash via evaluation mode + lg.start_evaluation_mode("val", "h1_1") + lg.add_scalars("loss", {"loss": 0.5}, 1, + signal_per_sample={"s0": 0.5}, aggregate_by_step=True) + lg._eval_mode_active = False + self.assertEqual(len(lg.query_per_sample("loss", exp_hash="h1_1")), 1) lg.remove_evaluation_hash("h1_1") - self.assertNotIn("h1_1", lg._signal_history_per_sample["loss"]) + self.assertEqual(lg.query_per_sample("loss", exp_hash="h1_1"), []) def test_removes_matching_entries_from_queue(self): lg = _lg() @@ -281,9 +298,9 @@ def test_removes_matching_entries_from_queue(self): def test_empty_hash_is_noop(self): lg = _lg() - lg._signal_history["loss"] = {"h1": {}} + _seed_eval_hash(lg, "h1") lg.remove_evaluation_hash("") - self.assertIn("h1", lg._signal_history["loss"]) + self.assertIn("h1", lg.get_signal_history()["loss"]) def test_missing_hash_does_not_raise(self): lg = _lg() @@ -300,8 +317,9 @@ def test_immediate_mode_writes_to_history(self): lg = _lg() lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"s0": 0.5}, aggregate_by_step=False) - self.assertIn(None, lg._signal_history.get("loss", {})) - self.assertEqual(lg._signal_history["loss"][None][1][0]["metric_value"], 0.5) + hist = lg.get_signal_history() + self.assertIn(None, hist.get("loss", {})) + self.assertEqual(hist["loss"][None][1][0]["metric_value"], 0.5) def test_immediate_mode_adds_to_queue(self): lg = _lg() @@ -314,7 +332,7 @@ def test_aggregate_mode_buffers_not_writes(self): lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"s0": 0.5}, aggregate_by_step=True) # Not in history yet — buffered - self.assertNotIn("loss", lg._signal_history) + self.assertEqual(lg.get_signal_history(), {}) self.assertIn((1, "loss", None), lg._current_step_buffer) def test_aggregate_mode_step_change_flushes_to_history(self): @@ -327,7 +345,7 @@ def test_aggregate_mode_step_change_flushes_to_history(self): lg.add_scalars("loss", {"loss": 0.9}, 2, signal_per_sample={"s0": 0.9}, aggregate_by_step=True) # Step 1 should now be averaged in history - entries = lg._signal_history["loss"][None][1] + entries = lg.get_signal_history()["loss"][None][1] self.assertAlmostEqual(entries[0]["metric_value"], 0.5, places=5) def test_aggregate_mode_averages_multiple_calls_same_step(self): @@ -338,7 +356,7 @@ def test_aggregate_mode_averages_multiple_calls_same_step(self): # Force flush lg.add_scalars("acc", {"acc": 0.9}, 2, signal_per_sample=None, aggregate_by_step=False) - entries = lg._signal_history["loss"][None][1] + entries = lg.get_signal_history()["loss"][None][1] self.assertAlmostEqual(entries[0]["metric_value"], 0.4, places=5) def test_per_sample_written_even_in_aggregate_mode(self): @@ -402,9 +420,9 @@ def test_dedup_does_not_corrupt_index(self): lg = _lg() lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.4)]) lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # duplicate ignored - # index should still point to exactly 1 row - idx = lg._sample_index["loss"]["h1"]["s0"] - self.assertEqual(len(idx), 1) + # exactly one row should remain queryable for (s0, h1) + rows = lg.query_per_sample("loss", sample_ids=["s0"], exp_hash="h1") + self.assertEqual(len(rows), 1) # --------------------------------------------------------------------------- @@ -479,9 +497,10 @@ def test_returns_deepcopy(self): lg = _lg() _add(lg, "loss", "s0", 1, 0.5) hist = lg.get_signal_history() - # Mutate the copy — internal state must not change + # Mutate the returned copy — a fresh read must not reflect the mutation hist["loss"][None][1][0]["metric_value"] = 999.0 - self.assertNotEqual(lg._signal_history["loss"][None][1][0]["metric_value"], 999.0) + fresh = lg.get_signal_history() + self.assertNotEqual(fresh["loss"][None][1][0]["metric_value"], 999.0) def test_empty_when_nothing_added(self): lg = _lg() @@ -531,7 +550,9 @@ def test_empty_when_no_eval(self): def test_returns_eval_hashes(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {}, "h1_2": {}, "h1": {}} + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1_2") + _seed_eval_hash(lg, "h1") hashes = lg.get_evaluation_marker_hashes() self.assertIn("h1_1", hashes) self.assertIn("h1_2", hashes) @@ -539,12 +560,15 @@ def test_returns_eval_hashes(self): def test_returns_sorted(self): lg = _lg() - lg._signal_history["loss"] = {"h1_3": {}, "h1_1": {}, "h1_2": {}} + _seed_eval_hash(lg, "h1_3") + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1_2") self.assertEqual(lg.get_evaluation_marker_hashes(), ["h1_1", "h1_2", "h1_3"]) def test_non_int_suffix_excluded(self): lg = _lg() - lg._signal_history["loss"] = {"h1_abc": {}, "h1_1": {}} + _seed_eval_hash(lg, "h1_abc") + _seed_eval_hash(lg, "h1_1") hashes = lg.get_evaluation_marker_hashes() self.assertNotIn("h1_abc", hashes) self.assertIn("h1_1", hashes) @@ -622,7 +646,7 @@ def test_set_note_on_history_entry(self): _add(lg, "loss", "s0", 5, 0.4) result = lg.set_point_note("loss", "run1", 5, "my note") self.assertTrue(result) - entry = lg._signal_history["loss"]["run1"][5][0] + entry = lg.get_signal_history()["loss"]["run1"][5][0] self.assertEqual(entry["point_note"], "my note") def test_clear_note_with_empty_string(self): @@ -630,7 +654,7 @@ def test_clear_note_with_empty_string(self): _add(lg, "loss", "s0", 5, 0.4) lg.set_point_note("loss", "run1", 5, "my note") lg.set_point_note("loss", "run1", 5, "") - entry = lg._signal_history["loss"]["run1"][5][0] + entry = lg.get_signal_history()["loss"]["run1"][5][0] self.assertNotIn("point_note", entry) def test_updates_pending_queue_entry(self): @@ -671,9 +695,10 @@ def test_dict_format_loads_correctly(self): } } }) - self.assertIn("loss", lg._signal_history) - self.assertIn("h1", lg._signal_history["loss"]) - self.assertEqual(lg._signal_history["loss"]["h1"][1][0]["metric_value"], 0.5) + hist = lg.get_signal_history() + self.assertIn("loss", hist) + self.assertIn("h1", hist["loss"]) + self.assertEqual(hist["loss"]["h1"][1][0]["metric_value"], 0.5) def test_dict_format_string_step_key_converted_to_int(self): lg = _lg() @@ -685,7 +710,7 @@ def test_dict_format_string_step_key_converted_to_int(self): } } }) - self.assertIn(42, lg._signal_history["loss"]["h1"]) + self.assertIn(42, lg.get_signal_history()["loss"]["h1"]) def test_list_format_loads_correctly(self): lg = _lg() @@ -693,21 +718,22 @@ def test_list_format_loads_correctly(self): {"metric_name": "acc", "experiment_hash": "h1", "model_age": 3, "metric_value": 0.9, "timestamp": 0}, ]) - self.assertIn("acc", lg._signal_history) - self.assertEqual(lg._signal_history["acc"]["h1"][3][0]["metric_value"], 0.9) + hist = lg.get_signal_history() + self.assertIn("acc", hist) + self.assertEqual(hist["acc"]["h1"][3][0]["metric_value"], 0.9) def test_list_format_skips_entries_without_metric_name(self): lg = _lg() lg.load_signal_history([ {"experiment_hash": "h1", "model_age": 1, "metric_value": 0.5}, ]) - self.assertEqual(lg._signal_history, {}) + self.assertEqual(lg.get_signal_history(), {}) def test_empty_input_is_noop(self): lg = _lg() lg.load_signal_history({}) lg.load_signal_history([]) - self.assertEqual(lg._signal_history, {}) + self.assertEqual(lg.get_signal_history(), {}) def test_adds_to_graph_names(self): lg = _lg() @@ -723,7 +749,7 @@ def test_missing_fields_get_defaults(self): {"metric_name": "loss", "model_age": 5, "metric_value": 0.1}, ]) # experiment_hash defaults to None - self.assertIn(None, lg._signal_history["loss"]) + self.assertIn(None, lg.get_signal_history()["loss"]) if __name__ == "__main__": diff --git a/weightslab/tests/gRPC/test_grpc_user_actions.py b/weightslab/tests/gRPC/test_grpc_user_actions.py index f9b655c2..fb405e83 100644 --- a/weightslab/tests/gRPC/test_grpc_user_actions.py +++ b/weightslab/tests/gRPC/test_grpc_user_actions.py @@ -531,16 +531,22 @@ def test_break_by_slices_from_tags_filters_expected_sample(self): {"tag:hard": [True, False]}, index=[11, 12], ) - # break-by-slices reads compact (sample_id, step, value, hash) tuples via - # query_per_sample (filtered by the tag-derived sample_ids), then aggregates - # the matching samples into a single MEAN curve per experiment_hash. + # break-by-slices aggregates the tag-derived sample_ids into a single MEAN + # curve per experiment_hash via aggregate_per_sample_by_step. _pts = [("11", 5, 0.2, "exp-1"), ("12", 5, 0.8, "exp-1")] - def _qps(graph_name, sample_ids=None, exp_hash=None): + def _agg(graph_name, sample_ids=None, exp_hash=None): wanted = {str(s) for s in sample_ids} if sample_ids is not None else None - return [t for t in _pts if wanted is None or str(t[0]) in wanted] + rows = [t for t in _pts if wanted is None or str(t[0]) in wanted] + by_hash: dict = {} + for sid, step, val, h in rows: + by_hash.setdefault(h, {}).setdefault(step, []).append(val) + return { + h: sorted((s, sum(v) / len(v)) for s, v in steps.items()) + for h, steps in by_hash.items() + } - signal_logger.query_per_sample.side_effect = _qps + signal_logger.aggregate_per_sample_by_step.side_effect = _agg signal_logger.get_evaluation_marker_hashes.return_value = [] servicer = ExperimentServiceServicer(exp_service=exp_service) diff --git a/weightslab/tests/trainer/services/test_trainer_services_unit.py b/weightslab/tests/trainer/services/test_trainer_services_unit.py index 01c965c9..3333489e 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_unit.py +++ b/weightslab/tests/trainer/services/test_trainer_services_unit.py @@ -77,16 +77,23 @@ def test_get_latest_logger_data_queue_mode(self): ) def test_get_latest_logger_data_break_by_slices(self): signal_logger = MagicMock() - # break-by-slices reads compact (sample_id, step, value, hash) tuples via - # query_per_sample (filtered by the tag-derived sample_ids), then aggregates - # the matching samples into a single MEAN curve per experiment_hash. + # break-by-slices aggregates the tag-derived sample_ids into a single MEAN + # curve per experiment_hash via aggregate_per_sample_by_step (DuckDB does the + # GROUP BY step / AVG natively). _pts = [("11", 3, 0.3, "exp"), ("12", 3, 0.6, "exp")] - def _qps(graph_name, sample_ids=None, exp_hash=None): + def _agg(graph_name, sample_ids=None, exp_hash=None): wanted = {str(s) for s in sample_ids} if sample_ids is not None else None - return [t for t in _pts if wanted is None or str(t[0]) in wanted] + rows = [t for t in _pts if wanted is None or str(t[0]) in wanted] + by_hash: dict = {} + for sid, step, val, h in rows: + by_hash.setdefault(h, {}).setdefault(step, []).append(val) + return { + h: sorted((s, sum(v) / len(v)) for s, v in steps.items()) + for h, steps in by_hash.items() + } - signal_logger.query_per_sample.side_effect = _qps + signal_logger.aggregate_per_sample_by_step.side_effect = _agg signal_logger.get_evaluation_marker_hashes.return_value = [] df_manager = MagicMock() df_manager.get_df_view.return_value = pd.DataFrame( From cd068d6a4badfca5ca7997c4c93e6c8c7b119236 Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Fri, 19 Jun 2026 16:50:59 +0200 Subject: [PATCH 06/16] Fix duplicate SaveCheckpointOperation in proto; document Studio feature toggles The branch sync left two identical `message SaveCheckpointOperation` definitions, which makes protoc reject the file. Remove the redundant first copy (keeping the documented "manual save now" one) and regenerate the Python descriptor. The grpc stub is unchanged (no rpc was duplicated). Docs: add the Weights Studio feature toggles (ENABLE_PLOTS, ENABLE_DATA_EXPLORATION, ENABLE_HYPERPARAMETERS_OPTIMIZATION, ENABLE_AGENT) to configuration.rst and the AGENTS.md frontend env-var table. Co-Authored-By: Claude Opus 4.8 (1M context) --- AGENTS.md | 17 +++++-- docs/configuration.rst | 56 ++++++++++++++++++++ weightslab/proto/experiment_service.proto | 5 -- weightslab/proto/experiment_service_pb2.py | 59 ---------------------- 4 files changed, 68 insertions(+), 69 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 50773f7c..8b8eeda5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -155,11 +155,18 @@ ones when debugging: | `WS_HISTOGRAM_MAX_BINS` | `512` | Cap on metadata histogram bars. | | `BB_THUMB_RENDER` | `10` | Max bounding boxes drawn per **thumbnail**, per overlay (GT and PRED capped independently). | | `BB_MODAL_RENDER` | `100` | Max bounding boxes drawn per **modal** image, per overlay. A `?` button in the modal shows the active limit. | - -> **VITE_ vs WS_/BB_:** `VITE_*` variables are baked at **build time** (changing -> them needs a rebuild). `WS_*` / `BB_*` are injected at **container start** into -> `config.js` and read as `window.*` globals — changing them needs only a -> container restart + browser reload (see the caching note in §5). +| `ENABLE_PLOTS` | `1` | `0`/`false` removes the plots board + Signals card and stops plot auto-refresh. | +| `ENABLE_DATA_EXPLORATION` | `1` | `0`/`false` removes the data grid + metadata/details panel and stops the data/metadata auto-refresh. | +| `ENABLE_HYPERPARAMETERS_OPTIMIZATION` | `1` | `0`/`false` removes the Hyperparameters section, makes HP inputs read-only, and stops the HP poll. | +| `ENABLE_AGENT` | `1` | `0`/`false` removes the agent chat bar + history panel and stops the agent health poll. | + +> **VITE_ vs WS_/BB_/ENABLE_:** `VITE_*` variables are baked at **build time** +> (changing them needs a rebuild). `WS_*` / `BB_*` / `ENABLE_*` are injected at +> **container start** into `config.js` and read as `window.*` globals (the +> toggles as `window.WS_ENABLE_*`) — changing them needs only a container restart +> + browser reload (see the caching note in §5). Each `ENABLE_*` defaults to on; +> set it to `0`/`false`/`no`/`off` to disable. Full reference: +> `weightslab/docs/configuration.rst` (“Feature toggles”). --- diff --git a/docs/configuration.rst b/docs/configuration.rst index de866df6..8eff60e1 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -589,3 +589,59 @@ shown below. Values are clamped to a hard ceiling of ``10000``. These caps only affect *rendering* — no sample data is dropped. They apply to detection bounding-box overlays; segmentation masks are unaffected. + + +Feature toggles +~~~~~~~~~~~~~~~ + +Whole areas of the Studio UI can be turned off for a given deployment — for +example a read-only demo that only shows plots, or a labelling-only view with no +agent. Each toggle **removes the area from the UI** (the elements are hidden) +**and stops its background work** (auto-refresh timers and gRPC polls are never +started), so a disabled area costs nothing at runtime. + +Like the bounding-box render limits, these are set on the Weights Studio frontend +container (for example in ``../weights_studio/docker/docker-compose.yml``) and +injected into the page at startup by the nginx entrypoint — changing them needs +no rebuild, just a container restart + browser reload. For a local ``vite`` dev +server, use the ``VITE_`` fallbacks shown below. Every toggle **defaults to +enabled**; set it to ``0`` / ``false`` / ``no`` / ``off`` (any case) to disable. + +.. list-table:: + :header-rows: 1 + :widths: 38 10 52 + + * - Variable + - Default + - Description + * - ``ENABLE_PLOTS`` + - ``1`` + - When disabled, removes the plots board and the left-panel Signals/metrics + card, and stops the plot-data auto-refresh (the ``GetLatestLoggerData`` + poll and the chart redraw loop). Dev-server fallback: + ``VITE_ENABLE_PLOTS``. + * - ``ENABLE_DATA_EXPLORATION`` + - ``1`` + - When disabled, removes the data sample grid and the metadata / details + left panel, and stops the data auto-refresh (the ``GetDataSamples`` / + ``GetMetaData`` timers and the slider-histogram poll). Dev-server + fallback: ``VITE_ENABLE_DATA_EXPLORATION``. + * - ``ENABLE_HYPERPARAMETERS_OPTIMIZATION`` + - ``1`` + - When disabled, removes the Hyperparameters section from the left panel, + makes the hyperparameter inputs read-only (no user edits are sent to the + backend), and stops the hyperparameter sync poll. Dev-server fallback: + ``VITE_ENABLE_HYPERPARAMETERS_OPTIMIZATION``. + * - ``ENABLE_AGENT`` + - ``1`` + - When disabled, removes the agent chat input bar (and its send button) and + the chat-history panel, and stops the agent health-check poll. Dev-server + fallback: ``VITE_ENABLE_AGENT``. + +.. note:: + + Each variable maps to a ``window.WS_ENABLE_*`` global injected into + ``config.js`` at container start (the same mechanism as the bounding-box + limits), with a build-time ``VITE_ENABLE_*`` fallback for the dev server. + Because ``config.js`` is served ``no-store``, a container restart + normal + reload is enough to pick up a change. diff --git a/weightslab/proto/experiment_service.proto b/weightslab/proto/experiment_service.proto index 7c4cffd5..fefb5e47 100644 --- a/weightslab/proto/experiment_service.proto +++ b/weightslab/proto/experiment_service.proto @@ -160,11 +160,6 @@ message LoadCheckpointOperation { int32 checkpoint_id = 1; } -message SaveCheckpointOperation { - bool save_architecture = 1; // force re-dump the model architecture even if a file already exists - bool save_optimizer = 2; // also persist optimizer state alongside the weights -} - message PlotNoteOperation { string metric_name = 1; string experiment_hash = 2; diff --git a/weightslab/proto/experiment_service_pb2.py b/weightslab/proto/experiment_service_pb2.py index ce6f6b14..1560bb5b 100644 --- a/weightslab/proto/experiment_service_pb2.py +++ b/weightslab/proto/experiment_service_pb2.py @@ -24,7 +24,6 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\x84\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"W\n\x12GetMetaDataRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12\x17\n\x0fmodal_sample_id\x18\x03 \x01(\t\"\x99\x01\n\x13GetMetaDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1a\n\x12\x61ll_metadata_names\x18\x03 \x03(\t\x12!\n\x0cgrid_records\x18\x04 \x03(\x0b\x32\x0b.DataRecord\x12!\n\x0cmodal_record\x18\x05 \x01(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xbe\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x38\n\x0bGetMetaData\x12\x13.GetMetaDataRequest\x1a\x14.GetMetaDataResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') _globals = globals() @@ -38,16 +37,6 @@ _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_options = b'8\001' _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._loaded_options = None _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_options = b'8\001' - _globals['_WEIGHTOPERATIONTYPE']._serialized_start=8986 - _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9086 - _globals['_ZEROFYPREDICATE']._serialized_start=9088 - _globals['_ZEROFYPREDICATE']._serialized_end=9199 - _globals['_AGENTINTENTTYPE']._serialized_start=9201 - _globals['_AGENTINTENTTYPE']._serialized_end=9278 - _globals['_SAMPLEEDITTYPE']._serialized_start=9280 - _globals['_SAMPLEEDITTYPE']._serialized_end=9353 - _globals['_AGENTPROVIDERTYPE']._serialized_start=9355 - _globals['_AGENTPROVIDERTYPE']._serialized_end=9399 _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9231 _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9331 _globals['_ZEROFYPREDICATE']._serialized_start=9333 @@ -90,10 +79,6 @@ _globals['_DENYSAMPLESOPERATION']._serialized_end=2317 _globals['_LOADCHECKPOINTOPERATION']._serialized_start=2319 _globals['_LOADCHECKPOINTOPERATION']._serialized_end=2367 - _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2369 - _globals['_SAVECHECKPOINTOPERATION']._serialized_end=2445 - _globals['_PLOTNOTEOPERATION']._serialized_start=2447 - _globals['_PLOTNOTEOPERATION']._serialized_end=2545 _globals['_PLOTNOTEOPERATION']._serialized_start=2369 _globals['_PLOTNOTEOPERATION']._serialized_end=2467 _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2469 @@ -150,50 +135,6 @@ _globals['_DATARECORD']._serialized_end=7277 _globals['_DATASAMPLESRESPONSE']._serialized_start=7279 _globals['_DATASAMPLESRESPONSE']._serialized_end=7369 - _globals['_POINTCLOUDREQUEST']._serialized_start=7371 - _globals['_POINTCLOUDREQUEST']._serialized_end=7445 - _globals['_POINTCLOUDCHUNK']._serialized_start=7448 - _globals['_POINTCLOUDCHUNK']._serialized_end=7639 - _globals['_DATAEDITSREQUEST']._serialized_start=7642 - _globals['_DATAEDITSREQUEST']._serialized_end=7862 - _globals['_DATAEDITSRESPONSE']._serialized_start=7864 - _globals['_DATAEDITSRESPONSE']._serialized_end=7917 - _globals['_DATASPLITSRESPONSE']._serialized_start=7919 - _globals['_DATASPLITSRESPONSE']._serialized_end=7977 - _globals['_AGENTHEALTHRESPONSE']._serialized_start=7979 - _globals['_AGENTHEALTHRESPONSE']._serialized_end=8036 - _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8038 - _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8132 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8134 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8193 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8195 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8235 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8237 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8297 - _globals['_GETAGENTMODELSREQUEST']._serialized_start=8299 - _globals['_GETAGENTMODELSREQUEST']._serialized_end=8322 - _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8324 - _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8398 - _globals['_RESETAGENTRESPONSE']._serialized_start=8400 - _globals['_RESETAGENTRESPONSE']._serialized_end=8454 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8456 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8507 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8509 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8570 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8572 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8654 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8656 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8717 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8719 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8747 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8750 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=8879 - _globals['_CANCELEVALUATIONREQUEST']._serialized_start=8881 - _globals['_CANCELEVALUATIONREQUEST']._serialized_end=8922 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=8924 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=8984 - _globals['_EXPERIMENTSERVICE']._serialized_start=9402 - _globals['_EXPERIMENTSERVICE']._serialized_end=10686 _globals['_GETMETADATAREQUEST']._serialized_start=7371 _globals['_GETMETADATAREQUEST']._serialized_end=7458 _globals['_GETMETADATARESPONSE']._serialized_start=7461 From 93eb100a2b2838aeba3bdb7811101e747526a803 Mon Sep 17 00:00:00 2001 From: AlexGrayBox Date: Fri, 19 Jun 2026 17:40:26 +0200 Subject: [PATCH 07/16] Server-side histogram binning + grid/sort perf fixes (#216) - GetHistogram RPC: bin one column server-side into <=512 typed bins (min/max/avg/count + per-(origin,discarded) sub-bars) instead of the client pulling every row and binning in the browser. Bit-identical to the client binning; ~116x smaller payload, ~50ms warm. Adds the proto messages + RPC, regenerated pb2/pb2_grpc, DataService.GetHistogram, and servicer delegation. - ApplyDataQuery: skip the forced full-view rebuild for SORT-ONLY operations (a sort just re-orders the existing snapshot). Global sort ~7.5s -> ~0.5s. - _slowUpdateInternals: run the view rebuild on a background thread for reader-triggered (non-force) refreshes, so grid/histogram reads never block on the multi-second collapse+combine. Reader p95 ~3000ms -> ~130ms. Filters/ resets still refresh inline (need fresh data). - ws-classification example: loosen eval (100->500) / checkpoint (25->250) cadence and use a bigger eval batch (16->128) so eval stops dominating wall-clock. Co-authored-by: Alexandru Rotaru Co-authored-by: Claude Opus 4.8 Co-authored-by: Guillaume --- .../PyTorch/ws-classification/config.yaml | 6 +- weightslab/proto/experiment_service.proto | 31 +++++ weightslab/proto/experiment_service_pb2.py | 121 +++++++++++++++++- .../proto/experiment_service_pb2_grpc.py | 22 +++- .../tests/gRPC/test_grpc_user_actions.py | 1 + weightslab/trainer/services/data_service.py | 115 ++++++++++++++++- weightslab/trainer/trainer_services.py | 3 + 7 files changed, 283 insertions(+), 16 deletions(-) diff --git a/weightslab/examples/PyTorch/ws-classification/config.yaml b/weightslab/examples/PyTorch/ws-classification/config.yaml index 655cff99..7efe94e0 100644 --- a/weightslab/examples/PyTorch/ws-classification/config.yaml +++ b/weightslab/examples/PyTorch/ws-classification/config.yaml @@ -8,8 +8,8 @@ training_steps_to_do: null # Set to null for infinite training until manually s compute_natural_sort: false # Experiment parameters -eval_full_to_train_steps_ratio: 100 -experiment_dump_to_train_steps_ratio: 25 +eval_full_to_train_steps_ratio: 500 # was 100 — full 10k eval was the dominant wall-clock cost +experiment_dump_to_train_steps_ratio: 250 # was 25 — frequent checkpoint dumps stalled training skip_checkpoint_load: false # If true restart the experiment from last state tqdm_display: true # Whether to use tqdm progress bars during training/evaluation is_training: false # Start training immediately or not @@ -35,5 +35,5 @@ data: batch_size: 16 test_loader: shuffle: false - batch_size: 16 + batch_size: 128 # was 16 — bigger eval batches => ~8x fewer eval steps per pass drop_last: false diff --git a/weightslab/proto/experiment_service.proto b/weightslab/proto/experiment_service.proto index fefb5e47..0165a56b 100644 --- a/weightslab/proto/experiment_service.proto +++ b/weightslab/proto/experiment_service.proto @@ -17,6 +17,8 @@ service ExperimentService { // Data Service (for weights_studio UI) rpc ApplyDataQuery (DataQueryRequest) returns (DataQueryResponse); rpc GetDataSamples (DataSamplesRequest) returns (DataSamplesResponse); + // Server-side histogram binning of one metadata/signal column. + rpc GetHistogram (HistogramRequest) returns (HistogramResponse); // Metadata-only retrieval (dataframe columns). Returns every metadata column // name for the WHOLE dataset, the current grid slice's per-sample metadata, and // the open modal sample's metadata. Separated from GetDataSamples, which now @@ -395,6 +397,35 @@ message DataSamplesResponse { repeated DataRecord data_records = 3; } +// --- Server-side histogram binning --- +// One stacked sub-segment of a bar: count of samples in this bin for a given +// (origin, discarded) combination (used to colour train/eval/discarded splits). +message HistogramSubBar { + string origin = 1; + bool discarded = 2; + int64 count = 3; +} + +// One histogram bar: aggregate stats over the samples whose row index falls in +// this bin's range, plus the per-(origin,discarded) breakdown. +message HistogramBin { + double min = 1; + double max = 2; + double avg = 3; + int64 count = 4; + repeated HistogramSubBar sub_bars = 5; +} + +message HistogramRequest { + string column = 1; // dataframe/signal column to histogram + int32 max_bins = 2; // 0 => server default (512) +} + +message HistogramResponse { + bool success = 1; + string message = 2; + int64 total_rows = 3; // rows in the view that were binned + repeated HistogramBin bins = 4; // --- Metadata retrieval (separated from GetDataSamples) --- message GetMetaDataRequest { int32 start_index = 1; // grid slice start (current view order) diff --git a/weightslab/proto/experiment_service_pb2.py b/weightslab/proto/experiment_service_pb2.py index 1560bb5b..d1f48342 100644 --- a/weightslab/proto/experiment_service_pb2.py +++ b/weightslab/proto/experiment_service_pb2.py @@ -2,7 +2,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: weightslab/proto/experiment_service.proto -# Protobuf Python Version: 6.31.1 +# Protobuf Python Version: 6.33.5 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -12,8 +12,8 @@ _runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.Domain.PUBLIC, 6, - 31, - 1, + 33, + 5, '', 'weightslab/proto/experiment_service.proto' ) @@ -24,6 +24,7 @@ +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"\xdc\x06\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"C\n\x0fHistogramSubBar\x12\x0e\n\x06origin\x18\x01 \x01(\t\x12\x11\n\tdiscarded\x18\x02 \x01(\x08\x12\r\n\x05\x63ount\x18\x03 \x01(\x03\"h\n\x0cHistogramBin\x12\x0b\n\x03min\x18\x01 \x01(\x01\x12\x0b\n\x03max\x18\x02 \x01(\x01\x12\x0b\n\x03\x61vg\x18\x03 \x01(\x01\x12\r\n\x05\x63ount\x18\x04 \x01(\x03\x12\"\n\x08sub_bars\x18\x05 \x03(\x0b\x32\x10.HistogramSubBar\"4\n\x10HistogramRequest\x12\x0e\n\x06\x63olumn\x18\x01 \x01(\t\x12\x10\n\x08max_bins\x18\x02 \x01(\x05\"f\n\x11HistogramResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\ntotal_rows\x18\x03 \x01(\x03\x12\x1b\n\x04\x62ins\x18\x04 \x03(\x0b\x32\r.HistogramBin\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xbb\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x35\n\x0cGetHistogram\x12\x11.HistogramRequest\x1a\x12.HistogramResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"W\n\x12GetMetaDataRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12\x17\n\x0fmodal_sample_id\x18\x03 \x01(\t\"\x99\x01\n\x13GetMetaDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1a\n\x12\x61ll_metadata_names\x18\x03 \x03(\t\x12!\n\x0cgrid_records\x18\x04 \x03(\x0b\x32\x0b.DataRecord\x12!\n\x0cmodal_record\x18\x05 \x01(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xbe\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x38\n\x0bGetMetaData\x12\x13.GetMetaDataRequest\x1a\x14.GetMetaDataResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') _globals = globals() @@ -37,6 +38,16 @@ _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_options = b'8\001' _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._loaded_options = None _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_options = b'8\001' + _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9145 + _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9245 + _globals['_ZEROFYPREDICATE']._serialized_start=9247 + _globals['_ZEROFYPREDICATE']._serialized_end=9358 + _globals['_AGENTINTENTTYPE']._serialized_start=9360 + _globals['_AGENTINTENTTYPE']._serialized_end=9437 + _globals['_SAMPLEEDITTYPE']._serialized_start=9439 + _globals['_SAMPLEEDITTYPE']._serialized_end=9512 + _globals['_AGENTPROVIDERTYPE']._serialized_start=9514 + _globals['_AGENTPROVIDERTYPE']._serialized_end=9558 _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9231 _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9331 _globals['_ZEROFYPREDICATE']._serialized_start=9333 @@ -81,6 +92,110 @@ _globals['_LOADCHECKPOINTOPERATION']._serialized_end=2367 _globals['_PLOTNOTEOPERATION']._serialized_start=2369 _globals['_PLOTNOTEOPERATION']._serialized_end=2467 + _globals['_TRAINERCOMMAND']._serialized_start=2470 + _globals['_TRAINERCOMMAND']._serialized_end=3330 + _globals['_HYPERPARAMETERDESC']._serialized_start=3333 + _globals['_HYPERPARAMETERDESC']._serialized_end=3490 + _globals['_NEURONSTATISTICS']._serialized_start=3493 + _globals['_NEURONSTATISTICS']._serialized_end=3863 + _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_start=3722 + _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_end=3771 + _globals['_LAYERREPRESENTATION']._serialized_start=3866 + _globals['_LAYERREPRESENTATION']._serialized_end=4234 + _globals['_ACTIVATIONREQUEST']._serialized_start=4236 + _globals['_ACTIVATIONREQUEST']._serialized_end=4308 + _globals['_ACTIVATIONMAP']._serialized_start=4310 + _globals['_ACTIVATIONMAP']._serialized_end=4382 + _globals['_ACTIVATIONRESPONSE']._serialized_start=4384 + _globals['_ACTIVATIONRESPONSE']._serialized_end=4484 + _globals['_TASKFIELD']._serialized_start=4487 + _globals['_TASKFIELD']._serialized_end=4634 + _globals['_RECORDMETADATA']._serialized_start=4637 + _globals['_RECORDMETADATA']._serialized_end=4969 + _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_start=4916 + _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_end=4969 + _globals['_SAMPLESTATISTICS']._serialized_start=4972 + _globals['_SAMPLESTATISTICS']._serialized_end=5119 + _globals['_COMMANDRESPONSE']._serialized_start=5122 + _globals['_COMMANDRESPONSE']._serialized_end=5352 + _globals['_SAMPLEREQUEST']._serialized_start=5354 + _globals['_SAMPLEREQUEST']._serialized_end=5439 + _globals['_SAMPLEREQUESTRESPONSE']._serialized_start=5442 + _globals['_SAMPLEREQUESTRESPONSE']._serialized_end=5743 + _globals['_BATCHSAMPLEREQUEST']._serialized_start=5746 + _globals['_BATCHSAMPLEREQUEST']._serialized_end=5892 + _globals['_BATCHSAMPLERESPONSE']._serialized_start=5894 + _globals['_BATCHSAMPLERESPONSE']._serialized_end=5956 + _globals['_WEIGHTSREQUEST']._serialized_start=5958 + _globals['_WEIGHTSREQUEST']._serialized_end=6004 + _globals['_WEIGHTSRESPONSE']._serialized_start=6007 + _globals['_WEIGHTSRESPONSE']._serialized_end=6292 + _globals['_DATAQUERYREQUEST']._serialized_start=6294 + _globals['_DATAQUERYREQUEST']._serialized_end=6376 + _globals['_CATEGORICALTAGDEF']._serialized_start=6378 + _globals['_CATEGORICALTAGDEF']._serialized_end=6431 + _globals['_DATAQUERYRESPONSE']._serialized_start=6434 + _globals['_DATAQUERYRESPONSE']._serialized_end=6731 + _globals['_DATASAMPLESREQUEST']._serialized_start=6734 + _globals['_DATASAMPLESREQUEST']._serialized_end=6928 + _globals['_DATASTAT']._serialized_start=6930 + _globals['_DATASTAT']._serialized_end=7039 + _globals['_DATARECORD']._serialized_start=7041 + _globals['_DATARECORD']._serialized_end=7103 + _globals['_DATASAMPLESRESPONSE']._serialized_start=7105 + _globals['_DATASAMPLESRESPONSE']._serialized_end=7195 + _globals['_HISTOGRAMSUBBAR']._serialized_start=7197 + _globals['_HISTOGRAMSUBBAR']._serialized_end=7264 + _globals['_HISTOGRAMBIN']._serialized_start=7266 + _globals['_HISTOGRAMBIN']._serialized_end=7370 + _globals['_HISTOGRAMREQUEST']._serialized_start=7372 + _globals['_HISTOGRAMREQUEST']._serialized_end=7424 + _globals['_HISTOGRAMRESPONSE']._serialized_start=7426 + _globals['_HISTOGRAMRESPONSE']._serialized_end=7528 + _globals['_POINTCLOUDREQUEST']._serialized_start=7530 + _globals['_POINTCLOUDREQUEST']._serialized_end=7604 + _globals['_POINTCLOUDCHUNK']._serialized_start=7607 + _globals['_POINTCLOUDCHUNK']._serialized_end=7798 + _globals['_DATAEDITSREQUEST']._serialized_start=7801 + _globals['_DATAEDITSREQUEST']._serialized_end=8021 + _globals['_DATAEDITSRESPONSE']._serialized_start=8023 + _globals['_DATAEDITSRESPONSE']._serialized_end=8076 + _globals['_DATASPLITSRESPONSE']._serialized_start=8078 + _globals['_DATASPLITSRESPONSE']._serialized_end=8136 + _globals['_AGENTHEALTHRESPONSE']._serialized_start=8138 + _globals['_AGENTHEALTHRESPONSE']._serialized_end=8195 + _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8197 + _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8291 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8293 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8352 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8354 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8394 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8396 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8456 + _globals['_GETAGENTMODELSREQUEST']._serialized_start=8458 + _globals['_GETAGENTMODELSREQUEST']._serialized_end=8481 + _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8483 + _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8557 + _globals['_RESETAGENTRESPONSE']._serialized_start=8559 + _globals['_RESETAGENTRESPONSE']._serialized_end=8613 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8615 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8666 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8668 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8729 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8731 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8813 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8815 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8876 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8878 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8906 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8909 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=9038 + _globals['_CANCELEVALUATIONREQUEST']._serialized_start=9040 + _globals['_CANCELEVALUATIONREQUEST']._serialized_end=9081 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=9083 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=9143 + _globals['_EXPERIMENTSERVICE']._serialized_start=9561 + _globals['_EXPERIMENTSERVICE']._serialized_end=10900 _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2469 _globals['_SAVECHECKPOINTOPERATION']._serialized_end=2545 _globals['_TRAINERCOMMAND']._serialized_start=2548 diff --git a/weightslab/proto/experiment_service_pb2_grpc.py b/weightslab/proto/experiment_service_pb2_grpc.py index 51f19e87..ba9759d1 100644 --- a/weightslab/proto/experiment_service_pb2_grpc.py +++ b/weightslab/proto/experiment_service_pb2_grpc.py @@ -5,7 +5,7 @@ from weightslab.proto import experiment_service_pb2 as weightslab_dot_proto_dot_experiment__service__pb2 -GRPC_GENERATED_VERSION = '1.76.0' +GRPC_GENERATED_VERSION = '1.81.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -25,7 +25,7 @@ ) -class ExperimentServiceStub(object): +class ExperimentServiceStub: """Missing associated documentation comment in .proto file.""" def __init__(self, channel): @@ -74,6 +74,10 @@ def __init__(self, channel): request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesRequest.SerializeToString, response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesResponse.FromString, _registered_method=True) + self.GetHistogram = channel.unary_unary( + '/ExperimentService/GetHistogram', + request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString, + response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString, self.GetMetaData = channel.unary_unary( '/ExperimentService/GetMetaData', request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, @@ -141,7 +145,7 @@ def __init__(self, channel): _registered_method=True) -class ExperimentServiceServicer(object): +class ExperimentServiceServicer: """Missing associated documentation comment in .proto file.""" def GetLatestLoggerData(self, request, context): @@ -193,6 +197,8 @@ def GetDataSamples(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetHistogram(self, request, context): + """Server-side histogram binning of one metadata/signal column. def GetMetaData(self, request, context): """Metadata-only retrieval (dataframe columns). Returns every metadata column name for the WHOLE dataset, the current grid slice's per-sample metadata, and @@ -322,6 +328,10 @@ def add_ExperimentServiceServicer_to_server(servicer, server): request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesRequest.FromString, response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesResponse.SerializeToString, ), + 'GetHistogram': grpc.unary_unary_rpc_method_handler( + servicer.GetHistogram, + request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.FromString, + response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.SerializeToString, 'GetMetaData': grpc.unary_unary_rpc_method_handler( servicer.GetMetaData, request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.FromString, @@ -395,7 +405,7 @@ def add_ExperimentServiceServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. -class ExperimentService(object): +class ExperimentService: """Missing associated documentation comment in .proto file.""" @staticmethod @@ -615,6 +625,7 @@ def GetDataSamples(request, _registered_method=True) @staticmethod + def GetHistogram(request, def GetMetaData(request, target, options=(), @@ -628,6 +639,9 @@ def GetMetaData(request, return grpc.experimental.unary_unary( request, target, + '/ExperimentService/GetHistogram', + weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString, + weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString, '/ExperimentService/GetMetaData', weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.FromString, diff --git a/weightslab/tests/gRPC/test_grpc_user_actions.py b/weightslab/tests/gRPC/test_grpc_user_actions.py index fb405e83..cd44c8dd 100644 --- a/weightslab/tests/gRPC/test_grpc_user_actions.py +++ b/weightslab/tests/gRPC/test_grpc_user_actions.py @@ -280,6 +280,7 @@ def _make_real_data_service(self): # the first call proceeds (mirrors DataService.__init__). ds._update_done = threading.Event() ds._update_done.set() + ds._refresh_in_flight = threading.Lock() # mirrors __init__: bg view-refresh guard ds._df_manager = df_manager ds._all_datasets_df = df.copy() ds._compute_natural_sort = False diff --git a/weightslab/trainer/services/data_service.py b/weightslab/trainer/services/data_service.py index 13ba5a17..9d2904da 100755 --- a/weightslab/trainer/services/data_service.py +++ b/weightslab/trainer/services/data_service.py @@ -237,6 +237,10 @@ def __init__(self, ctx): self._update_lock = threading.Lock() self._update_done = threading.Event() self._update_done.set() # "done" initially so the very first call proceeds + # Guard so a non-force (reader-triggered) view refresh runs in the BACKGROUND + # at most once at a time — readers never pay the rebuild cost (they read the + # current snapshot; the bg thread swaps in fresh data when ready). + self._refresh_in_flight = threading.Lock() self._df_manager = get_dataframe() # init references to the context components @@ -2492,6 +2496,17 @@ def _watched_lock(self, lock_name: str = "_lock"): # ------------------------------------------------------------------ # Main update method # ------------------------------------------------------------------ + def _bg_view_refresh(self) -> None: + """Background view rebuild for reader-triggered (non-force) refreshes. Runs the + real rebuild+swap via force=True OFF the request path, then releases the guard so + a later stale read can trigger another. Never raises into a request.""" + try: + self._slowUpdateInternals(force=True) + except Exception: + logger.exception("[ViewRefresh] background view refresh failed") + finally: + self._refresh_in_flight.release() + def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> None: """Update the internal dataframe view with the latest data from the manager. @@ -2521,7 +2536,24 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> current_time - self._last_internals_update_time <= 10: return - # --- Try to become the single updater --- + # --- Non-force (reader-triggered) refresh: run it in the BACKGROUND --- + # The view is stale, but a reader (grid/histogram/periodic fetch) must NOT block + # on the multi-second collapse+combine rebuild. Kick a single background refresh + # (the WL-ViewRefresh thread calls force=True, which does the real rebuild+atomic + # swap below) and return immediately — the caller reads the current (last-completed) + # snapshot. If a refresh is already running, just return; the next fetch sees the swap. + if not force: + if self._refresh_in_flight.acquire(blocking=False): + try: + threading.Thread( + target=self._bg_view_refresh, name="WL-ViewRefresh", daemon=True + ).start() + except Exception: + self._refresh_in_flight.release() # never leak the guard + logger.exception("[ViewRefresh] failed to start background refresh") + return + + # --- Try to become the single updater (force path: rebuild inline) --- t_wait_start = time.time() acquired = self._update_lock.acquire(blocking=False) @@ -3361,11 +3393,17 @@ def ApplyDataQuery(self, request, context): # Apply operations with lock with self._watched_lock("_lock[ApplyDataQuery/ops]"): - # If this is JUST a view-sort for a specific page, DO NOT force an internal refresh - # as that wipes out the existing slice/sort state before we try to modify the next slice. - is_only_view_sort = len(operations) == 1 and operations[0].get("function") == "df.sort_view_slice" - if not is_only_view_sort: - self._slowUpdateInternals(force=True) # Refresh internals before applying Agent operations + # Skip the forced full-view rebuild for SORT-ONLY operations. Sorting just + # re-orders the existing snapshot, so a fresh collapse+combine (hundreds of + # ms on large views, and — being lock-held — contends with the training + # thread for multi-second stalls) is unnecessary. Filters/edits still refresh + # so they operate on the latest data. The view is frozen on direct queries + # anyway (_is_filtered=True), so it wasn't auto-refreshing mid-sort regardless. + _SORT_FUNCS = {"df.sort_values", "df.sort_index", "df.sort_view_slice"} + is_sort_only = bool(operations) and all( + op.get("function") in _SORT_FUNCS for op in operations) + if not is_sort_only: + self._slowUpdateInternals(force=True) # Refresh internals before applying non-sort operations # Work on a copy to allow concurrent readers to see a consistent state df = self._all_datasets_df # Remove copy because memory waste and slowdown @@ -3539,6 +3577,71 @@ def GetDataSamples(self, request, context): data_records=[] ) + def GetHistogram(self, request, context): + """Server-side histogram binning of one column (typed RPC). + + Bins the current all-data view by ROW ORDER into <= max_bins equal- + population bins; each bin carries {min,max,avg,count} over its finite + values plus a per-(origin,discarded) sub-bar breakdown. Returns typed + HistogramBin messages (no DataStat name-encoding). Empty bins are emitted + with count=0 and NaN stats so the client's positional bars stay aligned. + """ + try: + column = request.column or "" + max_bins = int(request.max_bins) if request.max_bins > 0 else 512 + df = getattr(self, "_all_datasets_df", None) + if df is None or df.empty: + return pb2.HistogramResponse( + success=False, message="empty dataframe view", total_rows=0, bins=[]) + df = safe_reset_index(df) + n = len(df) + if column not in df.columns: + return pb2.HistogramResponse( + success=False, message=f"column '{column}' not in view", + total_rows=n, bins=[]) + + bars = max(1, min(n, max_bins)) + vals = pd.to_numeric(df[column], errors="coerce").to_numpy() + origin = (df["origin"].astype(str).to_numpy() if "origin" in df.columns + else np.full(n, "")) + disc = (df["discarded"].astype(bool).to_numpy() if "discarded" in df.columns + else np.zeros(n, bool)) + edges = (np.arange(bars + 1) * n) // bars + bin_of_row = np.searchsorted(edges, np.arange(n), side="right") - 1 + fin = np.isfinite(vals) + gf = pd.DataFrame({"b": bin_of_row[fin], "v": vals[fin], + "o": origin[fin], "d": disc[fin]}) + stats = gf.groupby("b")["v"].agg(["min", "max", "mean", "count"]) + sub_by_bin = {} + for (b, d, o), c in gf.groupby(["b", "d", "o"]).size().items(): + sub_by_bin.setdefault(int(b), []).append( + pb2.HistogramSubBar(origin=str(o), discarded=bool(d), count=int(c))) + have = stats.index.to_numpy() + mn, mx, av, cn = (stats["min"].to_numpy(), stats["max"].to_numpy(), + stats["mean"].to_numpy(), stats["count"].to_numpy()) + pos = {int(b): i for i, b in enumerate(have)} + _nan = float("nan") + bins = [] + for b in range(bars): + i = pos.get(b) + if i is None: + bins.append(pb2.HistogramBin( + min=_nan, max=_nan, avg=_nan, count=0, sub_bars=[])) + else: + bins.append(pb2.HistogramBin( + min=float(mn[i]), max=float(mx[i]), avg=float(av[i]), + count=int(cn[i]), sub_bars=sub_by_bin.get(b, []))) + logger.info("[HistBin] column=%s rows=%d bins=%d", column, n, len(bins)) + return pb2.HistogramResponse( + success=True, + message=f"histogram {column}: {len(bins)} bins from {n} rows", + total_rows=n, bins=bins) + except Exception as e: + logger.error("Error in GetHistogram: %s", str(e), exc_info=True) + return pb2.HistogramResponse( + success=False, message=f"histogram failed: {str(e)}", + total_rows=0, bins=[]) + # Streamed chunk size for GetPointCloud (raw float32 bytes per message). _POINT_CLOUD_CHUNK_BYTES = 1 << 20 # 1 MiB diff --git a/weightslab/trainer/trainer_services.py b/weightslab/trainer/trainer_services.py index de0f1ece..76e76f43 100644 --- a/weightslab/trainer/trainer_services.py +++ b/weightslab/trainer/trainer_services.py @@ -334,6 +334,9 @@ def GetDataSamples(self, request, context): logger.debug(f"\nExperimentServiceServicer.GetDataSamples({request})") return self._exp_service.data_service.GetDataSamples(request, context) + def GetHistogram(self, request, context): + logger.debug(f"\nExperimentServiceServicer.GetHistogram({request})") + return self._exp_service.data_service.GetHistogram(request, context) def GetMetaData(self, request, context): logger.debug(f"\nExperimentServiceServicer.GetMetaData({request})") return self._exp_service.data_service.GetMetaData(request, context) From 50df577fe338a9c947114f40cff510df10f3e025 Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 10:25:19 +0200 Subject: [PATCH 08/16] Add explore actions for weightslab, i.e., run weightslab in limited mode in root log dir --- weightslab/__init__.py | 3 +- weightslab/backend/explore_mode.py | 43 ++++ weightslab/src.py | 110 +++++++++ .../tests/integrations/test_explore_mode.py | 210 ++++++++++++++++++ .../services/test_trainer_services_unit.py | 103 +++++++++ .../trainer/services/experiment_service.py | 19 ++ weightslab/trainer/services/model_service.py | 5 + 7 files changed, 492 insertions(+), 1 deletion(-) create mode 100644 weightslab/backend/explore_mode.py create mode 100644 weightslab/tests/integrations/test_explore_mode.py diff --git a/weightslab/__init__.py b/weightslab/__init__.py index af668cc1..ecb5be12 100644 --- a/weightslab/__init__.py +++ b/weightslab/__init__.py @@ -11,7 +11,7 @@ import logging import threading -from .src import watch_or_edit, start_training, serve, keep_serving, save_signals, save_instance_signals, save_group_signals, tag_samples, register_categorical_tag, set_categorical_tag, discard_samples, get_samples_by_tag, get_discarded_samples, signal, eval_fn, compute_signals, SignalContext, clear_all, run_pending_evaluation, trigger_pending_evaluation_async, query_signal_history, query_sample_history, query_instance_history, write_history, write_dataframe, get_current_experiment_hash, pointcloud_thumbnail, pointcloud_boxes +from .src import watch_or_edit, start_training, serve, keep_serving, load_experiment_for_explore, save_signals, save_instance_signals, save_group_signals, tag_samples, register_categorical_tag, set_categorical_tag, discard_samples, get_samples_by_tag, get_discarded_samples, signal, eval_fn, compute_signals, SignalContext, clear_all, run_pending_evaluation, trigger_pending_evaluation_async, query_signal_history, query_sample_history, query_instance_history, write_history, write_dataframe, get_current_experiment_hash, pointcloud_thumbnail, pointcloud_boxes from .backend.ledgers import GLOBAL_LEDGER as ledger from .art import _BANNER from .utils.logs import setup_logging, set_log_directory, is_main_process @@ -76,6 +76,7 @@ "watch_or_edit", "serve", "keep_serving", + "load_experiment_for_explore", "save_signals", "save_instance_signals", "save_group_signals", diff --git a/weightslab/backend/explore_mode.py b/weightslab/backend/explore_mode.py new file mode 100644 index 00000000..4effe651 --- /dev/null +++ b/weightslab/backend/explore_mode.py @@ -0,0 +1,43 @@ +"""Process-wide read-only "explore" mode. + +When the backend is launched to browse a finished experiment loaded from disk +(``weightslab --logdir ``), it runs in *explore mode*: there is no +training loop, and the experiment is reconstructed from the checkpoints/logs on +disk so a user can inspect it in the UI while training continues elsewhere +(e.g. on a cluster). + +In this mode the backend refuses the actions that would mutate the model or the +training run — starting/resuming training, changing hyperparameters, and +loading/restoring/saving weights or checkpoints. Local **data management** +(tagging, discarding, queries, plot notes) and all **reads** stay available, +since the whole point is to manage and explore the data locally. + +This is a simple process-wide flag: a given backend process is either a live +training server or a read-only explorer for its whole lifetime. +""" + +import logging + +logger = logging.getLogger(__name__) + +_EXPLORE_MODE = False + +# Returned by guarded RPC handlers when a forbidden (mutating) action is attempted. +EXPLORE_BLOCKED_MESSAGE = ( + "This experiment is open in read-only explore mode (loaded from --logdir). " + "Training, hyperparameter changes, and weight/checkpoint loading are disabled." +) + + +def set_explore_mode(enabled: bool) -> None: + """Enable/disable the process-wide read-only explore mode.""" + global _EXPLORE_MODE + _EXPLORE_MODE = bool(enabled) + logger.info( + "Explore (read-only) mode %s", "ENABLED" if _EXPLORE_MODE else "disabled" + ) + + +def is_explore_mode() -> bool: + """True when the backend is serving a read-only experiment from disk.""" + return _EXPLORE_MODE diff --git a/weightslab/src.py b/weightslab/src.py index dde692d1..44470efc 100644 --- a/weightslab/src.py +++ b/weightslab/src.py @@ -1113,6 +1113,116 @@ def keep_serving(timeout: int = None, release_gpu: bool = True) -> None: logger.info("Shutting down WeightsLab services.") +def _rehydrate_dataframe_from_disk(root_log_dir) -> list: + """Best-effort: rebuild the data grid from the persisted H5 store so samples + are browsable in explore mode without the original ``Dataset`` object. + + Returns the list of data origins (splits) that were rehydrated. Any failure + is non-fatal — logs/plots still work without the data grid. + """ + from pathlib import Path + import pandas as _pd + from weightslab.backend import ledgers as _ledgers + from weightslab.data.h5_dataframe_store import H5DataFrameStore + from weightslab.data.dataframe_manager import LedgeredDataFrameManager + + data_h5 = Path(str(root_log_dir)) / "checkpoints" / "data" / "data.h5" + if not data_h5.exists(): + logger.info("Explore: no persisted data store at %s; data grid unavailable.", data_h5) + return [] + try: + store = H5DataFrameStore(str(data_h5)) + # Enumerate origins (HDF groups under the store's key prefix). + prefix = f"/{getattr(store, '_key_prefix', 'stats')}_" + origins: list = [] + try: + with _pd.HDFStore(str(data_h5), mode="r") as h5: + origins = sorted({k[len(prefix):] for k in h5.keys() if k.startswith(prefix)}) + except Exception: + logger.debug("Explore: could not enumerate data origins", exc_info=True) + if not origins: + return [] + + # No flush threads in a read-only explorer; data writes (tags/discard) are + # still applied in-memory and persisted on demand. + dfm = LedgeredDataFrameManager(enable_flushing_threads=False, enable_h5_persistence=True) + dfm.set_store(store) + loaded = [] + for origin in origins: + try: + dfm.register_split(origin, _pd.DataFrame(), store, autoload_arrays=False) + loaded.append(origin) + except Exception: + logger.warning("Explore: failed to rehydrate data split '%s'", origin, exc_info=True) + if loaded: + _ledgers.register_dataframe(dfm) + return loaded + except Exception: + logger.warning("Explore: data rehydration from disk failed; logs/plots still available.", exc_info=True) + return [] + + +def load_experiment_for_explore(root_log_dir, exp_hash: str = None) -> dict: + """Load a finished experiment from ``root_log_dir`` into a fresh, read-only ledger. + + Reconstructs hyperparameters, logger history, the checkpoint manager (and, + best-effort, the model and the data grid) purely from disk, then flips the + process into read-only **explore mode** (see + :mod:`weightslab.backend.explore_mode`). No training script, dataset, GPU, or + network is required — intended for browsing a run that is finished or still + training elsewhere (e.g. on a cluster). + + After this returns, start the gRPC server with :func:`serve` (``serving_grpc=True``) + and the UI can read everything while training/HP/weight mutations are refused. + + Args: + root_log_dir: An experiment ``root_log_dir`` produced by a previous run. + exp_hash: Optional specific experiment hash to open (defaults to the latest). + + Returns: + A dict summary: ``{root_log_dir, experiment_hash, has_logger, origins}``. + """ + from pathlib import Path + from weightslab.backend import ledgers as _ledgers + from weightslab.backend.explore_mode import set_explore_mode + from weightslab.components.checkpoint_manager import CheckpointManager + + root = Path(str(root_log_dir)).absolute() + if not root.exists(): + raise FileNotFoundError(f"root_log_dir does not exist: {root}") + + # A read-only explorer must not inherit any live training objects. + _ledgers.clear_all() + + # CheckpointManager.__init__ loads the manifest + logger snapshots (registering + # a logger with the saved history) and bootstraps the latest experiment state + # (config/HP, model best-effort, data snapshot) from disk. + manager = CheckpointManager(str(root)) + _ledgers.register_checkpoint_manager(manager) + + if exp_hash: + try: + manager.load_state(exp_hash) + except Exception: + logger.warning( + "Explore: could not load requested hash %s; using bootstrapped state.", + exp_hash, exc_info=True, + ) + + origins = _rehydrate_dataframe_from_disk(root) + + set_explore_mode(True) + + summary = { + "root_log_dir": str(root), + "experiment_hash": manager.get_current_experiment_hash(), + "has_logger": _ledgers.get_logger() is not None, + "origins": origins, + } + logger.info("Explore mode ready: %s", summary) + return summary + + def signal(name: str, subscribe_to: str = None, compute_every_n_steps: int = 1, **kwargs): """ Decorator that registers a custom signal function. diff --git a/weightslab/tests/integrations/test_explore_mode.py b/weightslab/tests/integrations/test_explore_mode.py new file mode 100644 index 00000000..36daf3ab --- /dev/null +++ b/weightslab/tests/integrations/test_explore_mode.py @@ -0,0 +1,210 @@ +"""Integration test for read-only "explore" mode (``weightslab --logdir``). + +Simulates the real workflow: train a small experiment with weightslab (writing +checkpoints + logger snapshots + the H5 data store to a ``root_log_dir``), then +"kill" the training (clear the ledger, like a fresh process), and finally load +the experiment purely from disk via ``wl.load_experiment_for_explore`` and serve +it read-only. + +Asserts that, after loading: +- the logged history is readable through the real gRPC servicer (the "access the + logs through the UI" requirement); +- the data splits are browsable; +- every mutating action a user must NOT be able to do — start training, change + hyperparameters, load/restore/save weights — is refused, while reads and data + management still work. +""" + +import os +import tempfile +import shutil +import unittest +import warnings + +warnings.filterwarnings("ignore") + +import torch as th +import torch.nn as nn + +import weightslab as wl +import weightslab.proto.experiment_service_pb2 as pb2 +from weightslab.backend import ledgers +from weightslab.backend import explore_mode +from weightslab.components.global_monitoring import ( + guard_training_context, + pause_controller, + start_hp_sync_thread_event, +) +from weightslab.trainer.experiment_context import ExperimentContext +from weightslab.trainer.services.experiment_service import ExperimentService +from weightslab.utils.tools import seed_everything + + +start_hp_sync_thread_event() + + +class _TinyDataset: + """Minimal (data, uid, target) dataset — no downloads, fully synthetic.""" + + def __init__(self, n=8, dim=4, num_classes=3): + g = th.Generator().manual_seed(0) + self._x = th.randn(n, dim, generator=g) + self._y = th.randint(0, num_classes, (n,), generator=g) + + def __len__(self): + return len(self._x) + + def __getitem__(self, idx): + return self._x[idx], th.tensor(idx, dtype=th.long), self._y[idx] + + +class _TinyNet(nn.Module): + def __init__(self, dim=4, num_classes=3): + super().__init__() + self.input_shape = (1, dim) + self.fc = nn.Linear(dim, num_classes) + + def forward(self, x): + return self.fc(x) + + +class ExploreModeTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + seed_everything() + cls.temp_dir = tempfile.mkdtemp(prefix="wl_explore_test_") + cls.root_log_dir = os.path.join(cls.temp_dir, "experiments") + + cls.config = { + "experiment_name": "explore_test", + "device": "cpu", + "root_log_dir": cls.root_log_dir, + "experiment_dump_to_train_steps_ratio": 2, + "data": {"train_loader": {"batch_size": 2, "shuffle": False}}, + "checkpoint_manager": {"dump_model_architecture": True}, + "ledger_enable_flushing_threads": True, + "ledger_enable_h5_persistence": True, + "ledger_flush_max_rows": 2, + "ledger_flush_interval": 1.0, + "serving_grpc": False, + "serving_cli": False, + "optimizer": {"lr": 0.01}, + } + + # ---- Train a small experiment (produces on-disk artifacts) ----------- + pause_controller.pause() + cls.dataset = _TinyDataset() + cls.logger = __import__( + "weightslab.backend.logger", fromlist=["LoggerQueue"] + ).LoggerQueue(register=True) + + cls.config = wl.watch_or_edit( + cls.config, flag="hyperparameters", defaults=cls.config, poll_interval=1.0 + ) + model = wl.watch_or_edit( + _TinyNet(), flag="model", device="cpu", + skip_previous_auto_load=True, compute_dependencies=False, + ) + wl.watch_or_edit( + cls.dataset, flag="data", compute_hash=False, is_training=True, + batch_size=2, shuffle=False, + ) + wl.watch_or_edit( + th.optim.Adam(model.parameters(), lr=0.01), flag="optimizer" + ) + wl.watch_or_edit( + nn.CrossEntropyLoss(reduction="none"), flag="signal", + log=True, name="train/loss", + ) + + cls.chkpt = ledgers.get_checkpoint_manager() + cls.chkpt.update_experiment_hash(first_time=True) + + loader = ledgers.get_dataloader() + optimizer = ledgers.get_optimizer() + criterion = ledgers.get_signal(name="train/loss") + + pause_controller.resume() + for _ in range(6): + with guard_training_context: + inputs, ids, labels = next(loader) + optimizer.zero_grad() + preds_raw = model(inputs) + preds = preds_raw.argmax(dim=1, keepdim=True) + loss = criterion(preds_raw, labels, batch_ids=ids, preds=preds) + loss.mean().backward() + optimizer.step() + pause_controller.pause() + + # Ensure everything is flushed to disk (checkpoints, logger, data). + cls.chkpt.save_model_checkpoint() + cls.chkpt.save_logger_snapshot() + cls.chkpt.save_pending_changes(force=True) + cls.trained_hash = cls.chkpt.get_current_experiment_hash() + + @classmethod + def tearDownClass(cls): + explore_mode.set_explore_mode(False) + ledgers.clear_all() + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + def setUp(self): + # Each test starts from a freshly loaded, read-only explorer (simulates a + # new `weightslab --logdir` process attaching to the killed run). + explore_mode.set_explore_mode(False) + self.summary = wl.load_experiment_for_explore(self.root_log_dir) + self.ctx = ExperimentContext() + self.service = ExperimentService(self.ctx) + + def test_explore_mode_is_enabled_and_experiment_loaded(self): + self.assertTrue(explore_mode.is_explore_mode()) + self.assertTrue(self.summary["has_logger"]) + self.assertIsNotNone(self.summary["experiment_hash"]) + + def test_logger_history_is_readable_through_servicer(self): + resp = self.service.GetLatestLoggerData( + pb2.GetLatestLoggerDataRequest( + request_full_history=True, max_points=1000, break_by_slices=False + ), + None, + ) + # The training above logged "train/loss" each step; it must survive the + # save→fresh-process→load round trip and be visible in the UI. + self.assertGreater(len(resp.points), 0) + + def test_data_is_rehydrated_from_disk(self): + # The persisted H5 data store is rebuilt into the ledger so the sample + # grid is browsable without the original Dataset object. (The split name + # is auto-derived from the dataset, so we don't assert a specific name.) + self.assertTrue(self.summary["origins"], "expected at least one data split") + dfm = ledgers.get_dataframe() + self.assertIsNotNone(dfm) + self.assertEqual(len(dfm.get_df_view()), len(self.dataset)) + + def test_blocks_training_start(self): + resp = self.service.ExperimentCommand( + pb2.TrainerCommand( + hyper_parameter_change=pb2.HyperParameterCommand( + hyper_parameters=pb2.HyperParameters(is_training=True) + ) + ), + None, + ) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_weight_restore(self): + resp = self.service.RestoreCheckpoint( + pb2.RestoreCheckpointRequest(experiment_hash=self.trained_hash), None + ) + self.assertFalse(resp.success) + + def test_reads_still_work(self): + resp = self.service.ExperimentCommand( + pb2.TrainerCommand(get_hyper_parameters=True), None + ) + self.assertTrue(resp.success) + + +if __name__ == "__main__": + unittest.main() diff --git a/weightslab/tests/trainer/services/test_trainer_services_unit.py b/weightslab/tests/trainer/services/test_trainer_services_unit.py index 3333489e..9d9676d8 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_unit.py +++ b/weightslab/tests/trainer/services/test_trainer_services_unit.py @@ -511,5 +511,108 @@ def test_manual_save_data_state_force_enables_h5_and_flushes(self): self.assertTrue(checkpoint_manager.save_data_snapshot.called) +class TestExploreModeGuards(unittest.TestCase): + """Read-only explore mode: training/HP/weight mutations are refused; data + management and reads stay allowed.""" + + def setUp(self): + from weightslab.backend import explore_mode + self._explore_mode = explore_mode + explore_mode.set_explore_mode(True) + self.addCleanup(lambda: explore_mode.set_explore_mode(False)) + + def _experiment_service(self, components=None): + ctx = _DummyCtx(components=components or {}) + with patch("weightslab.trainer.services.experiment_service.DataService"): + return ExperimentService(ctx) + + def test_explore_mode_off_by_default(self): + # Sanity: the cleanup of other tests must restore the disabled state. + self._explore_mode.set_explore_mode(False) + self.assertFalse(self._explore_mode.is_explore_mode()) + self._explore_mode.set_explore_mode(True) + self.assertTrue(self._explore_mode.is_explore_mode()) + + def test_blocks_hyperparameter_change(self): + service = self._experiment_service() + request = pb2.TrainerCommand( + hyper_parameter_change=pb2.HyperParameterCommand( + hyper_parameters=pb2.HyperParameters(is_training=True) + ) + ) + resp = service.ExperimentCommand(request, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_save_checkpoint(self): + service = self._experiment_service() + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation(save_optimizer=True) + ) + resp = service.ExperimentCommand(request, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_load_checkpoint(self): + service = self._experiment_service() + request = pb2.TrainerCommand( + load_checkpoint_operation=pb2.LoadCheckpointOperation(checkpoint_id=3) + ) + resp = service.ExperimentCommand(request, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_restore_checkpoint(self): + service = self._experiment_service() + resp = service.RestoreCheckpoint( + pb2.RestoreCheckpointRequest(experiment_hash="abc"), None + ) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_trigger_evaluation(self): + service = self._experiment_service() + resp = service.TriggerEvaluation( + pb2.TriggerEvaluationRequest(split_name="test"), None + ) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_manipulate_weights(self): + ctx = _DummyCtx(components={"model": MagicMock()}) + service = ModelService(ctx) + req = pb2.WeightsOperationRequest() + req.weight_operation.op_type = pb2.WeightOperationType.FREEZE + resp = service.ManipulateWeights(req, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_allows_plot_note_data_annotation(self): + # Data/annotation writes (plot notes) are NOT blocked in explore mode. + signal_logger = MagicMock() + signal_logger.set_point_note.return_value = True + service = self._experiment_service(components={"signal_logger": signal_logger}) + + request = pb2.TrainerCommand( + plot_note_operation=pb2.PlotNoteOperation( + metric_name="train/loss", + experiment_hash="abcdef", + model_age=5, + note="checkpoint of interest", + ) + ) + resp = service.ExperimentCommand(request, None) + self.assertTrue(resp.success) + signal_logger.set_point_note.assert_called_once() + + def test_allows_reads(self): + # Read requests are unaffected by explore mode. + service = self._experiment_service() + resp = service.ExperimentCommand( + pb2.TrainerCommand(get_hyper_parameters=True), None + ) + self.assertTrue(resp.success) + + if __name__ == "__main__": unittest.main() diff --git a/weightslab/trainer/services/experiment_service.py b/weightslab/trainer/services/experiment_service.py index 4d32323b..8f838469 100644 --- a/weightslab/trainer/services/experiment_service.py +++ b/weightslab/trainer/services/experiment_service.py @@ -14,6 +14,7 @@ from weightslab.backend.ledgers import set_hyperparam, list_hyperparams, resolve_hp_name, get_hyperparams from weightslab.backend import ledgers from weightslab.backend.audit_logger import AuditLogger +from weightslab.backend.explore_mode import is_explore_mode, EXPLORE_BLOCKED_MESSAGE from weightslab.trainer.services.model_service import ModelService from weightslab.trainer.services.data_service import DataService from weightslab.trainer.services.agent_service import AgentService @@ -364,6 +365,8 @@ def RestoreCheckpoint(self, request, context): - Calls checkpoint manager to load the state - Returns success flag and message """ + if is_explore_mode(): + return pb2.RestoreCheckpointResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) try: raw_experiment_hash = request.experiment_hash experiment_hash = raw_experiment_hash @@ -477,6 +480,10 @@ def TriggerEvaluation(self, request, context): This stores the evaluation request in the global eval_controller. The actual pass runs in the training thread via ``run_pending_evaluation()``. """ + # No training thread runs in explore mode, so evaluation can't execute. + if is_explore_mode(): + return pb2.TriggerEvaluationResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) + split_name = request.split_name or "" tags = list(request.tags) if request.tags else [] use_full_set = bool(request.use_full_set) @@ -730,6 +737,18 @@ def ExperimentCommand(self, request, context): self._ctx.ensure_components() components = self._ctx.components + # Read-only explore mode: refuse the mutating commands that would change + # the model or the training run. Data management (deny/tag operations, + # plot notes) and all read requests below stay allowed. + if is_explore_mode(): + for forbidden in ( + "hyper_parameter_change", + "save_checkpoint_operation", + "load_checkpoint_operation", + ): + if request.HasField(forbidden): + return pb2.CommandResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) + # Write requests if request.HasField("save_checkpoint_operation"): return self._handle_save_checkpoint( diff --git a/weightslab/trainer/services/model_service.py b/weightslab/trainer/services/model_service.py index bca63e66..5b9aac9c 100644 --- a/weightslab/trainer/services/model_service.py +++ b/weightslab/trainer/services/model_service.py @@ -10,6 +10,7 @@ from weightslab.trainer.trainer_tools import process_sample, _get_input_tensor_for_sample from weightslab.modules.neuron_ops import ArchitectureNeuronsOpType from weightslab.components.global_monitoring import weightslab_rlock, try_acquire_rlock, _GRPC_LOCK_TIMEOUT_S +from weightslab.backend.explore_mode import is_explore_mode, EXPLORE_BLOCKED_MESSAGE logger = logging.getLogger(__name__) @@ -311,6 +312,10 @@ def hook(mod, inp, out): # Weight manipulation (architecture operations) # ------------------------------------------------------------------------- def ManipulateWeights(self, request, context): + # Read-only explore mode: architecture/weight edits are disabled. + if is_explore_mode(): + return pb2.WeightsOperationResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) + self._ctx.ensure_components() components = self._ctx.components From 082b201d1db88c14b81c7ba3148d7e0fd7f57650 Mon Sep 17 00:00:00 2001 From: AlexGrayBox Date: Mon, 22 Jun 2026 11:51:10 +0200 Subject: [PATCH 09/16] Server-side histogram binning + grid/sort perf fixes (#221) * Server-side histogram binning + grid/sort perf fixes - GetHistogram RPC: bin one column server-side into <=512 typed bins (min/max/avg/count + per-(origin,discarded) sub-bars) instead of the client pulling every row and binning in the browser. Bit-identical to the client binning; ~116x smaller payload, ~50ms warm. Adds the proto messages + RPC, regenerated pb2/pb2_grpc, DataService.GetHistogram, and servicer delegation. - ApplyDataQuery: skip the forced full-view rebuild for SORT-ONLY operations (a sort just re-orders the existing snapshot). Global sort ~7.5s -> ~0.5s. - _slowUpdateInternals: run the view rebuild on a background thread for reader-triggered (non-force) refreshes, so grid/histogram reads never block on the multi-second collapse+combine. Reader p95 ~3000ms -> ~130ms. Filters/ resets still refresh inline (need fresh data). - ws-classification example: loosen eval (100->500) / checkpoint (25->250) cadence and use a bigger eval batch (16->128) so eval stops dominating wall-clock. Co-Authored-By: Claude Opus 4.8 * Fix proto files and is/is not ValueProxy from ledgers to ==/!= --------- Co-authored-by: Alexandru Rotaru Co-authored-by: Claude Opus 4.8 Co-authored-by: Guillaume --- .../PyTorch/ws-classification/main.py | 8 +- .../PyTorch/ws-detection/utils/data.py | 2 +- .../PyTorch/ws-segmentation/utils/data.py | 2 +- .../examples/Ultralytics/ws-detection/main.py | 2 +- .../ws-2d-lidar-detection/utils/data.py | 2 +- .../ws-3d-lidar-detection/utils/data.py | 2 +- weightslab/proto/experiment_service.proto | 2 + weightslab/proto/experiment_service_pb2.py | 247 +++++------------- .../proto/experiment_service_pb2_grpc.py | 40 ++- 9 files changed, 116 insertions(+), 191 deletions(-) diff --git a/weightslab/examples/PyTorch/ws-classification/main.py b/weightslab/examples/PyTorch/ws-classification/main.py index 8f2ec0d0..8dc53e78 100644 --- a/weightslab/examples/PyTorch/ws-classification/main.py +++ b/weightslab/examples/PyTorch/ws-classification/main.py @@ -89,7 +89,7 @@ def _build_filepath_mapping(self): # For each index, construct a meaningful filepath # MNIST doesn't have original individual files, so we create virtual paths for idx in range(len(self.mnist)): - if self.max_samples is not None and idx >= self.max_samples: + if self.max_samples != None and idx >= self.max_samples: break label = self.mnist.targets[idx].item() if hasattr(self.mnist.targets[idx], 'item') else self.mnist.targets[idx] split = 'train' if self.train else 'test' @@ -105,7 +105,7 @@ def _build_filepath_mapping(self): self.filepaths[idx] = virtual_path def __len__(self): - if self.max_samples is not None: + if self.max_samples != None: return min(len(self.mnist), self.max_samples) return len(self.mnist) @@ -371,7 +371,7 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): # Setup clean progress bar with custom format if tqdm_display: train_range = tqdm.tqdm( - range(training_steps_to_do) if training_steps_to_do is not None else itertools.count(), + range(training_steps_to_do) if training_steps_to_do != None else itertools.count(), desc="Training", bar_format="{desc}: {n}/{total} [{elapsed}<{remaining}, {rate_fmt}] {bar} | {postfix}", ncols=140, @@ -379,7 +379,7 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): leave=True ) else: - train_range = range(training_steps_to_do) if training_steps_to_do is not None else itertools.count() + train_range = range(training_steps_to_do) if training_steps_to_do != None else itertools.count() # ================ # Training Loop diff --git a/weightslab/examples/PyTorch/ws-detection/utils/data.py b/weightslab/examples/PyTorch/ws-detection/utils/data.py index 93e60dd1..c4823197 100644 --- a/weightslab/examples/PyTorch/ws-detection/utils/data.py +++ b/weightslab/examples/PyTorch/ws-detection/utils/data.py @@ -131,7 +131,7 @@ def __init__( val_set = set(all_imgs[::k]) selected = [f for f in all_imgs if f not in val_set] - selected = selected[:max_samples] if max_samples is not None else selected + selected = selected[:max_samples] if max_samples != None else selected self.images = [] self.masks = [] diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py index 864c6853..28ba4dde 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py @@ -51,7 +51,7 @@ def __init__( for f in os.listdir(img_dir) if f.lower().endswith((".jpg", ".jpeg", ".png")) ] - image_files = sorted(set(image_files))[:max_samples] if max_samples is not None else sorted(set(image_files)) # Optionally limit number of samples for faster testing + image_files = sorted(set(image_files))[:max_samples] if max_samples != None else sorted(set(image_files)) # Optionally limit number of samples for faster testing self.images = [] self.masks = [] diff --git a/weightslab/examples/Ultralytics/ws-detection/main.py b/weightslab/examples/Ultralytics/ws-detection/main.py index d2ba1a61..ea1876e3 100644 --- a/weightslab/examples/Ultralytics/ws-detection/main.py +++ b/weightslab/examples/Ultralytics/ws-detection/main.py @@ -64,7 +64,7 @@ def main(): trainer=WLAwareTrainer, data=data_root, imgsz=image_size, - epochs=1000 if max_steps is None else max(1, int(max_steps)), + epochs=1000 if max_steps == None else max(1, int(max_steps)), device=device, project=project, name=name, # → UL save_dir → WL logger log_dir/name resume=False, diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py index 10e19736..fd544bf3 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py @@ -116,7 +116,7 @@ def __init__( else: val_set = set(frames[::k]) selected = [f for f in frames if f not in val_set] - self.frames = selected[:max_samples] if max_samples is not None else selected + self.frames = selected[:max_samples] if max_samples != None else selected def __len__(self): return len(self.frames) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py index 77df470b..45a15a0a 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py @@ -446,7 +446,7 @@ def __init__( else: val_set = set(frames[::k]) selected = [f for f in frames if f not in val_set] - self.frames = selected[:max_samples] if max_samples is not None else selected + self.frames = selected[:max_samples] if max_samples != None else selected if len(self.frames) == 0: raise RuntimeError(f"No LiDAR frames found (source={source}, root={root})") diff --git a/weightslab/proto/experiment_service.proto b/weightslab/proto/experiment_service.proto index 0165a56b..770965fd 100644 --- a/weightslab/proto/experiment_service.proto +++ b/weightslab/proto/experiment_service.proto @@ -426,6 +426,8 @@ message HistogramResponse { string message = 2; int64 total_rows = 3; // rows in the view that were binned repeated HistogramBin bins = 4; +} + // --- Metadata retrieval (separated from GetDataSamples) --- message GetMetaDataRequest { int32 start_index = 1; // grid slice start (current view order) diff --git a/weightslab/proto/experiment_service_pb2.py b/weightslab/proto/experiment_service_pb2.py index d1f48342..6a521ac8 100644 --- a/weightslab/proto/experiment_service_pb2.py +++ b/weightslab/proto/experiment_service_pb2.py @@ -2,7 +2,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: weightslab/proto/experiment_service.proto -# Protobuf Python Version: 6.33.5 +# Protobuf Python Version: 6.31.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -12,8 +12,8 @@ _runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.Domain.PUBLIC, 6, - 33, - 5, + 31, + 1, '', 'weightslab/proto/experiment_service.proto' ) @@ -24,8 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"\xdc\x06\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"C\n\x0fHistogramSubBar\x12\x0e\n\x06origin\x18\x01 \x01(\t\x12\x11\n\tdiscarded\x18\x02 \x01(\x08\x12\r\n\x05\x63ount\x18\x03 \x01(\x03\"h\n\x0cHistogramBin\x12\x0b\n\x03min\x18\x01 \x01(\x01\x12\x0b\n\x03max\x18\x02 \x01(\x01\x12\x0b\n\x03\x61vg\x18\x03 \x01(\x01\x12\r\n\x05\x63ount\x18\x04 \x01(\x03\x12\"\n\x08sub_bars\x18\x05 \x03(\x0b\x32\x10.HistogramSubBar\"4\n\x10HistogramRequest\x12\x0e\n\x06\x63olumn\x18\x01 \x01(\t\x12\x10\n\x08max_bins\x18\x02 \x01(\x05\"f\n\x11HistogramResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\ntotal_rows\x18\x03 \x01(\x03\x12\x1b\n\x04\x62ins\x18\x04 \x03(\x0b\x32\r.HistogramBin\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xbb\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x35\n\x0cGetHistogram\x12\x11.HistogramRequest\x1a\x12.HistogramResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"W\n\x12GetMetaDataRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12\x17\n\x0fmodal_sample_id\x18\x03 \x01(\t\"\x99\x01\n\x13GetMetaDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1a\n\x12\x61ll_metadata_names\x18\x03 \x03(\t\x12!\n\x0cgrid_records\x18\x04 \x03(\x0b\x32\x0b.DataRecord\x12!\n\x0cmodal_record\x18\x05 \x01(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xbe\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x38\n\x0bGetMetaData\x12\x13.GetMetaDataRequest\x1a\x14.GetMetaDataResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"C\n\x0fHistogramSubBar\x12\x0e\n\x06origin\x18\x01 \x01(\t\x12\x11\n\tdiscarded\x18\x02 \x01(\x08\x12\r\n\x05\x63ount\x18\x03 \x01(\x03\"h\n\x0cHistogramBin\x12\x0b\n\x03min\x18\x01 \x01(\x01\x12\x0b\n\x03max\x18\x02 \x01(\x01\x12\x0b\n\x03\x61vg\x18\x03 \x01(\x01\x12\r\n\x05\x63ount\x18\x04 \x01(\x03\x12\"\n\x08sub_bars\x18\x05 \x03(\x0b\x32\x10.HistogramSubBar\"4\n\x10HistogramRequest\x12\x0e\n\x06\x63olumn\x18\x01 \x01(\t\x12\x10\n\x08max_bins\x18\x02 \x01(\x05\"f\n\x11HistogramResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\ntotal_rows\x18\x03 \x01(\x03\x12\x1b\n\x04\x62ins\x18\x04 \x03(\x0b\x32\r.HistogramBin\"W\n\x12GetMetaDataRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12\x17\n\x0fmodal_sample_id\x18\x03 \x01(\t\"\x99\x01\n\x13GetMetaDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1a\n\x12\x61ll_metadata_names\x18\x03 \x03(\t\x12!\n\x0cgrid_records\x18\x04 \x03(\x0b\x32\x0b.DataRecord\x12!\n\x0cmodal_record\x18\x05 \x01(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xf5\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x35\n\x0cGetHistogram\x12\x11.HistogramRequest\x1a\x12.HistogramResponse\x12\x38\n\x0bGetMetaData\x12\x13.GetMetaDataRequest\x1a\x14.GetMetaDataResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -38,26 +37,16 @@ _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_options = b'8\001' _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._loaded_options = None _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_options = b'8\001' - _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9145 - _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9245 - _globals['_ZEROFYPREDICATE']._serialized_start=9247 - _globals['_ZEROFYPREDICATE']._serialized_end=9358 - _globals['_AGENTINTENTTYPE']._serialized_start=9360 - _globals['_AGENTINTENTTYPE']._serialized_end=9437 - _globals['_SAMPLEEDITTYPE']._serialized_start=9439 - _globals['_SAMPLEEDITTYPE']._serialized_end=9512 - _globals['_AGENTPROVIDERTYPE']._serialized_start=9514 - _globals['_AGENTPROVIDERTYPE']._serialized_end=9558 - _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9231 - _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9331 - _globals['_ZEROFYPREDICATE']._serialized_start=9333 - _globals['_ZEROFYPREDICATE']._serialized_end=9444 - _globals['_AGENTINTENTTYPE']._serialized_start=9446 - _globals['_AGENTINTENTTYPE']._serialized_end=9523 - _globals['_SAMPLEEDITTYPE']._serialized_start=9525 - _globals['_SAMPLEEDITTYPE']._serialized_end=9598 - _globals['_AGENTPROVIDERTYPE']._serialized_start=9600 - _globals['_AGENTPROVIDERTYPE']._serialized_end=9644 + _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9564 + _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9664 + _globals['_ZEROFYPREDICATE']._serialized_start=9666 + _globals['_ZEROFYPREDICATE']._serialized_end=9777 + _globals['_AGENTINTENTTYPE']._serialized_start=9779 + _globals['_AGENTINTENTTYPE']._serialized_end=9856 + _globals['_SAMPLEEDITTYPE']._serialized_start=9858 + _globals['_SAMPLEEDITTYPE']._serialized_end=9931 + _globals['_AGENTPROVIDERTYPE']._serialized_start=9933 + _globals['_AGENTPROVIDERTYPE']._serialized_end=9977 _globals['_GETLATESTLOGGERDATAREQUEST']._serialized_start=46 _globals['_GETLATESTLOGGERDATAREQUEST']._serialized_end=183 _globals['_LOGGERDATAPOINT']._serialized_start=186 @@ -92,110 +81,6 @@ _globals['_LOADCHECKPOINTOPERATION']._serialized_end=2367 _globals['_PLOTNOTEOPERATION']._serialized_start=2369 _globals['_PLOTNOTEOPERATION']._serialized_end=2467 - _globals['_TRAINERCOMMAND']._serialized_start=2470 - _globals['_TRAINERCOMMAND']._serialized_end=3330 - _globals['_HYPERPARAMETERDESC']._serialized_start=3333 - _globals['_HYPERPARAMETERDESC']._serialized_end=3490 - _globals['_NEURONSTATISTICS']._serialized_start=3493 - _globals['_NEURONSTATISTICS']._serialized_end=3863 - _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_start=3722 - _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_end=3771 - _globals['_LAYERREPRESENTATION']._serialized_start=3866 - _globals['_LAYERREPRESENTATION']._serialized_end=4234 - _globals['_ACTIVATIONREQUEST']._serialized_start=4236 - _globals['_ACTIVATIONREQUEST']._serialized_end=4308 - _globals['_ACTIVATIONMAP']._serialized_start=4310 - _globals['_ACTIVATIONMAP']._serialized_end=4382 - _globals['_ACTIVATIONRESPONSE']._serialized_start=4384 - _globals['_ACTIVATIONRESPONSE']._serialized_end=4484 - _globals['_TASKFIELD']._serialized_start=4487 - _globals['_TASKFIELD']._serialized_end=4634 - _globals['_RECORDMETADATA']._serialized_start=4637 - _globals['_RECORDMETADATA']._serialized_end=4969 - _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_start=4916 - _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_end=4969 - _globals['_SAMPLESTATISTICS']._serialized_start=4972 - _globals['_SAMPLESTATISTICS']._serialized_end=5119 - _globals['_COMMANDRESPONSE']._serialized_start=5122 - _globals['_COMMANDRESPONSE']._serialized_end=5352 - _globals['_SAMPLEREQUEST']._serialized_start=5354 - _globals['_SAMPLEREQUEST']._serialized_end=5439 - _globals['_SAMPLEREQUESTRESPONSE']._serialized_start=5442 - _globals['_SAMPLEREQUESTRESPONSE']._serialized_end=5743 - _globals['_BATCHSAMPLEREQUEST']._serialized_start=5746 - _globals['_BATCHSAMPLEREQUEST']._serialized_end=5892 - _globals['_BATCHSAMPLERESPONSE']._serialized_start=5894 - _globals['_BATCHSAMPLERESPONSE']._serialized_end=5956 - _globals['_WEIGHTSREQUEST']._serialized_start=5958 - _globals['_WEIGHTSREQUEST']._serialized_end=6004 - _globals['_WEIGHTSRESPONSE']._serialized_start=6007 - _globals['_WEIGHTSRESPONSE']._serialized_end=6292 - _globals['_DATAQUERYREQUEST']._serialized_start=6294 - _globals['_DATAQUERYREQUEST']._serialized_end=6376 - _globals['_CATEGORICALTAGDEF']._serialized_start=6378 - _globals['_CATEGORICALTAGDEF']._serialized_end=6431 - _globals['_DATAQUERYRESPONSE']._serialized_start=6434 - _globals['_DATAQUERYRESPONSE']._serialized_end=6731 - _globals['_DATASAMPLESREQUEST']._serialized_start=6734 - _globals['_DATASAMPLESREQUEST']._serialized_end=6928 - _globals['_DATASTAT']._serialized_start=6930 - _globals['_DATASTAT']._serialized_end=7039 - _globals['_DATARECORD']._serialized_start=7041 - _globals['_DATARECORD']._serialized_end=7103 - _globals['_DATASAMPLESRESPONSE']._serialized_start=7105 - _globals['_DATASAMPLESRESPONSE']._serialized_end=7195 - _globals['_HISTOGRAMSUBBAR']._serialized_start=7197 - _globals['_HISTOGRAMSUBBAR']._serialized_end=7264 - _globals['_HISTOGRAMBIN']._serialized_start=7266 - _globals['_HISTOGRAMBIN']._serialized_end=7370 - _globals['_HISTOGRAMREQUEST']._serialized_start=7372 - _globals['_HISTOGRAMREQUEST']._serialized_end=7424 - _globals['_HISTOGRAMRESPONSE']._serialized_start=7426 - _globals['_HISTOGRAMRESPONSE']._serialized_end=7528 - _globals['_POINTCLOUDREQUEST']._serialized_start=7530 - _globals['_POINTCLOUDREQUEST']._serialized_end=7604 - _globals['_POINTCLOUDCHUNK']._serialized_start=7607 - _globals['_POINTCLOUDCHUNK']._serialized_end=7798 - _globals['_DATAEDITSREQUEST']._serialized_start=7801 - _globals['_DATAEDITSREQUEST']._serialized_end=8021 - _globals['_DATAEDITSRESPONSE']._serialized_start=8023 - _globals['_DATAEDITSRESPONSE']._serialized_end=8076 - _globals['_DATASPLITSRESPONSE']._serialized_start=8078 - _globals['_DATASPLITSRESPONSE']._serialized_end=8136 - _globals['_AGENTHEALTHRESPONSE']._serialized_start=8138 - _globals['_AGENTHEALTHRESPONSE']._serialized_end=8195 - _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8197 - _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8291 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8293 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8352 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8354 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8394 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8396 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8456 - _globals['_GETAGENTMODELSREQUEST']._serialized_start=8458 - _globals['_GETAGENTMODELSREQUEST']._serialized_end=8481 - _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8483 - _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8557 - _globals['_RESETAGENTRESPONSE']._serialized_start=8559 - _globals['_RESETAGENTRESPONSE']._serialized_end=8613 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8615 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8666 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8668 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8729 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8731 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8813 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8815 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8876 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8878 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8906 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8909 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=9038 - _globals['_CANCELEVALUATIONREQUEST']._serialized_start=9040 - _globals['_CANCELEVALUATIONREQUEST']._serialized_end=9081 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=9083 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=9143 - _globals['_EXPERIMENTSERVICE']._serialized_start=9561 - _globals['_EXPERIMENTSERVICE']._serialized_end=10900 _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2469 _globals['_SAVECHECKPOINTOPERATION']._serialized_end=2545 _globals['_TRAINERCOMMAND']._serialized_start=2548 @@ -250,52 +135,60 @@ _globals['_DATARECORD']._serialized_end=7277 _globals['_DATASAMPLESRESPONSE']._serialized_start=7279 _globals['_DATASAMPLESRESPONSE']._serialized_end=7369 - _globals['_GETMETADATAREQUEST']._serialized_start=7371 - _globals['_GETMETADATAREQUEST']._serialized_end=7458 - _globals['_GETMETADATARESPONSE']._serialized_start=7461 - _globals['_GETMETADATARESPONSE']._serialized_end=7614 - _globals['_POINTCLOUDREQUEST']._serialized_start=7616 - _globals['_POINTCLOUDREQUEST']._serialized_end=7690 - _globals['_POINTCLOUDCHUNK']._serialized_start=7693 - _globals['_POINTCLOUDCHUNK']._serialized_end=7884 - _globals['_DATAEDITSREQUEST']._serialized_start=7887 - _globals['_DATAEDITSREQUEST']._serialized_end=8107 - _globals['_DATAEDITSRESPONSE']._serialized_start=8109 - _globals['_DATAEDITSRESPONSE']._serialized_end=8162 - _globals['_DATASPLITSRESPONSE']._serialized_start=8164 - _globals['_DATASPLITSRESPONSE']._serialized_end=8222 - _globals['_AGENTHEALTHRESPONSE']._serialized_start=8224 - _globals['_AGENTHEALTHRESPONSE']._serialized_end=8281 - _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8283 - _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8377 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8379 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8438 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8440 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8480 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8482 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8542 - _globals['_GETAGENTMODELSREQUEST']._serialized_start=8544 - _globals['_GETAGENTMODELSREQUEST']._serialized_end=8567 - _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8569 - _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8643 - _globals['_RESETAGENTRESPONSE']._serialized_start=8645 - _globals['_RESETAGENTRESPONSE']._serialized_end=8699 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8701 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8752 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8754 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8815 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8817 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8899 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8901 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8962 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8964 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8992 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8995 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=9124 - _globals['_CANCELEVALUATIONREQUEST']._serialized_start=9126 - _globals['_CANCELEVALUATIONREQUEST']._serialized_end=9167 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=9169 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=9229 - _globals['_EXPERIMENTSERVICE']._serialized_start=9647 - _globals['_EXPERIMENTSERVICE']._serialized_end=10989 + _globals['_HISTOGRAMSUBBAR']._serialized_start=7371 + _globals['_HISTOGRAMSUBBAR']._serialized_end=7438 + _globals['_HISTOGRAMBIN']._serialized_start=7440 + _globals['_HISTOGRAMBIN']._serialized_end=7544 + _globals['_HISTOGRAMREQUEST']._serialized_start=7546 + _globals['_HISTOGRAMREQUEST']._serialized_end=7598 + _globals['_HISTOGRAMRESPONSE']._serialized_start=7600 + _globals['_HISTOGRAMRESPONSE']._serialized_end=7702 + _globals['_GETMETADATAREQUEST']._serialized_start=7704 + _globals['_GETMETADATAREQUEST']._serialized_end=7791 + _globals['_GETMETADATARESPONSE']._serialized_start=7794 + _globals['_GETMETADATARESPONSE']._serialized_end=7947 + _globals['_POINTCLOUDREQUEST']._serialized_start=7949 + _globals['_POINTCLOUDREQUEST']._serialized_end=8023 + _globals['_POINTCLOUDCHUNK']._serialized_start=8026 + _globals['_POINTCLOUDCHUNK']._serialized_end=8217 + _globals['_DATAEDITSREQUEST']._serialized_start=8220 + _globals['_DATAEDITSREQUEST']._serialized_end=8440 + _globals['_DATAEDITSRESPONSE']._serialized_start=8442 + _globals['_DATAEDITSRESPONSE']._serialized_end=8495 + _globals['_DATASPLITSRESPONSE']._serialized_start=8497 + _globals['_DATASPLITSRESPONSE']._serialized_end=8555 + _globals['_AGENTHEALTHRESPONSE']._serialized_start=8557 + _globals['_AGENTHEALTHRESPONSE']._serialized_end=8614 + _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8616 + _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8710 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8712 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8771 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8773 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8813 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8815 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8875 + _globals['_GETAGENTMODELSREQUEST']._serialized_start=8877 + _globals['_GETAGENTMODELSREQUEST']._serialized_end=8900 + _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8902 + _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8976 + _globals['_RESETAGENTRESPONSE']._serialized_start=8978 + _globals['_RESETAGENTRESPONSE']._serialized_end=9032 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=9034 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=9085 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=9087 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=9148 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=9150 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=9232 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=9234 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=9295 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=9297 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=9325 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=9328 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=9457 + _globals['_CANCELEVALUATIONREQUEST']._serialized_start=9459 + _globals['_CANCELEVALUATIONREQUEST']._serialized_end=9500 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=9502 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=9562 + _globals['_EXPERIMENTSERVICE']._serialized_start=9980 + _globals['_EXPERIMENTSERVICE']._serialized_end=11377 # @@protoc_insertion_point(module_scope) diff --git a/weightslab/proto/experiment_service_pb2_grpc.py b/weightslab/proto/experiment_service_pb2_grpc.py index ba9759d1..e75d391c 100644 --- a/weightslab/proto/experiment_service_pb2_grpc.py +++ b/weightslab/proto/experiment_service_pb2_grpc.py @@ -5,7 +5,7 @@ from weightslab.proto import experiment_service_pb2 as weightslab_dot_proto_dot_experiment__service__pb2 -GRPC_GENERATED_VERSION = '1.81.1' +GRPC_GENERATED_VERSION = '1.76.0' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -25,7 +25,7 @@ ) -class ExperimentServiceStub: +class ExperimentServiceStub(object): """Missing associated documentation comment in .proto file.""" def __init__(self, channel): @@ -78,6 +78,7 @@ def __init__(self, channel): '/ExperimentService/GetHistogram', request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString, response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString, + _registered_method=True) self.GetMetaData = channel.unary_unary( '/ExperimentService/GetMetaData', request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, @@ -145,7 +146,7 @@ def __init__(self, channel): _registered_method=True) -class ExperimentServiceServicer: +class ExperimentServiceServicer(object): """Missing associated documentation comment in .proto file.""" def GetLatestLoggerData(self, request, context): @@ -199,6 +200,11 @@ def GetDataSamples(self, request, context): def GetHistogram(self, request, context): """Server-side histogram binning of one metadata/signal column. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def GetMetaData(self, request, context): """Metadata-only retrieval (dataframe columns). Returns every metadata column name for the WHOLE dataset, the current grid slice's per-sample metadata, and @@ -332,6 +338,7 @@ def add_ExperimentServiceServicer_to_server(servicer, server): servicer.GetHistogram, request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.FromString, response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.SerializeToString, + ), 'GetMetaData': grpc.unary_unary_rpc_method_handler( servicer.GetMetaData, request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.FromString, @@ -405,7 +412,7 @@ def add_ExperimentServiceServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. -class ExperimentService: +class ExperimentService(object): """Missing associated documentation comment in .proto file.""" @staticmethod @@ -626,7 +633,6 @@ def GetDataSamples(request, @staticmethod def GetHistogram(request, - def GetMetaData(request, target, options=(), channel_credentials=None, @@ -642,6 +648,30 @@ def GetMetaData(request, '/ExperimentService/GetHistogram', weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString, weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetMetaData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, '/ExperimentService/GetMetaData', weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.FromString, From 9d143777d0a390732382bc2cc140736b2c97daa5 Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 14:11:18 +0200 Subject: [PATCH 10/16] Fix UL YAMML and JSON export with custom Cpython functions of ledgered parameters --- weightslab/backend/ledgers.py | 79 +++++++++++++++++-- .../Ultralytics/ws-detection/config.yaml | 2 +- .../examples/Ultralytics/ws-detection/main.py | 14 ++-- weightslab/tests/backend/test_ledgers.py | 51 +++++++++--- 4 files changed, 122 insertions(+), 24 deletions(-) diff --git a/weightslab/backend/ledgers.py b/weightslab/backend/ledgers.py index a1b4e1ee..f93d56ea 100644 --- a/weightslab/backend/ledgers.py +++ b/weightslab/backend/ledgers.py @@ -252,11 +252,7 @@ def __int__(self) -> int: return int(self._resolve()) def __float__(self) -> float: - v = self._resolve() - try: - return float(v) - except (TypeError, ValueError): - return None + return float(self._resolve()) def __index__(self) -> int: v = self._resolve() @@ -452,14 +448,19 @@ def get(self, ref=None, default=None, proxy: bool = True) -> Any: except for ``str`` values, which are returned raw (see below). """ if ref is not None: - if proxy: + value = self._obj.get(ref, default) if hasattr(self._obj, "get") else default + # Avoid wrapping a python list, dict, or any callable in a live proxy — + # only "simple" values become live key proxies. (The previous form put + # `not callable(...)` INSIDE the isinstance group, so callables — which + # are not lists/dicts — slipped through and got proxied.) + if proxy and not (isinstance(value, (list, dict)) or callable(value)): vp = Proxy._ValueProxy(self, ref, default) # str and torch.device are handed back as plain values (see - # _plain_get_value); other types (dict/list/int/float/...) stay + # _plain_get_value); other simple types (int/float/bool/...) stay # live proxies so studio edits keep tracking. plain = _plain_get_value(vp._resolve()) return vp if plain is _KEEP_AS_PROXY else plain - return self._obj.get(ref, default) + return value return self._obj if self._obj is not None else default def __getattr__(self, item): @@ -728,6 +729,68 @@ def __exit__(self, exc_type, exc, tb): MutableMapping.register(Proxy) +def _register_yaml_representers() -> None: + """Teach PyYAML to dump ledger proxies as their underlying value. + + A hyperparameter handle returned by ``watch_or_edit(..., flag='hyperparameters')`` + is a live ``Proxy`` / ``Proxy._ValueProxy``. When such a value flows into a + library that serializes it (e.g. Ultralytics writes its run args to ``args.yaml`` + via ``yaml.safe_dump``), PyYAML has no representer for the proxy type and raises + ``RepresenterError`` — and because the proxy masquerades as its wrapped type via + ``__class__``, callers' ``isinstance(x, (int, str, ...))`` "stringify" guards skip + it too. Registering representers that emit the *resolved* value makes proxies + transparently serializable everywhere, so the live-proxy HP design stays + compatible with such libraries. Registered on both the default and Safe dumpers. + """ + def _represent_obj_proxy(dumper, data): + return dumper.represent_data(data.get()) # Proxy -> wrapped object + + def _represent_value_proxy(dumper, data): + return dumper.represent_data(data._resolve()) # _ValueProxy -> resolved value + + # Register on every dumper variant: the pure-Python Dumper/SafeDumper, the + # libyaml C dumpers (CDumper/CSafeDumper — Ultralytics dumps with CSafeDumper), + # and the base Representer/SafeRepresenter. Each keeps its own representer + # table, so registering on one does not cover the others. + for _name in ("Dumper", "SafeDumper", "CDumper", "CSafeDumper", "Representer", "SafeRepresenter"): + _dumper = getattr(yaml, _name, None) + if _dumper is not None: + _dumper.add_representer(Proxy, _represent_obj_proxy) + _dumper.add_representer(Proxy._ValueProxy, _represent_value_proxy) + + +def _register_json_default() -> None: + """Teach the stdlib ``json`` encoder to serialize ledger proxies as their + underlying value, mirroring :func:`_register_yaml_representers`. + + ``json`` has no global representer registry, so we wrap ``JSONEncoder.default`` + (the hook called for objects the encoder doesn't natively handle). The C + encoder dispatches by concrete type, so a proxy — whose real type is not int/ + str/dict/... despite its ``__class__`` masquerade — reaches ``default`` and is + replaced by its resolved value. Makes ``json.dumps``/``json.dump`` of HP + proxies work everywhere (e.g. audit/JSON config dumps) without per-call hooks. + """ + import json + + if getattr(json.JSONEncoder, "_wl_proxy_patched", False): + return + _orig_default = json.JSONEncoder.default + + def default(self, obj): + if isinstance(obj, Proxy._ValueProxy): + return obj._resolve() + if isinstance(obj, Proxy): + return obj.get() + return _orig_default(self, obj) + + json.JSONEncoder.default = default + json.JSONEncoder._wl_proxy_patched = True + + +_register_yaml_representers() +_register_json_default() + + class Ledger: """Thread-safe ledger storing named registries for different object types. diff --git a/weightslab/examples/Ultralytics/ws-detection/config.yaml b/weightslab/examples/Ultralytics/ws-detection/config.yaml index 2ba8184e..9c07dfb9 100644 --- a/weightslab/examples/Ultralytics/ws-detection/config.yaml +++ b/weightslab/examples/Ultralytics/ws-detection/config.yaml @@ -38,7 +38,7 @@ ledger_flush_interval: 60.0 # Data num_classes: 2 image_size: 320 -data_root: .\data\data.yaml # Uncomment and set the path to your data.yaml file. YOLO format. +data_root: C:\Users\GuillaumePELLUET\Documents\Codes\weightslab_kitchen\guillaume_playground\ws-ultralytics_yolo\data\data.yaml # Uncomment and set the path to your data.yaml file. YOLO format. data: train_loader: batch_size: 4 diff --git a/weightslab/examples/Ultralytics/ws-detection/main.py b/weightslab/examples/Ultralytics/ws-detection/main.py index ea1876e3..c56597ca 100644 --- a/weightslab/examples/Ultralytics/ws-detection/main.py +++ b/weightslab/examples/Ultralytics/ws-detection/main.py @@ -41,8 +41,11 @@ def main(): os.makedirs(cfg["root_log_dir"], exist_ok=True) wl.watch_or_edit(cfg, flag="hyperparameters", defaults=cfg, poll_interval=1.0) - # Read raw config values BEFORE wrapping so YOLO.train kwargs are plain - # Python (avoids ProxyValue.__gt__ during max()/comparisons). + # After watch_or_edit, `cfg` is the live hyperparameter proxy, so these reads + # return ledger handles (e.g. image_size is a ValueProxy) that stay in sync + # with studio edits. They are passed straight to YOLO.train(...): ValueProxy + # supports int/compare ops for YOLO's imgsz handling, and the ledger registers + # a YAML representer (see ledgers.py) so Ultralytics can dump its run args. model_name = cfg["model"]["name"] data_root = str(cfg["data_root"]) image_size = cfg.get("image_size") @@ -52,7 +55,6 @@ def main(): serving_cli = cfg.get("serving_cli", False) project = cfg["root_log_dir"] name = cfg["experiment_name"] - signals_cfg = cfg.get('signals_cfg', {}) wl.serve(serving_grpc=serving_grpc, serving_cli=serving_cli) @@ -78,8 +80,10 @@ def main(): degrees=0.0, translate=0.0, scale=0.0, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.0, erasing=0.0, auto_augment=None, - # Signals cfg - **signals_cfg + # NOTE: signals_cfg (e.g. train_nms) is NOT passed here — it is read by + # WLAwareTrainer from the registered hyperparameters + # (ledgers.get_hyperparams()['signals_cfg']). Spreading it into .train() + # would make Ultralytics reject keys like `train_nms` as invalid YOLO args. ) wl.keep_serving() # Keep main thread alive to analyze training results directly diff --git a/weightslab/tests/backend/test_ledgers.py b/weightslab/tests/backend/test_ledgers.py index 232dbb3e..5f3288c0 100644 --- a/weightslab/tests/backend/test_ledgers.py +++ b/weightslab/tests/backend/test_ledgers.py @@ -197,26 +197,26 @@ def test_proxy_get_key_default_mode_returns_live_proxy(self): hp_handle["lr"] = 0.02 self.assertEqual(lr.get(), 0.02) - def test_value_proxy_subscript_access(self): - """ValueProxy subscript [key] is equivalent to .get(key) and chains.""" + def test_dict_value_returned_raw_not_proxied(self): + """Dict (and list / callable) values come back RAW, not wrapped in a live + proxy (see Proxy.get's list/dict/callable exclusion), so subscripting just + reads the plain mapping.""" hp_handle = GLOBAL_LEDGER.get_hyperparams() GLOBAL_LEDGER.register_hyperparams( params={"dataset": {"batch_size": 32, "splits": {"train": 0.8}}} ) dataset = hp_handle.get("dataset") - # [key] matches .get(key) for the resolved mapping. + # A dict value is handed back raw, not as a live ValueProxy. + self.assertIsInstance(dataset, dict) + self.assertFalse(hasattr(dataset, "set")) + + # Plain mapping access (and nesting) works. self.assertEqual(dataset["batch_size"], dataset.get("batch_size")) self.assertEqual(dataset["batch_size"], 32) - - # Nested dicts are wrapped in a live proxy so chaining keeps resolving. self.assertEqual(dataset["splits"]["train"], 0.8) - # Reads stay fresh against the underlying mapping. - hp_handle["dataset"] = {"batch_size": 64, "splits": {"train": 0.9}} - self.assertEqual(dataset["batch_size"], 64) - - # Missing keys raise KeyError, matching standard subscript semantics. + # Missing keys raise KeyError, matching standard dict subscript semantics. with self.assertRaises(KeyError): dataset["missing"] @@ -254,6 +254,37 @@ def test_proxy_get_key_explicit_plain_value_mode(self): hp_handle["data_root"] = "C:/data/v2" self.assertEqual(data_root, "C:/data/v1") + def test_proxy_yaml_and_json_serialization(self): + """Ledger proxies serialize to their underlying value for both YAML and + JSON, so libraries that dump their config (e.g. Ultralytics' args.yaml, + or JSON audit/config dumps) don't choke on a live hyperparameter proxy.""" + import json + import yaml + + hp = GLOBAL_LEDGER.get_hyperparams() + GLOBAL_LEDGER.register_hyperparams(params={"image_size": 320, "lr": 0.01}) + + img = hp.get("image_size") # a live ValueProxy, not a plain int + self.assertEqual(type(img).__name__, "_ValueProxy") + + # YAML: cover every dumper variant — Ultralytics dumps with CSafeDumper + # (the libyaml C dumper), which keeps its own representer table. + for dumper_name in ("Dumper", "SafeDumper", "CDumper", "CSafeDumper"): + dumper = getattr(yaml, dumper_name, None) + if dumper is None: + continue + self.assertEqual( + yaml.dump({"imgsz": img}, Dumper=dumper).strip(), "imgsz: 320", + f"{dumper_name} did not serialize the proxy", + ) + self.assertEqual( + yaml.safe_load(yaml.safe_dump(hp)), {"image_size": 320, "lr": 0.01} + ) + + # JSON: json.dumps of a scalar proxy and of the whole HP proxy. + self.assertEqual(json.loads(json.dumps({"imgsz": img})), {"imgsz": 320}) + self.assertEqual(json.loads(json.dumps(hp)), {"image_size": 320, "lr": 0.01}) + def test_value_proxy_numeric_comparisons(self): """ValueProxy supports all standard numeric and string comparison operators.""" hp_handle = GLOBAL_LEDGER.get_hyperparams() From cb98a66dc38cce07c465119950684098ae7555cb Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 18:45:34 +0200 Subject: [PATCH 11/16] Fix instance wise segmentation bugs and modal view sync with the UI --- docs/configuration.rst | 8 ++ weightslab/data/data_utils.py | 21 ++-- .../PyTorch/ws-segmentation/config.yaml | 2 +- .../examples/PyTorch/ws-segmentation/main.py | 6 +- .../PyTorch/ws-segmentation/utils/data.py | 31 +++-- weightslab/tests/gRPC/test_get_point_cloud.py | 40 +++++- weightslab/trainer/services/data_service.py | 114 +++++++++++------- 7 files changed, 155 insertions(+), 67 deletions(-) diff --git a/docs/configuration.rst b/docs/configuration.rst index 8eff60e1..7d6a8fe5 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -245,6 +245,14 @@ Data and Cache uniformly downsampled — keeping the first and last point and an evenly-spaced subset in between (no values are interpolated/invented). Set to ``0`` to disable the cap and return every step of the mean curve. + * - ``WL_POINT_CLOUD_CHUNK_BYTES`` + - ``1048576`` + - Size, in bytes, of each chunk streamed by the ``GetPointCloud`` RPC + (raw ``float32`` point-cloud data is sent as a sequence of binary + messages). Defaults to ``1048576`` (1 MiB). Larger chunks mean fewer + gRPC messages but more memory held per message; smaller chunks lower + peak memory at the cost of more round-trips. Must be a positive integer + — non-positive or non-numeric values fall back to the 1 MiB default. Evaluation Mode diff --git a/weightslab/data/data_utils.py b/weightslab/data/data_utils.py index 9ce19864..60c0f998 100644 --- a/weightslab/data/data_utils.py +++ b/weightslab/data/data_utils.py @@ -453,6 +453,13 @@ def load_label(dataset, sample_id): # Get dataset wrapper if exists wrapped = getattr(dataset, "wrapped_dataset", dataset) + def _convert_label(lbl): + if isinstance(lbl, list) and len(lbl) and isinstance(lbl[0], (th.Tensor, np.ndarray)): + label = to_numpy_safe(lbl).max(0) # Aggr. instances + else: + label = to_numpy_safe(lbl) # Third element is typically the label + return label + # Try common dataset patterns first if hasattr(wrapped, '__getitem__'): data = wrapped.get_items(index, include_metadata=False, include_labels=True, include_images=False) if hasattr(wrapped, 'get_items') else wrapped[index] @@ -460,23 +467,21 @@ def load_label(dataset, sample_id): if isinstance(data, (list, tuple)): if len(data) == 1: return None # Only data, no label - elif len(data) == 2: # Commonly (data, label) in standard PyTorch datasets - label = to_numpy_safe(data[1]) - elif len(data) == 3: # if len==3, data, uids, label, no extra info - label = to_numpy_safe(data[2]) # Third element is typically the label + elif len(data) <= 3: # if len==2|3, data, uids, label, no extra info + label = _convert_label(data[2]) elif len(data) > 3: # if len>3, data, uids, label, classes, extra info if len(data) == 4: - label = to_numpy_safe(data[2]) # Third element is typically the label metadata = data[3] classes = to_numpy_safe(metadata['classes']) if isinstance(metadata, dict) and 'classes' in metadata else None if classes is not None: - label = to_numpy_safe(data[2]) # Second element is typically the label + label = _convert_label(data[2]) + # Concat label with classes if available (bbox detection, i.e., (4,) -> (5,) with class id) label = np.concatenate([label, classes[..., None]], axis=1) else: - label = to_numpy_safe(data[2]) # Second element is typically the label + label = _convert_label(data[2]) else: - label = to_numpy_safe(data[2]) # Third element is typically the label + label = _convert_label(data[2]) metadata = data[3:] if label is not None: return label[0] if label.ndim == 1 and label.shape[0] == 1 else label diff --git a/weightslab/examples/PyTorch/ws-segmentation/config.yaml b/weightslab/examples/PyTorch/ws-segmentation/config.yaml index 5054332b..9aceb6c8 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/config.yaml +++ b/weightslab/examples/PyTorch/ws-segmentation/config.yaml @@ -29,7 +29,7 @@ ledger_flush_interval: 60.0 # Data num_classes: 6 image_size: 180 -data_root: .\BDD_subset # Bdd format +data_root: C:\Users\GuillaumePELLUET\Documents\Codes\weightslab\weightslab\examples\PyTorch\ws-segmentation\BDD_subset # Bdd format data: train_loader: batch_size: 2 diff --git a/weightslab/examples/PyTorch/ws-segmentation/main.py b/weightslab/examples/PyTorch/ws-segmentation/main.py index c3a4eb48..1f184724 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/main.py +++ b/weightslab/examples/PyTorch/ws-segmentation/main.py @@ -336,9 +336,9 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 print(f"📂 Data root: {data_root}") print("=" * 60 + "\n") - # ================ - # Training Loop - wl.start_training(timeout=3) # This will block and keep the main thread alive while background services run. You can optionally set a timeout (in seconds) to automatically stop after a certain duration. + # # ================ + # # Training Loop + # wl.start_training(timeout=3) # This will block and keep the main thread alive while background services run. You can optionally set a timeout (in seconds) to automatically stop after a certain duration. # ================ train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py index 28ba4dde..70715f33 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py @@ -114,22 +114,31 @@ def get_items(self, idx, include_metadata=False, include_labels=False, include_i img_t = self.image_transform(img) # Process labels/masks - mask_t_instances = list() + # # Sample wise segmentation mask_t = None if include_labels: mask = Image.open(mask_path) mask_r = self.mask_resize(mask) mask_np = np.array(mask_r, dtype=np.int64) - mask_t = torch.from_numpy(mask_np) # [H, W] int64 - - # Format labels to register multiple instance_ids - lbl_max = mask_t.max().item() - for i in range(1, lbl_max + 1): - m = torch.zeros_like(mask_t) - m[mask_t == i] = i # Assign class ID as instance ID for simplicity; if set to 1, all instances of the same class would be merged... - mask_t_instances.append(m) - return img_t, uid, mask_t_instances, metadata - + mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 + return img_t, uid, mask_t, metadata + # # # Instance wise segmentaiton + # # Process labels/masks + # mask_t_instances = list() + # mask_t = None + # if include_labels: + # mask = Image.open(mask_path) + # mask_r = self.mask_resize(mask) + # mask_np = np.array(mask_r, dtype=np.int64) + # mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 + + # # Format labels to register multiple instance_ids + # lbl_max = mask_t.max().item() + # for i in range(1, lbl_max + 1): + # m = torch.zeros_like(mask_t) + # m[mask_t == i] = i # Assign class ID as instance ID for simplicity; if set to 1, all instances of the same class would be merged... + # mask_t_instances.append(m) + # return img_t, uid, mask_t_instances, metadata def seg_collate(batch): """Collate WL per-sample tuples for instance-segmentation. diff --git a/weightslab/tests/gRPC/test_get_point_cloud.py b/weightslab/tests/gRPC/test_get_point_cloud.py index 1e8ba1ba..035fe319 100644 --- a/weightslab/tests/gRPC/test_get_point_cloud.py +++ b/weightslab/tests/gRPC/test_get_point_cloud.py @@ -3,7 +3,11 @@ import weightslab.proto.experiment_service_pb2 as pb2 -from weightslab.trainer.services.data_service import DataService +from weightslab.trainer.services.data_service import ( + DataService, + _DEFAULT_POINT_CLOUD_CHUNK_BYTES, + _point_cloud_chunk_bytes, +) PC_RANGE = (0.0, -32.0, -3.0, 64.0, 32.0, 1.0) @@ -84,6 +88,40 @@ def test_get_point_cloud_unknown_sample_fails_gracefully(): assert "not found" in chunks[0].message +def test_point_cloud_chunk_bytes_default(monkeypatch): + monkeypatch.delenv("WL_POINT_CLOUD_CHUNK_BYTES", raising=False) + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES == (1 << 20) + + +def test_point_cloud_chunk_bytes_env_override(monkeypatch): + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "4096") + assert _point_cloud_chunk_bytes() == 4096 + + +def test_point_cloud_chunk_bytes_invalid_falls_back(monkeypatch): + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "not-a-number") + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "0") + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "-10") + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES + + +def test_get_point_cloud_honours_configured_chunk_size(): + """A smaller chunk size splits the same cloud into more (correct) messages.""" + class _SmallChunkService(_StubService): + _POINT_CLOUD_CHUNK_BYTES = 4096 # bytes + + stub = _SmallChunkService(_FakeLidarDataset()) + chunks = _collect(stub, pb2.PointCloudRequest(sample_id="7", origin="train_loader")) + + total_bytes = 50_000 * 4 * 4 + assert len(chunks) > 1 + assert all(len(c.data) <= 4096 for c in chunks) + assert chunks[0].total_chunks == len(chunks) + assert sum(len(c.data) for c in chunks) == total_bytes + + def test_get_point_cloud_non_pointcloud_sample_fails_gracefully(): class ImgDataset(_FakeLidarDataset): def get_items(self, idx, **kwargs): diff --git a/weightslab/trainer/services/data_service.py b/weightslab/trainer/services/data_service.py index 9d2904da..004aea89 100755 --- a/weightslab/trainer/services/data_service.py +++ b/weightslab/trainer/services/data_service.py @@ -52,6 +52,34 @@ logger = logging.getLogger(__name__) +# Streamed chunk size for GetPointCloud (raw float32 bytes per gRPC message). +# Larger chunks mean fewer messages but more memory per message. Override with +# the WL_POINT_CLOUD_CHUNK_BYTES env variable (see docs/configuration.rst). +_DEFAULT_POINT_CLOUD_CHUNK_BYTES = 1 << 20 # 1 MiB + + +def _point_cloud_chunk_bytes() -> int: + """Read WL_POINT_CLOUD_CHUNK_BYTES; non-positive/invalid falls back to the default.""" + raw = os.getenv("WL_POINT_CLOUD_CHUNK_BYTES") + if raw is None or raw == "": + return _DEFAULT_POINT_CLOUD_CHUNK_BYTES + try: + val = int(raw) + except (TypeError, ValueError): + logger.warning( + "WL_POINT_CLOUD_CHUNK_BYTES=%r is not an integer — using default %d", + raw, _DEFAULT_POINT_CLOUD_CHUNK_BYTES, + ) + return _DEFAULT_POINT_CLOUD_CHUNK_BYTES + if val <= 0: + logger.warning( + "WL_POINT_CLOUD_CHUNK_BYTES=%r must be > 0 — using default %d", + raw, _DEFAULT_POINT_CLOUD_CHUNK_BYTES, + ) + return _DEFAULT_POINT_CLOUD_CHUNK_BYTES + return val + + def normalize_metadata_copy_source_name(source_name: str, experiment_hash: str = None) -> str: """Normalize a source metadata name for deterministic copied-column naming.""" name = str(source_name or "").strip() @@ -766,49 +794,49 @@ def _is_training_active(self) -> bool: return True def _pull_into_all_data_view_df(self): - """Stream stats from the global in-memory dataframe (ledger manager). + """Stream stats from the global in-memory dataframe (ledger manager). - Uses the shared dataframe manager instead of the H5 store and avoids - blocking on IO. Falls back to last snapshot if retrieval fails. - """ - try: - # Load dataframe from the shared dataframe manager with arrays autoloaded from h5 storage - df = self._df_manager.get_combined_df() if self._df_manager is not None else pd.DataFrame() - if df.empty: - logger.debug(f"[DataService] Pull returned empty dataframe (manager: {self._df_manager is not None})") - return df - - # The manager now expands samples into one row per (sample_id, annotation_id) - # instance. Collapse back to one row per sample for the sample-centric UI/agent - # view, nesting per-instance signals into a dict column. - df = self._df_manager.get_collapse_annotations_to_samples_df() - - # Ensure sample_id is a column if it was the index - df = safe_reset_index(df) - - # Ensure we have a unique index across all origins by using a MultiIndex (origin, sample_id) - # This is CRITICAL for correctly applying reindex() in _slowUpdateInternals without - # exploding the dataframe size due to duplicate sample_id index labels. - if SampleStatsEx.ORIGIN.value in df.columns: - # Use drop=True to ensure origin is NOT in both index and columns (avoids ambiguity) - # GetDataSamples calls reset_index() before processing rows, which restores them as columns - df = df.set_index([SampleStatsEx.ORIGIN.value, SampleStatsEx.SAMPLE_ID.value], drop=True) - else: - # Fallback to single index if origin is missing, though manager should provide it - df = df.set_index([SampleStatsEx.SAMPLE_ID.value], drop=True) + Uses the shared dataframe manager instead of the H5 store and avoids + blocking on IO. Falls back to last snapshot if retrieval fails. + """ + try: + # Load dataframe from the shared dataframe manager with arrays autoloaded from h5 storage + df = self._df_manager.get_combined_df() if self._df_manager is not None else pd.DataFrame() + if df.empty: + logger.debug(f"[DataService] Pull returned empty dataframe (manager: {self._df_manager is not None})") + return df - # DEDUPLICATE: Ensure index is unique before returning. - # If duplicates exist, reindex() will fail later. - if df.index.has_duplicates: - logger.debug(f"[DataService] Dropping {df.index.duplicated().sum()} duplicate index labels from data view.") - df = df[~df.index.duplicated(keep='last')] + # The manager now expands samples into one row per (sample_id, annotation_id) + # instance. Collapse back to one row per sample for the sample-centric UI/agent + # view, nesting per-instance signals into a dict column. + df = self._df_manager.get_collapse_annotations_to_samples_df() - return df - except Exception as e: - logger.debug(f"[DataService] Error pulling data view: {e}") - # Use getattr to safely check for attribute during __init__ - current_df = getattr(self, "_all_datasets_df", None) - return current_df if current_df is not None else pd.DataFrame() + # Ensure sample_id is a column if it was the index + df = safe_reset_index(df) + + # Ensure we have a unique index across all origins by using a MultiIndex (origin, sample_id) + # This is CRITICAL for correctly applying reindex() in _slowUpdateInternals without + # exploding the dataframe size due to duplicate sample_id index labels. + if SampleStatsEx.ORIGIN.value in df.columns: + # Use drop=True to ensure origin is NOT in both index and columns (avoids ambiguity) + # GetDataSamples calls reset_index() before processing rows, which restores them as columns + df = df.set_index([SampleStatsEx.ORIGIN.value, SampleStatsEx.SAMPLE_ID.value], drop=True) + else: + # Fallback to single index if origin is missing, though manager should provide it + df = df.set_index([SampleStatsEx.SAMPLE_ID.value], drop=True) + + # DEDUPLICATE: Ensure index is unique before returning. + # If duplicates exist, reindex() will fail later. + if df.index.has_duplicates: + logger.debug(f"[DataService] Dropping {df.index.duplicated().sum()} duplicate index labels from data view.") + df = df[~df.index.duplicated(keep='last')] + + return df + except Exception as e: + logger.debug(f"[DataService] Error pulling data view: {e}") + # Use getattr to safely check for attribute during __init__ + current_df = getattr(self, "_all_datasets_df", None) + return current_df if current_df is not None else pd.DataFrame() def _get_origin_filter(self, request): """Extract requested origins if present on request (backward compatible).""" @@ -881,7 +909,6 @@ def _is_nan_value(self, value): except (TypeError, ValueError): return False - def _compute_natural_sort_stats(self): """ Compute hardcoded natural sort statistics (brightness, hue, saturation, entropy) for all samples @@ -3642,8 +3669,9 @@ def GetHistogram(self, request, context): success=False, message=f"histogram failed: {str(e)}", total_rows=0, bins=[]) - # Streamed chunk size for GetPointCloud (raw float32 bytes per message). - _POINT_CLOUD_CHUNK_BYTES = 1 << 20 # 1 MiB + # Streamed chunk size for GetPointCloud (raw float32 bytes per message), + # configurable via the WL_POINT_CLOUD_CHUNK_BYTES env var (default 1 MiB). + _POINT_CLOUD_CHUNK_BYTES = _point_cloud_chunk_bytes() def GetPointCloud(self, request, context): """Stream one sample's raw point cloud as binary float32 chunks. From a0a09a9cd0235f003fed5c717e17f9fdf250108a Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 19:09:39 +0200 Subject: [PATCH 12/16] Remove all emoji from source code Strip colorful emoji from console.log, print statements, logger calls, docstrings, and comments across the entire backend codebase. Plain text conveys the same meaning without unicode-rendering concerns. Co-Authored-By: Claude Sonnet 4.6 --- weightslab/__init__.py | 2 +- weightslab/art.py | 26 +- weightslab/backend/audit_logger.py | 4 +- weightslab/backend/cli.py | 10 +- weightslab/backend/dataloader_interface.py | 78 +++--- weightslab/backend/ledgers.py | 26 +- weightslab/backend/logger.py | 28 +-- weightslab/backend/model_interface.py | 14 +- weightslab/baseline_models/pytorch/models.py | 130 +++++----- weightslab/components/__init__.py | 18 +- weightslab/components/checkpoint_manager.py | 196 +++++++-------- weightslab/components/experiment_hash.py | 18 +- weightslab/components/global_monitoring.py | 16 +- weightslab/components/parallel_primitives.py | 30 +-- weightslab/components/tracking.py | 20 +- weightslab/data/array_proxy.py | 6 +- weightslab/data/data_samples_with_ops.py | 40 ++-- weightslab/data/data_utils.py | 50 ++-- weightslab/data/dataframe_manager.py | 84 +++---- weightslab/data/h5_array_store.py | 14 +- weightslab/data/h5_dataframe_store.py | 24 +- weightslab/data/point_cloud_utils.py | 42 ++-- weightslab/data/sample_stats.py | 12 +- .../Lightning/ws-classification/main.py | 20 +- .../PyTorch/ws-classification/main.py | 22 +- .../PyTorch/ws-clustering/face/data.py | 94 ++++---- .../PyTorch/ws-clustering/face/model.py | 44 ++-- .../PyTorch/ws-clustering/face/signals.py | 4 +- .../PyTorch/ws-clustering/face/utils.py | 14 +- .../examples/PyTorch/ws-clustering/main.py | 46 ++-- .../examples/PyTorch/ws-detection/main.py | 30 +-- .../PyTorch/ws-detection/utils/criterions.py | 36 +-- .../PyTorch/ws-detection/utils/data.py | 28 +-- .../PyTorch/ws-detection/utils/model.py | 22 +- .../examples/PyTorch/ws-generation/main.py | 46 ++-- .../examples/PyTorch/ws-segmentation/main.py | 58 ++--- .../ws-segmentation/utils/criterions.py | 18 +- .../PyTorch/ws-segmentation/utils/data.py | 34 +-- .../PyTorch/ws-segmentation/utils/model.py | 4 +- .../examples/Ultralytics/ws-detection/main.py | 6 +- .../Usecases/ws-2d-lidar-detection/main.py | 12 +- .../ws-2d-lidar-detection/utils/criterions.py | 8 +- .../ws-2d-lidar-detection/utils/data.py | 14 +- .../ws-2d-lidar-detection/utils/model.py | 24 +- .../Usecases/ws-3d-lidar-detection/main.py | 68 +++--- .../ws-3d-lidar-detection/utils/criterions.py | 50 ++-- .../ws-3d-lidar-detection/utils/data.py | 92 +++---- .../utils/kitti_download.py | 10 +- .../ws-3d-lidar-detection/utils/model.py | 56 ++--- .../integrations/ultralytics/__init__.py | 4 +- .../integrations/ultralytics/dataset.py | 4 +- .../integrations/ultralytics/signals.py | 24 +- .../integrations/ultralytics/trainer.py | 20 +- weightslab/models/model_with_ops.py | 8 +- weightslab/models/monkey_patcher.py | 2 +- weightslab/modules/modules_with_ops.py | 52 ++-- weightslab/proto/experiment_service_pb2.py | 2 +- weightslab/security/cert_auth_manager.py | 2 +- weightslab/src.py | 220 ++++++++--------- .../tests/backend/test_compare_dataloaders.py | 128 +++++----- .../backend/test_data_loader_interface.py | 130 +++++----- weightslab/tests/backend/test_ledgers.py | 8 +- weightslab/tests/backend/test_logger_core.py | 18 +- .../tests/backend/test_ui_docker_bridge.py | 30 +-- .../tests/backend/test_write_dataframe.py | 6 +- .../tests/backend/test_write_history.py | 10 +- .../test_grpc_chaos_monkey_robustness.py | 2 +- .../components/test_checkpoint_workflow.py | 126 +++++----- .../components/test_global_monitoring_unit.py | 42 ++-- .../tests/data/test_data_samples_with_ops.py | 6 +- .../tests/data/test_dataframe_manager_unit.py | 16 +- weightslab/tests/data/test_flush_pipeline.py | 12 +- weightslab/tests/data/test_h5_array_store.py | 2 +- .../tests/data/test_h5_dataframe_store.py | 6 +- .../tests/data/test_point_cloud_utils.py | 30 +-- weightslab/tests/gRPC/test_get_point_cloud.py | 2 +- .../tests/gRPC/test_grpc_tag_operations.py | 52 ++-- .../tests/gRPC/test_grpc_user_actions.py | 4 +- weightslab/tests/general/test_cli.py | 172 +++++++------- weightslab/tests/general/test_signals.py | 20 +- .../tests/general/test_signals_wrapping.py | 8 +- .../test_pytorch_lightning_integration.py | 224 +++++++++--------- .../ultralytics/ddp/ddp_ablation.py | 26 +- .../ultralytics/ddp/ddp_test_suite.py | 216 ++++++++--------- .../tests/model/test_constraint_generation.py | 8 +- .../tests/model/test_dependency_patterns.py | 12 +- weightslab/tests/model/test_model_with_ops.py | 2 +- weightslab/tests/model/test_tracking.py | 2 +- .../tests/modules/test_modules_with_ops.py | 36 +-- weightslab/tests/test_secure_docker.py | 2 +- .../services/test_agent_prompt_unit.py | 2 +- .../services/test_trainer_services_server.py | 2 +- .../services/test_trainer_services_unit.py | 2 +- .../tests/watchdog/test_lock_monitor.py | 2 +- weightslab/tests/watchdog/test_watchdog.py | 28 +-- weightslab/trainer/experiment_context.py | 4 +- weightslab/trainer/services/agent/agent.py | 18 +- weightslab/trainer/services/agent_service.py | 10 +- .../trainer/services/data_image_utils.py | 6 +- weightslab/trainer/services/data_service.py | 108 ++++----- .../trainer/services/experiment_service.py | 38 +-- .../trainer/services/instance_merger.py | 12 +- weightslab/trainer/services/utils/tools.py | 2 +- weightslab/trainer/trainer_services.py | 24 +- weightslab/trainer/trainer_tools.py | 38 +-- weightslab/ui_docker_bridge.py | 106 ++++----- weightslab/utils/computational_graph.py | 184 +++++++------- weightslab/utils/logs.py | 2 +- weightslab/utils/tools.py | 24 +- weightslab/watchdog/__init__.py | 22 +- weightslab/watchdog/grpc_watchdog.py | 4 +- weightslab/watchdog/lock_monitor.py | 14 +- weightslab/watchdog/log_level.py | 2 +- weightslab/watchdog/watchdog.py | 40 ++-- 114 files changed, 2069 insertions(+), 2069 deletions(-) diff --git a/weightslab/__init__.py b/weightslab/__init__.py index ecb5be12..df7a8154 100644 --- a/weightslab/__init__.py +++ b/weightslab/__init__.py @@ -63,7 +63,7 @@ # Get Package Metadata try: # setuptools_scm will write weightslab/_version.py during build - from ._version import __version__ # type: ignore + from ._version import __version__ # type: ignore except Exception: # Fallback when developing locally or before build; keeps behavior stable. from datetime import datetime diff --git a/weightslab/art.py b/weightslab/art.py index cf34b072..4436ee2f 100644 --- a/weightslab/art.py +++ b/weightslab/art.py @@ -10,11 +10,11 @@ def get_git_info(): git_root = current_dir # Traverse up to find .git directory - for _ in range(10): # Limit search depth + for _ in range(10): # Limit search depth if os.path.isdir(os.path.join(git_root, '.git')): break parent = os.path.dirname(git_root) - if parent == git_root: # Reached filesystem root + if parent == git_root: # Reached filesystem root git_root = None break git_root = parent @@ -40,19 +40,19 @@ def get_git_info(): branch, version, commit_hash = get_git_info() _BANNER = f""" -\x1b[31m /WW /WW\x1b[0m /$$ /$$ /$$ \x1b[32m/$$\x1b[0m /$$ -\x1b[31m| WW /W | WW\x1b[0m |__/ | $$ | $$ \x1b[32m| $$\x1b[0m | $$ -\x1b[31m| WW /WWW| WW\x1b[0m /$$$$$$ /$$ /$$$$$$ | $$$$$$$ /$$$$$$ /$$$$$$$\x1b[32m| $$\x1b[0m /$$$$$$ | $$$$$$$ -\x1b[31m| WW/WW WW WW\x1b[0m /$$__ $$| $$ /$$__ $$| $$__ $$|_ $$_/ /$$_____/\x1b[32m| $$\x1b[0m |____ $$| $$__ $$ -\x1b[31m| WWWW_ WWWW\x1b[0m| $$$$$$$$| $$| $$ \ $$| $$ \ $$ | $$ | $$$$$$ \x1b[32m| $$\x1b[0m /$$$$$$$| $$ \ $$ -\x1b[31m| WWW/ \ WWW\x1b[0m| $$_____/| $$| $$ | $$| $$ | $$ | $$ /$$ \____ $$\x1b[32m| $$\x1b[0m /$$__ $$| $$ | $$ -\x1b[31m| WW/ \ WW\x1b[0m| $$$$$$$| $$| $$$$$$$| $$ | $$ | $$$$/ /$$$$$$$/\x1b[32m| $$$$$$$$\x1b[0m $$$$$$$| $$$$$$$/ -\x1b[31m|__/ \__/\x1b[0m \_______/|__/ \____ $$|__/ |__/ \___/ |_______/ \x1b[32m|________/\x1b[0m \_______/|_______/ - /$$ \ $$ - | $$$$$$/ +\x1b[31m /WW /WW\x1b[0m /$$ /$$ /$$ \x1b[32m/$$\x1b[0m /$$ +\x1b[31m| WW /W | WW\x1b[0m |__/ | $$ | $$ \x1b[32m| $$\x1b[0m | $$ +\x1b[31m| WW /WWW| WW\x1b[0m /$$$$$$ /$$ /$$$$$$ | $$$$$$$ /$$$$$$ /$$$$$$$\x1b[32m| $$\x1b[0m /$$$$$$ | $$$$$$$ +\x1b[31m| WW/WW WW WW\x1b[0m /$$__ $$| $$ /$$__ $$| $$__ $$|_ $$_/ /$$_____/\x1b[32m| $$\x1b[0m |____ $$| $$__ $$ +\x1b[31m| WWWW_ WWWW\x1b[0m| $$$$$$$$| $$| $$ \ $$| $$ \ $$ | $$ | $$$$$$ \x1b[32m| $$\x1b[0m /$$$$$$$| $$ \ $$ +\x1b[31m| WWW/ \ WWW\x1b[0m| $$_____/| $$| $$ | $$| $$ | $$ | $$ /$$ \____ $$\x1b[32m| $$\x1b[0m /$$__ $$| $$ | $$ +\x1b[31m| WW/ \ WW\x1b[0m| $$$$$$$| $$| $$$$$$$| $$ | $$ | $$$$/ /$$$$$$$/\x1b[32m| $$$$$$$$\x1b[0m $$$$$$$| $$$$$$$/ +\x1b[31m|__/ \__/\x1b[0m \_______/|__/ \____ $$|__/ |__/ \___/ |_______/ \x1b[32m|________/\x1b[0m \_______/|_______/ + /$$ \ $$ + | $$$$$$/ \______/ By GrayBx """ if branch is not None and version is not None and commit_hash is not None: _BANNER += f"\nBranch: {branch} | Version: {version} | Commit: {commit_hash}\n" -_BANNER__ = _BANNER # Expose banner with a different name for external use and legacy +_BANNER__ = _BANNER # Expose banner with a different name for external use and legacy diff --git a/weightslab/backend/audit_logger.py b/weightslab/backend/audit_logger.py index 9930e927..fdcde841 100644 --- a/weightslab/backend/audit_logger.py +++ b/weightslab/backend/audit_logger.py @@ -14,9 +14,9 @@ @dataclass class AuditEvent: """Immutable audit event structure.""" - timestamp: str # ISO format string + timestamp: str # ISO format string action_type: str - status: str # "success" or "failed" + status: str # "success" or "failed" details: Optional[Dict[str, Any]] = None error: Optional[str] = None diff --git a/weightslab/backend/cli.py b/weightslab/backend/cli.py index acaf911e..1734a15c 100644 --- a/weightslab/backend/cli.py +++ b/weightslab/backend/cli.py @@ -172,7 +172,7 @@ def _handle_command(cmd: str) -> Any: 'hyperparams_examples': { 'list': 'hp', 'show': 'hp fashion_mnist', - 'set': "set_hp # e.g. set_hp fashion_mnist data.train_loader.batch_size 32", + 'set': "set_hp # e.g. set_hp fashion_mnist data.train_loader.batch_size 32", }, 'evaluate_examples': { 'eval on default split': 'evaluate', @@ -451,7 +451,7 @@ def _handle_command(cmd: str) -> Any: for name in snap.get(k, []): try: getter = { - # 'models': GLOBAL_LEDGER.get_model, # don't print the model out + # 'models': GLOBAL_LEDGER.get_model, # don't print the model out 'dataloaders': GLOBAL_LEDGER.get_dataloader, 'optimizers': GLOBAL_LEDGER.get_optimizer, }[k] @@ -823,7 +823,7 @@ def _handle_command(cmd: str) -> Any: names = GLOBAL_LEDGER.list_hyperparams() if hasattr(GLOBAL_LEDGER, 'list_hyperparams') else [] if len(parts) == 1: return {'ok': True, 'hyperparams': names} - # support: hp list -> same as hp + # support: hp list -> same as hp name = parts[1] if name.lower() in ('list', 'ls', 'all'): return {'ok': True, 'hyperparams': names} @@ -998,7 +998,7 @@ def _handle_command(cmd: str) -> Any: elif toggle in ('off', 'false', '0', 'disable', 'disabled'): value = False else: - return {'ok': False, 'error': f'Unknown audit toggle "{toggle}". Use: audit on or audit off'} + return {'ok': False, 'error': f'Unknown audit toggle "{toggle}". Use: audit on or audit off'} set_hyperparam(name=name, value=value, key_path='auditor_mode') label = 'enabled' if value else 'disabled' @@ -1123,7 +1123,7 @@ def cli_serve(cli_host: str = 'localhost', cli_port: int = 0, *, spawn_client: b pass srv = None if attempt < max_attempts - 1: - continue # Try next port + continue # Try next port else: # All attempts failed logger.exception("cli_bind_failed_all_attempts") diff --git a/weightslab/backend/dataloader_interface.py b/weightslab/backend/dataloader_interface.py index 2d066666..d9b16e94 100644 --- a/weightslab/backend/dataloader_interface.py +++ b/weightslab/backend/dataloader_interface.py @@ -101,7 +101,7 @@ class WeightsLabDataSampler(Sampler): loader = DataLoader(dataset, batch_sampler=sampler) # Toggle shuffle at runtime - sampler.shuffle = False # Switch to sequential + sampler.shuffle = False # Switch to sequential """ def __init__( @@ -132,7 +132,7 @@ def __init__( self._deny_listed_uids_cache: set[str] = set() self._deny_list_revision: Optional[tuple[str, int]] = None # Evaluation-mode allow-list: when set, only samples whose uid is in - # this set are yielded. None = no filter (normal behaviour). + # this set are yielded. None = no filter (normal behaviour). self._eval_allow_list: Optional[set] = None def _get_deny_listed_uids(self, origin: str = None) -> set: @@ -143,7 +143,7 @@ def _get_deny_listed_uids(self, origin: str = None) -> set: if origin is not None: df_view = self.tracked_dataset._get_df_view(column='origin', value=origin) else: - df_view = self.tracked_dataset._get_df_view() # get all by default + df_view = self.tracked_dataset._get_df_view() # get all by default if not df_view.empty and SampleStatsEx.DISCARDED.value in df_view.columns: discarded_rows = df_view[df_view[SampleStatsEx.DISCARDED.value] == True] @@ -360,7 +360,7 @@ class DataLoaderInterface: from where manual next() left off, and after epoch exhaustion, both patterns automatically reset on the next iteration. - ✅ CORRECT usage - for-loops continue from manual next() position: + CORRECT usage - for-loops continue from manual next() position: loader = DataLoaderInterface(dataset, batch_size=32) @@ -779,12 +779,12 @@ def __iter__(self) -> Iterator: reset if we're already mid-epoch, which allows for-loops to continue from where manual next() calls left off. - ✅ CORRECT usage: + CORRECT usage: while step < max_steps: - data = next(loader) # Manual next (mid-epoch) + data = next(loader) # Manual next (mid-epoch) if step % 5 == 0: - for batch in loader: # For-loop continues from batch position - process(batch) # Gets remaining batches until epoch end + for batch in loader: # For-loop continues from batch position + process(batch) # Gets remaining batches until epoch end step += 1 How it works: @@ -828,7 +828,7 @@ def __next__(self) -> Any: try: data = next(loader) except StopIteration: - data = next(loader) # Auto-resets, returns first batch of next epoch + data = next(loader) # Auto-resets, returns first batch of next epoch 2. For-loop with proper termination: for batch in loader: @@ -837,9 +837,9 @@ def __next__(self) -> Any: 3. Mixed usage: while training: - data = next(loader) # Auto-resets as needed + data = next(loader) # Auto-resets as needed if should_eval(): - for batch in loader: # Gets remaining epoch, exits on StopIteration + for batch in loader: # Gets remaining epoch, exits on StopIteration eval(batch) """ self._sync_batch_size_from_ledger() @@ -921,7 +921,7 @@ def _next_batch(self) -> Any: batch = next(self._iterator) except StopIteration: if hasattr(self, 'is_a_loop') and self.is_a_loop: - raise # Re-raise so __next__() can handle epoch exhaustion + raise # Re-raise so __next__() can handle epoch exhaustion else: self._reset_iterator() batch = next(self._iterator) @@ -931,30 +931,30 @@ def _next_batch(self) -> Any: return batch # def _execute_offset(self) -> None: - # """ - # Execute sample offset if set, skipping samples as needed. - # This is a fallback mechanism for user-supplied dataloaders where - # we cannot use an OffsetSampler. - - # TODO (GP): - # We can reproduce the random generation of samples by restoring RNG state, if during the previous checkpoints, batchsize changed dynamically and shuffle is True. - # """ - # if self._sample_offset > 0: - # current_bs = self.get_batch_size() - # # Fast-forward the iterator by the offset amount - # while len(self._skipped) < self._sample_offset: - # try: - # bs = 4 if self._sample_offset - len(self._skipped) >= 4 else self._sample_offset - len(self._skipped) # Autoscale bs to sample offset - # self.set_batch_size(bs) - # self._skipped.extend(next(self._iterator)[1].detach().cpu().tolist()) - # logger.debug(f"Offset sampler: skipped {len(self._skipped)}/{self._sample_offset}") - # except StopIteration as e: - # logger.debug(f"Offset sampler: reached end of iterator while skipping: {e}") - # self._reset_iterator() # Reset iterator and try again - - # self.set_batch_size(current_bs) - # self._skipped = [] - # self._sample_offset = 0 + # """ + # Execute sample offset if set, skipping samples as needed. + # This is a fallback mechanism for user-supplied dataloaders where + # we cannot use an OffsetSampler. + + # TODO (GP): + # We can reproduce the random generation of samples by restoring RNG state, if during the previous checkpoints, batchsize changed dynamically and shuffle is True. + # """ + # if self._sample_offset > 0: + # current_bs = self.get_batch_size() + # # Fast-forward the iterator by the offset amount + # while len(self._skipped) < self._sample_offset: + # try: + # bs = 4 if self._sample_offset - len(self._skipped) >= 4 else self._sample_offset - len(self._skipped) # Autoscale bs to sample offset + # self.set_batch_size(bs) + # self._skipped.extend(next(self._iterator)[1].detach().cpu().tolist()) + # logger.debug(f"Offset sampler: skipped {len(self._skipped)}/{self._sample_offset}") + # except StopIteration as e: + # logger.debug(f"Offset sampler: reached end of iterator while skipping: {e}") + # self._reset_iterator() # Reset iterator and try again + + # self.set_batch_size(current_bs) + # self._skipped = [] + # self._sample_offset = 0 def _should_persist_workers(self, num_workers: int) -> bool: """Whether to keep DataLoader workers alive across iterator resets. @@ -1025,7 +1025,7 @@ def _reset_iterator(self) -> None: # Only relevant when workers are actually respawning; persistent workers # stay alive, so no settle-delay is needed. if respawning: - time.sleep(0.01) # 10ms delay for worker cleanup + time.sleep(0.01) # 10ms delay for worker cleanup # Create new iterator self._iterator = iter(self.dataloader) @@ -1040,8 +1040,8 @@ def reset_iterator(self) -> None: rng_state = capture_rng_state() batch1 = next(dataloader_interface) restore_rng_state(rng_state) - dataloader_interface.reset_iterator() # Create new iterator with restored RNG - batch1_repeat = next(dataloader_interface) # Same batches! + dataloader_interface.reset_iterator() # Create new iterator with restored RNG + batch1_repeat = next(dataloader_interface) # Same batches! """ self._reset_iterator() diff --git a/weightslab/backend/ledgers.py b/weightslab/backend/ledgers.py index f93d56ea..45a64f42 100644 --- a/weightslab/backend/ledgers.py +++ b/weightslab/backend/ledgers.py @@ -546,7 +546,7 @@ def __next__(self): return next(self._it) except StopIteration: # Let StopIteration propagate naturally - self._it.is_a_loop = False # Loop ends here + self._it.is_a_loop = False # Loop ends here raise except KeyError: # Quiet by default; only surface this diagnostic when the user @@ -743,10 +743,10 @@ def _register_yaml_representers() -> None: compatible with such libraries. Registered on both the default and Safe dumpers. """ def _represent_obj_proxy(dumper, data): - return dumper.represent_data(data.get()) # Proxy -> wrapped object + return dumper.represent_data(data.get()) # Proxy -> wrapped object def _represent_value_proxy(dumper, data): - return dumper.represent_data(data._resolve()) # _ValueProxy -> resolved value + return dumper.represent_data(data._resolve()) # _ValueProxy -> resolved value # Register on every dumper variant: the pure-Python Dumper/SafeDumper, the # libyaml C dumpers (CDumper/CSafeDumper — Ultralytics dumps with CSafeDumper), @@ -1257,7 +1257,7 @@ def list_models() -> List[str]: def register_model(model: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_model(name) # Init empty proxy + get_model(name) # Init empty proxy GLOBAL_LEDGER.register_model(model, weak=weak, name=name) @@ -1275,7 +1275,7 @@ def get_dataloaders(names: Optional[List[str]] = None) -> Dict[str, Any]: def register_dataloaders(dataloaders: Dict[str, Any], weak: bool = False) -> None: """Register multiple dataloaders from a dict, e.g., {'train': train_loader, 'val': val_loader}.""" for k in dataloaders.keys(): - get_dataloader(k) # Init empty proxy - get_dataloaders(list(dataloaders.keys())) + get_dataloader(k) # Init empty proxy - get_dataloaders(list(dataloaders.keys())) GLOBAL_LEDGER.register_dataloaders_dict(dataloaders, weak=weak) def list_dataloaders() -> List[str]: @@ -1283,7 +1283,7 @@ def list_dataloaders() -> List[str]: def register_dataloader(dataloader: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_dataloader(name) # Init the empty proxy first + get_dataloader(name) # Init the empty proxy first GLOBAL_LEDGER.register_dataloader(dataloader, weak=weak, name=name) @@ -1299,7 +1299,7 @@ def list_optimizers() -> List[str]: def register_optimizer(optimizer: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_optimizer(name) # Init the empty proxy first + get_optimizer(name) # Init the empty proxy first GLOBAL_LEDGER.register_optimizer(optimizer, weak=weak, name=name) # Hyperparameters @@ -1322,7 +1322,7 @@ def resolve_hp_name() -> str | None: return 'experiment' # If we have any names at all, returning the first one is better than returning None # and causing a "Cannot resolve hyperparams name" error in the UI. - return names[-1] # first is empty proxy parameters generated at init + return names[-1] # first is empty proxy parameters generated at init def set_hyperparam(key_path: str, value: Any, name: str = DEFAULT_NAME) -> None: try: @@ -1338,14 +1338,14 @@ def unwatch_hyperparams_file(name: str = DEFAULT_NAME) -> None: def register_hyperparams(params: Dict[str, Any] = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_hyperparams(name) # Init empty proxy + get_hyperparams(name) # Init empty proxy GLOBAL_LEDGER.register_hyperparams(params, weak=weak, name=name) # Logger def register_logger(logger: Any = None, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_logger(name) # Init empty proxy + get_logger(name) # Init empty proxy GLOBAL_LEDGER.register_logger(logger, name=name) def get_logger(name: str = DEFAULT_NAME) -> Any: @@ -1361,7 +1361,7 @@ def unregister_logger(name: str = DEFAULT_NAME) -> None: # Signals def register_signal(signal: Any = None, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_signal(name) # Init empty proxy + get_signal(name) # Init empty proxy GLOBAL_LEDGER.register_signal(signal, name=name) def get_signal(name: str = DEFAULT_NAME) -> Any: @@ -1377,7 +1377,7 @@ def unregister_signal(name: str = DEFAULT_NAME) -> None: # Checkpoint managers def register_checkpoint_manager(manager: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> Any: name = DEFAULT_NAME if name is None else name - get_checkpoint_manager(name) # Init empty proxy + get_checkpoint_manager(name) # Init empty proxy return GLOBAL_LEDGER.register_checkpoint_manager(manager, weak=weak, name=name) def get_checkpoint_manager(name: str = DEFAULT_NAME) -> Any: @@ -1393,7 +1393,7 @@ def unregister_checkpoint_manager(name: str = DEFAULT_NAME) -> None: # DataFrames def register_dataframe(dataframe: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_dataframe(name) # Init empty proxy + get_dataframe(name) # Init empty proxy return GLOBAL_LEDGER.register_dataframe(dataframe, weak=weak, name=name) def get_dataframe(name: str = DEFAULT_NAME) -> Any: diff --git a/weightslab/backend/logger.py b/weightslab/backend/logger.py index 2104eff9..dc077e6b 100644 --- a/weightslab/backend/logger.py +++ b/weightslab/backend/logger.py @@ -3,10 +3,10 @@ ``LoggerQueue`` is a thin interface that maps the logger's public methods onto a DuckDB database holding three history tables: -* ``signals`` — aggregated training-curve points (one row per averaged +* ``signals`` — aggregated training-curve points (one row per averaged step entry / evaluation marker). -* ``per_sample`` — per-sample signal values ``(sample_id, step, value)``. -* ``per_instance`` — per-instance values ``(sample_id, annotation_id, step, value)`` +* ``per_sample`` — per-sample signal values ``(sample_id, step, value)``. +* ``per_instance`` — per-instance values ``(sample_id, annotation_id, step, value)`` for detection / segmentation. Design notes @@ -71,7 +71,7 @@ def __init__(self, register: bool = True, db_path: str = ":memory:") -> None: self._eval_mode_hash: str = "" self._eval_mode_split: str = "" self._eval_mode_tags: list[str] = [] - self._eval_accum: dict = {} # {graph_name: [sum, count]} + self._eval_accum: dict = {} # {graph_name: [sum, count]} # DuckDB connection + write-staging buffers. self._lock = threading.RLock() @@ -357,7 +357,7 @@ def start_evaluation_mode(self, split_name: str, eval_hash: str, evaluation_tags """Redirect subsequent add_scalars() calls into the evaluation buffer. While evaluation mode is active, signals are NOT added to the normal - curve history. Instead they accumulate in an internal buffer. + curve history. Instead they accumulate in an internal buffer. ``stop_evaluation_mode()`` finalises the buffer into a single marker. Per-sample history *is* still updated (for Break-By-Slice on eval @@ -488,7 +488,7 @@ def add_scalars(self, graph_name, signal, global_step, signal_per_sample, aggreg for sid, value in signal_per_sample.items(): self._stage_sample_row(graph_name, eval_hash, sid, step_i, self._to_float(value)) - return # Do NOT add to normal history during evaluation mode + return # Do NOT add to normal history during evaluation mode # ------------------------------------------------------------ exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager else None @@ -581,11 +581,11 @@ def print_history(self): for metric_name, experiments in history.items(): print(f"Metric: {metric_name}") for exp_hash, steps in experiments.items(): - print(f" Experiment Hash: {exp_hash}") + print(f" Experiment Hash: {exp_hash}") for step, signals in steps.items(): - print(f" Step: {step}") + print(f" Step: {step}") for signal in signals: - print(f" Signal: {signal}") + print(f" Signal: {signal}") return history def print_history_per_sample(self): @@ -593,9 +593,9 @@ def print_history_per_sample(self): for metric_name, exps in history.items(): print(f"Metric: {metric_name}") for exp_hash, entries in exps.items(): - print(f" Experiment Hash: {exp_hash}") + print(f" Experiment Hash: {exp_hash}") for e in entries: - print(f" Sample ID: {e['sample_id']}, Step: {e['model_age']}, Value: {e['metric_value']}") + print(f" Sample ID: {e['sample_id']}, Step: {e['model_age']}, Value: {e['metric_value']}") return history def print_buffer(self): @@ -736,7 +736,7 @@ def query_per_instance( """Query per-instance signal history. Returns a list of ``(sample_id, annotation_id, step, value, exp_hash)`` - tuples. Any of *sample_id*, *annotation_id*, *exp_hash* may be ``None`` + tuples. Any of *sample_id*, *annotation_id*, *exp_hash* may be ``None`` to return all values along that dimension. """ with self._lock: @@ -1013,9 +1013,9 @@ def load_signal_history_per_sample(self, signals_per_sample): """Load per-sample history. Handles three formats: - - Compact: {graph: {hash: {"_compact": True, "sample_ids": [...], "steps": [...], "values": [...]}}} + - Compact: {graph: {hash: {"_compact": True, "sample_ids": [...], "steps": [...], "values": [...]}}} - Legacy list: {graph: {hash: [{sample_id, model_age, metric_value, ...}, ...]}} - - Legacy dict: {graph: {sample_id_as_key: {model_age, metric_value, ...}}} → stored under None hash + - Legacy dict: {graph: {sample_id_as_key: {model_age, metric_value, ...}}} → stored under None hash """ if not signals_per_sample: return diff --git a/weightslab/backend/model_interface.py b/weightslab/backend/model_interface.py index 73803e7f..773d290a 100755 --- a/weightslab/backend/model_interface.py +++ b/weightslab/backend/model_interface.py @@ -134,7 +134,7 @@ def __init__( if dummy_input is None: raise ValueError("Model object must have 'input_shape' attribute for proper registration with WeightsLab.") else: - self.model.input_shape = tuple(dummy_input.shape[1:]) # Exclude batch dimension + self.model.input_shape = tuple(dummy_input.shape[1:]) # Exclude batch dimension # Move dummy input to the correct device, or create a default one if not provided if dummy_input is not None: @@ -443,7 +443,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): True if an exception occurred and it was successfully handled by this method, preventing it from being re-raised. """ - self.visited_nodes = set() # Reset NetworkWithOps nodes visited + self.visited_nodes = set() # Reset NetworkWithOps nodes visited if exc_type is not None: logger.error( f"[{self.__class__.__name__}]: An exception occurred: \ @@ -492,7 +492,7 @@ def load_state_dict(self, state_dict, strict: bool = True): Note: `assign=False` is explicitly passed so that parameter tensors are updated **in-place** (data copy) rather than replaced with new - objects. Replacing parameter objects (assign=True, the NetworkWithOps + objects. Replacing parameter objects (assign=True, the NetworkWithOps default) would silently invalidate any optimizer that was created before this load_state_dict call, because the optimizer holds references to the old Parameter objects. @@ -885,9 +885,9 @@ def __repr__(self): pass seq_lines = seq_module_repr.split('\n') # The first line is formatted with the name, the rest are indented - seq_string += f" ({seq_name}): {seq_lines[0]}\n" + seq_string += f" ({seq_name}): {seq_lines[0]}\n" for seq_line in seq_lines[1:]: - seq_string += f" {seq_line}\n" + seq_string += f" {seq_line}\n" module_repr = f"{seq_string}" else: module_repr = f"ID=None | {module_repr}" @@ -899,9 +899,9 @@ def __repr__(self): lines = module_repr.split('\n') # The first line is formatted with the name, the rest are indented - string += f" ({name}): {lines[0]}\n" + string += f" ({name}): {lines[0]}\n" for line in lines[1:]: - string += f" {line}\n" + string += f" {line}\n" string += ")" return string diff --git a/weightslab/baseline_models/pytorch/models.py b/weightslab/baseline_models/pytorch/models.py index 9a4dd041..0b82a0b7 100644 --- a/weightslab/baseline_models/pytorch/models.py +++ b/weightslab/baseline_models/pytorch/models.py @@ -38,7 +38,7 @@ def __init__(self): self.m1 = nn.MaxPool2d(2) # Block 2 - self.c2 = nn.Conv2d(4, 4, 3) # Default stride=1, no padding + self.c2 = nn.Conv2d(4, 4, 3) # Default stride=1, no padding self.b2 = nn.BatchNorm2d(4) self.r2 = nn.ReLU() self.m2 = nn.MaxPool2d(2) @@ -154,21 +154,21 @@ def __init__(self): self.input_shape = (1, 1, 28, 28) # Block 1 (Path A) - self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Id 0 + self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Id 0 # Block 2 (Residual/Skip Path) # Note: c2 takes b1's output. c3 takes c2's output. - self.c2 = nn.Conv2d(4, 8, 3, padding=1) # Id 2 - self.c3 = nn.Conv2d(8, 4, 3, padding=1) # Id 3 + self.c2 = nn.Conv2d(4, 8, 3, padding=1) # Id 2 + self.c3 = nn.Conv2d(8, 4, 3, padding=1) # Id 3 def forward(self, x): # Path A - x1 = self.c1(x) # [4, 28, 28] - x2 = self.c2(x1) # [8, 28, 28] - x3 = self.c3(x2) # [4, 28, 28] + x1 = self.c1(x) # [4, 28, 28] + x2 = self.c2(x1) # [8, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual Connection (Add operation) - x_out = x1 + x3 # The output of b1 and c3 both flow into the add op + x_out = x1 + x3 # The output of b1 and c3 both flow into the add op return x_out @@ -199,15 +199,15 @@ def forward(self, x): # Main Path (where the skip connection comes from) x2 = self.c2(x1) - x3 = self.c3(x2) # [4, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x4 = self.c4(x1) - x5 = self.c5(self.b1(x4)) # [4, 28, 28] + x5 = self.c5(self.b1(x4)) # [4, 28, 28] # Residual Connection (Add operation) # Now x3 and x5 have the same shape: B x 4 x 28 x 28 - x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 + x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 return x_out @@ -237,15 +237,15 @@ def forward(self, x): # Main Path (where the skip connection comes from) x2 = self.c2(x1) - x3 = self.c3(x2) # [4, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x4 = self.c4(x1) - x5 = self.b1(x4) # [4, 28, 28] + x5 = self.b1(x4) # [4, 28, 28] # Residual Connection (Add operation) # Now x3 and x5 have the same shape: B x 4 x 28 x 28 - x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 + x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 return x_out @@ -259,7 +259,7 @@ def __init__(self): self.input_shape = (1, 1, 28, 28) # Block 1 (Path A) - Stays the same - self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Input (1), Output (4) + self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Input (1), Output (4) # Block 2 (Main Path) - Stays the same self.c2 = nn.Conv2d(4, 8, 3, padding=1) @@ -279,19 +279,19 @@ def forward(self, x): # Main Path (where the skip connection comes from) x2 = self.c2(x1) - x3 = self.c3(x2) # [4, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x4 = self.c4(x1) - x5 = self.b1(x4) # [4, 28, 28] + x5 = self.b1(x4) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x6 = self.c5(x1) - x7 = self.b2(x6) # [4, 28, 28] + x7 = self.b2(x6) # [4, 28, 28] # Residual Connection (Add operation) # Now x3 and x5 have the same shape: B x 4 x 28 x 28 - x_out = x3 + x5 - x7 # Assuming you intended to add x3 and x5/x4 + x_out = x3 + x5 - x7 # Assuming you intended to add x3 and x5/x4 return x_out @@ -381,7 +381,7 @@ def forward(self, x): # Main path out = self.block_conv1(x) out = self.block_bn1(out) - out = self.block_bn3(out) # Second BN to match original code + out = self.block_bn3(out) # Second BN to match original code out = self.relu(out) out = self.block_conv2(out) @@ -478,7 +478,7 @@ def __init__(self, in_channels=1, out_classes=1): nn.BatchNorm2d(c[1]), nn.ReLU(inplace=True) ) - self.pool1 = nn.MaxPool2d(2) # Downsample 1 + self.pool1 = nn.MaxPool2d(2) # Downsample 1 # --- B. BOTTLENECK --- # 2. BOTTLENECK: Conv -> 16 canaux @@ -516,7 +516,7 @@ def __init__(self, in_channels=1, out_classes=1): def forward(self, x): # 1. ENCODER x1 = self.enc1(x) - p1 = self.pool1(x1) # Skip x1 + p1 = self.pool1(x1) # Skip x1 # 2. BOTTLENECK bottleneck = self.bottleneck(p1) @@ -613,7 +613,7 @@ def __init__(self, *args, **kwargs): ) def forward(self, input): - input = torch.cat([input,]*3, dim=1) # Add channels dim + input = torch.cat([input,]*3, dim=1) # Add channels dim return self.model(input) @@ -880,7 +880,7 @@ def _init_generator_sequential(self, z_dim, img_channels, features_g): # Final Conv: N x 64 x 32 x 32 -> N x 3 x 64 x 64 nn.ConvTranspose2d(features_g, img_channels, kernel_size=4, stride=2, padding=1), - nn.Tanh() # Output range [-1, 1] + nn.Tanh() # Output range [-1, 1] ) def _init_discriminator_sequential(self, img_channels, features_d): @@ -888,7 +888,7 @@ def _init_discriminator_sequential(self, img_channels, features_d): return nn.Sequential( # Input: N x C x 64 x 64 nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, - padding=1), # Output: N x 64 x 32 x 32 + padding=1), # Output: N x 64 x 32 x 32 nn.LeakyReLU(0.2, inplace=True), # Block 2: N x 64 x 32 x 32 -> N x 128 x 16 x 16 @@ -960,7 +960,7 @@ def __init__(self, image_size=784, h_dim=200, z_dim=20): nn.Linear(z_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, image_size), - nn.Sigmoid() # Sigmoid to output pixel values in the range [0, 1] + nn.Sigmoid() # Sigmoid to output pixel values in the range [0, 1] ) def reparameterize(self, mu, log_var): @@ -1041,7 +1041,7 @@ def forward(self, x): x = self.conv2(x) # Flatten layer - x = x.view(x.shape[0], -1) # Flatten the tensor + x = x.view(x.shape[0], -1) # Flatten the tensor # Linear layers and ReLU activation function h_relu = self.linear1(x).clamp(min=0) @@ -1150,7 +1150,7 @@ def double_conv(in_c, out_c): # ------------------ DECODER (Upsampling Path) ------------------ # 4. Up 4 self.up4_up = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=1) - self.up4_conv = double_conv(192, 64) # Input channels: 64 + 64 = 128 + self.up4_conv = double_conv(192, 64) # Input channels: 64 + 64 = 128 # ------------------ OUTPUT Layer ------------------ # 1x1 convolution to map the final feature channels (64) to the number of classes @@ -1165,7 +1165,7 @@ def forward(self, x): # ------------------ DECODER (Concatenate and Convolve) ------------------ # Up 4 - x = self.up4_up(x2) # B3 + x = self.up4_up(x2) # B3 x = self._align_and_concat(x, x1) x = self.up4_conv(x) @@ -1192,7 +1192,7 @@ def _align_and_concat(self, upsampled, skip): upsampled, size=skip.shape[-2:], mode='bilinear', - align_corners=False # Set to False for compatibility and best practice + align_corners=False # Set to False for compatibility and best practice ) # Concatenate along the channel dimension (dim=1) @@ -1246,19 +1246,19 @@ def double_conv(in_c, out_c): # ------------------ DECODER (Upsampling Path) ------------------ # 1. Up 1 (Upsample + Conv + Skip Connection) self.up1_up = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) - self.up1_conv = double_conv(1024, 512) # Input channels: 512 (from up) + 512 (from skip) = 1024 + self.up1_conv = double_conv(1024, 512) # Input channels: 512 (from up) + 512 (from skip) = 1024 # 2. Up 2 self.up2_up = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) - self.up2_conv = double_conv(512, 256) # Input channels: 256 + 256 = 512 + self.up2_conv = double_conv(512, 256) # Input channels: 256 + 256 = 512 # 3. Up 3 self.up3_up = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) - self.up3_conv = double_conv(256, 128) # Input channels: 128 + 128 = 256 + self.up3_conv = double_conv(256, 128) # Input channels: 128 + 128 = 256 # 4. Up 4 self.up4_up = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) - self.up4_conv = double_conv(128, 64) # Input channels: 64 + 64 = 128 + self.up4_conv = double_conv(128, 64) # Input channels: 64 + 64 = 128 # ------------------ OUTPUT Layer ------------------ # 1x1 convolution to map the final feature channels (64) to the number of classes @@ -1278,11 +1278,11 @@ def forward(self, x): x4 = self.down3_conv(x4) x5 = self.down4_pool(x4) - x5 = self.down4_conv(x5) # This is the bottleneck feature map (lowest resolution) + x5 = self.down4_conv(x5) # This is the bottleneck feature map (lowest resolution) # ------------------ DECODER (Concatenate and Convolve) ------------------ # Up 1 - x = self.up1_up(x5) # Upsample x5 + x = self.up1_up(x5) # Upsample x5 x = self._align_and_concat(x, x4) x = self.up1_conv(x) @@ -1324,7 +1324,7 @@ def _align_and_concat(self, upsampled, skip): upsampled, size=skip.shape[-2:], mode='bilinear', - align_corners=False # Set to False for compatibility and best practice + align_corners=False # Set to False for compatibility and best practice ) # Concatenate along the channel dimension (dim=1) @@ -1346,8 +1346,8 @@ def __init__(self, n_channels=3, n_classes=1, filter_list=[64, 128, 256, 512, 10 self.input_shape = (1, n_channels, 256, 256) self.n_channels = n_channels self.n_classes = n_classes - self.filters = filter_list # [F1, F2, F3, F4, F5] - self.F_cat = self.filters[0] * 5 # Total channels in feature concatenation (e.g., 64 * 5 = 320) + self.filters = filter_list # [F1, F2, F3, F4, F5] + self.F_cat = self.filters[0] * 5 # Total channels in feature concatenation (e.g., 64 * 5 = 320) # ------------------- Internal Building Blocks ------------------- @@ -1537,7 +1537,7 @@ def forward(self, x): # Concatenate and convolve d1 = torch.cat((h1_d1, h2_d1, h3_d1, h4_d1, h5_d1), dim=1) - d1 = self.up_conv1(d1) # Output D1 (64 channels, Size H/1) + d1 = self.up_conv1(d1) # Output D1 (64 channels, Size H/1) # ------------------- OUTPUT ------------------- logits = self.outc(d1) @@ -1569,7 +1569,7 @@ def double_conv_3d(in_c, out_c): # --- ENCODER (Contracting Path) --- # Initial convolution and first block - self.inc = double_conv_3d(input_channels, base_channels) # C -> 32 + self.inc = double_conv_3d(input_channels, base_channels) # C -> 32 # Down 1 self.down1_pool = nn.MaxPool3d(kernel_size=2, stride=2) @@ -1606,39 +1606,39 @@ def forward(self, x): # x shape: (B, C, D, H, W) # --- ENCODER --- - x1 = self.inc(x) # B x 32 x D x H x W (Skip 1) + x1 = self.inc(x) # B x 32 x D x H x W (Skip 1) x2 = self.down1_pool(x1) - x2 = self.down1_conv(x2) # B x 64 x D/2 x H/2 x W/2 (Skip 2) + x2 = self.down1_conv(x2) # B x 64 x D/2 x H/2 x W/2 (Skip 2) x3 = self.down2_pool(x2) - x3 = self.down2_conv(x3) # B x 128 x D/4 x H/4 x W/4 (Skip 3) + x3 = self.down2_conv(x3) # B x 128 x D/4 x H/4 x W/4 (Skip 3) x4 = self.down3_pool(x3) - x4 = self.down3_conv(x4) # B x 256 x D/8 x H/8 x W/8 (Bottleneck) + x4 = self.down3_conv(x4) # B x 256 x D/8 x H/8 x W/8 (Bottleneck) # --- DECODER --- # Up 3 - up3 = self.up3_upsample(x4) # B x 128 x D/4 x H/4 x W/4 (Upsampled) + up3 = self.up3_upsample(x4) # B x 128 x D/4 x H/4 x W/4 (Upsampled) # Skip connection: Concatenate with x3 (128 channels) cat3 = torch.cat([x3, up3], dim=1) # B x 256 x D/4 x H/4 x W/4 - x = self.up3_conv(cat3) # B x 128 x D/4 x H/4 x W/4 + x = self.up3_conv(cat3) # B x 128 x D/4 x H/4 x W/4 # Up 2 - up2 = self.up2_upsample(x) # B x 64 x D/2 x H/2 x W/2 + up2 = self.up2_upsample(x) # B x 64 x D/2 x H/2 x W/2 # Skip connection: Concatenate with x2 (64 channels) cat2 = torch.cat([x2, up2], dim=1) # B x 128 x D/2 x H/2 x W/2 - x = self.up2_conv(cat2) # B x 64 x D/2 x H/2 x W/2 + x = self.up2_conv(cat2) # B x 64 x D/2 x H/2 x W/2 # Up 1 - up1 = self.up1_upsample(x) # B x 32 x D x H x W + up1 = self.up1_upsample(x) # B x 32 x D x H x W # Skip connection: Concatenate with x1 (32 channels) cat1 = torch.cat([x1, up1], dim=1) # B x 64 x D x H x W - x = self.up1_conv(cat1) # B x 32 x D x H x W + x = self.up1_conv(cat1) # B x 32 x D x H x W # Final Output - logits = self.out_conv(x) # B x C_out x D x H x W + logits = self.out_conv(x) # B x C_out x D x H x W return logits @@ -1694,7 +1694,7 @@ def __init__(self, S: int = 7, B: int = 2, C: int = 1, image_size: int = 224): self.preprocess = T.Compose([ T.Resize((image_size, image_size)), T.ToTensor(), - T.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet mean/std + T.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet mean/std std=[0.229, 0.224, 0.225]), ]) @@ -1734,12 +1734,12 @@ def decode_preds(self, pred: torch.Tensor, conf_thresh: float = 0.25) -> List[Li grid_y = grid_y.float() for i in range(N): - img_pred = pred[i] # S x S x (B*5 + C) + img_pred = pred[i] # S x S x (B*5 + C) img_boxes = [] # class score per cell (for single class we use sigmoid) if self.C == 1: - cls_score = torch.sigmoid(img_pred[..., -1]) # (S, S) + cls_score = torch.sigmoid(img_pred[..., -1]) # (S, S) else: cls_logits = img_pred[..., self.B * 5:] cls_score_all = F.softmax(cls_logits, dim=-1) @@ -1762,7 +1762,7 @@ def decode_preds(self, pred: torch.Tensor, conf_thresh: float = 0.25) -> List[Li w = tw / S h = th / S - final_conf = conf * cls_score # class-aware confidence + final_conf = conf * cls_score # class-aware confidence mask = final_conf > conf_thresh if mask.any(): @@ -1842,8 +1842,8 @@ class Yolov11(nn.Module): def __init__(self, variant: str = "yolo11n.pt", device=None, img_size: int = 640): super().__init__() try: - from ultralytics import YOLO # type: ignore - except Exception as exc: # pragma: no cover - optional dependency + from ultralytics import YOLO # type: ignore + except Exception as exc: # pragma: no cover - optional dependency raise ImportError( "Ultralytics is required for Yolov11 baseline. Install with: pip install ultralytics" ) from exc @@ -1871,13 +1871,13 @@ def predict(self, x, **kwargs): model = Yolov11() # Predict with the model - results = model.yolo("https://ultralytics.com/images/bus.jpg") # predict on an image + results = model.yolo("https://ultralytics.com/images/bus.jpg") # predict on an image # Access the results for result in results: - xywh = result.boxes.xywh # center-x, center-y, width, height - xywhn = result.boxes.xywhn # normalized - xyxy = result.boxes.xyxy # top-left-x, top-left-y, bottom-right-x, bottom-right-y - xyxyn = result.boxes.xyxyn # normalized - names = [result.names[cls.item()] for cls in result.boxes.cls.int()] # class name of each box - confs = result.boxes.conf # confidence score of each box \ No newline at end of file + xywh = result.boxes.xywh # center-x, center-y, width, height + xywhn = result.boxes.xywhn # normalized + xyxy = result.boxes.xyxy # top-left-x, top-left-y, bottom-right-x, bottom-right-y + xyxyn = result.boxes.xyxyn # normalized + names = [result.names[cls.item()] for cls in result.boxes.cls.int()] # class name of each box + confs = result.boxes.conf # confidence score of each box \ No newline at end of file diff --git a/weightslab/components/__init__.py b/weightslab/components/__init__.py index cbb7d872..a41418b3 100644 --- a/weightslab/components/__init__.py +++ b/weightslab/components/__init__.py @@ -11,18 +11,18 @@ # # Other components # from weightslab.components.tracking import Tracker, TrackingMode -# # from weightslab.components.global_monitoring import GlobalMonitoring # TODO: Fix missing GlobalMonitoring class +# # from weightslab.components.global_monitoring import GlobalMonitoring # TODO: Fix missing GlobalMonitoring class # __all__ = [ -# # Checkpoint management -# 'CheckpointManager', # Manual checkpoint system -# 'ExperimentHashGenerator', +# # Checkpoint management +# 'CheckpointManager', # Manual checkpoint system +# 'ExperimentHashGenerator', -# # Tracking -# 'Tracker', -# 'TrackingMode', +# # Tracking +# 'Tracker', +# 'TrackingMode', -# # Monitoring - commented out until GlobalMonitoring is implemented -# # 'GlobalMonitoring', +# # Monitoring - commented out until GlobalMonitoring is implemented +# # 'GlobalMonitoring', # ] diff --git a/weightslab/components/checkpoint_manager.py b/weightslab/components/checkpoint_manager.py index ca5a8e4a..6dfca3c8 100644 --- a/weightslab/components/checkpoint_manager.py +++ b/weightslab/components/checkpoint_manager.py @@ -6,12 +6,12 @@ Directory Structure: root_log_dir/ - data/ # Data-related files (global) - logs/ # Training logs (global) + data/ # Data-related files (global) + logs/ # Training logs (global) checkpoints/ - manifest.yaml # Tracks all hashes with timestamps + manifest.yaml # Tracks all hashes with timestamps models/ - {hash}/ # 24-byte hash: HP_MODEL_DATA + {hash}/ # 24-byte hash: HP_MODEL_DATA {hash}_step_000100.pt {hash}_architecture.pkl HP/ @@ -120,12 +120,12 @@ def __init__(self, root_log_dir: str = 'root_experiment', load_model: bool = Tru self.hash_generator = ExperimentHashGenerator() self.current_exp_hash: Optional[str] = None self.previous_exp_hash: Optional[str] = None - self.hash_by_module: list = [None, None, None] # HP, MODEL, DATA + self.hash_by_module: list = [None, None, None] # HP, MODEL, DATA # Step tracking self._step_counter = None self._model_init_step = 0 - self._last_time_loaded: Optional[float] = time.time() # Track last load time for model hash uniqueness + self._last_time_loaded: Optional[float] = time.time() # Track last load time for model hash uniqueness # First time only self.first_time = True @@ -151,9 +151,9 @@ def __init__(self, root_log_dir: str = 'root_experiment', load_model: bool = Tru def __repr__(self) -> str: return ( f"CheckpointManager(\n" - f" root_log_dir={self.root_log_dir}\n" - f" current_exp_hash={self.current_exp_hash}\n" - f" step_counter={self._step_counter}\n" + f" root_log_dir={self.root_log_dir}\n" + f" current_exp_hash={self.current_exp_hash}\n" + f" step_counter={self._step_counter}\n" f")" ) @@ -805,7 +805,7 @@ def _save_changes( # Get checkpoint manager hp manager_hp = config.get('checkpoint_manager', {}) if config else {} enable_checkpoints = manager_hp.get('enable_checkpoints', True) - dump_model_architecture = manager_hp.get('dump_model_architecture', False) # Set to false by default + dump_model_architecture = manager_hp.get('dump_model_architecture', False) # Set to false by default dump_model_state = manager_hp.get('dump_model_state', True) dump_optimizer_state = manager_hp.get('dump_optimizer_state', True) dump_data_state = manager_hp.get('dump_data_state', True) @@ -925,7 +925,7 @@ def save_model_checkpoint( 'model_state_dict': model.state_dict(), 'timestamp': datetime.now().isoformat(), 'exp_hash': self.current_exp_hash, - 'rng_state': capture_rng_state(), # Capture RNG state for reproducible training + 'rng_state': capture_rng_state(), # Capture RNG state for reproducible training } # Capture dataloader iteration state(s) for reproducible resume (support multiple loaders) @@ -976,7 +976,7 @@ def save_model_checkpoint( # If model architecture doesn't exist in this hash directory, save a reference to where it is if self.config.get('checkpoint_manager', {}).get('dump_model_architecture', False): - self._save_architecture_reference_if_needed() # TODO (GP): Disable for now because it adds complexity for big models, and we want to ensure architecture is always saved with weights for simplicity + self._save_architecture_reference_if_needed() # TODO (GP): Disable for now because it adds complexity for big models, and we want to ensure architecture is always saved with weights for simplicity # Persist logger queues alongside weight checkpoints try: @@ -1379,7 +1379,7 @@ def _load_architecture_with_retry(self, arch_file: Path, max_retries: int = 5, b break sleep_time = base_delay * (2 ** (attempt - 1)) logger.warning( - f" [WARN] Architecture load locked (attempt {attempt}/{max_retries}). " + f" [WARN] Architecture load locked (attempt {attempt}/{max_retries}). " f"Retrying in {sleep_time:.2f}s..." ) time.sleep(sleep_time) @@ -1388,7 +1388,7 @@ def _load_architecture_with_retry(self, arch_file: Path, max_retries: int = 5, b if isinstance(e, EOFError): sleep_time = base_delay * (2 ** (attempt - 1)) logger.warning( - f" [WARN] Architecture load incomplete (attempt {attempt}/{max_retries}). " + f" [WARN] Architecture load incomplete (attempt {attempt}/{max_retries}). " f"Retrying in {sleep_time:.2f}s..." ) time.sleep(sleep_time) @@ -1635,8 +1635,8 @@ def load_checkpoint(self, # Logger logger.info(f"Loading checkpoint {exp_hash[:16]}...") - logger.info(f" Target: HP={target_hp_hash} MODEL={target_model_hash} DATA={target_data_hash}") - logger.info(f" Current: HP={current_hp_hash} MODEL={current_model_hash} DATA={current_data_hash}") + logger.info(f" Target: HP={target_hp_hash} MODEL={target_model_hash} DATA={target_data_hash}") + logger.info(f" Current: HP={current_hp_hash} MODEL={current_model_hash} DATA={current_data_hash}") # Load model architecture if different, or load only RNG state for reproducibility if model hash is unchanged model_rng_loaded = False @@ -1651,7 +1651,7 @@ def load_checkpoint(self, with open(arch_ref_file, 'r') as f: ref_data = json.load(f) actual_arch_hash = ref_data.get('architecture_hash', exp_hash[8:-8]) - logger.debug(f" Architecture reference found: pointing to hash {actual_arch_hash}") + logger.debug(f" Architecture reference found: pointing to hash {actual_arch_hash}") except Exception as e: logger.warning(f"Failed to load architecture reference: {e}") @@ -1669,12 +1669,12 @@ def load_checkpoint(self, result['model'].guard_testing_context = guard_testing_context result['loaded_components'].add('model') - logger.info(f" [OK] Loaded model architecture from hash {actual_arch_hash[:16]}") - self._last_time_loaded = time.time() # Update last loaded time after successful load + logger.info(f" [OK] Loaded model architecture from hash {actual_arch_hash[:16]}") + self._last_time_loaded = time.time() # Update last loaded time after successful load except Exception as e: - logger.error(f" [ERROR] Failed to load model architecture: {e}") + logger.error(f" [ERROR] Failed to load model architecture: {e}") else: - logger.warning(f" [WARNING] Model architecture file not found: {actual_arch_file}") + logger.warning(f" [WARNING] Model architecture file not found: {actual_arch_file}") elif load_model and (target_model_hash == current_model_hash and not force): # Try to load only the RNG state from the latest model checkpoint for reproducibility @@ -1690,14 +1690,14 @@ def load_checkpoint(self, if rng_state: result['rng_state'] = rng_state model_rng_loaded = True - logger.info(f" [OK] Loaded model RNG state for reproducibility (model unchanged)") - self._last_time_loaded = time.time() # Update last loaded time after successful load + logger.info(f" [OK] Loaded model RNG state for reproducibility (model unchanged)") + self._last_time_loaded = time.time() # Update last loaded time after successful load except Exception as e: - logger.debug(f" [WARNING] Could not load model RNG state: {e}") + logger.debug(f" [WARNING] Could not load model RNG state: {e}") if not model_rng_loaded: - logger.info(f" [-] Model architecture unchanged, using current model") + logger.info(f" [-] Model architecture unchanged, using current model") else: - logger.info(f" [-] Model architecture unchanged, using current model") + logger.info(f" [-] Model architecture unchanged, using current model") # Load model weights (always if requested) if load_weights: @@ -1712,16 +1712,16 @@ def load_checkpoint(self, checkpoint_path = model_dir / manifest_weight_checkpoint if checkpoint_path.exists(): checkpoint_file_to_load = checkpoint_path - logger.debug(f" Using weight checkpoint from manifest: {manifest_weight_checkpoint}") + logger.debug(f" Using weight checkpoint from manifest: {manifest_weight_checkpoint}") # Fallback: scan for weight files (old behavior for backward compatibility) if checkpoint_file_to_load is None: checkpoint_file_to_load = self._select_weight_checkpoint_file(exp_hash, target_step=target_step) if checkpoint_file_to_load is not None: if target_step is None: - logger.debug(f" Using latest weight checkpoint from directory scan: {checkpoint_file_to_load.name}") + logger.debug(f" Using latest weight checkpoint from directory scan: {checkpoint_file_to_load.name}") else: - logger.debug(f" Using closest weight checkpoint for target step {target_step}: {checkpoint_file_to_load.name}") + logger.debug(f" Using closest weight checkpoint for target step {target_step}: {checkpoint_file_to_load.name}") if checkpoint_file_to_load: try: @@ -1733,9 +1733,9 @@ def load_checkpoint(self, checkpoint_rng_state = result['weights'].get('rng_state') if checkpoint_rng_state: result['rng_state'] = checkpoint_rng_state - logger.info(f" [OK] Loaded weights from step {step} with RNG state") + logger.info(f" [OK] Loaded weights from step {step} with RNG state") else: - logger.info(f" [OK] Loaded weights from step {step}") + logger.info(f" [OK] Loaded weights from step {step}") # Extract dataloader iteration state if available dataloader_iter_state = result['weights'].get('dataloader_iteration_state') @@ -1749,12 +1749,12 @@ def load_checkpoint(self, iter_state_map = {'default': dataloader_iter_state} result['dataloader_iteration_state'] = iter_state_map - logger.debug(f" [OK] Found dataloader iteration state(s): {iter_state_map}") + logger.debug(f" [OK] Found dataloader iteration state(s): {iter_state_map}") except Exception as e: - logger.error(f" [ERROR] Failed to load weights: {e}") + logger.error(f" [ERROR] Failed to load weights: {e}") self._last_time_loaded = time.time() else: - logger.warning(f" [WARNING] No weight files found for {exp_hash[8:-8]}") + logger.warning(f" [WARNING] No weight files found for {exp_hash[8:-8]}") # Load config if different if load_config and (target_hp_hash != current_hp_hash or force): @@ -1767,13 +1767,13 @@ def load_checkpoint(self, config_data = yaml.safe_load(f) result['config'] = config_data.get('hyperparameters', config_data) result['loaded_components'].add('config') - logger.info(f" [OK] Loaded config (hash changed)") + logger.info(f" [OK] Loaded config (hash changed)") except Exception as e: - logger.error(f" [ERROR] Failed to load config: {e}") + logger.error(f" [ERROR] Failed to load config: {e}") else: - logger.warning(f" [WARNING] Config file not found: {config_file}") + logger.warning(f" [WARNING] Config file not found: {config_file}") else: - logger.info(f" [-] Config unchanged, using current config") + logger.info(f" [-] Config unchanged, using current config") # Load data snapshot if different, or if only RNG state changed (for reproducibility) if load_data: @@ -1798,19 +1798,19 @@ def load_checkpoint(self, result['loaded_components'].add('data') if rng_state: result['rng_state'] = rng_state - logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows) with RNG state") + logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows) with RNG state") else: - logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows)") + logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows)") elif load_rng_only and rng_state: # Only RNG state is needed for reproducibility result['rng_state'] = rng_state - logger.info(f" [OK] Loaded RNG state for reproducibility (data unchanged)") + logger.info(f" [OK] Loaded RNG state for reproducibility (data unchanged)") else: - logger.info(f" [-] Data state unchanged, using current data") + logger.info(f" [-] Data state unchanged, using current data") except Exception as e: - logger.error(f" [ERROR] Failed to load data snapshot: {e}") + logger.error(f" [ERROR] Failed to load data snapshot: {e}") else: - logger.warning(f" [WARNING] Data snapshot file not found: {json_file}") + logger.warning(f" [WARNING] Data snapshot file not found: {json_file}") logger.info(f"Loaded components: {result['loaded_components']}") return result @@ -1863,8 +1863,8 @@ def load_state( # Apply model (architecture + weights) if 'model' in checkpoint_data['loaded_components']: try: - model = checkpoint_data['model'] # Include model architecture, weights, and optimizer state at this level - # model.update_optimizer() # Update optimizer with new model parameters if needed + model = checkpoint_data['model'] # Include model architecture, weights, and optimizer state at this level + # model.update_optimizer() # Update optimizer with new model parameters if needed # Remove existing locks if hasattr(model, 'guard_testing_context'): @@ -1876,8 +1876,8 @@ def load_state( ledgers.register_model(model) # Set Model Training Guard - guard_training_context.model = model # Train - guard_testing_context.model = model # Eval + guard_training_context.model = model # Train + guard_testing_context.model = model # Eval loaded_step = None if checkpoint_data.get('weights') is not None: @@ -1923,8 +1923,8 @@ def load_state( logger.warning(f"Could not load optimizer state: {e}") # Set Model Training Guard - guard_training_context.model = model # Train - guard_testing_context.model = model # Eval + guard_training_context.model = model # Train + guard_testing_context.model = model # Eval except Exception: if 'model' not in checkpoint_data['loaded_components']: @@ -1956,14 +1956,14 @@ def load_state( setattr(model, 'current_step', step) except Exception: pass - # model.update_optimizer() # Update optimizer with new model parameters if needed + # model.update_optimizer() # Update optimizer with new model parameters if needed logger.info(f"[OK] Applied weights to reloaded model (step {step})") self._model_init_step = step logger.info("Successfully recovered by reloading full checkpoint with architecture and weights") # Set Model Training Guard - guard_training_context.model = model # Train - guard_testing_context.model = model # Eval + guard_training_context.model = model # Train + guard_testing_context.model = model # Eval self.error_loading_checkpoint.remove('weights') if 'weights' in self.error_loading_checkpoint else None except Exception as e: @@ -1979,7 +1979,7 @@ def load_state( self.error_loading_checkpoint.remove('config') if 'config' in self.error_loading_checkpoint else None except Exception as e: logger.error(f"[ERROR] Failed to apply config: {e}") - self.error_loading_checkpoint.append('config') if 'config' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if config application failed + self.error_loading_checkpoint.append('config') if 'config' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if config application failed # Apply data (merge snapshot columns into current dataframe) if 'data' in checkpoint_data['loaded_components']: @@ -2001,7 +2001,7 @@ def load_state( self.error_loading_checkpoint.remove('data') if 'data' in self.error_loading_checkpoint else None except Exception as e: logger.error(f"[ERROR] Failed to apply data: {e}") - self.error_loading_checkpoint.append('data') if 'data' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if data application failed + self.error_loading_checkpoint.append('data') if 'data' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if data application failed # Restore RNG state if provided and not already restored if checkpoint_data.get('rng_state'): @@ -2013,13 +2013,13 @@ def load_state( # # We should re-enable this in the future once we have a more robust solution for managing dataloader iteration state across different types of dataloaders and shuffling state or not. # # Reset dataloaders iterators to ensure reproducibility # for loader_name in ledgers.get_dataloaders(): - # loader = ledgers.get_dataloader(loader_name) + # loader = ledgers.get_dataloader(loader_name) - # if loader is not None: - # # Resume loader state - # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): - # loader.reset_iterator() - # logger.debug(f"Reset iterator for dataloader: {loader}") + # if loader is not None: + # # Resume loader state + # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): + # loader.reset_iterator() + # logger.debug(f"Reset iterator for dataloader: {loader}") # # Restore RNG state again after resetting dataloaders # restore_rng_state(checkpoint_data['rng_state']) @@ -2034,46 +2034,46 @@ def load_state( # # We should re-enable this in the future once we have a more robust solution for managing dataloader iteration state across different types of dataloaders and shuffling state or not. # # Restore dataloader iteration state if provided # if checkpoint_data.get('dataloader_iteration_state'): - # try: - # iter_state_raw = checkpoint_data['dataloader_iteration_state'] - - # # Normalize to mapping loader_name -> state for backward compatibility - # if isinstance(iter_state_raw, dict) and 'samples_yielded' in iter_state_raw: - # state_map = {'default': iter_state_raw} - # elif isinstance(iter_state_raw, dict): - # state_map = iter_state_raw - # else: - # state_map = {'default': iter_state_raw} - - # restored_any = False - # for loader_name in ledgers.get_dataloaders(): - # loader = ledgers.get_dataloader(loader_name) - # if loader is None or not hasattr(loader, 'restore_iteration_state'): - # continue - - # state_for_loader = state_map.get(loader_name) or state_map.get('default') - # if state_for_loader: - # try: - # loader.restore_iteration_state(state_for_loader) - # # Resume loader state - # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): - # loader.reset_iterator() - # logger.debug(f"Reset iterator for dataloader: {loader}") - # logger.info(f"[OK] Restored dataloader iteration state for {loader_name}: {state_for_loader}") - # restored_any = True - # except Exception as inner_e: - # logger.warning(f"[WARNING] Failed to restore iteration state for {loader_name}: {inner_e}") - - # if not restored_any: - # logger.warning("No dataloader iteration state could be applied to registered loaders") - # self.error_loading_checkpoint.remove('dataloader_iteration') if 'dataloader_iteration' in self.error_loading_checkpoint else None - # except Exception as e: - # logger.error(f"[ERROR] Failed to restore dataloader iteration state: {e}") - # self.error_loading_checkpoint.append('dataloader_iteration') if 'dataloader_iteration' not in self.error_loading_checkpoint else None + # try: + # iter_state_raw = checkpoint_data['dataloader_iteration_state'] + + # # Normalize to mapping loader_name -> state for backward compatibility + # if isinstance(iter_state_raw, dict) and 'samples_yielded' in iter_state_raw: + # state_map = {'default': iter_state_raw} + # elif isinstance(iter_state_raw, dict): + # state_map = iter_state_raw + # else: + # state_map = {'default': iter_state_raw} + + # restored_any = False + # for loader_name in ledgers.get_dataloaders(): + # loader = ledgers.get_dataloader(loader_name) + # if loader is None or not hasattr(loader, 'restore_iteration_state'): + # continue + + # state_for_loader = state_map.get(loader_name) or state_map.get('default') + # if state_for_loader: + # try: + # loader.restore_iteration_state(state_for_loader) + # # Resume loader state + # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): + # loader.reset_iterator() + # logger.debug(f"Reset iterator for dataloader: {loader}") + # logger.info(f"[OK] Restored dataloader iteration state for {loader_name}: {state_for_loader}") + # restored_any = True + # except Exception as inner_e: + # logger.warning(f"[WARNING] Failed to restore iteration state for {loader_name}: {inner_e}") + + # if not restored_any: + # logger.warning("No dataloader iteration state could be applied to registered loaders") + # self.error_loading_checkpoint.remove('dataloader_iteration') if 'dataloader_iteration' in self.error_loading_checkpoint else None + # except Exception as e: + # logger.error(f"[ERROR] Failed to restore dataloader iteration state: {e}") + # self.error_loading_checkpoint.append('dataloader_iteration') if 'dataloader_iteration' not in self.error_loading_checkpoint else None # Restore logger snapshot for this experiment if available logger_len = self.get_logger_length() - if load_logger and logger_len == 0: # Load logger if requested and logger is currently empty (e.g. on fresh start) + if load_logger and logger_len == 0: # Load logger if requested and logger is currently empty (e.g. on fresh start) try: self.load_logger_snapshot() self.error_loading_checkpoint.remove('logger') if 'logger' in self.error_loading_checkpoint else None diff --git a/weightslab/components/experiment_hash.py b/weightslab/components/experiment_hash.py index 5f7ac621..48d8ef0f 100644 --- a/weightslab/components/experiment_hash.py +++ b/weightslab/components/experiment_hash.py @@ -72,15 +72,15 @@ def generate_hash( if config != -1: hp_hash = self._hash_config(config) if config is not None else "00000000" else: - hp_hash = self._last_hp_hash or "00000000" # If config is -1, keep previous HP hash to avoid marking as changed + hp_hash = self._last_hp_hash or "00000000" # If config is -1, keep previous HP hash to avoid marking as changed if model != -1: model_hash = self._hash_model(model, model_init_step=model_init_step, _last_time_loaded=_last_time_loaded) if model is not None else "00000000" else: - model_hash = self._last_model_hash or "00000000" # If model is -1, keep previous model hash to avoid marking as changed + model_hash = self._last_model_hash or "00000000" # If model is -1, keep previous model hash to avoid marking as changed if data_state != -1: data_hash = self._hash_data_state(data_state) if data_state is not None else "00000000" else: - data_hash = self._last_data_hash or "00000000" # If data_state is -1, keep previous data hash to avoid marking as changed + data_hash = self._last_data_hash or "00000000" # If data_state is -1, keep previous data hash to avoid marking as changed # Combine into 24-byte hash: HP (8) + MODEL (8) + DATA (8) final_hash = f"{hp_hash}{model_hash}{data_hash}" @@ -92,9 +92,9 @@ def generate_hash( self._last_data_hash = data_hash logger.info(f"Generated experiment hash: {final_hash}- (HP: {hp_hash}, Model: {model_hash}, Data: {data_hash})") - logger.debug(f" HP hash: {hp_hash}") - logger.debug(f" Model hash: {model_hash}") - logger.debug(f" Data hash: {data_hash}") + logger.debug(f" HP hash: {hp_hash}") + logger.debug(f" Model hash: {model_hash}") + logger.debug(f" Data hash: {data_hash}") return final_hash @@ -187,7 +187,7 @@ def _hash_model(self, model: th.nn.Module, model_init_step: int = 0, _last_time_ arch_info = [] # Model class name - arch_info.append(f"previously_loaded:{_last_time_loaded}") # Add a unique timestamp to ensure different hash for each load, even if architecture is the same + arch_info.append(f"previously_loaded:{_last_time_loaded}") # Add a unique timestamp to ensure different hash for each load, even if architecture is the same arch_info.append(f"class:{model.__class__.__name__}") arch_info.append(f"init_step:{int(model_init_step)}") @@ -197,7 +197,7 @@ def _hash_model(self, model: th.nn.Module, model_init_step: int = 0, _last_time_ # Remove these trackers from hash if 'train_dataset_tracker' in name or 'eval_dataset_tracker' in name: continue - if name: # Skip root module + if name: # Skip root module module_info = f"{name}:{module.__class__.__name__}" # Add key parameters for common layer types @@ -236,7 +236,7 @@ def _hash_config(self, config: Dict[str, Any]) -> str: config_cp.pop('root_log_dir', None) config_cp.pop('is_training', None) config_cp.pop('pause_at_step', None) - # config_cp.pop('auditor_mode', None) # Audit should be another state + # config_cp.pop('auditor_mode', None) # Audit should be another state if 'auditor_mode' not in config_cp: config_cp['auditor_mode'] = False diff --git a/weightslab/components/global_monitoring.py b/weightslab/components/global_monitoring.py index 17e9e029..20099775 100644 --- a/weightslab/components/global_monitoring.py +++ b/weightslab/components/global_monitoring.py @@ -146,7 +146,7 @@ def resume(self, force: bool = False) -> bool: self.checkpoint_manager = get_checkpoint_manager() if self.checkpoint_manager != None: self.checkpoint_manager.update_experiment_hash(first_time=True) - self.checkpoint_manager.save_pending_changes() # Write pending change to disk + self.checkpoint_manager.save_pending_changes() # Write pending change to disk hash_by_module = self.checkpoint_manager.hash_by_module else: logger.warning('Cannot access checkpoint manager on resume.') @@ -291,7 +291,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any, f: bool = Fals logger.debug(f"Suppressing exception: {exc_value} in GuardContext.__exit__:") traceback.print_exc() if os.getenv("WL_DEBUG", "0") == "1" else None self.architecture_guard.__exit__(exc_type, exc_value, traceback) - return True # suppress the exception + return True # suppress the exception self.architecture_guard.__exit__(exc_type, exc_value, traceback) @@ -357,11 +357,11 @@ def _pause_hp_sync_loop(poll_interval: float = 3): # # Drive controller from ledger when ledger explicitly sets the flag # controller_running = not controller_paused # if isinstance(hp_is_training, bool): - # if controller_paused and hp_is_training: - # resumed = pause_controller.resume() - # firstresume = False if resumed else True - # elif controller_running and not hp_is_training: - # pause_controller.pause() + # if controller_paused and hp_is_training: + # resumed = pause_controller.resume() + # firstresume = False if resumed else True + # elif controller_running and not hp_is_training: + # pause_controller.pause() # Re-evaluate controller state after potential changes controller_paused = pause_controller.is_paused() @@ -386,5 +386,5 @@ def start_hp_sync_thread_event(): # Start sync thread once at module import if _pause_sync_thread_started: - _pause_sync_thread_started = False # already activated + _pause_sync_thread_started = False # already activated start_hp_sync_thread_event() diff --git a/weightslab/components/parallel_primitives.py b/weightslab/components/parallel_primitives.py index 6d1feeae..63a98ca2 100644 --- a/weightslab/components/parallel_primitives.py +++ b/weightslab/components/parallel_primitives.py @@ -41,7 +41,7 @@ import logging import os -from weightslab.utils import ddp_info # single source of truth for (rank, world) +from weightslab.utils import ddp_info # single source of truth for (rank, world) logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ def ddp_log(msg): print(f"[ddp r{r}/{w}] {msg}", flush=True) -_collectives = 0 # collectives since the last reset (i.e. during this step) +_collectives = 0 # collectives since the last reset (i.e. during this step) def reset_collectives(): @@ -149,7 +149,7 @@ def _gather(obj, what): # --------------------------------------------------------------------------- # State registry — reconciled in ONE bundled broadcast at the anchor. # --------------------------------------------------------------------------- -_REGISTRY = [] # (name, snapshot, apply) +_REGISTRY = [] # (name, snapshot, apply) def register_consistent_state(name, snapshot, apply): @@ -182,7 +182,7 @@ def reconcile_all(): payload = [snapshot] else: payload = [None] - bundle = _broadcast(payload, what="reconcile_all") # collective ALWAYS reached + bundle = _broadcast(payload, what="reconcile_all") # collective ALWAYS reached if r != 0 and bundle: for name, _snap, apply in _REGISTRY: if name in bundle: @@ -222,7 +222,7 @@ def clear_registry(): # payload bounded by the per-step change set, not the dataset size. Merge MUST # be idempotent (a delta may re-flush on retry / respawn). # --------------------------------------------------------------------------- -_OUTBOXES = [] # (name, local_dump, merge) +_OUTBOXES = [] # (name, local_dump, merge) def register_outbox(name, local_dump, merge): @@ -252,7 +252,7 @@ def flush_outbox(): except Exception as exc: payload[name] = None logger.debug("[flush_outbox] dump '%s' failed: %s", name, exc) - bucket = _gather(payload, what="flush_outbox") # collective ALWAYS reached + bucket = _gather(payload, what="flush_outbox") # collective ALWAYS reached if r != 0 or not bucket: return for name, _dump, merge in _OUTBOXES: @@ -278,10 +278,10 @@ def _ensure_core_ddp_registered(): `guard_training_context.__enter__` on first entry per process — by that point the hparam store + dataloaders + pause_controller are all wired up. - - "hparams" — rank 0's hyperparams dict; children diff-apply each leaf. + - "hparams" — rank 0's hyperparams dict; children diff-apply each leaf. - "deny-list" — {origin: discarded sample-id set} across all known loaders; children mirror via the WL discard_samples API. - - "paused" — rank 0's pause_controller.is_paused(); rides in the same + - "paused" — rank 0's pause_controller.is_paused(); rides in the same bundle so sync_step's spin uses ONE broadcast per iter. """ global _CORE_REGISTERED @@ -290,10 +290,10 @@ def _ensure_core_ddp_registered(): # Imports are deferred — this module must stay import-light + cycle-free. from weightslab.components.global_monitoring import pause_controller from weightslab.components.parallel_state import ( - rank0_hparams, apply_hparams, # CONFIG plane ↓ - rank0_df_down_state, apply_df_down_state, # DATAFRAME plane ↓ - local_df_writes, merge_df_writes, # DATAFRAME plane ↑ - local_signal_triples, merge_signal_triples_into_logger, # LOGGER plane ↑ + rank0_hparams, apply_hparams, # CONFIG plane ↓ + rank0_df_down_state, apply_df_down_state, # DATAFRAME plane ↓ + local_df_writes, merge_df_writes, # DATAFRAME plane ↑ + local_signal_triples, merge_signal_triples_into_logger, # LOGGER plane ↑ ) # DOWN reconcile — CONFIG + CONTROL + DATAFRAME (DOWN_ONLY cols) — 1 broadcast register_consistent_state("hparams", rank0_hparams, apply_hparams) @@ -322,11 +322,11 @@ def sync_step(spin_wait=0.5): if not _active(): return rank, _ = ddp_info() - reset_collectives() # logs prior step's count (down+up), then resets + reset_collectives() # logs prior step's count (down+up), then resets while True: - bundle = reconcile_all() # DOWN: 1 broadcast, ALL consistent states + bundle = reconcile_all() # DOWN: 1 broadcast, ALL consistent states if not bundle or not bundle.get("paused", False): - return # → step body runs; UP flush happens in __exit__ + return # → step body runs; UP flush happens in __exit__ # Paused: no busy-spin. Rank 0 blocks on the resume Event (wakes on the gRPC # resume); rank-1+ block inside the next reconcile_all broadcast. Cheap only on # gloo (socket-wait); NCCL would spin (NCCL_BLOCKING_WAIT). The bounded timeout diff --git a/weightslab/components/tracking.py b/weightslab/components/tracking.py index fd56aa38..a3e2c98c 100644 --- a/weightslab/components/tracking.py +++ b/weightslab/components/tracking.py @@ -211,13 +211,13 @@ def update(self, tensor: th.Tensor): # Shape is expected to be in the form [batch_size x neuron_count] if len(tensor.shape) > 2: # raise ValueError( - # f"Neuron stats are updated on a per neuron level, hence only " - # f"two dims are expected [batch_size x neuron_count] but " - # f"activation map has shape: {str(tensor.shape)}") + # f"Neuron stats are updated on a per neuron level, hence only " + # f"two dims are expected [batch_size x neuron_count] but " + # f"activation map has shape: {str(tensor.shape)}") tensor = tensor.view(-1, self.number_of_neurons) try: if tensor.shape == th.Size([]): - tensor = tensor[None, None] # Add one dim + tensor = tensor[None, None] # Add one dim bs = tensor.shape[0] self.triggrs_by_neuron += th.sum( tensor, dim=(0, )).view(-1).long() @@ -286,9 +286,9 @@ def get_neuron_age(self, neuron_id: int): return self.updates_by_neuron[neuron_id].item() # def get_neuron_stats(self, neuron_id: int): - # """ Get how often did this neuron trigger on average. """ - # return self.get_neuron_triggers(neuron_id) / \ - # max(self.get_neuron_age(neuron_id), 1) + # """ Get how often did this neuron trigger on average. """ + # return self.get_neuron_triggers(neuron_id) / \ + # max(self.get_neuron_age(neuron_id), 1) def get_neuron_stats(self, neuron_id: int): """ Get how often did this neuron trigger on average. """ @@ -409,9 +409,9 @@ def update(self, tensor: th.Tensor): super().update(tensor) # Update trackers with class and sample ids. The shapes are expected # to be in the following form: - # * tensor: [batch_size x neuron_count] - # * tensor.in_id_batch: [batch_size] - # * tensor.label_batch: [batch_size] + # * tensor: [batch_size x neuron_count] + # * tensor.in_id_batch: [batch_size] + # * tensor.label_batch: [batch_size] if not hasattr(tensor, 'in_id_batch') or \ not hasattr(tensor, 'label_batch'): diff --git a/weightslab/data/array_proxy.py b/weightslab/data/array_proxy.py index cc8b76db..efec4c97 100644 --- a/weightslab/data/array_proxy.py +++ b/weightslab/data/array_proxy.py @@ -138,9 +138,9 @@ class ArrayAccessor: Pandas DataFrame accessor for automatic array loading. Usage: - df.arrays.load('prediction') # Load all prediction arrays - df.arrays.load_sample(sample_id, 'prediction') # Load specific array - df.arrays.set_store(array_store) # Configure array store + df.arrays.load('prediction') # Load all prediction arrays + df.arrays.load_sample(sample_id, 'prediction') # Load specific array + df.arrays.set_store(array_store) # Configure array store """ def __init__(self, pandas_obj): diff --git a/weightslab/data/data_samples_with_ops.py b/weightslab/data/data_samples_with_ops.py index 72f92795..8236fe5f 100644 --- a/weightslab/data/data_samples_with_ops.py +++ b/weightslab/data/data_samples_with_ops.py @@ -59,7 +59,7 @@ def _match_column_patterns(col: str, patterns: list) -> bool: if re.search(pattern, col): return True except re.error: - pass # Invalid regex, skip + pass # Invalid regex, skip return False @@ -109,18 +109,18 @@ class DataSampleTrackingWrapper(Dataset): Examples: Binary classification based on tags: >>> dataset = DataSampleTrackingWrapper( - ... mnist_train, - ... root_log_dir="./logs", - ... use_tags=True, - ... tags_mapping={'huge': 1} # Images tagged 'huge' → label 1, others → 0 + ... mnist_train, + ... root_log_dir="./logs", + ... use_tags=True, + ... tags_mapping={'huge': 1} # Images tagged 'huge' → label 1, others → 0 ... ) Multiclass classification based on tags: >>> dataset = DataSampleTrackingWrapper( - ... mnist_train, - ... root_log_dir="./logs", - ... use_tags=True, - ... tags_mapping={'small': 0, 'medium': 1, 'large': 2} + ... mnist_train, + ... root_log_dir="./logs", + ... use_tags=True, + ... tags_mapping={'small': 0, 'medium': 1, 'large': 2} ... ) """ def __init__( @@ -167,7 +167,7 @@ def __init__( # Setup H5 persistence path self._root_log_dir = Path(root_log_dir) if root_log_dir else self._resolve_root_log_dir() self._h5_path = None - self._h5_pending_uids = set() # Track UIDs with pending H5 saves + self._h5_pending_uids = set() # Track UIDs with pending H5 saves self._stats_store = stats_store self._enable_h5_persistence = enable_h5_persistence @@ -281,7 +281,7 @@ def __init__( self._map_updates_hook_fns = [] self._df_lock = threading.RLock() self.is_training = is_training - self._dataset_split = split # Store for H5 filename (can be train, test, val, validation, eval, etc.) + self._dataset_split = split # Store for H5 filename (can be train, test, val, validation, eval, etc.) # Initialize DataFrame as single source of truth # Start with defaults for all UIDs (single dict build per row to trim overhead) @@ -315,10 +315,10 @@ def __init__( # to get_items() which may load images and run heavy transforms. raw_item = wrapped_dataset.fast_get_label(p_idx) elif hasattr(wrapped_dataset, 'get_items'): - raw_item = wrapped_dataset.get_items(p_idx, include_metadata=preload_metadata, include_labels=preload_labels, include_images=False) # Try to get metadata if supported + raw_item = wrapped_dataset.get_items(p_idx, include_metadata=preload_metadata, include_labels=preload_labels, include_images=False) # Try to get metadata if supported else: # logger.warning(f"Wrapped dataset for split '{split}' does not have get_items method. Falling back to __getitem__, which may cause issues if the dataset is not designed for it. Consider implementing get_items for better performance and compatibility.") - raw_item = wrapped_dataset[p_idx] # By default load everything + raw_item = wrapped_dataset[p_idx] # By default load everything except Exception as e: logger.error(f"Failed to load physical index {p_idx} during initialization: {e}") continue @@ -363,7 +363,7 @@ def __init__( row = SampleStats.DEFAULTS.copy() row.update({ SampleStatsEx.SAMPLE_ID.value: sid, - # SampleStatsEx.INSTANCE_ID.value: str(0), # Added later in the preprocessing during df registration + # SampleStatsEx.INSTANCE_ID.value: str(0), # Added later in the preprocessing during df registration SampleStatsEx.ORIGIN.value: split, SampleStatsEx.GROUP_ID.value: str(group_id), SampleStatsEx.MEMBER_RANK.value: rank @@ -540,7 +540,7 @@ def _getitem_raw(self, index: int = None, id: int = None): target = self._tags_mapping[tag] break else: - target = 0 # Default to 0 if no tags match the mapping + target = 0 # Default to 0 if no tags match the mapping else: # No mapping provided but use_tags=True: keep original target logger.warning(f"use_tags=True but no tags_mapping provided for sample {id}") @@ -658,7 +658,7 @@ def _generate_unique_ids_parallel(self, dataset: Callable = None) -> np.ndarray: dataset = self.wrapped_dataset if dataset is None else dataset n_samples = len(dataset) - unique_ids = [i for i in range(n_samples)] # Initialize with indices as fallback IDs + unique_ids = [i for i in range(n_samples)] # Initialize with indices as fallback IDs unique_id_to_index = {} def compute_id(idx): @@ -681,11 +681,11 @@ def compute_id(idx): # Generate the ID uid = array_id_2bytes(data_array, return_hex=False, tronc_1byte=True) - uid = str(uid) # Convert to string for consistent handling + uid = str(uid) # Convert to string for consistent handling return idx, uid except Exception as e: logger.warning(f"Failed to generate ID for sample {idx}: {e}") - return idx, str(idx) # Fallback to index as ID + return idx, str(idx) # Fallback to index as ID # Use ThreadPoolExecutor; track progress on completed tasks. with ThreadPoolExecutor(thread_name_prefix="unique_id_generator") as executor: @@ -695,10 +695,10 @@ def compute_id(idx): # Collect results as they complete for future in tqdm(as_completed(futures), total=n_samples, desc="Generating unique IDs", unit="sample"): idx, uid = future.result() - uid = str(uid) # Ensure UID is a string for consistent handling + uid = str(uid) # Ensure UID is a string for consistent handling unique_ids[idx] = uid unique_id_to_index[uid] = idx if uid not in unique_id_to_index else unique_id_to_index[uid] - unique_ids = np.asanyarray(unique_ids, dtype=object) # Use object dtype for string UIDs + unique_ids = np.asanyarray(unique_ids, dtype=object) # Use object dtype for string UIDs return unique_ids, unique_id_to_index def _get_df_view(self, limit: int = -1, column: str = None, value: str = None) -> pd.DataFrame: diff --git a/weightslab/data/data_utils.py b/weightslab/data/data_utils.py index 60c0f998..9c0b256a 100644 --- a/weightslab/data/data_utils.py +++ b/weightslab/data/data_utils.py @@ -36,10 +36,10 @@ def _to_uint8_image(img_float: np.ndarray) -> np.ndarray: img = np.asarray(img_float) if img.ndim == 2: - img = img[..., None] # HxWx1 + img = img[..., None] # HxWx1 if img.shape[-1] == 1: - img = np.repeat(img, 3, axis=-1) # grayscale -> RGB + img = np.repeat(img, 3, axis=-1) # grayscale -> RGB if img.shape[-1] != 3: raise ValueError(f"Expected image with 1 or 3 channels, got shape {img.shape}") @@ -68,8 +68,8 @@ def overlay_gt_pred( pred_value=None, alpha_gt=0.45, alpha_pred=0.45, - color_gt=(0, 255, 0), # green - color_pred=(255, 0, 0), # red + color_gt=(0, 255, 0), # green + color_pred=(255, 0, 0), # red show_overlap_as_yellow=True ) -> np.ndarray: """ @@ -280,11 +280,11 @@ def get_mask(raw, dataset=None, dataset_index=None, raw_data=None): segmentation_map = np.zeros((height, width), dtype=np.int64) # Return segmentation map directly if it matches raw shape - if segmentation_map.shape == raw.shape[-2:]: # B, C, H, W + if segmentation_map.shape == raw.shape[-2:]: # B, C, H, W return raw # Generate segmentation map from bboxes - raw = raw[0] if raw.ndim == 3 else raw # Handle batch dimension if present + raw = raw[0] if raw.ndim == 3 else raw # Handle batch dimension if present for bbox_data in raw: x1, y1, x2, y2 = bbox_data[:4].astype(int) if bbox_data.max() > 1 else (bbox_data[:4] * [width, height, width, height]).astype(int) # Extract class id if available, otherwise use 1 @@ -344,10 +344,10 @@ def _extract_slice_from_4d(np_img: np.ndarray, slice_idx: int = None) -> np.ndar # Now we should have (Z, H, W) or (Z, H, W, C) z_dim = np_img.shape[0] if slice_idx is None: - slice_idx = z_dim // 2 # Middle slice + slice_idx = z_dim // 2 # Middle slice slice_idx = max(0, min(slice_idx, z_dim - 1)) - return np_img[slice_idx] # Returns (H, W) or (H, W, C) + return np_img[slice_idx] # Returns (H, W) or (H, W, C) def _get_image_array_and_metadata(wrapped, index, rank: int = 0) -> tuple: @@ -383,7 +383,7 @@ def _get_image_array_and_metadata(wrapped, index, rank: int = 0) -> tuple: if hasattr(np_img, 'numpy'): np_img = np_img.numpy() - is_volumetric = np_img.ndim >= 4 # 3 is for RGB; while 4 is 3D # TODO (GP): Should be fix because this will not work with grayscale image wo. color channel + is_volumetric = np_img.ndim >= 4 # 3 is for RGB; while 4 is 3D # TODO (GP): Should be fix because this will not work with grayscale image wo. color channel # For 4D volumetric data, detect and transpose channel-first formats: # 1. (C, Z, H, W) → (Z, H, W, C) - channels first in all dimensions @@ -421,7 +421,7 @@ def to_uint8(np_img: np.ndarray) -> np.ndarray: if np.issubdtype(np_img.dtype, np.floating): min_v = float(np.nanmin(np_img)) if np_img.size else 0.0 max_v = float(np.nanmax(np_img)) if np_img.size else 1.0 - if max_v <= 128.0: # Scale floats in [0, ~1] to [0, 255] + if max_v <= 128.0: # Scale floats in [0, ~1] to [0, 255] np_img = (np_img - min_v) / (max_v - min_v + 1e-8) * 255.0 # Clip to valid byte range then cast np_img = np.clip(np_img, 0, 255) @@ -455,9 +455,9 @@ def load_label(dataset, sample_id): def _convert_label(lbl): if isinstance(lbl, list) and len(lbl) and isinstance(lbl[0], (th.Tensor, np.ndarray)): - label = to_numpy_safe(lbl).max(0) # Aggr. instances + label = to_numpy_safe(lbl).max(0) # Aggr. instances else: - label = to_numpy_safe(lbl) # Third element is typically the label + label = to_numpy_safe(lbl) # Third element is typically the label return label # Try common dataset patterns first @@ -466,10 +466,10 @@ def _convert_label(lbl): if isinstance(data, (list, tuple)): if len(data) == 1: - return None # Only data, no label - elif len(data) <= 3: # if len==2|3, data, uids, label, no extra info + return None # Only data, no label + elif len(data) <= 3: # if len==2|3, data, uids, label, no extra info label = _convert_label(data[2]) - elif len(data) > 3: # if len>3, data, uids, label, classes, extra info + elif len(data) > 3: # if len>3, data, uids, label, classes, extra info if len(data) == 4: metadata = data[3] classes = to_numpy_safe(metadata['classes']) if isinstance(metadata, dict) and 'classes' in metadata else None @@ -520,12 +520,12 @@ def load_metadata(dataset, sample_id): if isinstance(data, (list, tuple)): if len(data) == 1: - return None # Only data, no metadata - elif len(data) == 2: # if len==2, only data and uid, no extra info - return None # No metadata, only data and uid - elif len(data) == 3: # if len==3, data, uids, label, no extra info - return None # No metadata, only data, uid, and label - elif len(data) > 3: # if len>3, data, uids, label, classes, extra info + return None # Only data, no metadata + elif len(data) == 2: # if len==2, only data and uid, no extra info + return None # No metadata, only data and uid + elif len(data) == 3: # if len==3, data, uids, label, no extra info + return None # No metadata, only data, uid, and label + elif len(data) > 3: # if len>3, data, uids, label, classes, extra info metadata = {} for item in data[3:]: if isinstance(item, dict): @@ -654,7 +654,7 @@ def load_raw_image_array(dataset, index, rank: int = 0) -> tuple: elif channels == 4: middle_pil = Image.fromarray(middle_slice_uint8, mode="RGBA") else: - middle_pil = Image.fromarray(middle_slice_uint8[..., 0], mode="L") # Fallback + middle_pil = Image.fromarray(middle_slice_uint8[..., 0], mode="L") # Fallback return np_img, is_volumetric, original_shape, middle_pil @@ -710,7 +710,7 @@ def load_uid(dataset, sample_id): if isinstance(data, (list, tuple)): if len(data) == 1: - return None # Only data, no metadata - elif len(data) >= 2: # if len==2, only data and uid, no extra info - return data[1] # Second element is typically the uid + return None # Only data, no metadata + elif len(data) >= 2: # if len==2, only data and uid, no extra info + return data[1] # Second element is typically the uid return None diff --git a/weightslab/data/dataframe_manager.py b/weightslab/data/dataframe_manager.py index fc6c8fd3..3185d785 100644 --- a/weightslab/data/dataframe_manager.py +++ b/weightslab/data/dataframe_manager.py @@ -28,7 +28,7 @@ pd.set_option('future.no_silent_downcasting', True) -logger = logging.getLogger(__name__) # Set up logger +logger = logging.getLogger(__name__) # Set up logger def _safe_update(target: pd.DataFrame, source: pd.DataFrame) -> None: @@ -98,10 +98,10 @@ def __init__(self, flush_interval: float = 3.0, flush_max_rows: int = 100, enabl self._flush_max_rows = flush_max_rows self._flush_thread: threading.Thread | None = None self._flush_stop = threading.Event() - self._flush_event = threading.Event() # Event to wake thread for force flush + self._flush_event = threading.Event() # Event to wake thread for force flush self._flush_queue_count = 0 self._dense_store: Dict[str, Dict[int, np.ndarray]] = {} - self._buffer: Dict[int, Dict[str, Any]] = {} # {sample_id: {col: value}} + self._buffer: Dict[int, Dict[str, Any]] = {} # {sample_id: {col: value}} # Registry of categorical tags: tag_name (without "tag:" prefix) -> ordered # list of allowed category values. Distinguishes multi-value string tags # (e.g. weather -> [rainy, sunny]) from the legacy boolean tags. @@ -175,7 +175,7 @@ def _count_instances(target: Any) -> int: # Check if all items are scalar-like all_scalar = all(isinstance(item, (int, float, np.integer, np.floating)) for item in target) if all_scalar: - return 1 # Single instance with multiple values + return 1 # Single instance with multiple values except Exception: pass @@ -198,8 +198,8 @@ def _instance_targets_list(target: Any) -> list: A single array/tensor/scalar/label is the sample's OWN target and lives on the sample row (instance_id 0), so it yields no separate instance rows. - - list/tuple of array-likes → the list (one entry per instance, rows 1..N) - - everything else → ``[]`` (single-target / classification → only the sample row) + - list/tuple of array-likes → the list (one entry per instance, rows 1..N) + - everything else → ``[]`` (single-target / classification → only the sample row) """ if isinstance(target, (list, tuple)) and len(target) > 0 \ and isinstance(target[0], (np.ndarray, torch.Tensor, list)): @@ -249,10 +249,10 @@ def _expand_records_to_multi_index(self, records: List[Dict[str, Any]]) -> pd.Da sid = self._normalize_sample_id(rec.get(SID)) inst_targets = self._instance_targets_list(rec.get(TARGET)) n_inst = len(inst_targets) - total = n_inst + 1 # +1 for the sample row at instance_id 0 + total = n_inst + 1 # +1 for the sample row at instance_id 0 sample_ids.extend([sid] * total) - annotation_ids.extend(range(total)) # 0 (sample), 1..N (instances) + annotation_ids.extend(range(total)) # 0 (sample), 1..N (instances) for key in keys: val = rec.get(key) @@ -465,7 +465,7 @@ def _auto_register_categorical_tags(self, df: pd.DataFrame) -> None: if isinstance(s.dtype, pd.CategoricalDtype): cats = s.dtype.categories.tolist() if any(isinstance(c, bool) for c in cats): - continue # boolean-style categorical → not a categorical tag + continue # boolean-style categorical → not a categorical tag candidate = cats elif pd.api.types.is_bool_dtype(s.dtype): continue @@ -595,12 +595,12 @@ def _load_existing_data(self, origin: str = None, autoload_arrays: bool | list | _safe_update(self._df, loaded_df) # 2) Append rows that exist ONLY in the loaded df. This is the key - # fix: a freshly-registered loader has just the sample row - # (annotation_id == 0) per sample, while the persisted df from a - # previous run also has the instance rows (annotation_id >= 1). - # Those instance rows must be added back, not dropped. Use a - # boolean mask (not .loc[difference]) so duplicate keys can't be - # re-expanded by the label lookup. + # fix: a freshly-registered loader has just the sample row + # (annotation_id == 0) per sample, while the persisted df from a + # previous run also has the instance rows (annotation_id >= 1). + # Those instance rows must be added back, not dropped. Use a + # boolean mask (not .loc[difference]) so duplicate keys can't be + # re-expanded by the label lookup. new_rows = loaded_df[~loaded_df.index.isin(self._df.index)] if not new_rows.empty: self._df = pd.concat([self._df, new_rows]) @@ -754,7 +754,7 @@ def _is_array_column_to_norm(self, column_name: str, value: Any) -> bool: def _should_array_be_stored(self, array_name) -> bool: """Check if array storage is enabled.""" - return array_name in SAMPLES_STATS_TO_SAVE_TO_H5 # Regexed signals are not considered here + return array_name in SAMPLES_STATS_TO_SAVE_TO_H5 # Regexed signals are not considered here def _is_array_column_to_norm(self, column_name: str, value: Any) -> bool: """True if ``column_name`` is an array column whose ``value`` is an array @@ -921,7 +921,7 @@ def enqueue_batch( preds_raw: np.ndarray | dict | None, preds: np.ndarray | dict | None, losses: Dict[str, Any] | None, - targets: np.ndarray | dict | None = None, + targets: np.ndarray | dict | None = None, step: int | None = None ): """ @@ -970,9 +970,9 @@ def index_batch(obj, batch_index, rec=False): pred = index_batch(pred, batch_index) else: pred = index_batch(preds, batch_index) - pred = pred if is_meaningful(pred) else None # Replace nan by None + pred = pred if is_meaningful(pred) else None # Replace nan by None if pred is not None: - rec[SampleStats.Ex.PREDICTION.value] = self._normalize_preds_raw_uint16(pred) # Not normalized as already integer + rec[SampleStats.Ex.PREDICTION.value] = self._normalize_preds_raw_uint16(pred) # Not normalized as already integer ## Target if targets is not None: target = None @@ -986,9 +986,9 @@ def index_batch(obj, batch_index, rec=False): target = torch.cat((target, targets['cls'][mask]), -1) else: target = index_batch(targets, batch_index) - target = target if is_meaningful(target) else None # Replace nan by None + target = target if is_meaningful(target) else None # Replace nan by None if target is not None: - rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer + rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer ## Step if step is not None and is_meaningful(step): rec[SampleStats.Ex.LAST_SEEN.value] = int(step) @@ -1005,7 +1005,7 @@ def index_batch(obj, batch_index, rec=False): for sample_id, record in records_to_add.items(): self._buffer.setdefault(sample_id, {}).update(record) logger.debug(f"Enqueued {len(records_to_add)} records to buffer. Buffer size is now {len(self._buffer)}.") - should_flush = len(self._buffer) >= self._flush_max_rows or self.first_init # Check buffer size and trigger flush if needed + should_flush = len(self._buffer) >= self._flush_max_rows or self.first_init # Check buffer size and trigger flush if needed # Trigger flush outside lock if should_flush: @@ -1130,11 +1130,11 @@ def _index_target(obj, i): bid = int(targets['batch_idx'][mask].ravel()[0].item()) if sid == usid[bid]: if 'bboxes' in targets: - target = targets['bboxes'][mask][aid_i-1] # aid start to 1 for instance rows + target = targets['bboxes'][mask][aid_i-1] # aid start to 1 for instance rows if 'cls' in targets: - target = torch.cat((target, targets['cls'][mask][aid_i-1]), -1) # aid start to 1 for instance rows + target = torch.cat((target, targets['cls'][mask][aid_i-1]), -1) # aid start to 1 for instance rows if target is not None: - rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer + rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer else: # Nested-list targets are flattened sample-major (targets_rav), so the # i-th flat entry is this i-th instance's target — index by i, not the @@ -1279,7 +1279,7 @@ def update_by_groups_bulk(self, origin: str, group_ids: List[Any], updates_list: self._df.loc[indices, col] = val affected_ids.extend(indices) else: - if not affected_ids: # Only print once to avoid log spam + if not affected_ids: # Only print once to avoid log spam print(f"[DEBUG] Could not find gid {repr(gid)} in gid_to_indices keys. Sample key: {repr(list(gid_to_indices.keys())[0]) if gid_to_indices else 'None'}") if affected_ids: @@ -1720,7 +1720,7 @@ def _apply_buffer_records(self, records: List[Dict[str, Any]]): if not records: return - current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds + current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds logger.debug(f"[{current_time}] [LedgeredDataFrameManager] Applying {len(records)} buffered records to Global DataFrame.") sample_ids = [rec["sample_id"] for rec in records] @@ -1740,7 +1740,7 @@ def _apply_buffer_records(self, records: List[Dict[str, Any]]): # Mark all as pending for h5 flush (outside lock) self.mark_dirty_batch(sample_ids) - current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds + current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds logger.debug(f"[{current_time}] [LedgeredDataFrameManager] Applied {len(records)} buffered records to Global DataFrame.") def _apply_buffer_records_nonblocking(self, records: List[Dict[str, Any]]): @@ -1995,8 +1995,8 @@ def _optimize_dataframe_memory(self, df: pd.DataFrame, categorical_tags: Dict[st # === 3) Categorical conversion (LAST step) === # Columns that are typically repetitive (good candidates for categorical) categorical_candidates = [ - SampleStats.Ex.ORIGIN.value, # Alias for origin (if different column name) - SampleStats.Ex.TASK_TYPE.value, # Task type (e.g. 'classification', 'segmentation') + SampleStats.Ex.ORIGIN.value, # Alias for origin (if different column name) + SampleStats.Ex.TASK_TYPE.value, # Task type (e.g. 'classification', 'segmentation') ] for col in categorical_candidates: @@ -2021,7 +2021,7 @@ def _optimize_dataframe_memory(self, df: pd.DataFrame, categorical_tags: Dict[st n_rows = len(df) compression_ratio = n_unique / n_rows if n_rows > 0 else 1.0 - if compression_ratio < 0.5 and n_unique > 1: # Worth compressing if < 50% unique + if compression_ratio < 0.5 and n_unique > 1: # Worth compressing if < 50% unique try: df[col] = df[col].astype('category') logger.debug( @@ -2086,7 +2086,7 @@ def _worker(): # Forced when buffer is full if force_requested: - self._flush_event.clear() # Clear before flush + self._flush_event.clear() # Clear before flush self.flush() # Wait for flush event (force) or timeout (periodic) @@ -2098,11 +2098,11 @@ def _worker(): if not self._flush_stop.is_set(): self.flush_if_needed_nonblocking(force=True) - self._flush_queue_count = 0 # Reset queue count after periodic flush + self._flush_queue_count = 0 # Reset queue count after periodic flush except Exception as e: traceback_str = traceback.format_exc() logger.error(f"[LedgeredDataFrameManager] Flush loop error: {e}\n{traceback_str}") - st = time.time() # Reset start time after each loop + st = time.time() # Reset start time after each loop self._flush_thread = threading.Thread(target=_worker, name="WL-Ledger_Dataframe_Flush", daemon=True) self._flush_thread.start() @@ -2153,7 +2153,7 @@ def get_collapse_annotations_to_samples_df(self, iid: str = None) -> pd.DataFram The shared dataframe manager now expands every sample into one row per instance/annotation using a ``(sample_id, annotation_id)`` MultiIndex. The studio UI and the agent, however, are sample-centric: they expect a - single row per sample. This helper folds the annotation rows back down: + single row per sample. This helper folds the annotation rows back down: - Sample-level columns (metadata, target, prediction, tags, ...) are duplicated identically on every annotation row, so we keep the first. @@ -2197,7 +2197,7 @@ def get_collapse_annotations_to_samples_df(self, iid: str = None) -> pd.DataFram # straight off the index (or columns) as numpy arrays — no reindex of ``df``. if has_annot_index: if SID not in (df.index.names or []): - return df # Cannot locate the sample level — leave untouched. + return df # Cannot locate the sample level — leave untouched. sid_arr = df.index.get_level_values(SID).to_numpy() annot_arr = np.asarray(df.index.get_level_values(ANNOT).to_numpy()) else: @@ -2270,7 +2270,7 @@ def get_collapse_annotations_to_samples_df(self, iid: str = None) -> pd.DataFram vals = None for c in per_instance_cols: v = col_lists[c][i] - if v is None or v != v: # None or NaN + if v is None or v != v: # None or NaN continue if vals is None: vals = {} @@ -2332,7 +2332,7 @@ def _coerce_df_for_h5(self, df: pd.DataFrame) -> pd.DataFrame: target_dtype = SAMPLES_STATS_DEFAULTS_TYPES[col] # Handle union types (e.g., int | list, str | list) - if hasattr(target_dtype, '__origin__'): # Python 3.10+ union types + if hasattr(target_dtype, '__origin__'): # Python 3.10+ union types if hasattr(target_dtype, '__args__'): target_dtype = target_dtype.__args__[0] @@ -2391,7 +2391,7 @@ def flush_async(self): """Signal flush thread. Returns once buffer has been drained (not after H5 write). Training is only blocked for the brief buffer-drain window (~1ms), not for the - full DF→H5 write. If the buffer refills before the flush thread loops back, the + full DF→H5 write. If the buffer refills before the flush thread loops back, the next call will wait again — bounding in-memory usage to 2×flush_max_rows records. """ with self._queue_lock: @@ -2474,7 +2474,7 @@ def create_ledger_manager(): enable_flushing_threads=enable_flush ) except Exception: - pass # Use defaults if hyperparams not available + pass # Use defaults if hyperparams not available return None @@ -2483,6 +2483,6 @@ def create_ledger_manager(): # from weightslab.backend import ledgers # LM = create_ledger_manager() # try: -# ledgers.register_dataframe(LM) +# ledgers.register_dataframe(LM) # except Exception as e: -# logger.debug(f"Failed to register LedgeredDataFrameManager in ledger: {e}") +# logger.debug(f"Failed to register LedgeredDataFrameManager in ledger: {e}") diff --git a/weightslab/data/h5_array_store.py b/weightslab/data/h5_array_store.py index 716b0bab..346e3b86 100644 --- a/weightslab/data/h5_array_store.py +++ b/weightslab/data/h5_array_store.py @@ -23,7 +23,7 @@ # Config global logger logger = logging.getLogger(__name__) -UINT_DEFAULT = 16 # Default to uint8 for array normalization +UINT_DEFAULT = 16 # Default to uint8 for array normalization class LRUArrayCache: @@ -34,7 +34,7 @@ class LRUArrayCache: Tracks memory usage and provides cache statistics. """ - def __init__(self, max_size_bytes: int = 2 * 1024**3): # 2GB default + def __init__(self, max_size_bytes: int = 2 * 1024**3): # 2GB default """ Initialize LRU cache. @@ -42,7 +42,7 @@ def __init__(self, max_size_bytes: int = 2 * 1024**3): # 2GB default max_size_bytes: Maximum total memory for cached arrays in bytes """ self._max_size = max_size_bytes - self._cache = OrderedDict() # Maintains insertion/access order + self._cache = OrderedDict() # Maintains insertion/access order self._current_size = 0 self._lock = threading.RLock() self._hits = 0 @@ -93,7 +93,7 @@ def put(self, key: str, array: np.ndarray) -> None: # Evict LRU entries until there's space while self._current_size + array_size > self._max_size and self._cache: - lru_key, lru_array = self._cache.popitem(last=False) # FIFO = LRU + lru_key, lru_array = self._cache.popitem(last=False) # FIFO = LRU self._current_size -= self._array_size(lru_array) logger.debug(f"[LRUArrayCache] Evicted {lru_key} to free memory (cache size: {self._current_size / 1024**2:.1f}MB)") @@ -283,7 +283,7 @@ def normalize_array_to_uint(arr: np.ndarray, preserve_original: bool = False, ui if arr_max == arr_min: if arr_max == 0: # All zeros, can store as uint with zero values - metadata['normalized'] = False # No need to normalize if all values are the same + metadata['normalized'] = False # No need to normalize if all values are the same return np.zeros(arr.shape, dtype=uint_dtype), metadata elif arr_max < 2**uint - 1: # Constant array @@ -320,7 +320,7 @@ def denormalize_array(arr: np.ndarray, metadata: Dict[str, Any], uint: int = 16) # Denormalize from uint range arr_min = metadata['min'] arr_max = metadata['max'] - uint = metadata.get('uint', uint) # Default to 16 if not specified + uint = metadata.get('uint', uint) # Default to 16 if not specified original_dtype = np.dtype(metadata['original_dtype']) # Scale back from 0-65535 to original range @@ -361,7 +361,7 @@ def __init__( """ self._path = Path(path) self._local_lock = threading.RLock() - self._rw_lock = _ReadWriteLock() # Read-write lock for concurrent reads + self._rw_lock = _ReadWriteLock() # Read-write lock for concurrent reads self._lock_path = self._path.with_suffix(self._path.suffix + ".lock") self._lock_timeout = lock_timeout self._poll_interval = poll_interval diff --git a/weightslab/data/h5_dataframe_store.py b/weightslab/data/h5_dataframe_store.py index 72c13a8e..6a797a96 100644 --- a/weightslab/data/h5_dataframe_store.py +++ b/weightslab/data/h5_dataframe_store.py @@ -15,7 +15,7 @@ from weightslab.data.sample_stats import SampleStats -logger = logging.getLogger(__name__) # Initialize logger +logger = logging.getLogger(__name__) # Initialize logger # WL signal columns use dotted names (e.g. "signals.defaults.brightness"), which # PyTables flags with NaturalNameWarning because they aren't valid Python @@ -42,7 +42,7 @@ def _align_col_dtype_for_assign(existing: pd.DataFrame, source: pd.DataFrame, co Best-effort: dtype alignment must never break a merge. """ try: - src_kind = source[col].dtype.kind # 'O' object, 'b' bool, 'i'/'u'/'f' numeric + src_kind = source[col].dtype.kind # 'O' object, 'b' bool, 'i'/'u'/'f' numeric tgt_dtype = existing[col].dtype if src_kind in ("O", "b") and tgt_dtype != object and tgt_dtype.kind != src_kind: existing[col] = existing[col].astype(object) @@ -101,7 +101,7 @@ def _unlock(): except OSError: pass - self._unlock = _unlock # type: ignore[attr-defined] + self._unlock = _unlock # type: ignore[attr-defined] while True: if _try_lock(): @@ -113,7 +113,7 @@ def _unlock(): def __exit__(self, exc_type, exc_val, exc_tb): try: if hasattr(self, "_unlock"): - self._unlock() # type: ignore[attr-defined] + self._unlock() # type: ignore[attr-defined] finally: if self._fh: try: @@ -184,9 +184,9 @@ def _extract_tag_columns(self, df: pd.DataFrame) -> dict: # Detect tag type from data non_null = df[col].dropna() if non_null.empty: - tag_cols[col] = None # Auto-detect if all null + tag_cols[col] = None # Auto-detect if all null elif all(isinstance(v, bool) for v in non_null): - tag_cols[col] = [True, False] # Boolean tag + tag_cols[col] = [True, False] # Boolean tag else: # String tag: use unique values as categories tag_cols[col] = non_null.unique().tolist() @@ -446,9 +446,9 @@ def deserialize_value(val): # Everything was persisted as plain strings (see upsert). Reconstruct the # in-memory representation for tag/discarded columns: - # 1. missing tokens ("nan"/"none"/"") → real NaN - # 2. boolean columns ("True"/"False") → real bool (bool('False') is truthy, - # so this MUST run before any bool checks) + # 1. missing tokens ("nan"/"none"/"") → real NaN + # 2. boolean columns ("True"/"False") → real bool (bool('False') is truthy, + # so this MUST run before any bool checks) # String categorical tags keep their string values here; their categorical # dtype + full category set is restored by _optimize_categorical_tags below. _BOOL_TOKENS = {"true": True, "false": False, "1": True, "0": False} @@ -505,7 +505,7 @@ def _verify_checksum(self, store: pd.HDFStore, key: str, expected_checksum: str) try: checksum_key = f"{key}/_checksum" if checksum_key not in store: - return True # No checksum to verify + return True # No checksum to verify checksum_df = store.get(checksum_key) stored_checksum = checksum_df["checksum"].iloc[0] return stored_checksum == expected_checksum @@ -562,7 +562,7 @@ def load(self, origin: str, columns: Optional[Iterable[str]] = None, start: Opti return pd.DataFrame() df = store.select(key, start=start, stop=stop, columns=list(columns) if columns else None) except (FileNotFoundError, OSError, KeyError) as exc: - if not non_blocking: # Only warn on blocking reads + if not non_blocking: # Only warn on blocking reads logger.warning(f"[H5DataFrameStore] Failed to load {key} from {self._path}: {exc}") return pd.DataFrame() except TimeoutError: @@ -811,7 +811,7 @@ def delete_column(self, column_name: str, origins: Optional[Iterable[str]] = Non True if successful, False otherwise """ if not self._path.exists(): - return True # Nothing to delete + return True # Nothing to delete # Create backup BEFORE any modifications backup_path = self._create_backup() diff --git a/weightslab/data/point_cloud_utils.py b/weightslab/data/point_cloud_utils.py index 521a1fec..dc090ae8 100644 --- a/weightslab/data/point_cloud_utils.py +++ b/weightslab/data/point_cloud_utils.py @@ -5,7 +5,7 @@ dimensionality) cannot be PIL-encoded directly, so the studio pipeline previews them as a server-rendered BEV (bird's-eye-view) image: - * thumbnails / preview cache / modal image -> ``point_cloud_to_bev_image`` + * thumbnails / preview cache / modal image -> ``point_cloud_to_bev_image`` * GT / prediction boxes overlaid on the BEV -> ``project_boxes_to_bev`` (3D boxes [cx, cy, cz, dx, dy, dz, yaw, cls?, conf?] or 2D metric boxes [cx, cy, dx, dy, cls?, conf?] -> normalized [x1, y1, x2, y2, cls, conf] @@ -78,9 +78,9 @@ def _default_feature_names(num_features: int) -> list: extra = num_features - len(base) if extra == 1: base = base + ["intensity"] - elif extra == 4: # intensity + normals + elif extra == 4: # intensity + normals base = base + ["intensity", "nx", "ny", "nz"] - elif extra == 3: # normals OR rgb — ambiguous, label generically + elif extra == 3: # normals OR rgb — ambiguous, label generically base = base + ["c0", "c1", "c2"] elif extra > 0: base = base + [f"c{i}" for i in range(extra)] @@ -140,11 +140,11 @@ def compute_point_normals(points: np.ndarray, k: int = 16) -> np.ndarray: k = int(max(3, min(k, n))) tree = cKDTree(xyz) _, idx = tree.query(xyz, k=k) - neigh = xyz[idx] # [M, k, 3] + neigh = xyz[idx] # [M, k, 3] centered = neigh - neigh.mean(axis=1, keepdims=True) cov = np.einsum("mki,mkj->mij", centered, centered) / k # Smallest-eigenvector of each 3x3 covariance is the surface normal. - eigvals, eigvecs = np.linalg.eigh(cov) # ascending eigenvalues + eigvals, eigvecs = np.linalg.eigh(cov) # ascending eigenvalues normals = eigvecs[:, :, 0] # Orient toward the sensor (origin) so shading is consistent. flip = np.einsum("mi,mi->m", normals, -xyz) < 0 @@ -175,7 +175,7 @@ def colorize_from_image(points_xyz, image, project_fn): Args: points_xyz: [M, 3] points in the LiDAR frame. - image: [H, W, 3] uint8 camera image (e.g. KITTI image_2). + image: [H, W, 3] uint8 camera image (e.g. KITTI image_2). project_fn: callable(points_xyz) -> ([M, 2] pixel uv, [M] bool valid) mapping LiDAR points to image pixels (dataset-specific, uses the calibration). Points that fall outside the image / behind @@ -213,7 +213,7 @@ def colorize_from_image(points_xyz, image, project_fn): dtype=np.float32, ) -_BEV_BACKGROUND = (13, 17, 23) # dark slate, matches the studio dark theme +_BEV_BACKGROUND = (13, 17, 23) # dark slate, matches the studio dark theme def default_bev_image_size() -> int: @@ -340,8 +340,8 @@ def point_cloud_to_bev_image( brightness grows with point density. +x is right, +y is up. Args: - points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop; derived + points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop; derived from the points when None. image_size: output resolution (default: WL_BEV_IMAGE_SIZE env or 640). """ @@ -408,12 +408,12 @@ def point_cloud_to_range_image( - Pixel value: distance (and optionally intensity) Args: - points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. - image_height: vertical resolution (elevation bins). - image_width: horizontal resolution (azimuth bins, default 512 like KITTI). - fov_up: max elevation angle in degrees (default 3.0°). - fov_down: min elevation angle in degrees (default -25.0°, typical LiDAR). - mode: "distance" (grayscale distance), "intensity" (intensity with hue), + points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. + image_height: vertical resolution (elevation bins). + image_width: horizontal resolution (azimuth bins, default 512 like KITTI). + fov_up: max elevation angle in degrees (default 3.0°). + fov_down: min elevation angle in degrees (default -25.0°, typical LiDAR). + mode: "distance" (grayscale distance), "intensity" (intensity with hue), or "distance+intensity" (default: distance as brightness, z/intensity as hue). Returns: @@ -433,8 +433,8 @@ def point_cloud_to_range_image( distance = np.sqrt(x**2 + y**2 + z**2) distance = np.maximum(distance, 1e-6) - azimuth = np.arctan2(y, x) # [-pi, pi] - elevation = np.arcsin(np.clip(z / distance, -1.0, 1.0)) # [-pi/2, pi/2] in radians + azimuth = np.arctan2(y, x) # [-pi, pi] + elevation = np.arcsin(np.clip(z / distance, -1.0, 1.0)) # [-pi/2, pi/2] in radians elevation_deg = np.degrees(elevation) # Map to image coordinates @@ -462,7 +462,7 @@ def point_cloud_to_range_image( intensity_norm = np.clip(intensity / (intensity.max() + 1e-6), 0.3, 1.0) colors = np.clip(colors * intensity_norm[:, None], 0, 255).astype(np.uint8) canvas[v, u] = colors - else: # "distance+intensity" (default) + else: # "distance+intensity" (default) # Distance as brightness (grayscale), height/intensity for hue dist_norm = distance / (distance.max() + 1e-6) z_norm = np.clip((z - np.percentile(z, 5)) / (np.percentile(z, 95) - np.percentile(z, 5) + 1e-6), 0.0, 1.0) @@ -494,10 +494,10 @@ def project_boxes_to_bev( """Project metric 3D/2D point-cloud boxes into the BEV image frame. Args: - boxes: [N, C] rows; C >= 7 -> 3D (cx, cy, cz, dx, dy, dz, yaw, + boxes: [N, C] rows; C >= 7 -> 3D (cx, cy, cz, dx, dy, dz, yaw, cls?, conf?), C <= 6 -> 2D metric (cx, cy, dx, dy, cls?, conf?). - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) of the + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) of the rendered BEV image. min_norm_size: minimum normalized box width/height (~2 px at 256) so distant pedestrians stay clickable in thumbnails. @@ -716,7 +716,7 @@ def pack_point_cloud(points: np.ndarray, max_points: int = 0, seed: int = 0): if max_points and pts.shape[0] > max_points: rng = np.random.default_rng(seed) keep = rng.choice(pts.shape[0], int(max_points), replace=False) - keep.sort() # preserve original ordering for cache-friendly decode + keep.sort() # preserve original ordering for cache-friendly decode pts = pts[keep] pts = np.ascontiguousarray(pts, dtype=" List[str]: """Return list of stats to save to H5, conditionally including predictions and targets.""" base_list = [ - "signals.*", # Prefix for dynamic signals - "SIGNALS.*", # Prefix for dynamic signals - "tag.*", # Prefix for dynamic TAG - "TAG.*", # Prefix for dynamic TAG + "signals.*", # Prefix for dynamic signals + "SIGNALS.*", # Prefix for dynamic signals + "tag.*", # Prefix for dynamic TAG + "TAG.*", # Prefix for dynamic TAG cls.Ex.DISCARDED.value, cls.Ex.TAG.value, diff --git a/weightslab/examples/Lightning/ws-classification/main.py b/weightslab/examples/Lightning/ws-classification/main.py index 9c81a46d..2bc9dc6e 100644 --- a/weightslab/examples/Lightning/ws-classification/main.py +++ b/weightslab/examples/Lightning/ws-classification/main.py @@ -44,7 +44,7 @@ def __init__(self, root, train=True, download=False, transform=None): root=root, train=train, download=download, - transform=None # We'll apply transform manually to track filepath + transform=None # We'll apply transform manually to track filepath ) self.transform = transform self.train = train @@ -111,7 +111,7 @@ def forward(self, x): def training_step(self, batch): with guard_training_context: x, ids, y = batch - logits = self(x) # forward pass + logits = self(x) # forward pass preds = torch.argmax(logits, dim=1) # WeightsLab tracked loss @@ -368,14 +368,14 @@ def main(): ) print("=" * 60) - print("🚀 STARTING TRAINING (PyTorch Lightning)") - print(f"📊 Max epochs: {max_epochs}") + print(" STARTING TRAINING (PyTorch Lightning)") + print(f" Max epochs: {max_epochs}") print( - f"⚙️ Trainer: accelerator={trainer_accelerator}, devices={trainer_devices}, " + f" Trainer: accelerator={trainer_accelerator}, devices={trainer_devices}, " f"strategy={trainer_strategy}" ) - print(f"📦 Dataset splits: train={len(_train_dataset)}, val={len(_val_dataset)}") - print(f"💾 Logs will be saved to: {log_dir}") + print(f" Dataset splits: train={len(_train_dataset)}, val={len(_val_dataset)}") + print(f" Logs will be saved to: {log_dir}") print("=" * 60 + "\n") # PyTorch Lightning Trainer @@ -383,7 +383,7 @@ def main(): # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. trainer = pl.Trainer( max_epochs=max_epochs, @@ -401,8 +401,8 @@ def main(): trainer.fit(L_model, train_loader, val_loader) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/PyTorch/ws-classification/main.py b/weightslab/examples/PyTorch/ws-classification/main.py index 8dc53e78..056b9736 100644 --- a/weightslab/examples/PyTorch/ws-classification/main.py +++ b/weightslab/examples/PyTorch/ws-classification/main.py @@ -63,7 +63,7 @@ def __init__(self, root, train=True, download=False, transform=None, max_samples root=root, train=train, download=download, - transform=None # We'll apply transform manually to track filepath + transform=None # We'll apply transform manually to track filepath ) except RuntimeError as e: logger.error(f"Error loading MNIST dataset: {e}") @@ -71,7 +71,7 @@ def __init__(self, root, train=True, download=False, transform=None, max_samples root=root, train=train, download=True, - transform=None # We'll apply transform manually to track filepath + transform=None # We'll apply transform manually to track filepath ) self.transform = transform self.train = train @@ -152,7 +152,7 @@ def train(loader, model, optimizer, criterion_mlt, device): batch_ids=ids, preds=preds ) - total_loss = loss_batch_mlt.mean() # Final scalar loss + total_loss = loss_batch_mlt.mean() # Final scalar loss # Model total_loss.backward() @@ -362,10 +362,10 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): ) print("=" * 60) - print("🚀 STARTING TRAINING") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") + print(" STARTING TRAINING") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") print(f"� Dataset splits: train={len(_train_dataset)}, test={len(_test_dataset)}") - print(f"💾 Logs will be saved to: {log_dir}") + print(f" Logs will be saved to: {log_dir}") print("=" * 60 + "\n") # Setup clean progress bar with custom format @@ -383,13 +383,13 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. train_loss = None test_loss, test_metric = None, None - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm for train_step in train_range: - age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) + age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) # Train one step train_loss = train(train_loader, model, optimizer, train_criterion, device) @@ -428,8 +428,8 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): train_range.set_postfix_str(" | ".join(postfix_parts)) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/PyTorch/ws-clustering/face/data.py b/weightslab/examples/PyTorch/ws-clustering/face/data.py index 672913f4..a41eb97c 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/data.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/data.py @@ -3,17 +3,17 @@ Supported back-ends ------------------- -"olivetti" Olivetti Faces from sklearn (40 identities, 400 images, 64×64 grey). - Self-contained; requires only scikit-learn. Good default toy set. -"lfw" Labeled Faces in the Wild (LFW) via torchvision. Downloaded on - first use to *root*. Much larger but realistic. -"folder" Generic ImageFolder layout: root/{split}/{class_name}/*.jpg +"olivetti" Olivetti Faces from sklearn (40 identities, 400 images, 64×64 grey). + Self-contained; requires only scikit-learn. Good default toy set. +"lfw" Labeled Faces in the Wild (LFW) via torchvision. Downloaded on + first use to *root*. Much larger but realistic. +"folder" Generic ImageFolder layout: root/{split}/{class_name}/*.jpg Every sample is returned as: (image_tensor: Tensor[C,H,W], - uid: str, - label: int, - metadata: dict) + uid: str, + label: int, + metadata: dict) These map directly onto the (data, uid, target, metadata) convention used throughout the WeightsLAB kitchen examples so the same training loop works @@ -41,16 +41,16 @@ class FaceDataset(Dataset): """Unified face recognition dataset wrapper. Args: - root: Download / data root (only used for lfw / folder). - dataset_type: One of "olivetti", "lfw", "folder". - split: "train" or "test" (ignored for pre-split sources). - image_size: Spatial size; images are resized to (image_size, image_size). - train_ratio: Fraction of per-class samples used for training + root: Download / data root (only used for lfw / folder). + dataset_type: One of "olivetti", "lfw", "folder". + split: "train" or "test" (ignored for pre-split sources). + image_size: Spatial size; images are resized to (image_size, image_size). + train_ratio: Fraction of per-class samples used for training (Olivetti only). - min_images_per_class: Classes with fewer samples are discarded. - transform: Optional torchvision transform; defaults to + min_images_per_class: Classes with fewer samples are discarded. + transform: Optional torchvision transform; defaults to Resize → ToTensor → Normalize([0.5], [0.5]). - seed: RNG seed for reproducible train/test splits. + seed: RNG seed for reproducible train/test splits. """ def __init__( @@ -65,12 +65,12 @@ def __init__( seed: int = 42, ): self.dataset_type = dataset_type - self.split = split - self.image_size = image_size - self.transform = transform or self._default_transform(image_size) + self.split = split + self.image_size = image_size + self.transform = transform or self._default_transform(image_size) # These are populated by each loader - self.images: Optional[np.ndarray] = None # (N, H, W) float [0,1] — Olivetti only + self.images: Optional[np.ndarray] = None # (N, H, W) float [0,1] — Olivetti only self.img_paths: Optional[np.ndarray] = None self.labels: np.ndarray = np.array([], dtype=np.int64) self.num_classes: int = 0 @@ -108,67 +108,67 @@ def _load_olivetti(self, train_ratio: float, min_images: int, seed: int): """Load and split the Olivetti Faces dataset (sklearn).""" from sklearn.datasets import fetch_olivetti_faces - data = fetch_olivetti_faces(shuffle=True, random_state=seed) - images = data.images # (400, 64, 64) float [0,1] + data = fetch_olivetti_faces(shuffle=True, random_state=seed) + images = data.images # (400, 64, 64) float [0,1] labels = data.target.astype(np.int64) # Drop classes with insufficient samples unique, counts = np.unique(labels, return_counts=True) - valid_classes = unique[counts >= min_images] - mask = np.isin(labels, valid_classes) + valid_classes = unique[counts >= min_images] + mask = np.isin(labels, valid_classes) images, labels = images[mask], labels[mask] # Remap labels to a contiguous 0…N-1 range mapping = {int(c): i for i, c in enumerate(sorted(valid_classes.tolist()))} - labels = np.array([mapping[int(l)] for l in labels], dtype=np.int64) + labels = np.array([mapping[int(l)] for l in labels], dtype=np.int64) # Per-class stratified train/test split rng = np.random.RandomState(seed) train_idx, test_idx = [], [] for cls in np.unique(labels): - idx = np.where(labels == cls)[0] + idx = np.where(labels == cls)[0] n_train = max(1, int(len(idx) * train_ratio)) - perm = rng.permutation(len(idx)) + perm = rng.permutation(len(idx)) train_idx.extend(idx[perm[:n_train]].tolist()) test_idx.extend(idx[perm[n_train:]].tolist()) - indices = train_idx if self.split == "train" else test_idx - self.images = images[indices] - self.labels = labels[indices] + indices = train_idx if self.split == "train" else test_idx + self.images = images[indices] + self.labels = labels[indices] self.num_classes = len(mapping) def _load_lfw(self, root: str, min_images: int, split: str): """Load LFW People via torchvision (downloads on first call).""" from torchvision.datasets import LFWPeople - split_map = {"train": "train", "test": "test", "val": "10fold"} - lfw_split = split_map.get(split, "train") - ds = LFWPeople(root=root, split=lfw_split, download=True, transform=None) + split_map = {"train": "train", "test": "test", "val": "10fold"} + lfw_split = split_map.get(split, "train") + ds = LFWPeople(root=root, split=lfw_split, download=True, transform=None) paths, lbls = zip(*ds.imgs) - lbls = np.array(lbls, dtype=np.int64) + lbls = np.array(lbls, dtype=np.int64) # Filter low-shot identities unique, counts = np.unique(lbls, return_counts=True) - valid = set(unique[counts >= min_images].tolist()) - mask = np.array([int(l) in valid for l in lbls]) + valid = set(unique[counts >= min_images].tolist()) + mask = np.array([int(l) in valid for l in lbls]) self.img_paths = np.array(paths)[mask] - lbls = lbls[mask] + lbls = lbls[mask] - mapping = {int(c): i for i, c in enumerate(sorted(valid))} - self.labels = np.array([mapping[int(l)] for l in lbls], dtype=np.int64) + mapping = {int(c): i for i, c in enumerate(sorted(valid))} + self.labels = np.array([mapping[int(l)] for l in lbls], dtype=np.int64) self.num_classes = len(mapping) def _load_folder(self, root: str, split: str): """Load from a torchvision ImageFolder directory.""" from torchvision.datasets import ImageFolder - split_dir = os.path.join(root, split) - ds = ImageFolder(split_dir) - paths, lbls = zip(*ds.imgs) + split_dir = os.path.join(root, split) + ds = ImageFolder(split_dir) + paths, lbls = zip(*ds.imgs) self.img_paths = list(paths) - self.labels = np.array(lbls, dtype=np.int64) + self.labels = np.array(lbls, dtype=np.int64) self.num_classes = len(ds.classes) # ---------------------------------------------------------- @@ -188,7 +188,7 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, int, Dict]: if self.dataset_type == "olivetti": from PIL import Image as PILImage - img_np = self.images[idx] # (H, W) float [0,1] + img_np = self.images[idx] # (H, W) float [0,1] img_pil = PILImage.fromarray( (img_np * 255).astype(np.uint8), mode="L" ).convert("RGB") @@ -200,9 +200,9 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, int, Dict]: uid = f"{self.split}_cls{label:04d}_idx{idx:06d}" metadata = { - "split": self.split, - "label_id": label, - "idx": idx, + "split": self.split, + "label_id": label, + "idx": idx, "dataset_type": self.dataset_type, } return image_tensor, uid, label, metadata diff --git a/weightslab/examples/PyTorch/ws-clustering/face/model.py b/weightslab/examples/PyTorch/ws-clustering/face/model.py index 5a7e51d0..5c9974f4 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/model.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/model.py @@ -3,20 +3,20 @@ Architecture ------------ -Pretrained backbone (ResNet-18 / ResNet-50 / MobileNet-V3-Small) +Pretrained backbone (ResNet-18 / ResNet-50 / MobileNet-V3-Small) ? -EmbeddingHead Linear ? BN ? ReLU ? Linear +EmbeddingHead Linear ? BN ? ReLU ? Linear ? L2-normalised D-dimensional embedding The backbone is optionally frozen so that only the lightweight head is trained -(recommended toy-example setup). The combined graph is registered with +(recommended toy-example setup). The combined graph is registered with WeightsLAB for model tracking. Public interface ---------------- -FaceEmbeddingModel.get_embeddings(images) ? normalised embeddings (B, D) -FaceEmbeddingModel.train_step(images, labels, ...) ? scalar loss float +FaceEmbeddingModel.get_embeddings(images) ? normalised embeddings (B, D) +FaceEmbeddingModel.train_step(images, labels, ...) ? scalar loss float """ import logging @@ -64,8 +64,8 @@ def __init__(self, backbone: nn.Module, head: EmbeddingHead): self.head = head def forward(self, x: torch.Tensor) -> torch.Tensor: - features = self.backbone(x) # (B, feature_dim) - embeddings = self.head(features) # (B, embedding_dim), L2-normalised + features = self.backbone(x) # (B, feature_dim) + embeddings = self.head(features) # (B, embedding_dim), L2-normalised return embeddings @@ -126,16 +126,16 @@ class FaceEmbeddingModel: """Wrapper that manages the backbone + head, optimiser, and WeightsLAB tracking. Args: - backbone_name: "resnet18" | "resnet50" | "mobilenet_v3_small" - embedding_dim: Output embedding dimensionality (default 128). + backbone_name: "resnet18" | "resnet50" | "mobilenet_v3_small" + embedding_dim: Output embedding dimensionality (default 128). head_hidden_dim: Hidden size of the projection MLP (default 256). - lr: Learning rate for AdamW (default 1e-3). - weight_decay: AdamW weight decay (default 1e-4). + lr: Learning rate for AdamW (default 1e-3). + weight_decay: AdamW weight decay (default 1e-4). freeze_backbone: When True, only the head's parameters receive gradients ? recommended for quick toy runs. - device: "cpu", "cuda", or "cuda:N". - pretrained: Load ImageNet-pretrained weights for the backbone. - margin: Triplet margin (default 0.3). + device: "cpu", "cuda", or "cuda:N". + pretrained: Load ImageNet-pretrained weights for the backbone. + margin: Triplet margin (default 0.3). """ def __init__( @@ -203,11 +203,11 @@ def __init__( f"trainable_params={n_trainable:,}" ) print( - f" Backbone : {backbone_name} (pretrained={pretrained}, frozen={freeze_backbone})\n" - f" Emb dim : {embedding_dim}\n" - f" Head dim : {head_hidden_dim}\n" - f" Trainable : {n_trainable:,} params\n" - f" Device : {self.device}" + f" Backbone : {backbone_name} (pretrained={pretrained}, frozen={freeze_backbone})\n" + f" Emb dim : {embedding_dim}\n" + f" Head dim : {head_hidden_dim}\n" + f" Trainable : {n_trainable:,} params\n" + f" Device : {self.device}" ) def _build_backbone( @@ -278,8 +278,8 @@ def train_step( """One gradient update using online batch-hard triplet mining. Args: - images: (B, C, H, W) float tensor - labels: (B,) long tensor of identity ids + images: (B, C, H, W) float tensor + labels: (B,) long tensor of identity ids batch_ids: list of sample UIDs for WeightsLAB signal logging loss_name: "triplet" (contrastive support planned) @@ -292,7 +292,7 @@ def train_step( images = images.to(self.device) labels = labels.to(self.device) - embeddings = self.net(images) # (B, D) + embeddings = self.net(images) # (B, D) # Mine hardest triplets in the batch anc_idx, pos_idx, neg_idx = mine_batch_hard(embeddings, labels) diff --git a/weightslab/examples/PyTorch/ws-clustering/face/signals.py b/weightslab/examples/PyTorch/ws-clustering/face/signals.py index 9b14f270..2ef47045 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/signals.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/signals.py @@ -6,8 +6,8 @@ Classes ------- -TripletLosses differentiable loss functions (return torch.Tensor) -FaceMetrics evaluation metrics and clustering-oriented test signals +TripletLosses differentiable loss functions (return torch.Tensor) +FaceMetrics evaluation metrics and clustering-oriented test signals """ import numpy as np diff --git a/weightslab/examples/PyTorch/ws-clustering/face/utils.py b/weightslab/examples/PyTorch/ws-clustering/face/utils.py index acc77706..9e9fa306 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/utils.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/utils.py @@ -21,13 +21,13 @@ def pairwise_distances(embeddings: torch.Tensor, squared: bool = False) -> torch Args: embeddings: (B, D) tensor - squared: return squared L2 distances when True + squared: return squared L2 distances when True Returns: (B, B) distance matrix """ dot = torch.matmul(embeddings, embeddings.t()) - sq_norms = torch.diagonal(dot) # (B,) + sq_norms = torch.diagonal(dot) # (B,) distances = sq_norms.unsqueeze(0) - 2.0 * dot + sq_norms.unsqueeze(1) distances = distances.clamp(min=0.0) @@ -58,8 +58,8 @@ def mine_batch_hard( Args: embeddings: (B, D) detached from graph during mining - labels: (B,) integer class ids - squared: use squared L2 distances for mining + labels: (B,) integer class ids + squared: use squared L2 distances for mining Returns: anc_idx, pos_idx, neg_idx 1-D LongTensors; only valid anchors @@ -69,7 +69,7 @@ def mine_batch_hard( B = labels.shape[0] device = labels.device - same = labels.unsqueeze(0) == labels.unsqueeze(1) # (B, B) + same = labels.unsqueeze(0) == labels.unsqueeze(1) # (B, B) diff = ~same eye = torch.eye(B, dtype=torch.bool, device=device) @@ -105,12 +105,12 @@ def compute_verification_metrics( n = len(embeddings) # Pairwise L2 distances - dot = embeddings @ embeddings.T # (N, N) + dot = embeddings @ embeddings.T # (N, N) sq = np.sum(embeddings ** 2, axis=1) dist_mat = (sq[:, None] - 2.0 * dot + sq[None, :]).clip(min=0.0) dist_mat = np.sqrt(dist_mat.clip(min=1e-16)) * (dist_mat != 0.0) - same_pair = labels[:, None] == labels[None, :] # (N, N) + same_pair = labels[:, None] == labels[None, :] # (N, N) # Upper triangle only (avoid double-counting) iu = np.triu_indices(n, k=1) diff --git a/weightslab/examples/PyTorch/ws-clustering/main.py b/weightslab/examples/PyTorch/ws-clustering/main.py index d7c03615..a1f62461 100644 --- a/weightslab/examples/PyTorch/ws-clustering/main.py +++ b/weightslab/examples/PyTorch/ws-clustering/main.py @@ -6,9 +6,9 @@ trained with online batch-hard triplet loss on the Olivetti Faces dataset. Dataset options (set in config.yaml -> data.dataset_type): - "olivetti" - sklearn Olivetti (40 ids, 400 imgs) - works offline, default - "lfw" - LFW People via torchvision (download required) - "folder" - any ImageFolder-style directory + "olivetti" - sklearn Olivetti (40 ids, 400 imgs) - works offline, default + "lfw" - LFW People via torchvision (download required) + "folder" - any ImageFolder-style directory Training flow ------------- @@ -62,7 +62,7 @@ def evaluate( all_uids: List[str] = [] for images, uids, labels, _metadata in loader: - emb = model.get_embeddings(images) # (B, D) + emb = model.get_embeddings(images) # (B, D) all_embeddings.append(emb.numpy()) if isinstance(labels, torch.Tensor): all_labels.append(labels.numpy()) @@ -80,15 +80,15 @@ def evaluate( name=name, ) - print(f" verification_accuracy : {metrics.get('verification_accuracy', float('nan')):.4f}") - print(f" rank1_accuracy : {metrics.get('rank1_accuracy', float('nan')):.4f}") - print(f" FAR : {metrics.get('far', float('nan')):.4f}") - print(f" FRR : {metrics.get('frr', float('nan')):.4f}") - print(f" best_threshold : {metrics.get('best_threshold', float('nan')):.4f}") + print(f" verification_accuracy : {metrics.get('verification_accuracy', float('nan')):.4f}") + print(f" rank1_accuracy : {metrics.get('rank1_accuracy', float('nan')):.4f}") + print(f" FAR : {metrics.get('far', float('nan')):.4f}") + print(f" FRR : {metrics.get('frr', float('nan')):.4f}") + print(f" best_threshold : {metrics.get('best_threshold', float('nan')):.4f}") if "num_clusters" in metrics: - print(f" num_clusters : {metrics['num_clusters']:.0f}") - print(f" noise_ratio : {metrics['noise_ratio']:.4f}") - print(f" mean_nn1_distance : {metrics['mean_nn1_distance']:.4f}") + print(f" num_clusters : {metrics['num_clusters']:.0f}") + print(f" noise_ratio : {metrics['noise_ratio']:.4f}") + print(f" mean_nn1_distance : {metrics['mean_nn1_distance']:.4f}") return metrics @@ -111,10 +111,10 @@ def train( performed every eval_full_to_train_steps_ratio steps when test_loader is provided. """ print("\n" + "=" * 60) - print("Face Recognition Training (open-ended while loop)") - print(f" Loss : {loss_name}") - print(f" Eval every : {eval_full_to_train_steps_ratio} steps") - print(" Max steps : infinite (stop with Ctrl+C)") + print("Face Recognition Training (open-ended while loop)") + print(f" Loss : {loss_name}") + print(f" Eval every : {eval_full_to_train_steps_ratio} steps") + print(" Max steps : infinite (stop with Ctrl+C)") print("=" * 60) data_iter = iter(train_loader) @@ -183,7 +183,7 @@ def train( print("\nTraining summary:") for k, v in summary.items(): - print(f" {k}: {v}") + print(f" {k}: {v}") return summary @@ -313,16 +313,16 @@ def train( print("\n" + "=" * 60) print("STARTING FACE RECOGNITION TRAINING") - print(f" Experiment : {parameters['experiment_name']}") - print(f" Device : {device}") - print(f" Steps : infinite | eval_full_to_train_steps_ratio={eval_full_to_train_steps_ratio}") - print(f" Loss : {model_cfg.get('loss', 'triplet')}") - print(f" Logs : {parameters['root_log_dir']}") + print(f" Experiment : {parameters['experiment_name']}") + print(f" Device : {device}") + print(f" Steps : infinite | eval_full_to_train_steps_ratio={eval_full_to_train_steps_ratio}") + print(f" Loss : {model_cfg.get('loss', 'triplet')}") + print(f" Logs : {parameters['root_log_dir']}") print("=" * 60) # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. train( model=model, diff --git a/weightslab/examples/PyTorch/ws-detection/main.py b/weightslab/examples/PyTorch/ws-detection/main.py index 63710bd4..9ff5357e 100644 --- a/weightslab/examples/PyTorch/ws-detection/main.py +++ b/weightslab/examples/PyTorch/ws-detection/main.py @@ -49,7 +49,7 @@ def train(loader, model, optimizer, sig, device, grid_size, conf_thresh): targets = [t.to(device) for t in targets] optimizer.zero_grad() - outputs = model(inputs) # [B, S, S, 5 + num_classes] + outputs = model(inputs) # [B, S, S, 5 + num_classes] # Decoded boxes for the UI overlay (detached — display only). preds = decode_predictions(outputs.detach(), grid_size, conf_thresh=conf_thresh) @@ -90,7 +90,7 @@ def test(loader, model, sig, device, grid_size, conf_thresh, test_loader_len): loss = float((losses / test_loader_len).detach().cpu().item()) iou = float((ious / test_loader_len).detach().cpu().item()) - return loss, iou * 100.0 # Return mean IoU as percentage + return loss, iou * 100.0 # Return mean IoU as percentage # ============================================================================= @@ -111,7 +111,7 @@ def test(loader, model, sig, device, grid_size, conf_thresh, test_loader_len): parameters.setdefault("training_steps_to_do", 500) parameters.setdefault("eval_full_to_train_steps_ratio", 50) parameters.setdefault("number_of_workers", 4) - parameters.setdefault("num_classes", 1) # Penn-Fudan: single class (person) + parameters.setdefault("num_classes", 1) # Penn-Fudan: single class (person) parameters.setdefault("image_size", 256) parameters.setdefault("grid_size", 8) parameters.setdefault("conf_thresh", 0.3) @@ -255,7 +255,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): class_counts = np.zeros(num_classes, dtype=np.float64) num_samples = min(len(dataset), max_samples) - for idx in tqdm.tqdm(range(num_samples), desc="📊 Analyzing Distribution"): + for idx in tqdm.tqdm(range(num_samples), desc=" Analyzing Distribution"): _, _, target, _ = dataset.get_items(idx, include_labels=True) if target is None or len(target) == 0: continue @@ -263,10 +263,10 @@ def compute_class_weights(dataset, num_classes, max_samples=200): if 0 <= c < num_classes: class_counts[c] += 1 - class_counts = np.maximum(class_counts, 1) # Avoid div by zero + class_counts = np.maximum(class_counts, 1) # Avoid div by zero total = class_counts.sum() class_weights = total / (num_classes * class_counts) - class_weights = class_weights / class_weights.mean() # Normalize + class_weights = class_weights / class_weights.mean() # Normalize print("\nClass distribution and weights:", flush=True) for c in range(num_classes): @@ -287,16 +287,16 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("=" * 60) - print("🚀 STARTING PENN-FUDAN PEDESTRIAN DETECTION TRAINING") - print(f"📈 Total steps: {max_steps}") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") - print(f"💾 Logs will be saved to: {log_dir}") - print(f"📂 Data root: {data_root}") + print(" STARTING PENN-FUDAN PEDESTRIAN DETECTION TRAINING") + print(f" Total steps: {max_steps}") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") + print(f" Logs will be saved to: {log_dir}") + print(f" Data root: {data_root}") print("=" * 60 + "\n") # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. # ================ train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() @@ -310,7 +310,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): # Test if age == 0 or age % eval_full_to_train_steps_ratio == 0: - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating") if tqdm_display else test_loader test_loss, test_metric = test(test_loader_it, model, test_sig, device, grid_size, conf_thresh, test_loader_len) @@ -332,8 +332,8 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/PyTorch/ws-detection/utils/criterions.py b/weightslab/examples/PyTorch/ws-detection/utils/criterions.py index 700b9099..6feeab6e 100644 --- a/weightslab/examples/PyTorch/ws-detection/utils/criterions.py +++ b/weightslab/examples/PyTorch/ws-detection/utils/criterions.py @@ -13,13 +13,13 @@ # is assigned to the grid cell containing its center; that cell is "responsible" # for predicting the box. # -# * PerSampleDetectionLoss -> one differentiable loss scalar per sample ([B]), -# wrapped with ``per_sample=True`` (the value WL backprops + dashboards). -# * PerSampleIoU -> mean IoU over a sample's boxes ([B]), a metric. -# * PerInstanceIoU -> flat tensor of one IoU per GT box (sample-major -# order), wrapped with ``per_instance=True`` so WL auto-saves it at -# (sample_id, annotation_id). The ordering matches the per-sample target -# iteration, so the wrapper's auto ``batch_idx`` maps each value correctly. +# * PerSampleDetectionLoss -> one differentiable loss scalar per sample ([B]), +# wrapped with ``per_sample=True`` (the value WL backprops + dashboards). +# * PerSampleIoU -> mean IoU over a sample's boxes ([B]), a metric. +# * PerInstanceIoU -> flat tensor of one IoU per GT box (sample-major +# order), wrapped with ``per_instance=True`` so WL auto-saves it at +# (sample_id, annotation_id). The ordering matches the per-sample target +# iteration, so the wrapper's auto ``batch_idx`` maps each value correctly. _EPS = 1e-6 _LAMBDA_COORD = 5.0 @@ -45,12 +45,12 @@ def _responsible_cells(boxes, S): Args: boxes: [N, 4] xyxy in [0, 1]. - S: grid size. + S: grid size. Returns: - rows, cols: [N] long, the responsible cell indices. - off_x, off_y: [N] center offset within the cell, in [0, 1). - w, h: [N] box size as a fraction of the image. + rows, cols: [N] long, the responsible cell indices. + off_x, off_y: [N] center offset within the cell, in [0, 1). + w, h: [N] box size as a fraction of the image. """ cx = (boxes[:, 0] + boxes[:, 2]) / 2 cy = (boxes[:, 1] + boxes[:, 3]) / 2 @@ -69,12 +69,12 @@ def _per_sample_loss(outputs, targets, num_classes, weights=None): B, S = outputs.shape[0], outputs.shape[1] device = outputs.device - obj_logit = outputs[..., 0] # [B, S, S] + obj_logit = outputs[..., 0] # [B, S, S] tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) w_pred = torch.sigmoid(outputs[..., 3]) h_pred = torch.sigmoid(outputs[..., 4]) - cls_logits = outputs[..., 5:] # [B, S, S, C] + cls_logits = outputs[..., 5:] # [B, S, S, C] if weights is not None: weights = torch.as_tensor(weights, device=device, dtype=outputs.dtype) @@ -138,7 +138,7 @@ def _per_box_iou(outputs, targets, grid_size): Returns a list[B] of 1-D tensors (one IoU per box for that sample, in annotation order). Detached — this is a metric, not a loss. """ - boxes_grid, _, _ = decode_grid(outputs, grid_size) # [B, S, S, 4] + boxes_grid, _, _ = decode_grid(outputs, grid_size) # [B, S, S, 4] B = outputs.shape[0] S = grid_size device = outputs.device @@ -154,8 +154,8 @@ def _per_box_iou(outputs, targets, grid_size): gt_boxes = tgt[:, :4] rows, cols, _, _, _, _ = _responsible_cells(gt_boxes, S) - pred_boxes = boxes_grid[s, rows, cols] # [N, 4] - ious = box_iou_xyxy(pred_boxes, gt_boxes) # [N] + pred_boxes = boxes_grid[s, rows, cols] # [N, 4] + ious = box_iou_xyxy(pred_boxes, gt_boxes) # [N] per_sample.append(ious.detach()) return per_sample @@ -222,8 +222,8 @@ def decode_predictions(outputs, grid_size, conf_thresh=0.3, max_det=10): boxes_grid, obj, cls_probs = decode_grid(outputs, grid_size) B, S = outputs.shape[0], grid_size - cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] - score = obj * cls_conf # combined confidence + cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] + score = obj * cls_conf # combined confidence flat_boxes = boxes_grid.view(B, S * S, 4) flat_score = score.view(B, S * S) diff --git a/weightslab/examples/PyTorch/ws-detection/utils/data.py b/weightslab/examples/PyTorch/ws-detection/utils/data.py index c4823197..dbff4a02 100644 --- a/weightslab/examples/PyTorch/ws-detection/utils/data.py +++ b/weightslab/examples/PyTorch/ws-detection/utils/data.py @@ -21,9 +21,9 @@ # box per pedestrian from each mask. Downloaded + extracted on first use. # # On-disk layout after extraction: -# /PennFudanPed/ -# PNGImages/FudanPed00001.png ... -# PedMasks/FudanPed00001_mask.png ... # pixel value k = k-th pedestrian, 0 = bg +# /PennFudanPed/ +# PNGImages/FudanPed00001.png ... +# PedMasks/FudanPed00001_mask.png ... # pixel value k = k-th pedestrian, 0 = bg # # WL renders detection targets/predictions from a per-sample [N, 6] array # ``[x1, y1, x2, y2, class_id, confidence]`` normalized to [0, 1] (GT conf = 1.0) @@ -73,7 +73,7 @@ def _boxes_from_mask(mask_path): mask = np.array(Image.open(mask_path)) h, w = mask.shape[:2] obj_ids = np.unique(mask) - obj_ids = obj_ids[obj_ids != 0] # drop background + obj_ids = obj_ids[obj_ids != 0] # drop background boxes = [] for oid in obj_ids: @@ -91,9 +91,9 @@ class PennFudanDetectionDataset(Dataset): """Pedestrian bounding-box detection over the Penn-Fudan images. Args: - root: directory to download/extract the dataset into. - split: "train" or "val" (deterministic split of the 170 images). - image_size: square resize fed to the model. + root: directory to download/extract the dataset into. + split: "train" or "val" (deterministic split of the 170 images). + image_size: square resize fed to the model. val_fraction: fraction of images held out for validation. max_samples: optional cap on the split size (for quick runs). """ @@ -165,16 +165,16 @@ def _load_boxes(self, mask_path): norm[:, [0, 2]] /= float(w) norm[:, [1, 3]] /= float(h) n = norm.shape[0] - cls = np.zeros((n, 1), dtype=np.float32) # single class: person + cls = np.zeros((n, 1), dtype=np.float32) # single class: person conf = np.ones((n, 1), dtype=np.float32) return np.concatenate([norm, cls, conf], axis=1).astype(np.float32) def __getitem__(self, idx): """Returns (item, uid, target, metadata). - - item: normalized image tensor [C, H, W] - - uid: unique sample id (string) - - target: [N, 6] float32 = [x1, y1, x2, y2, class_id, confidence] + - item: normalized image tensor [C, H, W] + - uid: unique sample id (string) + - target: [N, 6] float32 = [x1, y1, x2, y2, class_id, confidence] - metadata: dict with source paths """ return self.get_items(idx, include_metadata=True, include_labels=True, include_images=True) @@ -210,10 +210,10 @@ def det_collate(batch): sample's boxes in annotation order). Returns: - images: FloatTensor [B, C, H, W] - ids: list[str] of length B + images: FloatTensor [B, C, H, W] + ids: list[str] of length B targets: list[B] of [N_i, 6] float tensors ([x1, y1, x2, y2, cls, conf]) - metas: list[B] of metadata dicts + metas: list[B] of metadata dicts """ images = torch.stack([b[0] for b in batch], dim=0) ids = [b[1] for b in batch] diff --git a/weightslab/examples/PyTorch/ws-detection/utils/model.py b/weightslab/examples/PyTorch/ws-detection/utils/model.py index effbc5f0..ed204a66 100644 --- a/weightslab/examples/PyTorch/ws-detection/utils/model.py +++ b/weightslab/examples/PyTorch/ws-detection/utils/model.py @@ -6,12 +6,12 @@ # (objectness, tx, ty, tw, th, class_logits...). # # Encoding (all coordinates normalized to the [0, 1] image frame): -# * objectness = sigmoid(t_obj) -> P(box present in this cell) -# * cx = (col + sigmoid(tx)) / S -> box center, x -# * cy = (row + sigmoid(ty)) / S -> box center, y -# * w = sigmoid(tw) -> box width (fraction of image) -# * h = sigmoid(th) -> box height (fraction of image) -# * class = softmax(class_logits) +# * objectness = sigmoid(t_obj) -> P(box present in this cell) +# * cx = (col + sigmoid(tx)) / S -> box center, x +# * cy = (row + sigmoid(ty)) / S -> box center, y +# * w = sigmoid(tw) -> box width (fraction of image) +# * h = sigmoid(th) -> box height (fraction of image) +# * class = softmax(class_logits) # # Raw forward output keeps logits (loss applies the activations); `decode` # turns logits into xyxy boxes for metrics and UI rendering. @@ -28,12 +28,12 @@ def decode_grid(outputs, grid_size): encoding lives in exactly one place. Args: - outputs: [B, S, S, 5 + num_classes] raw logits. + outputs: [B, S, S, 5 + num_classes] raw logits. grid_size: S. Returns: - boxes: [B, S, S, 4] xyxy in [0, 1] - obj: [B, S, S] objectness probability + boxes: [B, S, S, 4] xyxy in [0, 1] + obj: [B, S, S] objectness probability cls_probs: [B, S, S, num_classes] class probabilities """ B, S, _, _ = outputs.shape @@ -86,7 +86,7 @@ def __init__( # --- Pretrained backbone (ImageNet) --- weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1 if pretrained else None backbone = mobilenet_v3_small(weights=weights) - self.backbone = backbone.features # [B, 576, H/32, W/32] + self.backbone = backbone.features # [B, 576, H/32, W/32] backbone_out_ch = 576 self.freeze_backbone = freeze_backbone @@ -121,7 +121,7 @@ def forward(self, x): feat = self.backbone(x) feat = self.neck(feat) - out = self.head(feat) # [B, preds_per_cell, S', S'] + out = self.head(feat) # [B, preds_per_cell, S', S'] # Resize feature grid to the configured grid_size. if out.shape[-1] != self.grid_size or out.shape[-2] != self.grid_size: diff --git a/weightslab/examples/PyTorch/ws-generation/main.py b/weightslab/examples/PyTorch/ws-generation/main.py index b1890d5b..a2315af4 100644 --- a/weightslab/examples/PyTorch/ws-generation/main.py +++ b/weightslab/examples/PyTorch/ws-generation/main.py @@ -40,22 +40,22 @@ def __init__(self, in_ch: int = 3, base: int = 4, bottleneck: int = 512, image_s # ---- Encoder ---- self.enc1_conv = nn.Conv2d(in_ch, C1, kernel_size=3, padding=1) - self.enc1_bn = nn.BatchNorm2d(C1) + self.enc1_bn = nn.BatchNorm2d(C1) self.enc1_pool = nn.MaxPool2d(2) self.enc2_conv = nn.Conv2d(C1, C2, kernel_size=3, padding=1) - self.enc2_bn = nn.BatchNorm2d(C2) + self.enc2_bn = nn.BatchNorm2d(C2) self.enc2_pool = nn.MaxPool2d(2) self.enc3_conv = nn.Conv2d(C2, C3, kernel_size=3, padding=1) - self.enc3_bn = nn.BatchNorm2d(C3) + self.enc3_bn = nn.BatchNorm2d(C3) self.enc3_pool = nn.MaxPool2d(2) # ---- Mid / Bottleneck ---- - self.mid_conv3 = nn.Conv2d(C3, C3, kernel_size=3, padding=1) - self.mid_conv5 = nn.Conv2d(C3, C3, kernel_size=5, padding=2) - self.mid_conv7 = nn.Conv2d(C3, C3, kernel_size=7, padding=3) - self.mid_bn = nn.BatchNorm2d(C3 * 3) + self.mid_conv3 = nn.Conv2d(C3, C3, kernel_size=3, padding=1) + self.mid_conv5 = nn.Conv2d(C3, C3, kernel_size=5, padding=2) + self.mid_conv7 = nn.Conv2d(C3, C3, kernel_size=7, padding=3) + self.mid_bn = nn.BatchNorm2d(C3 * 3) # NEW: Spatial path for reconstruction (preserves 2D structure) self.spatial_bottleneck = nn.Conv2d(C3 * 3, C3, kernel_size=1) @@ -66,15 +66,15 @@ def __init__(self, in_ch: int = 3, base: int = 4, bottleneck: int = 512, image_s # ---- Decoder ---- self.up1_conv = nn.Conv2d(C3, C2, kernel_size=3, padding=1) - self.up1_bn = nn.BatchNorm2d(C2) + self.up1_bn = nn.BatchNorm2d(C2) self.up2_conv = nn.Conv2d(C2, C1, kernel_size=3, padding=1) - self.up2_bn = nn.BatchNorm2d(C1) + self.up2_bn = nn.BatchNorm2d(C1) # ---- Heads ---- - self.cls_head = nn.Linear(bottleneck, 1) # anomaly classification - self.recon_head = nn.Conv2d(C1, in_ch, kernel_size=1) # reconstruction - self.embed_head = nn.Linear(bottleneck, 64) # contrastive embedding + self.cls_head = nn.Linear(bottleneck, 1) # anomaly classification + self.recon_head = nn.Conv2d(C1, in_ch, kernel_size=1) # reconstruction + self.embed_head = nn.Linear(bottleneck, 64) # contrastive embedding def forward(self, x): # Encoder @@ -139,11 +139,11 @@ def __init__(self, root, split="train", transform=None): print(f"Warning: split directory {split_dir} not found.") return - for folder in sorted(os.listdir(split_dir)): # sorted for determinism + for folder in sorted(os.listdir(split_dir)): # sorted for determinism folder_path = os.path.join(split_dir, folder) if not os.path.isdir(folder_path): continue - for fname in sorted(os.listdir(folder_path)): # sorted for determinism + for fname in sorted(os.listdir(folder_path)): # sorted for determinism if fname.lower().endswith(('.png', '.jpg', '.jpeg')): full_path = os.path.join(folder_path, fname) if folder == 'good': @@ -176,9 +176,9 @@ def __getitem__(self, idx): # Fully Balanced Pairing Strategy (50/50 Pairs, 50/50 Labels) # Cycle through 4 types of pairs based on idx % 4: # 0: Good + Good (Positive Contrastive, 100% Good Labels) - # 1: Bad + Bad (Positive Contrastive, 100% Bad Labels) - # 2: Good + Bad (Negative Contrastive, 50/50 Labels) - # 3: Bad + Good (Negative Contrastive, 50/50 Labels) + # 1: Bad + Bad (Positive Contrastive, 100% Bad Labels) + # 2: Good + Bad (Negative Contrastive, 50/50 Labels) + # 3: Bad + Good (Negative Contrastive, 50/50 Labels) p_type = idx % 4 if p_type == 0: @@ -220,10 +220,10 @@ def __getitem__(self, idx): group_id = f"{self.split}_pair_{uid1}_{uid2}" return ( - [img1_t, img2_t], # The pair of inputs - [idx1, idx2], # Relative indices - [label1, label2], # Individual labels - { # Metadata — "uids" causes 2 ledger rows per pair + [img1_t, img2_t], # The pair of inputs + [idx1, idx2], # Relative indices + [label1, label2], # Individual labels + { # Metadata — "uids" causes 2 ledger rows per pair "group_id": group_id, "uids": [uid1, uid2], } @@ -443,7 +443,7 @@ def evaluate_all(loader, model, cls_criterion, contrastive_criterion, metric, de ]) _train_ds = VADDataset(data_root, split="train", transform=transform) - _test_ds = VADDataset(data_root, split="test", transform=transform) + _test_ds = VADDataset(data_root, split="test", transform=transform) train_loader = wl.watch_or_edit( _train_ds, flag="data", loader_name="train_loader", @@ -486,7 +486,7 @@ def evaluate_all(loader, model, cls_criterion, contrastive_criterion, metric, de # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. pbar = tqdm(range(training_steps), desc="Training") for step in pbar: diff --git a/weightslab/examples/PyTorch/ws-segmentation/main.py b/weightslab/examples/PyTorch/ws-segmentation/main.py index 1f184724..3f3f7059 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/main.py +++ b/weightslab/examples/PyTorch/ws-segmentation/main.py @@ -45,13 +45,13 @@ def _instance_batch_idx(labels): def _run_instance_signals(sig, outputs, labels, ids, preds, return_metric=False): """Compute + log/save the per-sample AND per-instance Dice (metric) and BCE (loss).""" bce_sample = sig["bce_sample"](outputs, labels, batch_ids=ids, preds=preds) - dice_sample = sig["dice_sample"](outputs, labels, batch_ids=ids) # Register processed predictions one time only + dice_sample = sig["dice_sample"](outputs, labels, batch_ids=ids) # Register processed predictions one time only - sig["dice_instance"](outputs, labels, batch_ids=ids) # Register processed predictions one time only + sig["dice_instance"](outputs, labels, batch_ids=ids) # Register processed predictions one time only sig["bce_instance"](outputs, labels, batch_ids=ids) avg_loss = 0.5 * dice_sample + 0.5 * bce_sample - wl.save_signals({"combined_bce_dice_per_sample": avg_loss}, ids) # Save the per-sample aggregate loss for backward step + wl.save_signals({"combined_bce_dice_per_sample": avg_loss}, ids) # Save the per-sample aggregate loss for backward step if return_metric: return avg_loss, dice_sample return avg_loss @@ -91,11 +91,11 @@ def train(loader, model, optimizer, sig, device): with guard_training_context: (inputs, ids, labels, _) = next(loader) inputs = inputs.to(device) - labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances + labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances optimizer.zero_grad() - outputs = model(inputs) # [B,C,H,W] - preds = outputs.argmax(dim=1) # [B,H,W] + outputs = model(inputs) # [B,C,H,W] + preds = outputs.argmax(dim=1) # [B,H,W] # Per-instance + per-sample Dice/BCE (tracked & saved at annotation level). loss_per_sample = _run_instance_signals(sig, outputs, labels, ids, preds=preds) @@ -110,7 +110,7 @@ def train(loader, model, optimizer, sig, device): wl.save_signals( _user_custom_signals(preds, labels), ids - ) # Save the per-sample predictions for visualization + ) # Save the per-sample predictions for visualization return float(loss.detach().cpu().item()) @@ -122,23 +122,23 @@ def test(loader, model, sig, device, test_loader_len): with guard_testing_context, torch.no_grad(): for inputs, ids, labels, _ in loader: inputs = inputs.to(device) - labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances + labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances outputs = model(inputs) - preds = outputs.argmax(dim=1) # [B,H,W] + preds = outputs.argmax(dim=1) # [B,H,W] # Per-instance + per-sample Dice/BCE (tracked & saved at annotation level). loss_per_sample, dice_sample = _run_instance_signals(sig, outputs, labels, ids, preds=preds, return_metric=True) - losses += torch.mean(loss_per_sample) # Average over the batch and accumulate - dices += torch.mean(dice_sample) # Average over the batch and accumulate + losses += torch.mean(loss_per_sample) # Average over the batch and accumulate + dices += torch.mean(dice_sample) # Average over the batch and accumulate # I want to see in the UI the per-sample classes predicted by the model - wl.save_signals(_user_custom_signals(preds, labels), ids) # Save the per-sample predictions for visualization + wl.save_signals(_user_custom_signals(preds, labels), ids) # Save the per-sample predictions for visualization loss = float((losses / test_loader_len).detach().cpu().item()) dice = float((dices / test_loader_len).detach().cpu().item()) - return loss, dice * 100.0 # Return average Dice as percentage + return loss, dice * 100.0 # Return average Dice as percentage # ============================================================================= @@ -159,8 +159,8 @@ def test(loader, model, sig, device, test_loader_len): parameters.setdefault("training_steps_to_do", 500) parameters.setdefault("eval_full_to_train_steps_ratio", 50) parameters.setdefault("number_of_workers", 4) - parameters.setdefault("num_classes", 6) # adjust to your label set - parameters.setdefault("ignore_index", 255) # if you have void pixels + parameters.setdefault("num_classes", 6) # adjust to your label set + parameters.setdefault("ignore_index", 255) # if you have void pixels parameters.setdefault("image_size", 256) parameters.setdefault("compute_natural_sort", True) @@ -212,7 +212,7 @@ def test(loader, model, sig, device, test_loader_len): num_classes=num_classes, ignore_index=ignore_index, image_size=image_size, - max_samples=train_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing + max_samples=train_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing ) _val_dataset = BDD100kSegDataset( root=data_root, @@ -220,7 +220,7 @@ def test(loader, model, sig, device, test_loader_len): num_classes=num_classes, ignore_index=ignore_index, image_size=image_size, - max_samples=test_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing + max_samples=test_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing ) train_loader = wl.watch_or_edit( @@ -299,8 +299,8 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 class_counts = np.zeros(num_classes, dtype=np.float64) num_samples = min(len(dataset), max_samples) - for idx in tqdm.tqdm(range(num_samples), desc="📊 Analyzing Distribution"): - _, _, label, _ = dataset.get_items(idx, include_labels=True) # Get the label/mask for this sample + for idx in tqdm.tqdm(range(num_samples), desc=" Analyzing Distribution"): + _, _, label, _ = dataset.get_items(idx, include_labels=True) # Get the label/mask for this sample label_np = label.numpy() if hasattr(label, 'numpy') else np.array(label) for c in range(num_classes): class_counts[c] += (label_np == c).sum() @@ -329,30 +329,30 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 ) print("=" * 60) - print("🚀 STARTING BDD100k SEGMENTATION TRAINING") - print(f"📈 Total steps: {max_steps}") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") - print(f"💾 Logs will be saved to: {log_dir}") - print(f"📂 Data root: {data_root}") + print(" STARTING BDD100k SEGMENTATION TRAINING") + print(f" Total steps: {max_steps}") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") + print(f" Logs will be saved to: {log_dir}") + print(f" Data root: {data_root}") print("=" * 60 + "\n") # # ================ # # Training Loop - # wl.start_training(timeout=3) # This will block and keep the main thread alive while background services run. You can optionally set a timeout (in seconds) to automatically stop after a certain duration. + # wl.start_training(timeout=3) # This will block and keep the main thread alive while background services run. You can optionally set a timeout (in seconds) to automatically stop after a certain duration. # ================ train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() test_loss, test_metric = None, None start_time = time.time() for train_step in train_range: - age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) + age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) # Train train_loss = train(train_loader, model, optimizer, train_sig, device) # Test if age == 0 or age % eval_full_to_train_steps_ratio == 0: - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating") if tqdm_display else test_loader test_loss, test_metric = test(test_loader_it, model, test_sig, device, test_loader_len) @@ -374,8 +374,8 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 ) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py b/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py index 43358196..e137dd32 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py @@ -9,11 +9,11 @@ # The segmentation dataset yields, per sample, a LIST of instance masks # (each [H, W] with pixel value = class id). These criterions compute Dice and # BCE for every instance against the model's per-class probability map, then: -# * PerInstance* returns a flat tensor (one value per instance, ordered -# sample-major) — wrapped with `per_instance=True` so WL auto-saves it at -# (sample_id, annotation_id). -# * PerSample* aggregates instances to one value per sample (mean) — wrapped -# with `per_sample=True` for the per-sample dashboards. +# * PerInstance* returns a flat tensor (one value per instance, ordered +# sample-major) — wrapped with `per_instance=True` so WL auto-saves it at +# (sample_id, annotation_id). +# * PerSample* aggregates instances to one value per sample (mean) — wrapped +# with `per_sample=True` for the per-sample dashboards. # The instance ordering matches the `batch_idx` passed by the training loop # (built from the same per-sample instance lists), so WL maps each value to the # correct annotation. @@ -26,14 +26,14 @@ def _instance_dice_bce(outputs, labels, **kwargs): Args: outputs: logits [B, C, H, W]. - labels: list[B]; labels[s] is a list of instance masks ([H, W], value = class id). + labels: list[B]; labels[s] is a list of instance masks ([H, W], value = class id). Returns: (dice_per_sample, bce_per_sample) where each is a list[B] of 1-D tensors holding one value per instance for that sample (empty tensor if none). Values are kept on the outputs' device; BCE retains grad, Dice is a metric. """ - probs = torch.softmax(outputs, dim=1) # [B, C, H, W], differentiable + probs = torch.softmax(outputs, dim=1) # [B, C, H, W], differentiable B, C = probs.shape[0], probs.shape[1] device = outputs.device @@ -58,12 +58,12 @@ def _instance_dice_bce(outputs, labels, **kwargs): cls = int(m.max().item()) ch = cls if 0 <= cls < C else 0 gt = (m > 0).float() - p = probs[s, ch].clamp(_EPS, 1.0 - _EPS) # [H, W] + p = probs[s, ch].clamp(_EPS, 1.0 - _EPS) # [H, W] inter = (p * gt).sum() dice = (2.0 * inter + _EPS) / (p.sum() + gt.sum() + _EPS) bce = F.binary_cross_entropy(p, gt) if weights is not None: - bce = bce * weights[ch] # scalar class weight for this instance + bce = bce * weights[ch] # scalar class weight for this instance dices.append(dice) bces.append(bce) diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py index 70715f33..656c0180 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py @@ -51,7 +51,7 @@ def __init__( for f in os.listdir(img_dir) if f.lower().endswith((".jpg", ".jpeg", ".png")) ] - image_files = sorted(set(image_files))[:max_samples] if max_samples != None else sorted(set(image_files)) # Optionally limit number of samples for faster testing + image_files = sorted(set(image_files))[:max_samples] if max_samples != None else sorted(set(image_files)) # Optionally limit number of samples for faster testing self.images = [] self.masks = [] @@ -120,24 +120,24 @@ def get_items(self, idx, include_metadata=False, include_labels=False, include_i mask = Image.open(mask_path) mask_r = self.mask_resize(mask) mask_np = np.array(mask_r, dtype=np.int64) - mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 + mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 return img_t, uid, mask_t, metadata # # # Instance wise segmentaiton # # Process labels/masks # mask_t_instances = list() # mask_t = None # if include_labels: - # mask = Image.open(mask_path) - # mask_r = self.mask_resize(mask) - # mask_np = np.array(mask_r, dtype=np.int64) - # mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 - - # # Format labels to register multiple instance_ids - # lbl_max = mask_t.max().item() - # for i in range(1, lbl_max + 1): - # m = torch.zeros_like(mask_t) - # m[mask_t == i] = i # Assign class ID as instance ID for simplicity; if set to 1, all instances of the same class would be merged... - # mask_t_instances.append(m) + # mask = Image.open(mask_path) + # mask_r = self.mask_resize(mask) + # mask_np = np.array(mask_r, dtype=np.int64) + # mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 + + # # Format labels to register multiple instance_ids + # lbl_max = mask_t.max().item() + # for i in range(1, lbl_max + 1): + # m = torch.zeros_like(mask_t) + # m[mask_t == i] = i # Assign class ID as instance ID for simplicity; if set to 1, all instances of the same class would be merged... + # mask_t_instances.append(m) # return img_t, uid, mask_t_instances, metadata def seg_collate(batch): @@ -150,10 +150,10 @@ def seg_collate(batch): background) are filtered out so every kept instance is a real annotation. Returns: - images: FloatTensor [B, C, H, W] - ids: list[str] of length B - labels: list[B] where labels[s] is a list of instance mask tensors - metas: list[B] of metadata dicts + images: FloatTensor [B, C, H, W] + ids: list[str] of length B + labels: list[B] where labels[s] is a list of instance mask tensors + metas: list[B] of metadata dicts """ images = torch.stack([b[0] for b in batch], dim=0) ids = [b[1] for b in batch] diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/model.py b/weightslab/examples/PyTorch/ws-segmentation/utils/model.py index 26339e9d..d45311f5 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/model.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/model.py @@ -58,7 +58,7 @@ def forward(self, x): # Decoder u2 = self.up2(b) - # ⚠️ Important: no `if` on shapes; always interpolate + # Important: no `if` on shapes; always interpolate u2 = F.interpolate(u2, size=e2.shape[-2:], mode="bilinear", align_corners=False) d2 = self.dec2(torch.cat([u2, e2], dim=1)) @@ -66,5 +66,5 @@ def forward(self, x): u1 = F.interpolate(u1, size=e1.shape[-2:], mode="bilinear", align_corners=False) d1 = self.dec1(torch.cat([u1, e1], dim=1)) - logits = self.head(d1) # [B, C, H, W] + logits = self.head(d1) # [B, C, H, W] return logits diff --git a/weightslab/examples/Ultralytics/ws-detection/main.py b/weightslab/examples/Ultralytics/ws-detection/main.py index c56597ca..3174059f 100644 --- a/weightslab/examples/Ultralytics/ws-detection/main.py +++ b/weightslab/examples/Ultralytics/ws-detection/main.py @@ -60,7 +60,7 @@ def main(): # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. YOLO(model_name).train( trainer=WLAwareTrainer, @@ -68,7 +68,7 @@ def main(): imgsz=image_size, epochs=1000 if max_steps == None else max(1, int(max_steps)), device=device, - project=project, name=name, # → UL save_dir → WL logger log_dir/name + project=project, name=name, # → UL save_dir → WL logger log_dir/name resume=False, cache=False, optimizer="SGD", @@ -86,7 +86,7 @@ def main(): # would make Ultralytics reject keys like `train_nms` as invalid YOLO args. ) - wl.keep_serving() # Keep main thread alive to analyze training results directly + wl.keep_serving() # Keep main thread alive to analyze training results directly if __name__ == "__main__": diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py index 7b48ec51..41c62365 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py @@ -34,7 +34,7 @@ def train(loader, model, optimizer, sig, device, grid_size, pc_range, conf_thres points = points.to(device) targets = [t.to(device) for t in targets] optimizer.zero_grad() - outputs = model(points) # [B, S, S, 5 + num_classes] + outputs = model(points) # [B, S, S, 5 + num_classes] preds = decode_predictions(outputs.detach(), grid_size, pc_range, conf_thresh=conf_thresh) loss_per_sample = sig["loss"](outputs, targets, batch_ids=ids, preds=preds) sig["iou_sample"](outputs, targets, batch_ids=ids) @@ -162,15 +162,15 @@ def _make_det_signals(split): serving_cli=parameters.get("serving_cli", True)) print("=" * 60) - print("🚀 STARTING 2D LiDAR DETECTION TRAINING (Pillars2D-lite)") - print(f"📡 {len(_train_dataset)} train / {len(_val_dataset)} val scans") - print(f"💾 Logs: {log_dir}") + print(" STARTING 2D LiDAR DETECTION TRAINING (Pillars2D-lite)") + print(f" {len(_train_dataset)} train / {len(_val_dataset)} val scans") + print(f" Logs: {log_dir}") print("=" * 60 + "\n") # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() test_loss, test_metric = None, None @@ -193,5 +193,5 @@ def _make_det_signals(split): + (f" test_loss={test_loss:.4f}" if test_loss is not None else "") + (f" IoU={test_metric:.2f}%" if test_metric is not None else "")) - print(f"\n✅ Done in {time.time() - start_time:.1f}s; logs at {log_dir}") + print(f"\n Done in {time.time() - start_time:.1f}s; logs at {log_dir}") wl.keep_serving() diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py index 7c5252ca..5c893bba 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py @@ -11,9 +11,9 @@ # Targets are [N, 6] rows [cx, cy, dx, dy, class_id, confidence] (metric). Each # GT box is assigned to the grid cell containing its (cx, cy) centre. # -# * PerSampleDetection2DLoss -> one differentiable scalar per sample ([B]). -# * PerSampleIoU2D -> mean axis-aligned IoU over a sample's boxes. -# * PerInstanceIoU2D -> one IoU per GT box (sample-major order). +# * PerSampleDetection2DLoss -> one differentiable scalar per sample ([B]). +# * PerSampleIoU2D -> mean axis-aligned IoU over a sample's boxes. +# * PerInstanceIoU2D -> one IoU per GT box (sample-major order). _EPS = 1e-6 _LAMBDA_COORD = 2.0 @@ -95,7 +95,7 @@ def _per_box_iou(outputs, targets, grid_size, pc_range): if tgt.ndim == 1: tgt = tgt.view(-1, 6) rows, cols, _, _ = _responsible_cells(tgt, S, pc_range) - pred = boxes_grid[s, rows, cols] # [N, 4] (cx,cy,w,h) + pred = boxes_grid[s, rows, cols] # [N, 4] (cx,cy,w,h) gt = torch.stack([tgt[:, 0], tgt[:, 1], tgt[:, 2], tgt[:, 3]], dim=1) per_sample.append(iou_2d_axis_aligned(pred, gt).detach()) return per_sample diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py index fd544bf3..40c79b1c 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py @@ -12,10 +12,10 @@ # task is to detect axis-aligned 2D boxes around object clusters. # # Per sample: -# * cloud: [M, 2] float32 (x, y) — genuinely 2D (the studio viewer renders -# it top-down; no z channel, so it is treated as a 2D cloud). -# * target: [N, 6] float32 = [cx, cy, dx, dy, class_id, confidence] -# (metric units; 2D box schema — exactly 6 columns). +# * cloud: [M, 2] float32 (x, y) — genuinely 2D (the studio viewer renders +# it top-down; no z channel, so it is treated as a 2D cloud). +# * target: [N, 6] float32 = [cx, cy, dx, dy, class_id, confidence] +# (metric units; 2D box schema — exactly 6 columns). # # task_type "detection_pointcloud" is shared with the 3D example; the box-row # column count (<= 6) is what marks this as 2D. @@ -29,7 +29,7 @@ PAD_VALUE = -1000.0 # Typical (length, width) per class for the generator. -_CLASS_DIMS = np.array([[3.6, 1.7], [0.7, 0.7]], dtype=np.float32) # Vehicle, Pedestrian +_CLASS_DIMS = np.array([[3.6, 1.7], [0.7, 0.7]], dtype=np.float32) # Vehicle, Pedestrian def _sample_rect_perimeter(rng, dims, n): @@ -74,7 +74,7 @@ def generate_synthetic_scene(seed, pc_range): n_pts = int(np.clip(400.0 / (1.0 + dist / 6.0), 20, 200)) local = _sample_rect_perimeter(rng, dims, n_pts) world = local + np.array([cx, cy], dtype=np.float32) - world += rng.normal(0.0, 0.03, world.shape).astype(np.float32) # sensor noise + world += rng.normal(0.0, 0.03, world.shape).astype(np.float32) # sensor noise clouds.append(world) boxes.append([cx, cy, dims[0], dims[1], float(cls), 1.0]) @@ -96,7 +96,7 @@ def __init__( max_samples=None, seed=0, thumbnail_projection="bev", - **_ignored, # tolerate shared kwargs (kitti_*, extra_features) for parity + **_ignored, # tolerate shared kwargs (kitti_*, extra_features) for parity ): super().__init__() self.split = split diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py index 2ea4e3b9..50df19b2 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py @@ -3,13 +3,13 @@ # ============================================================================= # The 2D analogue of the 3D PointPillars-lite, with z and yaw dropped: # -# 1. Point Feature Net: points are binned into grid cells on the (x, y) plane; -# each point gets 6 features (x, y, offsets to the cell's point mean, -# offsets to the cell center), runs a shared Linear+BN+ReLU, and is -# max-pooled per cell -> a [C, H, W] feature image. -# 2. A tiny 2D CNN backbone. -# 3. A YOLO-style grid head: each S x S cell predicts ONE 2D box -# (objectness, tx, ty, log w, log h, class_logits...). +# 1. Point Feature Net: points are binned into grid cells on the (x, y) plane; +# each point gets 6 features (x, y, offsets to the cell's point mean, +# offsets to the cell center), runs a shared Linear+BN+ReLU, and is +# max-pooled per cell -> a [C, H, W] feature image. +# 2. A tiny 2D CNN backbone. +# 3. A YOLO-style grid head: each S x S cell predicts ONE 2D box +# (objectness, tx, ty, log w, log h, class_logits...). # # decode_grid_2d turns logits into metric (cx, cy, w, h) boxes. import math @@ -24,8 +24,8 @@ def decode_grid_2d(outputs, grid_size, pc_range): """Decode raw grid logits -> per-cell 2D boxes, objectness, class probs. Returns: - boxes: [B, S, S, 4] (cx, cy, w, h) in meters - obj: [B, S, S] objectness probability + boxes: [B, S, S, 4] (cx, cy, w, h) in meters + obj: [B, S, S] objectness probability cls_probs: [B, S, S, num_classes] """ B, S = outputs.shape[0], grid_size @@ -35,7 +35,7 @@ def decode_grid_2d(outputs, grid_size, pc_range): obj = torch.sigmoid(outputs[..., 0]) tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) - dims = torch.exp(outputs[..., 3:5].clamp(-4.0, 4.0)) # (w_x, w_y), meters + dims = torch.exp(outputs[..., 3:5].clamp(-4.0, 4.0)) # (w_x, w_y), meters cls_probs = torch.softmax(outputs[..., 5:], dim=-1) cols = torch.arange(S, device=device).view(1, 1, S).expand(B, S, S) @@ -62,7 +62,7 @@ def __init__(self, num_classes=2, pc_range=DEFAULT_PC_RANGE, voxel_size=0.5, x_min, y_min, _, x_max, y_max, _ = self.pc_range self.nx = int(round((x_max - x_min) / voxel_size)) self.ny = int(round((y_max - y_min) / voxel_size)) - self.preds_per_cell = 5 + num_classes # obj + (tx,ty,log w,log h) + classes + self.preds_per_cell = 5 + num_classes # obj + (tx,ty,log w,log h) + classes self.pfn_channels = pfn_channels self.pfn = nn.Sequential( @@ -109,7 +109,7 @@ def _augment_points(self, points): cy = y_min + (iy.to(pts.dtype) + 0.5) * self.voxel_size f_center = torch.stack([pts[:, 0] - cx, pts[:, 1] - cy], dim=1) - feats = torch.cat([pts[:, :2], f_cluster, f_center], dim=1) # [M, 6] + feats = torch.cat([pts[:, :2], f_cluster, f_center], dim=1) # [M, 6] return feats, flat def _scatter_to_canvas(self, point_feats, flat): diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py index d0d41eb7..7dfc79ae 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py @@ -59,32 +59,32 @@ def render_thumbnail_2d(self, points): # You can customize parameters here (resolution, FOV, rendering mode): return point_cloud_to_range_image( points, - image_height=80, # Custom height (default 64) - image_width=512, # Custom width (default 512, like KITTI) - fov_up=3.0, # Max elevation angle in degrees - fov_down=-25.0, # Min elevation angle (typical LiDAR) - mode="distance+intensity", # or "distance", "intensity" + image_height=80, # Custom height (default 64) + image_width=512, # Custom width (default 512, like KITTI) + fov_up=3.0, # Max elevation angle in degrees + fov_down=-25.0, # Min elevation angle (typical LiDAR) + mode="distance+intensity", # or "distance", "intensity" ) # Optional: override box projection for your custom 2D frame. # Uncomment if needed: # # def project_boxes_2d(self, boxes_3d): - # """Custom box projection to your 2D frame. + # """Custom box projection to your 2D frame. # - # Args: - # boxes_3d: [N, C] where C >= 7 is 3D ([cx,cy,cz,dx,dy,dz,yaw,...]) - # or C <= 6 is 2D ([cx,cy,dx,dy,...]) + # Args: + # boxes_3d: [N, C] where C >= 7 is 3D ([cx,cy,cz,dx,dy,dz,yaw,...]) + # or C <= 6 is 2D ([cx,cy,dx,dy,...]) # - # Returns: - # [N, 6] normalized xyxy boxes [x1, y1, x2, y2, class_id, confidence] - # in [0, 1] range (image coordinates, y down). - # """ - # from weightslab.data.point_cloud_utils import project_boxes_to_bev, get_pc_range - # # For now, just use the standard BEV projection as fallback. - # # Implement your custom projection here. - # pc_range = get_pc_range(self) - # return project_boxes_to_bev(boxes_3d, pc_range) + # Returns: + # [N, 6] normalized xyxy boxes [x1, y1, x2, y2, class_id, confidence] + # in [0, 1] range (image coordinates, y down). + # """ + # from weightslab.data.point_cloud_utils import project_boxes_to_bev, get_pc_range + # # For now, just use the standard BEV projection as fallback. + # # Implement your custom projection here. + # pc_range = get_pc_range(self) + # return project_boxes_to_bev(boxes_3d, pc_range) # ============================================================================= @@ -105,7 +105,7 @@ def train(loader, model, optimizer, sig, device, grid_size, pc_range, conf_thres targets = [t.to(device) for t in targets] optimizer.zero_grad() - outputs = model(points) # [B, S, S, 9 + num_classes] + outputs = model(points) # [B, S, S, 9 + num_classes] # Decoded 3D boxes (detached — stored alongside the loss for analysis). preds = decode_predictions( @@ -147,7 +147,7 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load loss = float((losses / test_loader_len).detach().cpu().item()) iou = float((ious / test_loader_len).detach().cpu().item()) - return loss, iou * 100.0 # Return mean BEV IoU as percentage + return loss, iou * 100.0 # Return mean BEV IoU as percentage # ============================================================================= @@ -168,7 +168,7 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load parameters.setdefault("training_steps_to_do", 500) parameters.setdefault("eval_full_to_train_steps_ratio", 50) parameters.setdefault("number_of_workers", 4) - parameters.setdefault("num_classes", 3) # Car, Pedestrian, Cyclist + parameters.setdefault("num_classes", 3) # Car, Pedestrian, Cyclist parameters.setdefault("point_cloud_range", list(DEFAULT_PC_RANGE)) parameters.setdefault("voxel_size", 0.5) parameters.setdefault("grid_size", 32) @@ -340,7 +340,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): class_counts = np.zeros(num_classes, dtype=np.float64) num_samples = min(len(dataset), max_samples) - for idx in tqdm.tqdm(range(num_samples), desc="📊 Analyzing Distribution"): + for idx in tqdm.tqdm(range(num_samples), desc=" Analyzing Distribution"): _, _, target, _ = dataset.get_items(idx, include_labels=True) if target is None or len(target) == 0: continue @@ -348,10 +348,10 @@ def compute_class_weights(dataset, num_classes, max_samples=200): if 0 <= c < num_classes: class_counts[c] += 1 - class_counts = np.maximum(class_counts, 1) # Avoid div by zero + class_counts = np.maximum(class_counts, 1) # Avoid div by zero total = class_counts.sum() class_weights = total / (num_classes * class_counts) - class_weights = class_weights / class_weights.mean() # Normalize + class_weights = class_weights / class_weights.mean() # Normalize print("\nClass distribution and weights:", flush=True) for c in range(num_classes): @@ -372,18 +372,18 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("=" * 60) - print("🚀 STARTING LIDAR 3D DETECTION TRAINING (PointPillars-lite)") - print(f"📡 Data source: {_train_dataset.source} " + print(" STARTING LIDAR 3D DETECTION TRAINING (PointPillars-lite)") + print(f" Data source: {_train_dataset.source} " f"({len(_train_dataset)} train / {len(_val_dataset)} val frames)") - print(f"📈 Total steps: {max_steps}") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") - print(f"💾 Logs will be saved to: {log_dir}") - print(f"📂 Data root: {data_root}") + print(f" Total steps: {max_steps}") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") + print(f" Logs will be saved to: {log_dir}") + print(f" Data root: {data_root}") print("=" * 60 + "\n") # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. # ================ train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() @@ -399,7 +399,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): # Test if age == 0 or age % eval_full_to_train_steps_ratio == 0: - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating") if tqdm_display else test_loader test_loss, test_metric = test( test_loader_it, model, test_sig, device, @@ -423,8 +423,8 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py index 74809175..7a34da0f 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py @@ -13,25 +13,25 @@ # coordinates. Each GT box is assigned to the BEV grid cell containing its # (cx, cy) center; that cell is "responsible" for predicting the box. # -# * PerSampleDetection3DLoss -> one differentiable loss scalar per sample -# ([B]), wrapped with ``per_sample=True`` (the value WL backprops + -# dashboards). -# * PerSampleBevIoU -> mean BEV IoU over a sample's boxes ([B]). -# * PerInstanceBevIoU -> flat tensor of one IoU per GT box -# (sample-major order), wrapped with ``per_instance=True`` so WL auto-saves -# it at (sample_id, annotation_id). The ordering matches the per-sample -# target iteration, so the wrapper's auto ``batch_idx`` maps each value -# correctly. +# * PerSampleDetection3DLoss -> one differentiable loss scalar per sample +# ([B]), wrapped with ``per_sample=True`` (the value WL backprops + +# dashboards). +# * PerSampleBevIoU -> mean BEV IoU over a sample's boxes ([B]). +# * PerInstanceBevIoU -> flat tensor of one IoU per GT box +# (sample-major order), wrapped with ``per_instance=True`` so WL auto-saves +# it at (sample_id, annotation_id). The ordering matches the per-sample +# target iteration, so the wrapper's auto ``batch_idx`` maps each value +# correctly. # # The IoU metric is axis-aligned in the BEV plane (yaw ignored) — a cheap, # dependency-free proxy for rotated-box IoU that is monotone enough to rank # samples / instances in the dashboards. _EPS = 1e-6 -_LAMBDA_COORD = 2.0 # x, y, z localization -_LAMBDA_SIZE = 1.0 # log-dims -_LAMBDA_YAW = 1.0 # sin / cos regression -_LAMBDA_NOOBJ = 0.5 # empty-cell objectness down-weighting +_LAMBDA_COORD = 2.0 # x, y, z localization +_LAMBDA_SIZE = 1.0 # log-dims +_LAMBDA_YAW = 1.0 # sin / cos regression +_LAMBDA_NOOBJ = 0.5 # empty-cell objectness down-weighting def bev_iou_axis_aligned(a, b): @@ -55,14 +55,14 @@ def _responsible_cells(boxes, grid_size, pc_range): """Map GT boxes -> their responsible BEV (row, col) cell and cell offsets. Args: - boxes: [N, 9] target rows (metric). + boxes: [N, 9] target rows (metric). grid_size: S. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). Returns: - rows, cols: [N] long, the responsible cell indices. - off_x, off_y: [N] center offset within the cell, in [0, 1). - z_t: [N] z center normalized to [0, 1] over the z range. + rows, cols: [N] long, the responsible cell indices. + off_x, off_y: [N] center offset within the cell, in [0, 1). + z_t: [N] z center normalized to [0, 1] over the z range. """ x_min, y_min, z_min, x_max, y_max, z_max = pc_range S = grid_size @@ -82,14 +82,14 @@ def _per_sample_loss(outputs, targets, num_classes, grid_size, pc_range, weights B, S = outputs.shape[0], grid_size device = outputs.device - obj_logit = outputs[..., 0] # [B, S, S] + obj_logit = outputs[..., 0] # [B, S, S] tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) tz = torch.sigmoid(outputs[..., 3]) - log_dims = outputs[..., 4:7] # [B, S, S, 3] + log_dims = outputs[..., 4:7] # [B, S, S, 3] t_sin = outputs[..., 7] t_cos = outputs[..., 8] - cls_logits = outputs[..., 9:] # [B, S, S, C] + cls_logits = outputs[..., 9:] # [B, S, S, C] if weights is not None: weights = torch.as_tensor(weights, device=device, dtype=outputs.dtype) @@ -162,7 +162,7 @@ def _per_box_bev_iou(outputs, targets, grid_size, pc_range): Returns a list[B] of 1-D tensors (one IoU per box for that sample, in annotation order). Detached — this is a metric, not a loss. """ - boxes_grid, _, _ = decode_grid_3d(outputs, grid_size, pc_range) # [B, S, S, 7] + boxes_grid, _, _ = decode_grid_3d(outputs, grid_size, pc_range) # [B, S, S, 7] B, S = outputs.shape[0], grid_size device = outputs.device @@ -176,7 +176,7 @@ def _per_box_bev_iou(outputs, targets, grid_size, pc_range): tgt = tgt.view(-1, 9) rows, cols, _, _, _ = _responsible_cells(tgt, S, pc_range) - pred = boxes_grid[s, rows, cols] # [N, 7] + pred = boxes_grid[s, rows, cols] # [N, 7] pred_bev = torch.stack( [pred[:, 0], pred[:, 1], pred[:, 3], pred[:, 4]], dim=1) gt_bev = torch.stack( @@ -254,8 +254,8 @@ def decode_predictions(outputs, grid_size, pc_range, conf_thresh=0.3, max_det=20 boxes_grid, obj, cls_probs = decode_grid_3d(outputs, grid_size, pc_range) B, S = outputs.shape[0], grid_size - cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] - score = obj * cls_conf # combined confidence + cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] + score = obj * cls_conf # combined confidence flat_boxes = boxes_grid.view(B, S * S, 7) flat_score = score.view(B, S * S) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py index 45a15a0a..7abb9f81 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py @@ -11,23 +11,23 @@ # ============================================================================= # Self-driving 3D detection over LiDAR point clouds. Two sources: # -# * "kitti": the KITTI 3D Object Detection benchmark. Expected layout -# (download from https://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d): -# /kitti/training/velodyne/000000.bin ... (x, y, z, intensity float32) -# /kitti/training/label_2/000000.txt ... (camera-frame 3D boxes) -# /kitti/training/calib/000000.txt ... (velo->cam calibration) -# * "synthetic": procedurally generated road scenes (ground plane + car / -# pedestrian / cyclist point clusters). Lets the example run -# out-of-the-box with zero download; useful to validate the -# whole WL pipeline before pointing it at real data. +# * "kitti": the KITTI 3D Object Detection benchmark. Expected layout +# (download from https://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d): +# /kitti/training/velodyne/000000.bin ... (x, y, z, intensity float32) +# /kitti/training/label_2/000000.txt ... (camera-frame 3D boxes) +# /kitti/training/calib/000000.txt ... (velo->cam calibration) +# * "synthetic": procedurally generated road scenes (ground plane + car / +# pedestrian / cyclist point clusters). Lets the example run +# out-of-the-box with zero download; useful to validate the +# whole WL pipeline before pointing it at real data. # # Per-sample target is a [N, 9] float32 array, one row per ground-truth box, # all in the LiDAR (velodyne) frame, metric units: # -# [cx, cy, cz, dx, dy, dz, yaw, class_id, confidence] +# [cx, cy, cz, dx, dy, dz, yaw, class_id, confidence] # -# cx/cy/cz: box center (m); dx/dy/dz: size along the object's x/y/z axes -# (length, width, height); yaw: rotation around +z; GT confidence = 1.0. +# cx/cy/cz: box center (m); dx/dy/dz: size along the object's x/y/z axes +# (length, width, height); yaw: rotation around +z; GT confidence = 1.0. CLASS_NAMES = ["Car", "Pedestrian", "Cyclist"] @@ -42,14 +42,14 @@ # Typical (length, width, height) per class, used by the synthetic generator. _CLASS_DIMS = np.array( [ - [4.0, 1.7, 1.5], # Car - [0.8, 0.6, 1.75], # Pedestrian - [1.8, 0.6, 1.7], # Cyclist + [4.0, 1.7, 1.5], # Car + [0.8, 0.6, 1.75], # Pedestrian + [1.8, 0.6, 1.7], # Cyclist ], dtype=np.float32, ) -_GROUND_Z = -1.7 # LiDAR is mounted ~1.7 m above the road in KITTI. +_GROUND_Z = -1.7 # LiDAR is mounted ~1.7 m above the road in KITTI. # ============================================================================= @@ -78,7 +78,7 @@ def read_kitti_calib(path): m[:3, :4] = vals.reshape(3, 4) mats["Tr_velo_to_cam"] = m elif key.strip() == "P2": - mats["P2"] = vals.reshape(3, 4) # left colour camera projection + mats["P2"] = vals.reshape(3, 4) # left colour camera projection return mats @@ -89,10 +89,10 @@ def project_velo_to_image(points_xyz, calib): the camera (positive depth). Used to colourise the cloud from image_2. """ n = points_xyz.shape[0] - homo = np.concatenate([points_xyz, np.ones((n, 1))], axis=1) # [N, 4] - cam = (calib["R0_rect"] @ calib["Tr_velo_to_cam"] @ homo.T) # [4, N] + homo = np.concatenate([points_xyz, np.ones((n, 1))], axis=1) # [N, 4] + cam = (calib["R0_rect"] @ calib["Tr_velo_to_cam"] @ homo.T) # [4, N] depth = cam[2] - pix = calib["P2"] @ cam # [3, N] + pix = calib["P2"] @ cam # [3, N] valid = depth > 1e-3 uv = np.zeros((n, 2), dtype=np.float32) uv[valid] = (pix[:2, valid] / pix[2, valid]).T @@ -113,7 +113,7 @@ def _read_kitti_kv_file(path): try: out[key.strip()] = np.array([float(v) for v in vals.split()], dtype=np.float64) except ValueError: - pass # non-numeric header lines (calib_time, etc.) + pass # non-numeric header lines (calib_time, etc.) return out @@ -158,7 +158,7 @@ def parse_tracklets(xml_path): for tracklet in root.iter("item"): otype = tracklet.findtext("objectType") if otype is None or otype not in _TRACKLET_CLASS_MAP: - continue # not a tracklet item, or a class we don't keep + continue # not a tracklet item, or a class we don't keep h = tracklet.findtext("h"); w = tracklet.findtext("w"); l = tracklet.findtext("l") first = tracklet.findtext("first_frame") poses = tracklet.find("poses") @@ -207,8 +207,8 @@ def read_kitti_label(label_path, calib, pc_range): ry = float(parts[14]) center = _cam_to_velo(loc_cam, calib)[0] - center[2] += h / 2.0 # KITTI location is the bottom face center - yaw = -ry - np.pi / 2.0 # camera rotation_y -> velo-frame yaw + center[2] += h / 2.0 # KITTI location is the bottom face center + yaw = -ry - np.pi / 2.0 # camera rotation_y -> velo-frame yaw x_min, y_min, z_min, x_max, y_max, z_max = pc_range if not (x_min <= center[0] <= x_max and y_min <= center[1] <= y_max @@ -225,14 +225,14 @@ def read_kitti_label(label_path, calib, pc_range): def _sample_box_surface(rng, dims, n): """Uniformly sample n points on the surface of an axis-aligned box at origin.""" l, w, h = dims - areas = np.array([w * h, w * h, l * h, l * h, l * w, l * w]) # +-x, +-y, +-z faces + areas = np.array([w * h, w * h, l * h, l * h, l * w, l * w]) # +-x, +-y, +-z faces face = rng.choice(6, size=n, p=areas / areas.sum()) u = rng.uniform(-0.5, 0.5, size=n) v = rng.uniform(-0.5, 0.5, size=n) pts = np.zeros((n, 3), dtype=np.float32) sign = np.where(face % 2 == 0, 0.5, -0.5) - ax = face // 2 # 0: x faces, 1: y faces, 2: z faces + ax = face // 2 # 0: x faces, 1: y faces, 2: z faces pts[ax == 0] = np.stack( [sign[ax == 0] * l, u[ax == 0] * w, v[ax == 0] * h], axis=1) pts[ax == 1] = np.stack( @@ -307,15 +307,15 @@ class Lidar3DDetectionDataset(Dataset): """LiDAR 3D box detection over KITTI scans or synthetic scenes. Args: - root: data directory (expects /kitti/training/* for KITTI). - split: "train" or "val" (deterministic split). - source: "kitti", "synthetic", or "auto" (kitti if present on disk). - num_classes: how many of CLASS_NAMES to keep. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop, meters. - max_points: random subsample cap per cloud (speed / memory). + root: data directory (expects /kitti/training/* for KITTI). + split: "train" or "val" (deterministic split). + source: "kitti", "synthetic", or "auto" (kitti if present on disk). + num_classes: how many of CLASS_NAMES to keep. + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop, meters. + max_points: random subsample cap per cloud (speed / memory). num_synthetic: number of generated scenes when source is synthetic. - val_fraction: fraction of frames held out for validation. - max_samples: optional cap on the split size (for quick runs). + val_fraction: fraction of frames held out for validation. + max_samples: optional cap on the split size (for quick runs). """ def __init__( @@ -353,9 +353,9 @@ def __init__( # Per-point channels. xyz + intensity are always present (the model # consumes the first 4 columns); ``extra_features`` appends extra # VISUALISATION-only channels the studio viewer can colour/shade by: - # "normals" -> nx, ny, nz (PCA over neighbours) - # "rgb" -> r, g, b (camera image projection; KITTI only, - # synthetic falls back to a height pseudo-colour) + # "normals" -> nx, ny, nz (PCA over neighbours) + # "rgb" -> r, g, b (camera image projection; KITTI only, + # synthetic falls back to a height pseudo-colour) self.extra_features = tuple(str(f).strip().lower() for f in (extra_features or ())) # Real KITTI drives ship camera images + calibration, so colourise by # default (set extra_features explicitly to override, e.g. [] or [normals]). @@ -407,7 +407,7 @@ def __init__( download_dir = kitti_download_dir or default_download_dir() drives = list(kitti_raw_drives) or ["drive_0001"] frames = [] - self._raw_tracklets = {} # drive -> {frame_index: [N, 9] GT boxes} + self._raw_tracklets = {} # drive -> {frame_index: [N, 9] GT boxes} for drive in drives: if download: self._raw_date_dir = ensure_sequence(kitti_raw_date, drive, dest_dir=download_dir) @@ -513,7 +513,7 @@ def _enrich_features(self, points, calib, image_path): """Append the configured visualisation channels (normals, rgb) to [M, 4].""" if points.shape[0] == 0 or not self.extra_features: return points.astype(np.float32) - channels = [points[:, :4]] # x, y, z, intensity (always) + channels = [points[:, :4]] # x, y, z, intensity (always) if "normals" in self.extra_features: from weightslab.data.point_cloud_utils import compute_point_normals @@ -536,7 +536,7 @@ def _point_rgb(self, points, calib, image_path): points[:, :3], image, lambda p: project_velo_to_image(p, calib)) except Exception: - pass # fall through to pseudo-colour + pass # fall through to pseudo-colour # Synthetic / no image: pseudo-colour from height so the channel is useful. z_min, z_max = self.pc_range[2], self.pc_range[5] @@ -546,9 +546,9 @@ def _point_rgb(self, points, calib, image_path): def __getitem__(self, idx): """Returns (item, uid, target, metadata). - - item: point cloud FloatTensor [M, 4] (x, y, z, intensity) - - uid: unique sample id (string) - - target: [N, 9] float32 = [cx, cy, cz, dx, dy, dz, yaw, cls, conf] + - item: point cloud FloatTensor [M, 4] (x, y, z, intensity) + - uid: unique sample id (string) + - target: [N, 9] float32 = [cx, cy, cz, dx, dy, dz, yaw, cls, conf] - metadata: dict with source paths / generation seed """ return self.get_items(idx, include_metadata=True, include_labels=True, include_images=True) @@ -571,10 +571,10 @@ def lidar_collate(batch): layout WL's per-instance helpers expect. Returns: - points: FloatTensor [B, M_max, 4] - ids: list[str] of length B + points: FloatTensor [B, M_max, 4] + ids: list[str] of length B targets: list[B] of [N_i, 9] float tensors - metas: list[B] of metadata dicts + metas: list[B] of metadata dicts """ clouds = [ b[0] if isinstance(b[0], torch.Tensor) else torch.as_tensor(b[0], dtype=torch.float32) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py index 372e8a1d..e70c63df 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py @@ -13,8 +13,8 @@ calib_velo_to_cam.txt calib_imu_to_velo.txt __sync/ - velodyne_points/data/0000000000.bin ... (x, y, z, reflectance float32) - image_02/data/0000000000.png ... (left colour camera) + velodyne_points/data/0000000000.bin ... (x, y, z, reflectance float32) + image_02/data/0000000000.png ... (left colour camera) ... Downloads stream to disk with a tqdm progress bar and are idempotent (a @@ -149,8 +149,8 @@ def ensure_sequence(date, drive, dest_dir=None, keep_zip=False): """Download + extract one raw sequence (idempotent). Returns the date dir. Args: - date: e.g. "2011_09_26". - drive: e.g. "drive_0001". + date: e.g. "2011_09_26". + drive: e.g. "drive_0001". dest_dir: where to download/extract (default: a temp dir). keep_zip: keep the downloaded .zip after extraction (default: delete). @@ -162,7 +162,7 @@ def ensure_sequence(date, drive, dest_dir=None, keep_zip=False): seq_dir = os.path.join(dest_dir, date, f"{date}_{drive}_sync") if os.path.isdir(os.path.join(seq_dir, "velodyne_points", "data")): - return os.path.join(dest_dir, date) # already extracted + return os.path.join(dest_dir, date) # already extracted filename = f"{date}_{drive}_sync.zip" zip_path = os.path.join(dest_dir, filename) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py index 6cf83231..0af99b0d 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py @@ -4,24 +4,24 @@ # Three stages, following the PointPillars recipe (Lang et al., CVPR 2019) but # heavily slimmed down: # -# 1. Pillar Feature Net: points are grouped into vertical columns ("pillars") -# on a BEV grid; each point gets 9 features (x, y, z, intensity, offsets -# to the pillar's point mean, offsets to the pillar center), runs through -# a shared Linear+BN+ReLU, and is max-pooled per pillar -> a sparse -# [C, H, W] BEV pseudo-image. -# 2. A tiny 2D CNN backbone over the BEV pseudo-image (2 stride-2 blocks). -# 3. A YOLO-v1-style grid head: each S x S BEV cell predicts ONE 3D box: -# (objectness, tx, ty, tz, log l, log w, log h, sin yaw, cos yaw, -# class_logits...). +# 1. Pillar Feature Net: points are grouped into vertical columns ("pillars") +# on a BEV grid; each point gets 9 features (x, y, z, intensity, offsets +# to the pillar's point mean, offsets to the pillar center), runs through +# a shared Linear+BN+ReLU, and is max-pooled per pillar -> a sparse +# [C, H, W] BEV pseudo-image. +# 2. A tiny 2D CNN backbone over the BEV pseudo-image (2 stride-2 blocks). +# 3. A YOLO-v1-style grid head: each S x S BEV cell predicts ONE 3D box: +# (objectness, tx, ty, tz, log l, log w, log h, sin yaw, cos yaw, +# class_logits...). # # Encoding (BEV cell-relative, mirrors the 2D ws-detection example): -# * objectness = sigmoid(t_obj) -> P(box centered in cell) -# * cx = x_min + (col + sigmoid(tx)) / S * range_x -# * cy = y_min + (row + sigmoid(ty)) / S * range_y -# * cz = z_min + sigmoid(tz) * range_z -# * (l, w, h) = exp(t_l, t_w, t_h) -> size in meters -# * yaw = atan2(t_sin, t_cos) -# * class = softmax(class_logits) +# * objectness = sigmoid(t_obj) -> P(box centered in cell) +# * cx = x_min + (col + sigmoid(tx)) / S * range_x +# * cy = y_min + (row + sigmoid(ty)) / S * range_y +# * cz = z_min + sigmoid(tz) * range_z +# * (l, w, h) = exp(t_l, t_w, t_h) -> size in meters +# * yaw = atan2(t_sin, t_cos) +# * class = softmax(class_logits) # # Raw forward output keeps logits (the loss applies activations); `decode_grid_3d` # turns logits into metric 3D boxes for metrics and prediction dumps. @@ -39,13 +39,13 @@ def decode_grid_3d(outputs, grid_size, pc_range): Shared by the model and the criterions so the encoding lives in one place. Args: - outputs: [B, S, S, 9 + num_classes] raw logits. + outputs: [B, S, S, 9 + num_classes] raw logits. grid_size: S. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). Returns: - boxes: [B, S, S, 7] (cx, cy, cz, l, w, h, yaw) in meters - obj: [B, S, S] objectness probability + boxes: [B, S, S, 7] (cx, cy, cz, l, w, h, yaw) in meters + obj: [B, S, S] objectness probability cls_probs: [B, S, S, num_classes] class probabilities """ B, S = outputs.shape[0], grid_size @@ -56,7 +56,7 @@ def decode_grid_3d(outputs, grid_size, pc_range): tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) tz = torch.sigmoid(outputs[..., 3]) - dims = torch.exp(outputs[..., 4:7].clamp(-4.0, 4.0)) # (l, w, h), meters + dims = torch.exp(outputs[..., 4:7].clamp(-4.0, 4.0)) # (l, w, h), meters yaw = torch.atan2(outputs[..., 7], outputs[..., 8]) cls_probs = torch.softmax(outputs[..., 9:], dim=-1) @@ -94,11 +94,11 @@ def __init__( self.pc_range = tuple(pc_range) self.voxel_size = float(voxel_size) self.pad_value = float(pad_value) - self.input_shape = (1, 4096, 4) # padded cloud [B, M, 4] for summaries + self.input_shape = (1, 4096, 4) # padded cloud [B, M, 4] for summaries x_min, y_min, _, x_max, y_max, _ = self.pc_range - self.nx = int(round((x_max - x_min) / voxel_size)) # BEV canvas cols - self.ny = int(round((y_max - y_min) / voxel_size)) # BEV canvas rows + self.nx = int(round((x_max - x_min) / voxel_size)) # BEV canvas cols + self.ny = int(round((y_max - y_min) / voxel_size)) # BEV canvas rows # Channels per head cell: obj(1) + box(8: tx ty tz, log lwh, sin cos) # + class logits(num_classes) @@ -176,7 +176,7 @@ def _augment_points(self, points): cy = y_min + (iy.to(pts.dtype) + 0.5) * self.voxel_size f_center = torch.stack([pts[:, 0] - cx, pts[:, 1] - cy], dim=1) - feats = torch.cat([pts, f_cluster, f_center], dim=1) # [M, 9] + feats = torch.cat([pts, f_cluster, f_center], dim=1) # [M, 9] return feats, flat def _scatter_to_canvas(self, point_feats, pillar_idx): @@ -217,9 +217,9 @@ def forward(self, points): else: canvases.append(self._scatter_to_canvas(chunks.pop(0), flat)) - x = torch.stack(canvases, dim=0) # [B, C, ny, nx] - x = self.backbone(x) # [B, 128, ny/4, nx/4] - out = self.head(x) # [B, preds_per_cell, S', S'] + x = torch.stack(canvases, dim=0) # [B, C, ny, nx] + x = self.backbone(x) # [B, 128, ny/4, nx/4] + out = self.head(x) # [B, preds_per_cell, S', S'] # Resize the feature grid to the configured head grid_size. if out.shape[-1] != self.grid_size or out.shape[-2] != self.grid_size: diff --git a/weightslab/integrations/ultralytics/__init__.py b/weightslab/integrations/ultralytics/__init__.py index 7bb3989e..15f7c57e 100644 --- a/weightslab/integrations/ultralytics/__init__.py +++ b/weightslab/integrations/ultralytics/__init__.py @@ -21,8 +21,8 @@ YOLO(cfg["model"]).train( trainer=WLAwareTrainer, data=cfg["data_root"], imgsz=640, epochs=1000, batch=4, - project="./logs", name="exp", # → WL log_dir/name - workers=0, # WL invariant (parent-process uid counter) + project="./logs", name="exp", # → WL log_dir/name + workers=0, # WL invariant (parent-process uid counter) ) wl.keep_serving() diff --git a/weightslab/integrations/ultralytics/dataset.py b/weightslab/integrations/ultralytics/dataset.py index 4d819b59..ff4be229 100644 --- a/weightslab/integrations/ultralytics/dataset.py +++ b/weightslab/integrations/ultralytics/dataset.py @@ -60,9 +60,9 @@ def fast_get_label(self, i): if shp is None: from PIL import Image as _PIL with _PIL.open(lab["im_file"]) as im: - w0, h0 = im.size # PIL: (w, h) + w0, h0 = im.size # PIL: (w, h) shp = (h0, w0) - lab["shape"] = shp # memoize + lab["shape"] = shp # memoize h0, w0 = shp new = self.imgsz r = min(new / h0, new / w0) diff --git a/weightslab/integrations/ultralytics/signals.py b/weightslab/integrations/ultralytics/signals.py index c1610a8f..e6a733f7 100644 --- a/weightslab/integrations/ultralytics/signals.py +++ b/weightslab/integrations/ultralytics/signals.py @@ -47,12 +47,12 @@ # fallbacks kick in only if `model.args.{conf,iou}` is `None` — which # happens for training because UL only auto-populates those for predict. # -# OVERLAY_CONF_FALLBACK — tiny so early-epoch overlays aren't empty. -# UL's predict default of 0.25 would hide the -# model entirely while it's still learning. -# OVERLAY_IOU_FALLBACK — matches UL's default inference IoU. -# OVERLAY_MAX_DETS — readability cap; UL's NMS otherwise produces -# up to 300 boxes per image, flooding the studio. +# OVERLAY_CONF_FALLBACK — tiny so early-epoch overlays aren't empty. +# UL's predict default of 0.25 would hide the +# model entirely while it's still learning. +# OVERLAY_IOU_FALLBACK — matches UL's default inference IoU. +# OVERLAY_MAX_DETS — readability cap; UL's NMS otherwise produces +# up to 300 boxes per image, flooding the studio. OVERLAY_CONF_FALLBACK = 1e-4 OVERLAY_IOU_FALLBACK = 0.45 OVERLAY_MAX_DETS = 50 @@ -157,7 +157,7 @@ class Signal: `preds=` kwarg. """ name: str - flag: str # "loss" | "metric" + flag: str # "loss" | "metric" reduce: Callable[[dict], Optional[th.Tensor]] preds: Optional[Callable[[dict], Optional[dict]]] = None @@ -211,8 +211,8 @@ def install_val_pipeline(validator, signals: list[Signal]): channels = _make_channels(signals) _orig = validator.update_metrics def _ship(preds, batch): - validator._wl_preds = preds # exposed to signal reducers/predsers - res = _orig(preds, batch) # runs first — fills _process_batch buf + validator._wl_preds = preds # exposed to signal reducers/predsers + res = _orig(preds, batch) # runs first — fills _process_batch buf _ship_round(signals, channels, batch) return res validator.update_metrics = _ship @@ -255,9 +255,9 @@ def default_train_signals(model, signals_cfg: dict = {}) -> list[Signal]: bl = crit.bbox_loss detect_head = next((m for m in model.modules() if isinstance(m, Detect)), None) - get_bce = fwd_hook(crit.bce) # bce is a plain nn.Module - get_iou = fn_tap(ul_loss, "bbox_iou") # bbox_iou is a plain function - get_dfl = method_call_tap(bl, "dfl_loss") # DFLoss overrides __call__ + get_bce = fwd_hook(crit.bce) # bce is a plain nn.Module + get_iou = fn_tap(ul_loss, "bbox_iou") # bbox_iou is a plain function + get_dfl = method_call_tap(bl, "dfl_loss") # DFLoss overrides __call__ get_bl_args = pre_hook(bl) def _fg_state(): diff --git a/weightslab/integrations/ultralytics/trainer.py b/weightslab/integrations/ultralytics/trainer.py index 95a54dfa..6efc7aaa 100644 --- a/weightslab/integrations/ultralytics/trainer.py +++ b/weightslab/integrations/ultralytics/trainer.py @@ -89,7 +89,7 @@ def _validate(loader): except Exception as e: raised_exc = e finally: - trainer.validator.dataloader = val_loader # Reset val loader + trainer.validator.dataloader = val_loader # Reset val loader # Finally raise exc. if raised_exc is not None: @@ -120,20 +120,20 @@ def _on_val_end(validator): return for ul_key, wl_key in ( ("metrics/precision(B)", "val/precision"), - ("metrics/recall(B)", "val/recall"), - ("metrics/mAP50(B)", "val/mAP50"), - ("metrics/mAP50-95(B)", "val/mAP50-95"), - ("fitness", "val/fitness"), + ("metrics/recall(B)", "val/recall"), + ("metrics/mAP50(B)", "val/mAP50"), + ("metrics/mAP50-95(B)", "val/mAP50-95"), + ("fitness", "val/fitness"), ): if ul_key in rd and wl_key in ch: ch[wl_key](torch.tensor([float(rd[ul_key])])) - self.add_callback("on_train_start", _on_train_start) + self.add_callback("on_train_start", _on_train_start) self.add_callback("on_train_batch_start", _on_train_batch_start) - self.add_callback("on_train_batch_end", _on_train_batch_end) - self.add_callback("on_val_batch_start", _on_val_batch_start) - self.add_callback("on_val_batch_end", _on_val_batch_end) - self.add_callback("on_val_end", _on_val_end) + self.add_callback("on_train_batch_end", _on_train_batch_end) + self.add_callback("on_val_batch_start", _on_val_batch_start) + self.add_callback("on_val_batch_end", _on_val_batch_end) + self.add_callback("on_val_end", _on_val_end) def validate(self): # UL's metrics.process does np.concatenate([]) → ValueError when val diff --git a/weightslab/models/model_with_ops.py b/weightslab/models/model_with_ops.py index c3a7b13e..322abcff 100755 --- a/weightslab/models/model_with_ops.py +++ b/weightslab/models/model_with_ops.py @@ -20,9 +20,9 @@ def __init__(self): # Initialize variables self.current_step = 0 - self.visited_nodes = set() # Memory trace of explored nodes - self.visited_incoming_nodes = set() # Memory trace of explored nodes - self.name = self._get_name() # Name of the model + self.visited_nodes = set() # Memory trace of explored nodes + self.visited_incoming_nodes = set() # Memory trace of explored nodes + self.name = self._get_name() # Name of the model self.linearized_layers = [] self._architecture_change_hook_fns = [] self.tracking_mode = TrackingMode.DISABLED @@ -450,7 +450,7 @@ def _operate( elif current_child_name is not None and current_child_name in module.src_to_dst_mapping_tnsrs: kwargs['current_child_name'] = current_child_name else: - kwargs['current_child_name'] = None # Its child is an Orphan node + kwargs['current_child_name'] = None # Its child is an Orphan node # # Operate module.operate( neuron_indices, diff --git a/weightslab/models/monkey_patcher.py b/weightslab/models/monkey_patcher.py index 8a19ce4c..13e7bd79 100644 --- a/weightslab/models/monkey_patcher.py +++ b/weightslab/models/monkey_patcher.py @@ -74,7 +74,7 @@ def wrapped_forward(self, input): data=input ) return output - module.forward = types.MethodType(wrapped_forward, module) # Monkey patch + module.forward = types.MethodType(wrapped_forward, module) # Monkey patch module.is_leaf = True return module diff --git a/weightslab/modules/modules_with_ops.py b/weightslab/modules/modules_with_ops.py index c4c54c7e..a4b96763 100644 --- a/weightslab/modules/modules_with_ops.py +++ b/weightslab/modules/modules_with_ops.py @@ -46,7 +46,7 @@ def __init__( self.module_name = module_name self.device = device self.tracking_mode = TrackingMode.DISABLED - self.operation_age = {op.name: 0 for op in ArchitectureNeuronsOpType} # keep track of all operations performed + self.operation_age = {op.name: 0 for op in ArchitectureNeuronsOpType} # keep track of all operations performed # IN/OUT neurons indexing & mapping dictionary self.src_to_dst_mapping_tnsrs = {} @@ -77,7 +77,7 @@ def __init__( } # Naming - self.assign_id() # assign ids + self.assign_id() # assign ids # Tracking self.register_trackers() @@ -183,7 +183,7 @@ def __hash__(self) -> int: # Trackers Functions # ================== def register_trackers(self): - is_disabled = bool(getattr(self, "wl_same_flag", False)) # Remove SAME layer like BN from neurons stats ..etc + is_disabled = bool(getattr(self, "wl_same_flag", False)) # Remove SAME layer like BN from neurons stats ..etc # Train if self.get_neurons('out_neurons') is not None: @@ -246,7 +246,7 @@ def get_operation( Callable: The operation function. """ if callable(op_type): - return op_type # if already got, just return the fct + return op_type # if already got, just return the fct elif op_type == ArchitectureNeuronsOpType.ADD or \ op_type == ArchitectureNeuronsOpType.ADD.value: return self._add_neurons @@ -435,7 +435,7 @@ def _process_neurons_indices( elif not isinstance(neuron_indices, set): # If it's a single int, wrap; if it's iterable, cast to set try: - neuron_indices = set(neuron_indices) # type: ignore[arg-type] + neuron_indices = set(neuron_indices) # type: ignore[arg-type] except TypeError: neuron_indices = {neuron_indices} @@ -470,7 +470,7 @@ def _process_neurons_indices( if mapped_indices_dict is not None: mapped_indexs = normalize_dicts( {"mapped": mapped_indices_dict} - )["mapped"] # TODO (GP): Improve this function + )["mapped"] # TODO (GP): Improve this function else: # No mapping tensors available: fall back to identity mapping n_neurons = self.get_neurons( @@ -526,7 +526,7 @@ def register( tracker = self.get_tracker() if tracker is None or activation_map is None or input is None: return - activation_map = (activation_map > 0).long() # bool to int + activation_map = (activation_map > 0).long() # bool to int processed_activation_map = th.sum(activation_map, dim=(-2, -1)) if len(activation_map.shape) > 2 else activation_map copy_forward_tracked_attrs(processed_activation_map, activation_map) tracker.update(processed_activation_map) @@ -995,7 +995,7 @@ def _add_neurons( self.related_dst_to_src_mapping_tnsrs[ current_name ].keys() - )[neuron_indice + -1 + length] # get new index + )[neuron_indice + -1 + length] # get new index # Update the mapping tensor with 1 or range(x) neurons self.related_dst_to_src_mapping_tnsrs[ @@ -1013,7 +1013,7 @@ def _add_neurons( channel_size, mapped_neuron_indice * channel_size ) - ) # in range of x neurons + ) # in range of x neurons ] } ) @@ -1042,13 +1042,13 @@ def _add_neurons( self.super_in_name, self.get_neurons(self.super_in_name) + nb_neurons ) - ) # Update neurons count + ) # Update neurons count elif dependency == DepType.SAME: if self.get_neurons(self.super_out_name) is not None: self.set_neurons( attr_name='in_neurons', new_value=self.get_neurons(self.super_out_name) - ) # Update neurons count + ) # Update neurons count # By default get deps name from current relation deps_names = list(self.dst_to_src_mapping_tnsrs.keys()) @@ -1071,7 +1071,7 @@ def _add_neurons( ) if index >= len(mapped_neuron_indice): logger.warning( - f"Index {index} out of range for " + + f"Index {index} out of range for " + f"mapped_neuron_indice with length " + f"{len(mapped_neuron_indice)}" ) @@ -1086,8 +1086,8 @@ def _add_neurons( self.dst_to_src_mapping_tnsrs[ deps_name ][mapped_neuron_indice] - ) for i in range(0, nb_neurons) # neurons - ] for j in range(0, nb_neurons) # neurons | chan. + ) for i in range(0, nb_neurons) # neurons + ] for j in range(0, nb_neurons) # neurons | chan. } ) @@ -1195,7 +1195,7 @@ def _prune_neurons( f"overlap: {neuron_indices} & {neurons} => " f"{neuron_indices & neurons}" ) - return # Do not change + return # Do not change # # Enough neurons to operate if len(neurons) <= 1: @@ -1221,7 +1221,7 @@ def _prune_neurons( ) self.weight = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if self.weight.grad is not None: with th.no_grad(): @@ -1232,7 +1232,7 @@ def _prune_neurons( ) self.weight.grad = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if hasattr(self, 'bias') and self.bias is not None and \ not is_incoming: @@ -1244,7 +1244,7 @@ def _prune_neurons( ) self.bias.data = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if self.bias.grad is not None: with th.no_grad(): @@ -1255,7 +1255,7 @@ def _prune_neurons( ) self.bias.grad = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if hasattr(self, 'running_mean'): tmp_tsnr = th.index_select( @@ -1263,7 +1263,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_mean = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_mean = tmp_tsnr.clone().detach().to(self.device) # Safe approach if self.running_mean.grad is not None: tmp_tsnr = th.index_select( @@ -1271,7 +1271,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_mean.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_mean.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach if hasattr(self, 'running_var'): tmp_tsnr = th.index_select( @@ -1279,7 +1279,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_var = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_var = tmp_tsnr.clone().detach().to(self.device) # Safe approach if self.running_var.grad is not None: tmp_tsnr = th.index_select( @@ -1287,7 +1287,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_var.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_var.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach # Sort indices to prune from last to first to maintain # the original order @@ -1385,12 +1385,12 @@ def _prune_neurons( self.super_in_name, len(idx_tokeep) ) - ) # Update neurons count + ) # Update neurons count elif dependency == DepType.SAME: self.set_neurons( attr_name='in_neurons', new_value=self.get_neurons(self.super_out_name) - ) # Update neurons count + ) # Update neurons count # By default get deps name from current relation deps_names = self.dst_to_src_mapping_tnsrs.keys() @@ -1494,7 +1494,7 @@ def _freeze_neurons( # Work on the output tensors_name = self.learnable_tensors_name if not is_incoming \ - else ['weight'] # Weight is the only learnable tensor input + else ['weight'] # Weight is the only learnable tensor input for tensor_name in tensors_name: neurons_lr = { neuron_indices[n]: diff --git a/weightslab/proto/experiment_service_pb2.py b/weightslab/proto/experiment_service_pb2.py index 6a521ac8..14d171a0 100644 --- a/weightslab/proto/experiment_service_pb2.py +++ b/weightslab/proto/experiment_service_pb2.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! +# Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: weightslab/proto/experiment_service.proto # Protobuf Python Version: 6.31.1 diff --git a/weightslab/security/cert_auth_manager.py b/weightslab/security/cert_auth_manager.py index f4b5331e..358dd86b 100644 --- a/weightslab/security/cert_auth_manager.py +++ b/weightslab/security/cert_auth_manager.py @@ -193,7 +193,7 @@ def get_or_create_auth_token(self) -> str: try: os.chmod(self.token_file, 0o600) except Exception: - pass # Windows doesn't support chmod + pass # Windows doesn't support chmod logger.info(f"Wrote gRPC auth token to {self.token_file}") except Exception as e: logger.error(f"Could not save token: {e}") diff --git a/weightslab/src.py b/weightslab/src.py index 44470efc..35c1a420 100644 --- a/weightslab/src.py +++ b/weightslab/src.py @@ -43,11 +43,11 @@ def _rebind_caller_local(original_obj: Any, new_obj: Any) -> None: This lets ``wl.watch_or_edit(parameters, ...)`` (without capturing the return value) transparently replace ``parameters`` with the returned Proxy in the - calling scope. Silently does nothing on non-CPython runtimes. + calling scope. Silently does nothing on non-CPython runtimes. """ try: # frame 0 = _rebind_caller_local - # frame 1 = watch_or_edit (or whatever internal caller) + # frame 1 = watch_or_edit (or whatever internal caller) # frame 2 = user code frame = sys._getframe(2) changed = False @@ -217,26 +217,26 @@ def _get_step(step: int | None = None) -> int: if m is not None: # Safe attribute access (handle Proxy returning None for missing attr) if hasattr(m, 'get_age'): - val = m.get_age() -1 # At this point, model already saw one batch, except if we started by evaluation + val = m.get_age() -1 # At this point, model already saw one batch, except if we started by evaluation if val is not None: - step = max([int(val), 0]) # Use age-1 as step to reflect completed step; ensure non-negative + step = max([int(val), 0]) # Use age-1 as step to reflect completed step; ensure non-negative elif hasattr(m, 'current_step'): val = m.current_step if val is not None: - step = max([int(val), 0]) # Use current_step-1 as step to reflect completed step; ensure non-negative + step = max([int(val), 0]) # Use current_step-1 as step to reflect completed step; ensure non-negative elif step is not None: # step = step # fallback to provided step - m.current_step = step # add current_step attribute to model for future tracking + m.current_step = step # add current_step attribute to model for future tracking - m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType + m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType elif step is not None: # If model doesn't have current_step, force it to 0 or try to infer from checkpoint manager - m.current_step = step # add current_step attribute to model for future tracking - m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType + m.current_step = step # add current_step attribute to model for future tracking + m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType return step @@ -334,7 +334,7 @@ def _log_signal(scalar: float, signal_per_sample: dict, reg_name: str, step: int {reg_name: scalar}, global_step=step, signal_per_sample=signal_per_sample, - aggregate_by_step=kwargs.get('per_sample', True) # Aggregate per-sample signals by step for logging if per_sample is True, + aggregate_by_step=kwargs.get('per_sample', True) # Aggregate per-sample signals by step for logging if per_sample is True, ) except Exception: traceback.print_exc() @@ -516,9 +516,9 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): if instance_batch_idx is None and 'batch_idx' in kw: instance_batch_idx = kw['batch_idx'] elif instance_batch_idx is None and targets is not None and isinstance(targets, list): - instance_batch_idx = [i for i, tars in enumerate(targets) for _ in tars] # Auto determine batch_idx from targets if not explicitly provided (assumes targets is list of lists of annotations) + instance_batch_idx = [i for i, tars in enumerate(targets) for _ in tars] # Auto determine batch_idx from targets if not explicitly provided (assumes targets is list of lists of annotations) else: - instance_batch_idx = ledgers.get_dataframe()._df.loc[batch_ids].index.get_level_values(1).tolist() # Query directly instance_ids related and ordered to the samples_ids in the batch + instance_batch_idx = ledgers.get_dataframe()._df.loc[batch_ids].index.get_level_values(1).tolist() # Query directly instance_ids related and ordered to the samples_ids in the batch batch_ids = ledgers.get_dataframe()._df.loc[batch_ids].index.get_level_values(0).tolist() # If output is a dict (from PerInstanceDetectionLoss), pick 'sample' @@ -533,7 +533,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): if kwargs.get('per_sample', False) and not isinstance(out, dict): if hasattr(out, 'ndim') and out.ndim > 1: - out = out.mean(dim=tuple(range(1, out.ndim))) # Reduce to [B,]0 + out = out.mean(dim=tuple(range(1, out.ndim))) # Reduce to [B,]0 # Extract scalar from tensor scalar, batch_scalar = _extract_scalar_from_tensor(batch_scalar, out, batch_ids) @@ -550,7 +550,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): batch_idx=instance_batch_idx, targets=targets, step=step, - log=False, # already logged sample-level above + log=False, # already logged sample-level above ) except Exception as e: traceback.print_exc() if os.environ.get('WEIGHTSLAB_LOG_LEVEL') == 'DEBUG' else None @@ -640,7 +640,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): origin=kwargs.get('origin', 'train') ) try: - res = func(ctx) # Compute per sample result with unified context + res = func(ctx) # Compute per sample result with unified context except TypeError: # Fallback for legacy subscriber functions res = func(sample_id=int(uid), value=val, dataframe=df_proxy) @@ -650,10 +650,10 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): dynamic_updates[name] = signal_value if dynamic_updates and meta.get('log', True): logger.debug(f"Dynamic updates computed for signal '{reg_name}': {list(dynamic_updates.keys())}") - _log_signal(sum(signal_value)/len(signal_value), signal_value, name, step=step, **kwargs) # Log custom subscribed signals + _log_signal(sum(signal_value)/len(signal_value), signal_value, name, step=step, **kwargs) # Log custom subscribed signals except Exception as e: logger.debug(f"Dynamic signal {name} failed: {e}") - pass # User function error, skip + pass # User function error, skip # Save statistics if requested and applicable. # Skip the per-sample save path when per_instance=True — instance values @@ -676,7 +676,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): preds_raw=preds_raw, preds=preds, targets=targets, - log=False # Already logged above, no need to log again in save_signals; set to False to avoid duplicate logging if save_signals is called separately without logging + log=False # Already logged above, no need to log again in save_signals; set to False to avoid duplicate logging if save_signals is called separately without logging ) # Return the original output (dict for per-instance losses so caller can @@ -748,7 +748,7 @@ def watch_or_edit(obj: Callable, obj_name: str = None, flag: str = None, **kwarg forced_model_wrapping = kwargs.pop('forced_model_wrapping', False) # Now construct the wrapper and let it register into the ledger. - wrapper = ModelInterface(obj, **kwargs) if forced_model_wrapping or _model == None else _model + wrapper = ModelInterface(obj, **kwargs) if forced_model_wrapping or _model == None else _model # No rebind here since the model wrapper is designed to be a drop-in replacement for the original model @@ -791,11 +791,11 @@ def watch_or_edit(obj: Callable, obj_name: str = None, flag: str = None, **kwarg if 'loader_name' not in kwargs and 'name' in kwargs: kwargs['loader_name'] = kwargs['name'] except Exception: - pass # If we can't get hyperparameters, continue without root_log_dir + pass # If we can't get hyperparameters, continue without root_log_dir # Now construct the wrapper and let it register into the ledger. wrapper = DataLoaderInterface(obj, **kwargs) - _dataloader.__pl_saved_kwargs = kwargs # Force pytorch lightning compatibility + _dataloader.__pl_saved_kwargs = kwargs # Force pytorch lightning compatibility # There is not rebind here because obj can be a dataloader or a dataset @@ -981,7 +981,7 @@ def new_forward(*a, **kw): logger.info(f"Loaded hyperparameters from checkpoint {latest_hash[:16]}") checkpoint_hp_loaded = True except Exception: - pass # If checkpoint loading fails, proceed with normal registration + pass # If checkpoint loading fails, proceed with normal registration defaults = kwargs.get('defaults', None) if not checkpoint_hp_loaded: @@ -1071,7 +1071,7 @@ def start_training(timeout: int = None) -> None: if timeout is not None and isinstance(timeout, int) and timeout > 0: logger.info(f"Starting WeightsLab training mode with a timeout of {timeout} seconds.") time.sleep(timeout) - pause_ctrl.resume() # Ensure we're not paused if start_training is called after serve + pause_ctrl.resume() # Ensure we're not paused if start_training is called after serve def serve(serving_cli: bool = False, serving_grpc: bool = False, **kwargs) -> None: """Start WeightsLab services. @@ -1742,14 +1742,14 @@ def save_signals( Examples: Classification — one loss scalar per image:: - for inputs, targets, ids in train_loader: # ids: sample IDs, len B - logits = model(inputs) # (B, num_classes) - loss = loss_fn(logits, targets) # (B,) per-sample loss + for inputs, targets, ids in train_loader: # ids: sample IDs, len B + logits = model(inputs) # (B, num_classes) + loss = loss_fn(logits, targets) # (B,) per-sample loss wl.save_signals( - signals={"train_loss": loss}, # (B,) -> signals//train_loss + signals={"train_loss": loss}, # (B,) -> signals//train_loss batch_ids=ids, - preds_raw=logits, # (B, num_classes) - targets=targets, # (B,) + preds_raw=logits, # (B, num_classes) + targets=targets, # (B,) step=current_step, log=True, ) @@ -1757,7 +1757,7 @@ def save_signals( Several named per-sample metrics at once:: wl.save_signals( - signals={"iou": iou_per_image, "dice": dice_per_image}, # each (B,) + signals={"iou": iou_per_image, "dice": dice_per_image}, # each (B,) batch_ids=ids, ) """ @@ -1814,9 +1814,9 @@ def expand_dim(x): return x[:, np.newaxis] return x - preds_np = normalize(preds) + preds_np = normalize(preds) preds_raw_np = normalize(preds_raw) - target_np = normalize(targets) + target_np = normalize(targets) # Processing signals if isinstance(signals, dict): @@ -1836,8 +1836,8 @@ def expand_dim(x): losses_data = None # Expand dims for 1D arrays (skipped for lists) - target_np = expand_dim(target_np) - preds_np = expand_dim(preds_np) + target_np = expand_dim(target_np) + preds_np = expand_dim(preds_np) preds_raw_np = expand_dim(preds_raw_np) # Enqueue to dataframe manager buffer for efficiency @@ -1890,29 +1890,29 @@ def save_instance_signals( Worked example — ``batch_ids = ["img7", "img3"]`` (B = 2), 5 boxes total:: - # box: 0 1 2 3 4 - batch_idx = [ 0, 0, 1, 1, 1 ] # boxes 0-1 -> img7, 2-4 -> img3 - ious = [0.91, 0.62, 0.50, 0.74, 0.30] # one IoU per box + # box: 0 1 2 3 4 + batch_idx = [ 0, 0, 1, 1, 1 ] # boxes 0-1 -> img7, 2-4 -> img3 + ious = [0.91, 0.62, 0.50, 0.74, 0.30] # one IoU per box wl.save_instance_signals( - signals={"iou_instance": ious}, # -> signals//iou_instance + signals={"iou_instance": ious}, # -> signals//iou_instance batch_ids=["img7", "img3"], batch_idx=batch_idx, origin="train", ) # writes: - # ("img7", 1)=0.91 ("img7", 2)=0.62 - # ("img3", 1)=0.50 ("img3", 2)=0.74 ("img3", 3)=0.30 + # ("img7", 1)=0.91 ("img7", 2)=0.62 + # ("img3", 1)=0.50 ("img3", 2)=0.74 ("img3", 3)=0.30 Typical detection loop using the Ultralytics batch dict directly:: image, batch_ids, batch = inputs[0], inputs[1], inputs[3]["batch"] raw_preds = model(image) - iou_per_box = compute_iou(raw_preds, batch) # flat [total_instances] + iou_per_box = compute_iou(raw_preds, batch) # flat [total_instances] wl.save_instance_signals( signals={"iou_instance": iou_per_box}, batch_ids=batch_ids, - batch_idx=batch["batch_idx"], # Ultralytics flat index + batch_idx=batch["batch_idx"], # Ultralytics flat index step=current_step, ) @@ -1921,9 +1921,9 @@ def save_instance_signals( in the same per-sample order ``batch_idx`` implies. It is flattened sample-major to align with the instances:: - targets = [ # batch_ids = ["img7", "img3"] - [box7_0, box7_1], # img7's two boxes -> annotation_id 1, 2 - [box3_0, box3_1, box3_2], # img3's three boxes -> annotation_id 1, 2, 3 + targets = [ # batch_ids = ["img7", "img3"] + [box7_0, box7_1], # img7's two boxes -> annotation_id 1, 2 + [box3_0, box3_1, box3_2], # img3's three boxes -> annotation_id 1, 2, 3 ] wl.save_instance_signals(signals={"iou_instance": ious}, batch_ids=["img7", "img3"], @@ -2096,7 +2096,7 @@ def get_active_group_mask( Example:: # Cosine embedding loss — one value per pair in the batch - loss_embed = loss_cosine(e1, e2, y) # shape: (B/2,) + loss_embed = loss_cosine(e1, e2, y) # shape: (B/2,) group_mask = wl.get_active_group_mask(group_ids, origin="train_loader") # Zero out tainted pairs so they don't update weights n_active = group_mask.sum().clamp(min=1) @@ -2119,7 +2119,7 @@ def get_active_group_mask( if gid in tainted: mask[i] = 0.0 except Exception: - pass # Fail-safe: if check fails, treat all groups as active + pass # Fail-safe: if check fails, treat all groups as active return mask @@ -2237,14 +2237,14 @@ def save_group_signals( try: tainted_group_ids = DATAFRAME_M.get_tainted_group_ids(group_ids, origin) except Exception: - pass # Never block training on best-effort discard check + pass # Never block training on best-effort discard check # Broadcast to all members in ledger (skip tainted groups) all_updates = [] active_group_ids = [] for i, gid in enumerate(group_ids): if gid in tainted_group_ids: - continue # Skip: at least one member was discarded; group loss is undefined + continue # Skip: at least one member was discarded; group loss is undefined # We also record the last seen step for all members updates = scalar_signals.copy() @@ -2258,7 +2258,7 @@ def save_group_signals( active_group_ids.append(gid) if not active_group_ids: - return # All groups were tainted; nothing to write + return # All groups were tainted; nothing to write # Bulk update for performance (avoids repeated dataframe scans) DATAFRAME_M.update_by_groups_bulk(origin=origin, group_ids=active_group_ids, updates_list=all_updates) @@ -2331,7 +2331,7 @@ def _unpack_batch(batch, device=None): def _make_default_eval_fn(model): """Return a default evaluation callable that uses all registered ledger signals. - This is used when no ``@wl.eval_fn`` decorator was applied. For every + This is used when no ``@wl.eval_fn`` decorator was applied. For every batch it: 1. Unpacks ``(inputs, targets, ids)`` using a heuristic (tuple/list/dict). @@ -2342,7 +2342,7 @@ def _make_default_eval_fn(model): evaluation-mode buffer. Loss-style signals (wrapped ``forward``) and metric-style signals - (wrapped ``compute``) are both handled. Per-signal errors are silently + (wrapped ``compute``) are both handled. Per-signal errors are silently skipped so a missing target or shape mismatch does not abort the whole evaluation. """ @@ -2393,7 +2393,7 @@ def _default_eval(loader): except Exception: pass - preds = model(inputs) # infer predictions + preds = model(inputs) # infer predictions # Call each registered signal so its wrapped forward/compute # fires and feeds into the evaluation-mode logger buffer. @@ -2457,8 +2457,8 @@ def eval_fn(func): The decorated function receives a single *loader* argument — a ``_EvalManagedLoader`` wrapping the requested split's - ``DataLoaderInterface``. It should iterate that loader and compute - the watched criteria / metrics exactly as in a normal test pass. All + ``DataLoaderInterface``. It should iterate that loader and compute + the watched criteria / metrics exactly as in a normal test pass. All ``add_scalars`` calls are intercepted by the logger's evaluation-mode buffer. @@ -2487,8 +2487,8 @@ def pointcloud_thumbnail(func): own function, e.g. a range/spherical projection: @wl.pointcloud_thumbnail - def to_range_image(points): # points: [M, 2..F] float - return my_range_projection(points) # -> (H, W, 3) uint8 or PIL.Image + def to_range_image(points): # points: [M, 2..F] float + return my_range_projection(points) # -> (H, W, 3) uint8 or PIL.Image (Note: ``@wl.3d_pc_thumb`` isn't valid Python — identifiers can't start with a digit — so the verb is spelled out.) A ``render_thumbnail_2d`` @@ -2508,7 +2508,7 @@ def pointcloud_boxes(func): @wl.pointcloud_boxes def boxes_to_range(boxes): - return my_boxes_in_range_frame(boxes) # -> [N, 6] normalized + return my_boxes_in_range_frame(boxes) # -> [N, 6] normalized A ``project_boxes_2d`` method on the dataset takes precedence. """ @@ -2587,19 +2587,19 @@ def run_pending_evaluation( Can still be called from the training loop with explicit arguments for backwards-compatibility:: - if wl.run_pending_evaluation(): # ledger mode — no args needed + if wl.run_pending_evaluation(): # ledger mode — no args needed continue Args: - loaders: Optional mapping of *loader_name* → ``DataLoaderInterface``. + loaders: Optional mapping of *loader_name* → ``DataLoaderInterface``. When ``None``, the loader is looked up by split name from the ledger. - model: Optional tracked model instance (used to read ``get_age()``). + model: Optional tracked model instance (used to read ``get_age()``). When ``None``, resolved from the ledger. - eval_fn: Optional callable with signature ``eval_fn(loader) -> None``. + eval_fn: Optional callable with signature ``eval_fn(loader) -> None``. When ``None``, the function registered via ``@wl.eval_fn`` is used. - device: Unused; kept for API symmetry. + device: Unused; kept for API symmetry. Returns: ``True`` if an evaluation was executed (caller should ``continue`` @@ -2840,7 +2840,7 @@ def run_pending_evaluation( model_age = 0 try: - model_age = _model.get_age() - 1 if _model is not None and hasattr(_model, "get_age") else 0 # Model anticipates a step after eval, so subtract 1 to report the age corresponding to the just-evaluated checkpoint. + model_age = _model.get_age() - 1 if _model is not None and hasattr(_model, "get_age") else 0 # Model anticipates a step after eval, so subtract 1 to report the age corresponding to the just-evaluated checkpoint. except Exception: pass @@ -2909,42 +2909,42 @@ def run_pending_evaluation( logger.info(f"\n{'='*70}") logger.info(f"[WeightsLab] Evaluation Results") logger.info(f"{'='*70}") - logger.info(f" Split: {split_name}") - logger.info(f" Model Step: {model_age}") - logger.info(f" Tags: {tags}") - logger.info(f" Total Samples: {filtered_count if filtered_count is not None else 'unknown'}") - logger.info(f" Total Batches: {total_batches}") - logger.info(f" Eval Hash: {eval_hash}") + logger.info(f" Split: {split_name}") + logger.info(f" Model Step: {model_age}") + logger.info(f" Tags: {tags}") + logger.info(f" Total Samples: {filtered_count if filtered_count is not None else 'unknown'}") + logger.info(f" Total Batches: {total_batches}") + logger.info(f" Eval Hash: {eval_hash}") if result: - logger.info(f" Metrics:\n") + logger.info(f" Metrics:\n") for k, v in result.items(): if isinstance(v, float): - logger.info(f" {k:30s} = {v:.6f}") + logger.info(f" {k:30s} = {v:.6f}") else: - logger.info(f" {k:30s} = {v}") + logger.info(f" {k:30s} = {v}") else: - logger.info(f" Status: No metrics recorded") + logger.info(f" Status: No metrics recorded") error_msg = ( f"Evaluation did not produce any metrics.\n" - f" Possible causes:\n" - f" • Evaluation function is not compatible with the experiment setup\n" - f" • No signals were computed during evaluation\n" - f" • Model or data loader not registered in the ledger\n\n" - f" Solution: Create a custom evaluation function decorated with @wl.eval_fn.\n" - f" This function should:\n" - f" 1. Accept only one parameter: loader\n" - f" 2. Be fully based on the WeightsLab ledger\n" - f" 3. Retrieve model, device, and metrics from wl.ledger.*\n" - f" 4. Register loss/metric functions with wl.watch_or_edit(..., flag='loss/metric')\n\n" - f" Example from detection use case:\n" - f" @wl.eval_fn\n" - f" def validate(loader):\n" - f" model = wl.ledger.get_model()\n" - f" device = wl.ledger.get_device()\n" - f" for batch in loader:\n" - f" ...\n\n" - f" See documentation: https://grayboxtech.github.io/weightslab/latest/index.html" + f" Possible causes:\n" + f" • Evaluation function is not compatible with the experiment setup\n" + f" • No signals were computed during evaluation\n" + f" • Model or data loader not registered in the ledger\n\n" + f" Solution: Create a custom evaluation function decorated with @wl.eval_fn.\n" + f" This function should:\n" + f" 1. Accept only one parameter: loader\n" + f" 2. Be fully based on the WeightsLab ledger\n" + f" 3. Retrieve model, device, and metrics from wl.ledger.*\n" + f" 4. Register loss/metric functions with wl.watch_or_edit(..., flag='loss/metric')\n\n" + f" Example from detection use case:\n" + f" @wl.eval_fn\n" + f" def validate(loader):\n" + f" model = wl.ledger.get_model()\n" + f" device = wl.ledger.get_device()\n" + f" for batch in loader:\n" + f" ...\n\n" + f" See documentation: https://grayboxtech.github.io/weightslab/latest/index.html" ) logger.warning(error_msg) @@ -2985,7 +2985,7 @@ def _build_eval_allow_list(loader_if, tags: list, split_name: str) -> set: for tag in tags: col = f"{SampleStatsEx.TAG.value}:{tag}" if col in df.columns: - col_mask = df[col] == True # noqa: E712 + col_mask = df[col] == True # noqa: E712 mask = col_mask if mask is None else (mask & col_mask) if mask is None: @@ -3172,7 +3172,7 @@ def __iter__(self): def get_current_experiment_hash() -> str | None: """Return the hash of the currently active experiment run. - Reads the hash from the registered checkpoint manager. Returns ``None`` + Reads the hash from the registered checkpoint manager. Returns ``None`` when no experiment is active or no checkpoint manager has been registered. Example:: @@ -3291,7 +3291,7 @@ def write_history( Parameters ---------- path : str, optional - Output file path **or** directory. When omitted (``None``), the + Output file path **or** directory. When omitted (``None``), the ``root_log_dir`` from the active checkpoint manager is used as the output directory. @@ -3301,14 +3301,14 @@ def write_history( is auto-generated as ``_history.`` inside that directory, where ```` is an 8-character hex MD5 of the normalized call parameters (*type_of_history*, *graph_name*, - *experiment_hash*, *sample_id*, *instance_id*). The same filter + *experiment_hash*, *sample_id*, *instance_id*). The same filter combination always produces the same filename; different filters produce different filenames. - The directory is created automatically if it does not exist. format : {"json", "csv"} Output format (default ``"json"``). type_of_history : {None, "all", "global", "sample", "instance", "instances"} - Which history to include. ``None`` or ``"all"`` writes every type. + Which history to include. ``None`` or ``"all"`` writes every type. ``"global"`` writes the aggregated training-curve history. ``"sample"`` writes per-sample history. ``"instance"`` / ``"instances"`` writes per-instance history. @@ -3316,7 +3316,7 @@ def write_history( Restrict to one or more signal / metric names. experiment_hash : str, optional ``None`` (default) — use the current experiment hash from the - checkpoint manager. ``"all"`` — include every hash. + checkpoint manager. ``"all"`` — include every hash. Any other string — restrict to that specific experiment run. sample_id : str or list of str, optional Restrict per-sample and per-instance rows to one or more sample IDs. @@ -3389,9 +3389,9 @@ def write_history( # --- Normalize all parameters first (needed for the auto-filename hash) --- # Resolve experiment_hash: - # None → use the current hash from the checkpoint manager (default) - # "all" → no filter, include every hash - # any str → filter to that specific hash + # None → use the current hash from the checkpoint manager (default) + # "all" → no filter, include every hash + # any str → filter to that specific hash if experiment_hash is None or experiment_hash == 'last': try: _current = ( @@ -3403,7 +3403,7 @@ def write_history( except Exception: experiment_hash = None elif experiment_hash == "all": - experiment_hash = None # sentinel: skip hash filtering below + experiment_hash = None # sentinel: skip hash filtering below # Normalize graph_name → set or None _gn_filter = None @@ -3598,19 +3598,19 @@ def write_dataframe( Parameters ---------- path : str, optional - Output file path **or** directory. When omitted (``None``), the + Output file path **or** directory. When omitted (``None``), the ``root_log_dir`` from the active checkpoint manager is used. - If *path* has a file extension the file is written directly. - If *path* has no extension or is an existing directory, a filename is auto-generated as ``_dataframe.`` inside that directory. ```` is an 8-character MD5 hex digest of the normalized call - parameters (*columns*, *sample_id*, *instance_id*). Same filters → + parameters (*columns*, *sample_id*, *instance_id*). Same filters → same filename (idempotent overwrite); different filters → different file. - The directory is created automatically if it does not exist. format : {"json", "csv"} - Output format. Default ``"json"``. + Output format. Default ``"json"``. columns : str or list of str, optional Which columns to include (index levels ``sample_id`` / ``annotation_id`` are always written). @@ -3623,10 +3623,10 @@ def write_dataframe( - ``"discarded"`` — only the boolean ``discarded`` column. - A list of any mix of the above group names and/or exact column names. sample_id : str or list of str, optional - Restrict to one or more sample IDs (index level 0). ``None`` keeps all. + Restrict to one or more sample IDs (index level 0). ``None`` keeps all. instance_id : int or list of int, optional Restrict to one or more annotation IDs (index level 1, 0 = sample row, - ≥ 1 = per-instance rows). ``None`` keeps all. + ≥ 1 = per-instance rows). ``None`` keeps all. Returns ------- @@ -3636,7 +3636,7 @@ def write_dataframe( Notes ----- The function calls ``flush()`` on the dataframe manager before reading so - that any in-flight writes are included in the output. Pass + that any in-flight writes are included in the output. Pass ``instance_id=0`` to keep only sample-level rows; pass ``instance_id=[1,2]`` to keep specific annotation rows. @@ -3765,7 +3765,7 @@ def write_dataframe( mask = df_out.index.get_level_values(1).astype(int).isin(_iid_set) df_out = df_out.loc[mask] except Exception: - pass # non-integer annotation_ids — skip this filter + pass # non-integer annotation_ids — skip this filter logger.debug("write_dataframe: after instance_id filter → %d row(s).", len(df_out)) # Filter columns by group or exact name @@ -3789,7 +3789,7 @@ def write_dataframe( else: if _item in df_out.columns: _selected.append(_item) - _selected = list(dict.fromkeys(_selected)) # deduplicate, preserve order + _selected = list(dict.fromkeys(_selected)) # deduplicate, preserve order df_out = df_out[_selected] if _selected else df_out[[]] logger.debug("write_dataframe: column filter → %d column(s): %s", len(_selected), _selected) diff --git a/weightslab/tests/backend/test_compare_dataloaders.py b/weightslab/tests/backend/test_compare_dataloaders.py index 245b413c..b17469c3 100644 --- a/weightslab/tests/backend/test_compare_dataloaders.py +++ b/weightslab/tests/backend/test_compare_dataloaders.py @@ -62,7 +62,7 @@ def setUp(self): # worker parallelism measurable for the multi-worker throughput test. self.dataset_size = 256 self.batch_size = 32 - self.delay_per_sample = 0.01 # 10ms per sample to justify worker overhead + self.delay_per_sample = 0.01 # 10ms per sample to justify worker overhead pause_controller.resume() def _create_torch_dataloader(self, num_workers=0): @@ -123,7 +123,7 @@ def test_single_worker_correctness(self): self.assertTrue(torch.equal(torch_target, wl_target), f"Batch {i} target mismatch") - print(f"✓ Single worker: {len(torch_batches)} batches match perfectly") + print(f" Single worker: {len(torch_batches)} batches match perfectly") @_SKIP_MULTIWORKER_ON_WIN def test_multi_worker_correctness(self): @@ -141,7 +141,7 @@ def test_multi_worker_correctness(self): for batch in torch_loader: torch_batches.append(batch) - wl_loader.reset_iterator() # Reset for fresh iteration + wl_loader.reset_iterator() # Reset for fresh iteration for batch in wl_loader: wl_batches.append(batch) @@ -158,67 +158,67 @@ def test_multi_worker_correctness(self): self.assertTrue(torch.allclose(torch_sorted, wl_sorted), "All data samples must be present") - print(f"✓ Multi-worker: {len(torch_batches)} batches, all samples present") + print(f" Multi-worker: {len(torch_batches)} batches, all samples present") # def test_throughput_comparison(self): - # """Compare throughput: single worker vs multi-worker.""" - # print("\n" + "="*70) - # print("TEST: Throughput Comparison") - # print("="*70) - - # results = {} - - # # Torch DataLoader: Single Worker - # torch_loader = self._create_torch_dataloader(num_workers=0) - # start = time.time() - # for _ in torch_loader: - # pass - # torch_single_time = time.time() - start - # results['PyTorch (1 worker)'] = torch_single_time - - # # Torch DataLoader: Multiple Workers - # torch_loader = self._create_torch_dataloader(num_workers=4) - # start = time.time() - # for _ in torch_loader: - # pass - # torch_multi_time = time.time() - start - # results['PyTorch (4 workers)'] = torch_multi_time - - # # WeightsLab DataLoaderInterface: Single Worker - # wl_loader = self._create_weightslab_dataloader(num_workers=0) - # start = time.time() - # for _ in wl_loader: - # pass - # wl_single_time = time.time() - start - # results['WeightsLab (1 worker)'] = wl_single_time - - # # WeightsLab DataLoaderInterface: Multiple Workers - # wl_loader = self._create_weightslab_dataloader(num_workers=4) - # wl_loader.reset_iterator() # Ensure fresh start - # start = time.time() - # for _ in wl_loader: - # pass - # wl_multi_time = time.time() - start - # results['WeightsLab (4 workers)'] = wl_multi_time - - # # Print comparison - # print("\nThroughput Results (loading {} batches):".format(self.dataset_size // self.batch_size)) - # print("-" * 70) - # for name, elapsed in results.items(): - # throughput = (self.dataset_size / self.batch_size) / elapsed if elapsed > 0 else 0 - # print(f"{name:35} {elapsed:8.3f}s ({throughput:6.2f} batches/sec)") - - # print("-" * 70) - # speedup_wl = results['WeightsLab (1 worker)'] / results['WeightsLab (4 workers)'] - # speedup_torch = results['PyTorch (1 worker)'] / results['PyTorch (4 workers)'] - - # print(f"Multi-worker speedup:") - # print(f" PyTorch: {speedup_torch:.2f}x faster") - # print(f" WeightsLab: {speedup_wl:.2f}x faster") - - # # Verify multi-worker is faster than single-worker - # self.assertGreater(wl_single_time, wl_multi_time * 0.8, - # "Multi-worker should be faster or comparable to single-worker") + # """Compare throughput: single worker vs multi-worker.""" + # print("\n" + "="*70) + # print("TEST: Throughput Comparison") + # print("="*70) + + # results = {} + + # # Torch DataLoader: Single Worker + # torch_loader = self._create_torch_dataloader(num_workers=0) + # start = time.time() + # for _ in torch_loader: + # pass + # torch_single_time = time.time() - start + # results['PyTorch (1 worker)'] = torch_single_time + + # # Torch DataLoader: Multiple Workers + # torch_loader = self._create_torch_dataloader(num_workers=4) + # start = time.time() + # for _ in torch_loader: + # pass + # torch_multi_time = time.time() - start + # results['PyTorch (4 workers)'] = torch_multi_time + + # # WeightsLab DataLoaderInterface: Single Worker + # wl_loader = self._create_weightslab_dataloader(num_workers=0) + # start = time.time() + # for _ in wl_loader: + # pass + # wl_single_time = time.time() - start + # results['WeightsLab (1 worker)'] = wl_single_time + + # # WeightsLab DataLoaderInterface: Multiple Workers + # wl_loader = self._create_weightslab_dataloader(num_workers=4) + # wl_loader.reset_iterator() # Ensure fresh start + # start = time.time() + # for _ in wl_loader: + # pass + # wl_multi_time = time.time() - start + # results['WeightsLab (4 workers)'] = wl_multi_time + + # # Print comparison + # print("\nThroughput Results (loading {} batches):".format(self.dataset_size // self.batch_size)) + # print("-" * 70) + # for name, elapsed in results.items(): + # throughput = (self.dataset_size / self.batch_size) / elapsed if elapsed > 0 else 0 + # print(f"{name:35} {elapsed:8.3f}s ({throughput:6.2f} batches/sec)") + + # print("-" * 70) + # speedup_wl = results['WeightsLab (1 worker)'] / results['WeightsLab (4 workers)'] + # speedup_torch = results['PyTorch (1 worker)'] / results['PyTorch (4 workers)'] + + # print(f"Multi-worker speedup:") + # print(f" PyTorch: {speedup_torch:.2f}x faster") + # print(f" WeightsLab: {speedup_wl:.2f}x faster") + + # # Verify multi-worker is faster than single-worker + # self.assertGreater(wl_single_time, wl_multi_time * 0.8, + # "Multi-worker should be faster or comparable to single-worker") @_SKIP_MULTIWORKER_ON_WIN def test_correctness_with_reset(self): @@ -233,7 +233,7 @@ def test_correctness_with_reset(self): first_iteration = [] for i, batch in enumerate(wl_loader): first_iteration.append(batch[0].clone()) - if i >= 5: # Just collect a few batches + if i >= 5: # Just collect a few batches break # Reset and iterate again @@ -250,7 +250,7 @@ def test_correctness_with_reset(self): self.assertTrue(torch.allclose(first, second), f"Batch {i} differs after reset") - print(f"✓ Reset iterator works: {len(first_iteration)} batches verified") + print(f" Reset iterator works: {len(first_iteration)} batches verified") if __name__ == '__main__': diff --git a/weightslab/tests/backend/test_data_loader_interface.py b/weightslab/tests/backend/test_data_loader_interface.py index c77cee72..4da94914 100644 --- a/weightslab/tests/backend/test_data_loader_interface.py +++ b/weightslab/tests/backend/test_data_loader_interface.py @@ -265,17 +265,17 @@ def test_mixed_manual_and_for_loop_iteration(self): Pattern: step = 0 while step < max_steps: - data = next(loader) # Manual iteration with auto-reset after epoch + data = next(loader) # Manual iteration with auto-reset after epoch if step % 5 == 0: - for batches in loader: # For-loop continues from current position - process(batches) # Gets remaining batches, ends with StopIteration + for batches in loader: # For-loop continues from current position + process(batches) # Gets remaining batches, ends with StopIteration step += 1 """ iface = DataLoaderInterface(self.train_ds, batch_size=self.batch_size, is_training=False, compute_hash=True) loader = ledgers.get_dataloader() batches_per_epoch = len(iface.dataloader) step = 0 - max_steps = 30 # Run for multiple epochs + max_steps = 30 # Run for multiple epochs manual_batches_collected = 0 for_loop_batches_collected = 0 @@ -368,10 +368,10 @@ def setUpClass(cls): # Auto register hp = ledgers.get_hyperparams() - hp['ledger_flush_interval'] = 10 # Disable flushing threads for tests - hp['ledger_flush_max_rows'] = 15 # Disable flushing threads for tests - hp['ledger_enable_h5_persistence'] = False # Disable flushing threads for tests - hp['ledger_enable_flushing_threads'] = False # Disable flushing threads for tests + hp['ledger_flush_interval'] = 10 # Disable flushing threads for tests + hp['ledger_flush_max_rows'] = 15 # Disable flushing threads for tests + hp['ledger_enable_h5_persistence'] = False # Disable flushing threads for tests + hp['ledger_enable_flushing_threads'] = False # Disable flushing threads for tests # Set controller to resumed state pause_controller._resume() @@ -429,7 +429,7 @@ def test_rng_reproducibility_with_shuffle(self): # 2. Capture RNG state print("\n2. Capturing RNG state...") rng_state = capture_rng_state() - dataloader.reset_iterator() # Reset to use captured RNG + dataloader.reset_iterator() # Reset to use captured RNG print(f"[OK] RNG state captured and iterator reset") # 3. Generate batches with current RNG @@ -455,65 +455,65 @@ def test_rng_reproducibility_with_shuffle(self): b2_check = np.array_equal(bids_2, bids_2_repeat) print(f"\n{'='*60}") print("Verification:") - print(f" Batch 1 match: {b1_check}") - print(f" Batch 2 match: {b2_check}") + print(f" Batch 1 match: {b1_check}") + print(f" Batch 2 match: {b2_check}") self.assertTrue(b1_check, "First batches should be identical") self.assertTrue(b2_check, "Second batches should be identical") print(f"[OK] RNG reproducibility verified!\n") # TODO (GP): Re-enable once OffsetSampler is implemented and tested # def test_iteration_state_reproducibility_without_shuffle(self): - # """Test dataloader reproducibility without shuffle: capture iteration state → resume identically. - - # With shuffle disabled, RNG is irrelevant. We capture the iteration position - # (number of batches yielded) and restore that position efficiently using - # OffsetSampler to skip samples at the index level without data reprocessing. - # """ - # print(f"\n{'='*60}") - # print("Iteration State Reproducibility - No Shuffle") - # print(f"{'='*60}\n") - - # print("1. Creating dataloader (shuffle=False)...") - # dataloader = DataLoaderInterface( - # self.dataset, - - # batch_size=2, - # shuffle=False, - # num_workers=0 - # ) - # print(f"[OK] DataLoader created (batch_size=2, shuffle=False)") - - # # 2. Consume two batches, then capture state - # print("\n2. Consuming first 2 batches...") - # _, bids_1, _ = next(dataloader) - # _, bids_2, _ = next(dataloader) - # print(f"Batches 1-2: {bids_1}, {bids_2}") - - # iter_state = dataloader.capture_iteration_state() - # print(f"[OK] Iteration state captured: {iter_state}") - - # # 3. Consume next two batches - # print("\n3. Consuming batches 3-4...") - # _, bids_3, _ = next(dataloader) - # _, bids_4, _ = next(dataloader) - # print(f"Batches 3-4: {bids_3}, {bids_4}") - - # # 4. Restore iteration state - # print(f"\n4. Restoring to position after batch 2...") - # dataloader.restore_iteration_state(iter_state) - # print(f"[OK] Iteration state restored (skipped first 2 batches efficiently)") - - # # 5. Generate batches again - should match 3 and 4 - # print("\n5. Generating next batches (should match 3-4)...") - # _, bids_3_repeat, _ = next(dataloader) - # _, bids_4_repeat, _ = next(dataloader) - # print(f"Repeated batches: {bids_3_repeat}, {bids_4_repeat}") - - # # Verify - # print(f"\n{'='*60}") - # print("Verification:") - # print(f" Batch 3 match: {torch.equal(bids_3, bids_3_repeat)}") - # print(f" Batch 4 match: {torch.equal(bids_4, bids_4_repeat)}") - # self.assertTrue(torch.equal(bids_3, bids_3_repeat), "Batch 3 should be identical") - # self.assertTrue(torch.equal(bids_4, bids_4_repeat), "Batch 4 should be identical") - # print(f"[OK] Iteration state reproducibility verified!\n") + # """Test dataloader reproducibility without shuffle: capture iteration state → resume identically. + + # With shuffle disabled, RNG is irrelevant. We capture the iteration position + # (number of batches yielded) and restore that position efficiently using + # OffsetSampler to skip samples at the index level without data reprocessing. + # """ + # print(f"\n{'='*60}") + # print("Iteration State Reproducibility - No Shuffle") + # print(f"{'='*60}\n") + + # print("1. Creating dataloader (shuffle=False)...") + # dataloader = DataLoaderInterface( + # self.dataset, + + # batch_size=2, + # shuffle=False, + # num_workers=0 + # ) + # print(f"[OK] DataLoader created (batch_size=2, shuffle=False)") + + # # 2. Consume two batches, then capture state + # print("\n2. Consuming first 2 batches...") + # _, bids_1, _ = next(dataloader) + # _, bids_2, _ = next(dataloader) + # print(f"Batches 1-2: {bids_1}, {bids_2}") + + # iter_state = dataloader.capture_iteration_state() + # print(f"[OK] Iteration state captured: {iter_state}") + + # # 3. Consume next two batches + # print("\n3. Consuming batches 3-4...") + # _, bids_3, _ = next(dataloader) + # _, bids_4, _ = next(dataloader) + # print(f"Batches 3-4: {bids_3}, {bids_4}") + + # # 4. Restore iteration state + # print(f"\n4. Restoring to position after batch 2...") + # dataloader.restore_iteration_state(iter_state) + # print(f"[OK] Iteration state restored (skipped first 2 batches efficiently)") + + # # 5. Generate batches again - should match 3 and 4 + # print("\n5. Generating next batches (should match 3-4)...") + # _, bids_3_repeat, _ = next(dataloader) + # _, bids_4_repeat, _ = next(dataloader) + # print(f"Repeated batches: {bids_3_repeat}, {bids_4_repeat}") + + # # Verify + # print(f"\n{'='*60}") + # print("Verification:") + # print(f" Batch 3 match: {torch.equal(bids_3, bids_3_repeat)}") + # print(f" Batch 4 match: {torch.equal(bids_4, bids_4_repeat)}") + # self.assertTrue(torch.equal(bids_3, bids_3_repeat), "Batch 3 should be identical") + # self.assertTrue(torch.equal(bids_4, bids_4_repeat), "Batch 4 should be identical") + # print(f"[OK] Iteration state reproducibility verified!\n") diff --git a/weightslab/tests/backend/test_ledgers.py b/weightslab/tests/backend/test_ledgers.py index 5f3288c0..8db41501 100644 --- a/weightslab/tests/backend/test_ledgers.py +++ b/weightslab/tests/backend/test_ledgers.py @@ -32,13 +32,13 @@ def test_default_name_usage(self): # Register without providing name - should use DEFAULT_NAME GLOBAL_LEDGER.register_model(model=d) self.assertIn(DEFAULT_NAME, GLOBAL_LEDGER.list_models()) - got = GLOBAL_LEDGER.get_model() # Should get 'main' by default + got = GLOBAL_LEDGER.get_model() # Should get 'main' by default self.assertIs(got, d) def test_proxy_initialization_pattern(self): """Test that get before register returns Proxy(None), then updates on register.""" # Get before register - should return Proxy(None) - hp = GLOBAL_LEDGER.get_hyperparams() # Uses DEFAULT_NAME + hp = GLOBAL_LEDGER.get_hyperparams() # Uses DEFAULT_NAME # Proxy should exist but not have underlying object yet self.assertEqual(hp.get(), {}) @@ -107,7 +107,7 @@ def test_weak_registration(self): self.assertNotIn("w", names) def test_optimizer_live_update_through_proxy(self): - GLOBAL_LEDGER.get_optimizer('opt_live') # Init opt with a proxy entry + GLOBAL_LEDGER.get_optimizer('opt_live') # Init opt with a proxy entry # define a simple optimizer-like object class DummyOpt: @@ -264,7 +264,7 @@ def test_proxy_yaml_and_json_serialization(self): hp = GLOBAL_LEDGER.get_hyperparams() GLOBAL_LEDGER.register_hyperparams(params={"image_size": 320, "lr": 0.01}) - img = hp.get("image_size") # a live ValueProxy, not a plain int + img = hp.get("image_size") # a live ValueProxy, not a plain int self.assertEqual(type(img).__name__, "_ValueProxy") # YAML: cover every dumper variant — Ultralytics dumps with CSafeDumper diff --git a/weightslab/tests/backend/test_logger_core.py b/weightslab/tests/backend/test_logger_core.py index 9443e571..f1b0cc80 100644 --- a/weightslab/tests/backend/test_logger_core.py +++ b/weightslab/tests/backend/test_logger_core.py @@ -195,7 +195,7 @@ def test_stop_when_not_active_returns_empty(self): def test_stop_skips_zero_count_signals(self): lg = _lg() lg.start_evaluation_mode("val", "h1_1") - lg._eval_accum["loss"] = [0.0, 0] # injected directly with count=0 + lg._eval_accum["loss"] = [0.0, 0] # injected directly with count=0 results = lg.stop_evaluation_mode(model_age=1) self.assertNotIn("loss", results) @@ -226,7 +226,7 @@ class TestAbortEvaluationMode(unittest.TestCase): def test_abort_when_not_active_is_noop(self): lg = _lg() - lg.abort_evaluation_mode() # should not raise + lg.abort_evaluation_mode() # should not raise self.assertFalse(lg._eval_mode_active) def test_abort_clears_active_flag_and_accum(self): @@ -253,7 +253,7 @@ def test_abort_removes_queue_entries_for_eval_hash(self): lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample=None, aggregate_by_step=False) # Manually inject a queue entry for the eval hash lg._pending_queue.append({"experiment_hash": "h1_1", "metric_name": "loss"}) - lg._eval_mode_active = True # re-arm + lg._eval_mode_active = True # re-arm lg.abort_evaluation_mode() hashes_in_queue = {e.get("experiment_hash") for e in lg._pending_queue} self.assertNotIn("h1_1", hashes_in_queue) @@ -290,7 +290,7 @@ def test_removes_matching_entries_from_queue(self): lg = _lg() lg._pending_queue = [ {"experiment_hash": "h1_1", "metric_name": "loss"}, - {"experiment_hash": "h1", "metric_name": "loss"}, + {"experiment_hash": "h1", "metric_name": "loss"}, ] lg.remove_evaluation_hash("h1_1") self.assertEqual(len(lg._pending_queue), 1) @@ -304,7 +304,7 @@ def test_empty_hash_is_noop(self): def test_missing_hash_does_not_raise(self): lg = _lg() - lg.remove_evaluation_hash("nonexistent_hash_1") # must not raise + lg.remove_evaluation_hash("nonexistent_hash_1") # must not raise # --------------------------------------------------------------------------- @@ -393,7 +393,7 @@ def test_adds_new_triples(self): def test_dedup_same_sample_and_step(self): lg = _lg() lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.4)]) - lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # same (sid, step) → ignored + lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # same (sid, step) → ignored rows = lg.query_per_sample("loss") self.assertEqual(len(rows), 1) self.assertAlmostEqual(rows[0][2], 0.4, places=4) @@ -401,7 +401,7 @@ def test_dedup_same_sample_and_step(self): def test_different_step_is_not_dedup(self): lg = _lg() lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.4)]) - lg.ingest_per_sample("loss", "h1", [("s0", 2, 0.9)]) # different step → accepted + lg.ingest_per_sample("loss", "h1", [("s0", 2, 0.9)]) # different step → accepted rows = lg.query_per_sample("loss") self.assertEqual(len(rows), 2) @@ -419,7 +419,7 @@ def test_updates_sample_index(self): def test_dedup_does_not_corrupt_index(self): lg = _lg() lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.4)]) - lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # duplicate ignored + lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # duplicate ignored # exactly one row should remain queryable for (s0, h1) rows = lg.query_per_sample("loss", sample_ids=["s0"], exp_hash="h1") self.assertEqual(len(rows), 1) @@ -673,7 +673,7 @@ def test_nonexistent_step_returns_false(self): def test_does_not_modify_non_matching_queue_entries(self): lg = self._lg_with_hash("run1") _add(lg, "loss", "s0", 5, 0.4) - _add(lg, "acc", "s0", 5, 0.9) + _add(lg, "acc", "s0", 5, 0.9) lg.set_point_note("loss", "run1", 5, "only loss") acc_entry = next(e for e in lg._pending_queue if e["metric_name"] == "acc") self.assertNotIn("point_note", acc_entry) diff --git a/weightslab/tests/backend/test_ui_docker_bridge.py b/weightslab/tests/backend/test_ui_docker_bridge.py index 01102b3d..c43f2bd5 100644 --- a/weightslab/tests/backend/test_ui_docker_bridge.py +++ b/weightslab/tests/backend/test_ui_docker_bridge.py @@ -150,7 +150,7 @@ def test_make_executable_is_noop_on_windows(self): def test_make_executable_swallows_oserror(self): # A non-chmod-able path (e.g. root-owned system install) must not raise. with patch("weightslab.ui_docker_bridge.os.stat", side_effect=OSError("denied")): - _make_executable("/root/owned.sh") # should not raise + _make_executable("/root/owned.sh") # should not raise @unittest.skipIf(sys.platform == "win32", "execute bit is POSIX-only") def test_ensure_scripts_executable_marks_bundled_scripts(self): @@ -185,15 +185,15 @@ def test_launch_default_no_cert_gen_cleans_and_launches_unsecured( _mock_shell, _gb, mock_mgr, ): mgr = MagicMock() - mgr.has_valid_certs.return_value = False # no certs on disk -> unsecured + mgr.has_valid_certs.return_value = False # no certs on disk -> unsecured mock_mgr.from_env_or_default.return_value = mgr with patch.dict(os.environ, {}, clear=False): os.environ.pop("VITE_PORT", None) with self.assertLogs("weightslab.ui_docker_bridge", level="INFO") as log_context: ui_launch(argparse.Namespace()) mock_check.assert_called_once() - mock_ensure.assert_not_called() # certs NOT generated by default - mock_clean.assert_called_once() # stale cleanup ran + mock_ensure.assert_not_called() # certs NOT generated by default + mock_clean.assert_called_once() # stale cleanup ran mock_compose.assert_called_once_with( "/fake/docker-compose.yml", "/fake/envoy.yaml", @@ -237,13 +237,13 @@ def test_launch_certs_flag_generates_and_runs_secured( _mock_shell, _gb, mock_mgr, ): mgr = MagicMock(certs_dir="/fake/certs") - mgr.has_valid_certs.return_value = True # certs present after generation + mgr.has_valid_certs.return_value = True # certs present after generation mock_mgr.from_env_or_default.return_value = mgr with patch.dict(os.environ, {}, clear=False): os.environ.pop("VITE_PORT", None) with self.assertLogs("weightslab.ui_docker_bridge", level="INFO") as log_context: ui_launch(argparse.Namespace(certs=True)) - mock_ensure.assert_called_once() # --certs generates certs + mock_ensure.assert_called_once() # --certs generates certs self.assertTrue(any("https://localhost:5173" in msg for msg in log_context.output)) @@ -310,7 +310,7 @@ def test_removes_when_present(self, mock_run): def test_noop_when_absent(self, mock_run): mock_run.return_value = MagicMock(stdout="") _remove_docker_image(_FRONTEND_IMAGE) - mock_run.assert_called_once() # only the 'docker images -q' query, no rmi + mock_run.assert_called_once() # only the 'docker images -q' query, no rmi class TestCleanStaleDockerResources(unittest.TestCase): @@ -354,7 +354,7 @@ class TestUiSecureEnvironment(unittest.TestCase): @patch("weightslab.ui_docker_bridge._generate_certs_with_fallback", return_value=0) def test_ui_secure_environment_success(self, mock_gen_certs, mock_cert_manager): """`weightslab se`: generate certs + token, export WEIGHTSLAB_CERTS_DIR.""" - mock_manager_instance = MagicMock() # certs_dir is a MagicMock (supports .mkdir) + mock_manager_instance = MagicMock() # certs_dir is a MagicMock (supports .mkdir) mock_manager_instance.get_or_create_auth_token.return_value = "fake_token" mock_cert_manager.return_value = mock_manager_instance @@ -411,19 +411,19 @@ def test_main_dispatches_start_example(self, mock_example): def test_main_ui_without_action_does_not_crash(self): with patch("sys.argv", ["weightslab", "ui"]): - main() # should print ui help, not raise + main() # should print ui help, not raise def test_main_start_without_target_does_not_crash(self): with patch("sys.argv", ["weightslab", "start"]): - main() # should print start help, not raise + main() # should print start help, not raise def test_main_help_does_not_crash(self): with patch("sys.argv", ["weightslab", "help"]): - main() # should not raise + main() # should not raise def test_main_no_args_does_not_crash(self): with patch("sys.argv", ["weightslab"]): - main() # should not raise + main() # should not raise class TestUserOnboardingFlow(unittest.TestCase): @@ -637,7 +637,7 @@ def test_installs_requirements_non_interactively_when_present(self, mock_run): mock_run.assert_called_once() cmd = mock_run.call_args.args[0] self.assertEqual(cmd[:5], [sys.executable, "-m", "pip", "install", "-r"]) - self.assertIn("--no-input", cmd) # never prompts + self.assertIn("--no-input", cmd) # never prompts self.assertTrue(mock_run.call_args.kwargs.get("check")) @patch("weightslab.ui_docker_bridge.subprocess.run") @@ -659,12 +659,12 @@ def _capture_main(argv): try: main() except SystemExit: - pass # argparse -h/--help exits 0 + pass # argparse -h/--help exits 0 return buf.getvalue() def test_dash_h_shows_banner_and_command_reference(self): out = self._capture_main(["weightslab", "-h"]) - self.assertIn("WeightsLab", out) # tagline from description + self.assertIn("WeightsLab", out) # tagline from description self.assertIn("ui launch", out) self.assertIn("--certs", out) self.assertIn("start example", out) diff --git a/weightslab/tests/backend/test_write_dataframe.py b/weightslab/tests/backend/test_write_dataframe.py index 882b0a21..163eec03 100644 --- a/weightslab/tests/backend/test_write_dataframe.py +++ b/weightslab/tests/backend/test_write_dataframe.py @@ -197,7 +197,7 @@ def test_returns_path_when_no_manager(self, tmp_json): patch("weightslab.src.get_logger", return_value=None): result = write_dataframe(tmp_json) assert result == tmp_json - assert not os.path.isfile(tmp_json) # nothing written + assert not os.path.isfile(tmp_json) # nothing written # --------------------------------------------------------------------------- @@ -264,7 +264,7 @@ def test_sample_id_single(self, mgr, tmp_json): _call(tmp_json, mgr, sample_id="s1") data = json.loads(open(tmp_json).read()) assert all(r["sample_id"] == "s1" for r in data) - assert len(data) == 2 # s1 has annotation_ids 0 and 1 + assert len(data) == 2 # s1 has annotation_ids 0 and 1 def test_sample_id_list(self, mgr, tmp_json): _call(tmp_json, mgr, sample_id=["s1", "s2"]) @@ -277,7 +277,7 @@ def test_instance_id_zero_keeps_sample_rows(self, mgr, tmp_json): _call(tmp_json, mgr, instance_id=0) data = json.loads(open(tmp_json).read()) assert all(r["annotation_id"] == 0 for r in data) - assert len(data) == 2 # s1 and s2 both have annotation_id=0 + assert len(data) == 2 # s1 and s2 both have annotation_id=0 def test_instance_id_list(self, mgr, tmp_json): _call(tmp_json, mgr, instance_id=[1, 2]) diff --git a/weightslab/tests/backend/test_write_history.py b/weightslab/tests/backend/test_write_history.py index 8bc15616..3608fc70 100644 --- a/weightslab/tests/backend/test_write_history.py +++ b/weightslab/tests/backend/test_write_history.py @@ -24,7 +24,7 @@ def _make_logger(): """Return a fresh LoggerQueue with data under two hashes. h1: loss (steps 1+2), acc (step 1), iou instances (annotation_ids 1,2) - h2: loss (step 1), acc (step 1), iou instance (annotation_id 3) ← current hash + h2: loss (step 1), acc (step 1), iou instance (annotation_id 3) ← current hash """ lg = LoggerQueue(register=False) ckpt = _mock_chkpt("h1") @@ -33,7 +33,7 @@ def _make_logger(): # h1 data lg.add_scalars("loss", {"loss": 1.0}, 1, signal_per_sample={"s1": 1.0}, aggregate_by_step=False) lg.add_scalars("loss", {"loss": 2.0}, 2, signal_per_sample={"s2": 2.0}, aggregate_by_step=False) - lg.add_scalars("acc", {"acc": 0.9}, 1, signal_per_sample={"s1": 0.9}, aggregate_by_step=False) + lg.add_scalars("acc", {"acc": 0.9}, 1, signal_per_sample={"s1": 0.9}, aggregate_by_step=False) lg.add_instance_scalars("iou", sample_ids=["s1"], annotation_ids=[1], values=[0.8], global_step=1, exp_hash="h1") lg.add_instance_scalars("iou", sample_ids=["s2"], annotation_ids=[2], @@ -42,7 +42,7 @@ def _make_logger(): # h2 data — left as the "current" hash after setup ckpt.get_current_experiment_hash.return_value = "h2" lg.add_scalars("loss", {"loss": 3.0}, 1, signal_per_sample={"s1": 3.0}, aggregate_by_step=False) - lg.add_scalars("acc", {"acc": 0.7}, 1, signal_per_sample={"s2": 0.7}, aggregate_by_step=False) + lg.add_scalars("acc", {"acc": 0.7}, 1, signal_per_sample={"s2": 0.7}, aggregate_by_step=False) lg.add_instance_scalars("iou", sample_ids=["s1"], annotation_ids=[3], values=[0.7], global_step=1, exp_hash="h2") @@ -117,7 +117,7 @@ def test_instance_row_keys(self, lg, tmp_json): def test_file_is_valid_json(self, lg, tmp_json): _call(tmp_json, lg) - json.loads(open(tmp_json).read()) # must not raise + json.loads(open(tmp_json).read()) # must not raise def test_output_file_created(self, lg, tmp_json): _call(tmp_json, lg) @@ -395,7 +395,7 @@ def test_empty_logger_writes_header_only_csv(self, tmp_csv): write_history(tmp_csv, format="csv", experiment_hash="all") with open(tmp_csv, newline="", encoding="utf-8") as fh: rows = list(csv.reader(fh)) - assert len(rows) == 1 # header only + assert len(rows) == 1 # header only def test_case_insensitive_type(self, lg, tmp_json): _call(tmp_json, lg, type_of_history="SAMPLE") diff --git a/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py b/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py index 2ae9118a..412b379f 100644 --- a/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py +++ b/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py @@ -209,7 +209,7 @@ def run_one_call_with_watchdog(): def _invoke(): try: result_holder["value"] = wrapped.unary_unary(request={}, context=SimpleNamespace()) - except Exception as exc: # expected on first attempt + except Exception as exc: # expected on first attempt error_holder["error"] = exc worker = threading.Thread(target=_invoke, name="WL-Test-gRPC-Worker", daemon=True) diff --git a/weightslab/tests/components/test_checkpoint_workflow.py b/weightslab/tests/components/test_checkpoint_workflow.py index 1f824018..c7f003c7 100644 --- a/weightslab/tests/components/test_checkpoint_workflow.py +++ b/weightslab/tests/components/test_checkpoint_workflow.py @@ -102,7 +102,7 @@ class SimpleCNN(nn.Module): def __init__(self, conv1_out=8, conv2_out=16): super(SimpleCNN, self).__init__() - self.input_shape = (1, 1, 28, 28) # MNIST input shape + self.input_shape = (1, 1, 28, 28) # MNIST input shape self.conv1 = nn.Conv2d(1, conv1_out, kernel_size=3, padding=1) self.pool1 = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(conv1_out, conv2_out, kernel_size=3, padding=1) @@ -272,27 +272,27 @@ def train_epochs(self, model, loader, optimizer, criterion, num_epochs, criterio def check_reproducibility(self, original_loss, reloaded_loss, original_uids=None, reloaded_uids=None, loss_tol=0.1, uids_msg=None): """Common reproducibility check for losses and UIDs""" return - # # Check reproducibility of losses and UIDs + # # Check reproducibility of losses and UIDs # if isinstance(original_loss, (list, tuple)): - # original_loss_sum = sum(original_loss)/len(original_loss) + # original_loss_sum = sum(original_loss)/len(original_loss) # else: - # original_loss_sum = original_loss + # original_loss_sum = original_loss # if isinstance(reloaded_loss, (list, tuple)): - # reloaded_loss_sum = sum(reloaded_loss)/len(reloaded_loss) + # reloaded_loss_sum = sum(reloaded_loss)/len(reloaded_loss) # else: - # reloaded_loss_sum = reloaded_loss + # reloaded_loss_sum = reloaded_loss # loss_diff = abs(original_loss_sum - reloaded_loss_sum) # loss_relative_diff = loss_diff / original_loss_sum if original_loss_sum != 0 else 0 # print(f"[OK] Loss comparison:") - # print(f" Original: {original_loss_sum:.6f}") - # print(f" Reloaded: {reloaded_loss_sum:.6f}") - # print(f" Relative difference: {loss_relative_diff*100:.3f}%") + # print(f" Original: {original_loss_sum:.6f}") + # print(f" Reloaded: {reloaded_loss_sum:.6f}") + # print(f" Relative difference: {loss_relative_diff*100:.3f}%") # self.assertLess(loss_relative_diff, loss_tol, msg=f"Training should be reproducible within {loss_tol*100:.1f}%") # if original_uids is not None and reloaded_uids is not None: - # print(f"[OK] UIDs comparison:") - # print(f" Original: {original_uids}") - # print(f" Reloaded: {reloaded_uids}") - # self.assertListEqual(reloaded_uids, original_uids, msg=uids_msg or "Sample UIDs should match for reproducibility") + # print(f"[OK] UIDs comparison:") + # print(f" Original: {original_uids}") + # print(f" Reloaded: {reloaded_uids}") + # self.assertListEqual(reloaded_uids, original_uids, msg=uids_msg or "Sample UIDs should match for reproducibility") @classmethod def setUpClass(cls): @@ -365,8 +365,8 @@ def setUpClass(cls): download=True, transform=transform ) - mnist_subset = Subset(full_dataset, list(range(10))) # Create subset with 10 samples - cls.dataset = TaggableDataset(mnist_subset) # Wrap in taggable dataset + mnist_subset = Subset(full_dataset, list(range(10))) # Create subset with 10 samples + cls.dataset = TaggableDataset(mnist_subset) # Wrap in taggable dataset # ================= # Initialize Logger @@ -383,7 +383,7 @@ def setUpClass(cls): # Initialize Model # ================ model = SimpleCNN(conv1_out=8, conv2_out=16) - model = register_in_ledger(model, flag="model", device=DEVICE, skip_previous_auto_load=True, compute_dependencies=False) # Compute dependencies is disabled + model = register_in_ledger(model, flag="model", device=DEVICE, skip_previous_auto_load=True, compute_dependencies=False) # Compute dependencies is disabled # ===================== # Initialize DataLoader @@ -518,7 +518,7 @@ def test_01_train_A(self): self.state['uids_a'] = uids_A # Final verbose - print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") + print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") print(f"\n[OK] TEST A PASSED - Initial training completed") # ============================= @@ -546,15 +546,15 @@ def test_02_train_B_model_change(self): # Modify model architecture # TODO (GP): Still have the pb btw model archi. and torch cache # 11/05/2026-11:20:11.207 DEBUG:weightslab.components.global_monitoring:__exit__: Suppressing exception: Function ConvolutionBackward0 returned an invalid gradient at index 2 - got [15] but expected shape compatible with [16] in GuardContext.__exit__ - # model.operate(0, {-1, -2, -3, -4}, 1) # Increase conv1 out channels by 2 - # model.operate(2, {-1}, 2) # Freeze fc1 layer - # model.operate(-2, {}, 3) # Freeze fc1 layer - # # model.operate(-1, {1}, 4) # Reset fc2 layer + # model.operate(0, {-1, -2, -3, -4}, 1) # Increase conv1 out channels by 2 + # model.operate(2, {-1}, 2) # Freeze fc1 layer + # model.operate(-2, {}, 3) # Freeze fc1 layer + # # model.operate(-1, {1}, 4) # Reset fc2 layer - # print(f" Conv1: 8 -> 12 channels") - # print(f" Conv2: 16 -> 15 channels") - # print(f" FC1: Frozen") - # print(f" FC2: Reset") + # print(f" Conv1: 8 -> 12 channels") + # print(f" Conv2: 16 -> 15 channels") + # print(f" FC1: Frozen") + # print(f" FC2: Reset") # Update hash here to get hash exp_hash_b, _, changed = self.chkpt_manager.update_experiment_hash(force=True) @@ -587,7 +587,7 @@ def test_02_train_B_model_change(self): # Final verbose print(f"\n[OK] TEST B PASSED - Model architecture updated") - print(f" Final model_age: {model.get_age()}") + print(f" Final model_age: {model.get_age()}") # ======================================================================== # Test: 03_train_C_hyperparams_change @@ -614,7 +614,7 @@ def test_03_train_C_hyperparams_change(self): # Change batch size new_bs = 3 self.config['data']['train_loader']['batch_size'] = new_bs - print(f" Batch size: 2 -> 4") + print(f" Batch size: 2 -> 4") # Update hash exp_hash_c, _, _ = self.chkpt_manager.update_experiment_hash() @@ -651,7 +651,7 @@ def test_03_train_C_hyperparams_change(self): # Final verbose print(f"\n[OK] TEST C PASSED - Hyperparameters updated") - print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") + print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") # ======================================================================== # Test: 04_train_D_data_change @@ -666,8 +666,8 @@ def test_04_train_D_data_change(self): model = ledgers.get_model() # Data - dataloader = ledgers.get_dataloader() # Get dataloader - dfm = ledgers.get_dataframe() # Get dataframe manager + dataloader = ledgers.get_dataloader() # Get dataloader + dfm = ledgers.get_dataframe() # Get dataframe manager # Optimizer and criterion optimizer = ledgers.get_optimizer() @@ -686,8 +686,8 @@ def test_04_train_D_data_change(self): rows.append( { SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation - f"{SampleStatsEx.TAG.value}:ugly": True, # Random tag with 'ugly' + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + f"{SampleStatsEx.TAG.value}:ugly": True, # Random tag with 'ugly' SampleStatsEx.DISCARDED.value: bool(1 - dfm.get_df_view()[SampleStatsEx.DISCARDED.value].iloc[idx]) } ) @@ -698,8 +698,8 @@ def test_04_train_D_data_change(self): dfm.upsert_df(df_update, origin='train_loader', force_flush=True) # Changes will be pending - print(f" Added 'ugly' tag to 20 samples") - print(f" Discarded 20 samples") + print(f" Added 'ugly' tag to 20 samples") + print(f" Discarded 20 samples") # Update hash exp_hash_d, _, changed = self.chkpt_manager.update_experiment_hash() @@ -710,7 +710,7 @@ def test_04_train_D_data_change(self): self.assertNotEqual(self.state['exp_hash_c'], exp_hash_d, "Hash should be different") print("\nResuming training for 11 epochs...") - pause_controller.resume() # Pending changes to dump: data state + pause_controller.resume() # Pending changes to dump: data state loss_D, uids_D = self.train_epochs( model, dataloader, optimizer, criterion, num_epochs=self.config['training']['num_epochs'], @@ -739,7 +739,7 @@ def test_04_train_D_data_change(self): # Final verbose print(f"\n[OK] TEST D PASSED - Data state updated") - print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") + print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") # ======================================================================== # Test: 05_train_E_reload_and_branch @@ -759,7 +759,7 @@ def test_05_train_E_reload_and_branch(self): all_hashes = self.chkpt_manager.get_all_hashes(sort_by='created') print(f"\n[OK] Found {len(all_hashes)} experiment states:") for i, entry in enumerate(all_hashes): - print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") + print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") # Reload state B (second state created) hash_a_from_manifest = self.state['exp_hash_a'] @@ -784,19 +784,19 @@ def test_05_train_E_reload_and_branch(self): if 'data' in hp_reloaded and 'train_loader' in hp_reloaded['data']: hp_reloaded['data']['train_loader']['batch_size'] = 1 old_batch_size = hp_original.get('data', {}).get('train_loader', {}).get('batch_size', 2) - print(f" Batch size: {old_batch_size} -> 1") + print(f" Batch size: {old_batch_size} -> 1") # Discard more data # Add 20 random tags with 'ugly' tagged_samples = random.sample(range(10), 1) rows = [] - dfm = ledgers.get_dataframe() # Get dataframe manager + dfm = ledgers.get_dataframe() # Get dataframe manager for idx in tagged_samples: uid, in_uid = dfm.get_df_view().index[idx] rows.append( { SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation f"{SampleStatsEx.TAG.value}:ugly": True, SampleStatsEx.DISCARDED.value: bool(1 - dfm.get_df_view(SampleStatsEx.DISCARDED.value).iloc[idx]) } @@ -846,7 +846,7 @@ def test_05_train_E_reload_and_branch(self): self.state['uids_e'] = uids_E print(f"\n[OK] TEST E PASSED - Reloaded and generate a new train branch successfully") - print(f" Final model_age: {model.get_age()}") + print(f" Final model_age: {model.get_age()}") # ======================================================================== # Test: 06_reload_before_model_change @@ -857,9 +857,9 @@ def test_06_reload_before_model_change(self): print("TEST 06: Reload Before Model Change - Fix Conv Size with RNG State") print(f"{'='*80}\n") - hash_A_original = self.state['exp_hash_a'] # Before model change - loss_A_original = self.state['losses_a'] # Before model change - uids_A_original = self.state['uids_a'] # Before model change + hash_A_original = self.state['exp_hash_a'] # Before model change + loss_A_original = self.state['losses_a'] # Before model change + uids_A_original = self.state['uids_a'] # Before model change print(f"Reloading state A (before model change) for verification: {hash_A_original[:16]}...") success = self.chkpt_manager.load_state(exp_hash=hash_A_original) @@ -901,7 +901,7 @@ def test_06_reload_before_model_change(self): # Fix model conv size - create new model with different architecture print("\nFixing model architecture...") model = ledgers.get_model() - # model.operate(0, {-1}, 1) # Commented; see test 2 - still have the pb btw model archi. and torch cache + # model.operate(0, {-1}, 1) # Commented; see test 2 - still have the pb btw model archi. and torch cache # model.operate(2, {-1}, 2) # model.operate(-2, {}, 3) # model.operate(-1, {-1 }, 4) @@ -929,9 +929,9 @@ def test_06_reload_before_model_change(self): # Compare: First batch should be same, but losses differ due to different model print(f"\n[OK] Reproducibility verified:") - print(f" Original model first batch loss: {loss_A_reloaded}") - print(f" Fixed model first batch loss: {loss_H}") - print(f" (Same RNG = same batches, different losses due to model change)") + print(f" Original model first batch loss: {loss_A_reloaded}") + print(f" Fixed model first batch loss: {loss_H}") + print(f" (Same RNG = same batches, different losses due to model change)") # Store state self.state['losses_h'] = loss_H @@ -949,7 +949,7 @@ def test_07_change_data_from_test06(self): print("TEST 07: Change Data from Test 06 - Discard More Data") print(f"{'='*80}\n") - hash_H = self.state['exp_hash_h'] # From test 06 + hash_H = self.state['exp_hash_h'] # From test 06 print(f"Starting from state H: {hash_H[:16]}...") @@ -962,7 +962,7 @@ def test_07_change_data_from_test06(self): uid, in_uid = dfm.get_df_view().index[idx] rows.append({ SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation f"{SampleStatsEx.TAG.value}:discard_25pct": True, SampleStatsEx.DISCARDED.value: True }) @@ -1004,7 +1004,7 @@ def test_08_reload_before_data_change_verify_and_modify(self): print("TEST 08: Reload Before Data Change - Verify and Modify Model") print(f"{'='*80}\n") - hash_c = self.state['exp_hash_c'] # Before data change (after HP change) + hash_c = self.state['exp_hash_c'] # Before data change (after HP change) print(f"Part A: Reloading state C and verifying training reproducibility...") print(f"Reloading state C: {hash_c[:16]}...") @@ -1068,7 +1068,7 @@ def test_09_reload_before_hp_change_verify_and_modify(self): print("TEST 09: Reload Before HP Change - Verify and Fix Everything") print(f"{'='*80}\n") - hash_b = self.state['exp_hash_b'] # Before HP change (after model change) + hash_b = self.state['exp_hash_b'] # Before HP change (after model change) loss_b = self.state['losses_b'] print(f"Part A: Reloading state B and verifying training reproducibility...") @@ -1100,11 +1100,11 @@ def test_09_reload_before_hp_change_verify_and_modify(self): # Fix HP hp = ledgers.get_hyperparams() - hp['data']['train_loader']['batch_size'] = 7 # Change batch size + hp['data']['train_loader']['batch_size'] = 7 # Change batch size # Fix model model = ledgers.get_model() - # # model.operate(0, {-3}, 1) # Further modify conv1 + # # model.operate(0, {-3}, 1) # Further modify conv1 # model.operate(-1, {-1 }, 4) # Fix data - discard 5 samples @@ -1115,7 +1115,7 @@ def test_09_reload_before_hp_change_verify_and_modify(self): uid, in_uid = dfm.get_df_view().index[idx] rows.append({ SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation f"{SampleStatsEx.TAG.value}:discard_fix": True, SampleStatsEx.DISCARDED.value: True }) @@ -1156,7 +1156,7 @@ def test_10_reload_branch_j_verify_reproducibility(self): print("TEST 10: Reload Branch J - Verify Training Reproducibility") print(f"{'='*80}\n") - hash_j = self.state['exp_hash_j'] # From test 08.b + hash_j = self.state['exp_hash_j'] # From test 08.b print(f"Reloading branch J: {hash_j[:16]}...") @@ -1191,7 +1191,7 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): print(f"{'='*80}\n") # Reference variables - target_hash = self.state['exp_hash_d'] # Target is branch_d + target_hash = self.state['exp_hash_d'] # Target is branch_d print(f"Simulating fresh restart: loading everything from config...") print(f"Target state: {target_hash[:16]} (branch_d)") @@ -1215,7 +1215,7 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): print("[OK] Hyperparameters re-registered") # Create fresh model - model_restarted = SimpleCNN(conv1_out=8, conv2_out=16) # Match branch_d architecture + model_restarted = SimpleCNN(conv1_out=8, conv2_out=16) # Match branch_d architecture # # Model arch. and weights are updated at the init of model interface model_restarted = register_in_ledger(model_restarted, flag="model", device=DEVICE) @@ -1257,7 +1257,7 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): all_hashes = self.chkpt_manager.get_all_hashes(sort_by='created') print(f"\n[OK] Found {len(all_hashes)} experiment states:") for i, entry in enumerate(all_hashes): - print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") + print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") # Reload state B (second state created) hash_a_from_manifest = self.state['exp_hash_a'] @@ -1271,18 +1271,18 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): print(f"[OK] Checkpoint loaded to reach target state {target_hash[:16]}") print("\nTraining for 11 epochs to verify reproducibility...") pause_controller.resume() - model_restarted = ledgers.get_model() # Get model after loading state + model_restarted = ledgers.get_model() # Get model after loading state _, _ = self.train_epochs(model_restarted, dataloader, optimizer_restarted, criterion, num_epochs=self.config['training']['num_epochs'], criterion_bin=criterion_bin) pause_controller.pause() # # Check reproducibility with original loss and UIDs # self.assertEqual(model_restarted.layers[-1].operation_age['FREEZE'], 1, - # "Model architecture should match state in D") + # "Model architecture should match state in D") # self.assertEqual(model_restarted.layers[-1].operation_age['RESET'], 1, - # "Model architecture should match state in D") + # "Model architecture should match state in D") # self.assertEqual(model_restarted.layers[0].out_neurons, 12, - # "Model architecture should match state in D") + # "Model architecture should match state in D") # Not possible as data are generated randomly without reproducibility now # self.check_reproducibility(loss_d_original, loss_d_verify, originals_uids, None, loss_tol=1e-1) diff --git a/weightslab/tests/components/test_global_monitoring_unit.py b/weightslab/tests/components/test_global_monitoring_unit.py index 07650885..ee88c790 100644 --- a/weightslab/tests/components/test_global_monitoring_unit.py +++ b/weightslab/tests/components/test_global_monitoring_unit.py @@ -38,31 +38,31 @@ def test_contextvar_set_and_restore(self): self.assertIn(get_current_context(), {Context.UNKNOWN, Context.TESTING, Context.TRAINING}) # def test_guard_context_training_non_audit(self): - # model = _DummyModel() - # gc = GuardContext(for_training=True) - # gc.model = model + # model = _DummyModel() + # gc = GuardContext(for_training=True) + # gc.model = model - # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ - # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value=None), \ - # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={}): - # gc.__enter__() - # self.assertEqual(get_current_context(), Context.TRAINING) - # self.assertIn(True, model.train_calls) - # result = gc.__exit__(None, None, None) + # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ + # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value=None), \ + # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={}): + # gc.__enter__() + # self.assertEqual(get_current_context(), Context.TRAINING) + # self.assertIn(True, model.train_calls) + # result = gc.__exit__(None, None, None) - # self.assertFalse(result) + # self.assertFalse(result) # def test_guard_context_training_audit_uses_eval(self): - # model = _DummyModel() - # gc = GuardContext(for_training=True) - # gc.model = model - - # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ - # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value="hp"), \ - # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={"auditorMode": True}): - # gc.__enter__() - # self.assertEqual(model.eval_calls, 1) - # gc.__exit__(None, None, None) + # model = _DummyModel() + # gc = GuardContext(for_training=True) + # gc.model = model + + # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ + # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value="hp"), \ + # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={"auditorMode": True}): + # gc.__enter__() + # self.assertEqual(model.eval_calls, 1) + # gc.__exit__(None, None, None) def test_guard_context_suppresses_runtime_error(self): gc = GuardContext(for_training=False) diff --git a/weightslab/tests/data/test_data_samples_with_ops.py b/weightslab/tests/data/test_data_samples_with_ops.py index 93700647..702cd1f8 100644 --- a/weightslab/tests/data/test_data_samples_with_ops.py +++ b/weightslab/tests/data/test_data_samples_with_ops.py @@ -46,7 +46,7 @@ def __getitem__(self, idx): # Return random data with shape (3, 32, 32) to simulate images data = np.random.randn(3, 32, 32).astype(np.float32) uid = str(idx) # Consistent with string UID preference - label = idx % 10 # Simulate 10 classes + label = idx % 10 # Simulate 10 classes return data, uid, label @@ -214,7 +214,7 @@ def test_getitem_returns_data_and_id(self): # Should return tuple with (data, id, target, ...) self.assertIsInstance(result, tuple) - self.assertGreaterEqual(len(result), 3) # data, id, target at minimum + self.assertGreaterEqual(len(result), 3) # data, id, target at minimum # First element should be numpy array or tensor self.assertTrue(isinstance(result[0], (np.ndarray, torch.Tensor))) @@ -315,7 +315,7 @@ def test_binary_tag_labeling_single_tag(self): """ self.temp_dir = tempfile.mkdtemp() - tags_mapping = {"target_tag": 1} # Binary: only 1 tag in mapping + tags_mapping = {"target_tag": 1} # Binary: only 1 tag in mapping wrapper = DataSampleTrackingWrapper( wrapped_dataset=self.dataset, root_log_dir=self.temp_dir, diff --git a/weightslab/tests/data/test_dataframe_manager_unit.py b/weightslab/tests/data/test_dataframe_manager_unit.py index 7329d7fb..4fd4b74c 100644 --- a/weightslab/tests/data/test_dataframe_manager_unit.py +++ b/weightslab/tests/data/test_dataframe_manager_unit.py @@ -99,7 +99,7 @@ def test_enqueue_instance_batch_buffers_records(self): self.assertEqual(rec["signal:bbox_loss"], 0.2) self.assertEqual(rec[SampleStats.Ex.LAST_SEEN.value], 3) self.assertIn(SampleStats.Ex.TARGET.value, rec) - self.assertTrue(mgr._df.empty) # df untouched until flush + self.assertTrue(mgr._df.empty) # df untouched until flush def test_flush_applies_instance_records(self): """Flushing instance records writes per-(sample_id, annotation_id) values.""" @@ -119,7 +119,7 @@ def test_flush_applies_instance_records(self): mgr.flush() result = mgr.get_df_view() - self.assertEqual(len(result), 4) # sample row (0) + 3 instance rows + self.assertEqual(len(result), 4) # sample row (0) + 3 instance rows self.assertAlmostEqual(float(result.loc[("1", 1), "signal:il"]), 0.5) self.assertAlmostEqual(float(result.loc[("1", 2), "signal:il"]), 0.6) self.assertAlmostEqual(float(result.loc[("1", 3), "signal:il"]), 0.7) @@ -178,9 +178,9 @@ def test_multi_instance_expansion(self): # Single sample with 3 instances (detections/annotations) # Use list of arrays to indicate multiple instances target = [ - np.array([10, 20, 30, 40]), # instance 0 - np.array([50, 60, 70, 80]), # instance 1 - np.array([90, 100, 110, 120]) # instance 2 + np.array([10, 20, 30, 40]), # instance 0 + np.array([50, 60, 70, 80]), # instance 1 + np.array([90, 100, 110, 120]) # instance 2 ] df = pd.DataFrame([{ "sample_id": 1, @@ -296,7 +296,7 @@ def test_categorical_memory_optimization(self): self.assertTrue((result_df["metadata"] == "urban").sum() > 0) # Memory usage comparison - # original_bytes = 100 * (len("train") + len("urban")) # Rough estimate + # original_bytes = 100 * (len("train") + len("urban")) # Rough estimate # With categorical: ~100 bytes for codes + ~40 bytes for categories = ~140 bytes # Real compression achieved by pandas @@ -357,7 +357,7 @@ def test_per_sample_buffer_into_multi_index_does_not_corrupt(self): losses={"signals//train/clsf_sample": np.array([0.99])}, step=11, ) - mgr.flush() # Would raise if bug regressed + mgr.flush() # Would raise if bug regressed result = mgr.get_df_view() self.assertAlmostEqual(result.loc[("1", 0), col], 0.99) @@ -385,7 +385,7 @@ def get_index_from_sample_id(self, sid): "origin": "train", SampleStats.Ex.TARGET.value: np.zeros((30, 30), dtype=np.float32), }) - row.name = ("12", 0) # MultiIndex-style row.name + row.name = ("12", 0) # MultiIndex-style row.name # Should not raise and should pass just the sample_id, not the tuple mgr._normalize_arrays_for_storage(row) diff --git a/weightslab/tests/data/test_flush_pipeline.py b/weightslab/tests/data/test_flush_pipeline.py index 0398876b..a6001c5b 100644 --- a/weightslab/tests/data/test_flush_pipeline.py +++ b/weightslab/tests/data/test_flush_pipeline.py @@ -25,7 +25,7 @@ def _make_mgr(flush_max_rows=4, enable_flushing_threads=False) -> LedgeredDataFrameManager: mgr = LedgeredDataFrameManager( - flush_interval=60.0, # disable periodic timer during tests + flush_interval=60.0, # disable periodic timer during tests flush_max_rows=flush_max_rows, enable_flushing_threads=enable_flushing_threads, enable_h5_persistence=False, @@ -125,7 +125,7 @@ class TestFlushAsyncReturnsAfterBufferDrain(unittest.TestCase): """ def test_flush_async_does_not_wait_for_h5(self): - H5_WRITE_DELAY = 1.0 # seconds — intentionally slow + H5_WRITE_DELAY = 1.0 # seconds — intentionally slow mgr = _make_mgr(flush_max_rows=4, enable_flushing_threads=True) @@ -166,7 +166,7 @@ class TestBufferRefillDuringH5Write(unittest.TestCase): def test_training_resumes_after_second_drain(self): FLUSH_MAX = 4 - H5_WRITE_DELAY = 0.3 # seconds + H5_WRITE_DELAY = 0.3 # seconds mgr = _make_mgr(flush_max_rows=FLUSH_MAX, enable_flushing_threads=True) @@ -181,9 +181,9 @@ def counting_slow_h5(*args, **kwargs): second_enqueue_returned = threading.Event() def training_sim(): - _enqueue(mgr, [str(i) for i in range(FLUSH_MAX)]) # fills buffer, triggers flush - time.sleep(0.05) # let flush thread start H5 write - _enqueue(mgr, [str(i) for i in range(FLUSH_MAX, FLUSH_MAX * 2)]) # refill + _enqueue(mgr, [str(i) for i in range(FLUSH_MAX)]) # fills buffer, triggers flush + time.sleep(0.05) # let flush thread start H5 write + _enqueue(mgr, [str(i) for i in range(FLUSH_MAX, FLUSH_MAX * 2)]) # refill second_enqueue_returned.set() with patch.object(mgr, "_flush_to_h5_if_needed", side_effect=counting_slow_h5): diff --git a/weightslab/tests/data/test_h5_array_store.py b/weightslab/tests/data/test_h5_array_store.py index 0307b267..5e05bc47 100644 --- a/weightslab/tests/data/test_h5_array_store.py +++ b/weightslab/tests/data/test_h5_array_store.py @@ -210,7 +210,7 @@ def test_clean_write_leaves_no_temp_or_backup(self): def test_recover_safe_on_empty_directory(self): """recover() must not raise when arrays.h5 does not exist yet.""" store = self._make_store() - store.recover() # Should complete without error + store.recover() # Should complete without error if __name__ == "__main__": diff --git a/weightslab/tests/data/test_h5_dataframe_store.py b/weightslab/tests/data/test_h5_dataframe_store.py index 94f38e52..3c6457a9 100644 --- a/weightslab/tests/data/test_h5_dataframe_store.py +++ b/weightslab/tests/data/test_h5_dataframe_store.py @@ -108,8 +108,8 @@ def test_categorical_tags_preservation(self): df = pd.DataFrame({ 'sample_id': [1, 2, 3], 'brightness': [0.75, 0.82, 0.65], - 'tag:quality': ['high', 'low', 'high'], # String tag - 'tag:outdoor': [True, False, True], # Boolean tag + 'tag:quality': ['high', 'low', 'high'], # String tag + 'tag:outdoor': [True, False, True], # Boolean tag }).set_index('sample_id') # Write (should optimize to categorical) @@ -194,7 +194,7 @@ def test_upsert_merge_multi_index(self): # Update with new data for same sample but different annotation df2 = pd.DataFrame({ - 'brightness': [0.80], # Update brightness for annotation 1 + 'brightness': [0.80], # Update brightness for annotation 1 'iou': [0.60], }) df2.index = pd.MultiIndex.from_arrays( diff --git a/weightslab/tests/data/test_point_cloud_utils.py b/weightslab/tests/data/test_point_cloud_utils.py index 8dac4a6b..62243a9d 100644 --- a/weightslab/tests/data/test_point_cloud_utils.py +++ b/weightslab/tests/data/test_point_cloud_utils.py @@ -65,7 +65,7 @@ def test_is_point_cloud_task(): def test_is_point_cloud_detection_task(): assert is_point_cloud_detection_task("detection_pointcloud") assert is_point_cloud_detection_task("Detection_PointCloud") - assert is_point_cloud_detection_task("detection_3d") # legacy alias + assert is_point_cloud_detection_task("detection_3d") # legacy alias assert not is_point_cloud_detection_task("detection") assert not is_point_cloud_detection_task("segmentation") assert not is_point_cloud_detection_task(None) @@ -77,10 +77,10 @@ def test_looks_like_point_cloud(): assert looks_like_point_cloud(_cloud()[:, :2]) # Multi-channel clouds (xyz + intensity + normals + rgb = 10 cols) qualify. assert looks_like_point_cloud(np.zeros((100, 10), np.float32)) - assert not looks_like_point_cloud(_cloud()[:8]) # too few rows - assert not looks_like_point_cloud(np.zeros((100, 20), np.float32)) # too many cols - assert not looks_like_point_cloud(np.zeros((64, 64), np.uint8)) # int image - assert not looks_like_point_cloud(np.zeros((64, 64, 3), np.float32)) # 3D array + assert not looks_like_point_cloud(_cloud()[:8]) # too few rows + assert not looks_like_point_cloud(np.zeros((100, 20), np.float32)) # too many cols + assert not looks_like_point_cloud(np.zeros((64, 64), np.uint8)) # int image + assert not looks_like_point_cloud(np.zeros((64, 64, 3), np.float32)) # 3D array def test_point_distances(): @@ -97,7 +97,7 @@ def test_compute_point_normals_planar(): normals = compute_point_normals(pts, k=12) assert normals.shape == (500, 3) np.testing.assert_allclose(np.linalg.norm(normals, axis=1), 1.0, atol=1e-4) - assert np.abs(normals[:, 2]).mean() > 0.95 # mostly aligned with z + assert np.abs(normals[:, 2]).mean() > 0.95 # mostly aligned with z def test_voxel_downsample_reduces_points(): @@ -106,12 +106,12 @@ def test_voxel_downsample_reduces_points(): out = voxel_downsample(pts, voxel_size=0.25) assert out.shape[1] == 4 assert out.shape[0] < pts.shape[0] - assert out.shape[0] <= 4 ** 3 # at most one point per 0.25 voxel in the unit cube + assert out.shape[0] <= 4 ** 3 # at most one point per 0.25 voxel in the unit cube def test_colorize_from_image(): image = np.zeros((10, 20, 3), np.uint8) - image[:, :, 0] = 255 # all red + image[:, :, 0] = 255 # all red pts = np.array([[1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], np.float32) def project(p): @@ -119,15 +119,15 @@ def project(p): return uv, np.array([True, False]) rgb = colorize_from_image(pts, image, project) - np.testing.assert_allclose(rgb[0], [1.0, 0.0, 0.0], atol=1e-5) # sampled red - np.testing.assert_allclose(rgb[1], [0.5, 0.5, 0.5], atol=1e-5) # invalid -> grey + np.testing.assert_allclose(rgb[0], [1.0, 0.0, 0.0], atol=1e-5) # sampled red + np.testing.assert_allclose(rgb[1], [0.5, 0.5, 0.5], atol=1e-5) # invalid -> grey def test_range_image_shape(): img = point_cloud_to_range_image(_cloud(2000), image_height=48, image_width=256) assert img.size == (256, 48) arr = np.asarray(img) - assert (arr != arr[0, 0]).any() # some points were projected + assert (arr != arr[0, 0]).any() # some points were projected def test_get_point_feature_names_from_dataset_and_default(): @@ -151,7 +151,7 @@ def my_thumb(points): img = render_thumbnail_2d_for_dataset(object(), _cloud()) assert marker["called"] and np.asarray(img)[0, 0, 0] == 9 finally: - register_thumbnail_fn(None) # reset global state + register_thumbnail_fn(None) # reset global state def my_boxes(boxes): return np.zeros((len(boxes), 6), np.float32) @@ -167,7 +167,7 @@ def my_boxes(boxes): def test_filter_valid_points_drops_pads_and_nonfinite(): pts = _cloud(100) - pts[10] = -1000.0 # pad row (all coords at PAD_VALUE) + pts[10] = -1000.0 # pad row (all coords at PAD_VALUE) pts[20, 2] = np.nan out = filter_valid_points(pts) assert out.shape[0] == 98 @@ -234,7 +234,7 @@ def test_project_boxes_min_size_clamp(): def test_project_boxes_2d_rows(): - boxes = np.array([[10.0, 5.0, 2.0, 2.0, 2.0, 0.7]], np.float32) # cx,cy,dx,dy,cls,conf + boxes = np.array([[10.0, 5.0, 2.0, 2.0, 2.0, 0.7]], np.float32) # cx,cy,dx,dy,cls,conf assert boxes_dimensionality(boxes) == 2 bev = project_boxes_to_bev(boxes, PC_RANGE, 0.0) assert bev[0, 4] == 2.0 @@ -311,7 +311,7 @@ def get_items(self, idx, include_metadata=False, include_labels=False, include_i np_img, is_volumetric, shape, pil = load_raw_image_array(PcDataset(), 0) assert not is_volumetric assert pil is not None and pil.mode == "RGB" - assert pil.size[0] == pil.size[1] # square BEV render + assert pil.size[0] == pil.size[1] # square BEV render assert np_img.ndim == 3 and np_img.shape[2] == 3 diff --git a/weightslab/tests/gRPC/test_get_point_cloud.py b/weightslab/tests/gRPC/test_get_point_cloud.py index 035fe319..e74f6629 100644 --- a/weightslab/tests/gRPC/test_get_point_cloud.py +++ b/weightslab/tests/gRPC/test_get_point_cloud.py @@ -110,7 +110,7 @@ def test_point_cloud_chunk_bytes_invalid_falls_back(monkeypatch): def test_get_point_cloud_honours_configured_chunk_size(): """A smaller chunk size splits the same cloud into more (correct) messages.""" class _SmallChunkService(_StubService): - _POINT_CLOUD_CHUNK_BYTES = 4096 # bytes + _POINT_CLOUD_CHUNK_BYTES = 4096 # bytes stub = _SmallChunkService(_FakeLidarDataset()) chunks = _collect(stub, pb2.PointCloudRequest(sample_id="7", origin="train_loader")) diff --git a/weightslab/tests/gRPC/test_grpc_tag_operations.py b/weightslab/tests/gRPC/test_grpc_tag_operations.py index 0f94643a..b3edc63e 100644 --- a/weightslab/tests/gRPC/test_grpc_tag_operations.py +++ b/weightslab/tests/gRPC/test_grpc_tag_operations.py @@ -164,7 +164,7 @@ def setUpClass(cls): flag="model", device=DEVICE, skip_previous_auto_load=True, - compute_dependencies=False, # dependency analysis is currently disabled + compute_dependencies=False, # dependency analysis is currently disabled ) # Register dataloader @@ -269,7 +269,7 @@ def test_01_add_tags_accumulate(self): response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to add tags: {response.message}") - print(f"✓ Successfully added tag 'test_tag' to 10 samples") + print(f" Successfully added tag 'test_tag' to 10 samples") # Verify tags were added by checking the dataframe df = self.data_service._all_datasets_df @@ -300,7 +300,7 @@ def test_01_add_tags_accumulate(self): self.assertTrue(value, f"Sample {sample_id} should have tag 'test_tag'") - print(f"✓ Verified tag column exists and has correct values") + print(f" Verified tag column exists and has correct values") def test_02_add_multiple_tags(self): """Test adding multiple different tags""" @@ -319,7 +319,7 @@ def test_02_add_multiple_tags(self): response1 = self.data_service.EditDataSample(request1, self.mock_context) self.assertTrue(response1.success) - print(f"✓ Added tag 'difficult' to samples 0-4") + print(f" Added tag 'difficult' to samples 0-4") # Add "outlier" tag to samples 5-9 request2 = pb2.DataEditsRequest( @@ -334,13 +334,13 @@ def test_02_add_multiple_tags(self): response2 = self.data_service.EditDataSample(request2, self.mock_context) self.assertTrue(response2.success) - print(f"✓ Added tag 'outlier' to samples 5-9") + print(f" Added tag 'outlier' to samples 5-9") # Verify both tags exist df = self.data_service._all_datasets_df self.assertIn("tag:difficult", df.columns) self.assertIn("tag:outlier", df.columns) - print(f"✓ Both tag columns exist in dataframe") + print(f" Both tag columns exist in dataframe") def test_03_remove_tag_from_samples(self): """Test removing a tag from specific samples using EDIT_REMOVE""" @@ -359,7 +359,7 @@ def test_03_remove_tag_from_samples(self): response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to remove tag: {response.message}") - print(f"✓ Removed tag 'test_tag' from samples 0-4") + print(f" Removed tag 'test_tag' from samples 0-4") # Verify tag was removed from those samples df = self.data_service._all_datasets_df @@ -408,7 +408,7 @@ def test_03_remove_tag_from_samples(self): self.assertTrue(value, f"Sample {sample_id} should still have tag 'test_tag'") - print(f"✓ Verified tag removal worked correctly") + print(f" Verified tag removal worked correctly") def test_04_delete_entire_tag_column(self): """Test deleting an entire tag column using EDIT_REMOVE with value=-1""" @@ -417,22 +417,22 @@ def test_04_delete_entire_tag_column(self): # Delete the "difficult" tag column completely request = pb2.DataEditsRequest( stat_name="tag:difficult", - float_value=-1, # Signal for column deletion + float_value=-1, # Signal for column deletion string_value="", bool_value=False, type=SampleEditType.EDIT_REMOVE, - samples_ids=["0"], # Just need one sample as reference + samples_ids=["0"], # Just need one sample as reference sample_origins=["test"] ) response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to delete tag column: {response.message}") - print(f"✓ Deleted entire 'difficult' tag column") + print(f" Deleted entire 'difficult' tag column") # Verify column no longer exists df = self.data_service._all_datasets_df self.assertNotIn("tag:difficult", df.columns, "Tag column should be deleted") - print(f"✓ Verified tag column no longer exists in dataframe") + print(f" Verified tag column no longer exists in dataframe") def test_05_deny_listed_operations(self): """Test discarded (discard/restore) operations""" @@ -455,7 +455,7 @@ def test_05_deny_listed_operations(self): stat_name=SampleStatsEx.DISCARDED.value, float_value=0, string_value="", - bool_value=True, # True = discarded + bool_value=True, # True = discarded type=SampleEditType.EDIT_OVERRIDE, samples_ids=sample_ids, sample_origins=origins @@ -463,7 +463,7 @@ def test_05_deny_listed_operations(self): response = self.data_service.EditDataSample(request_discard, self.mock_context) self.assertTrue(response.success, f"Failed to discard samples: {response.message}") - print(f"✓ Marked samples 10-14 as discarded") + print(f" Marked samples 10-14 as discarded") # Verify samples are marked as discarded df = self.data_service._all_datasets_df @@ -488,14 +488,14 @@ def test_05_deny_listed_operations(self): self.assertTrue(value, f"Sample {sample_id} should be discarded") - print(f"✓ Verified samples are discarded") + print(f" Verified samples are discarded") # Now restore samples 10-12 request_restore = pb2.DataEditsRequest( stat_name=SampleStatsEx.DISCARDED.value, float_value=0, string_value="", - bool_value=False, # False = restored + bool_value=False, # False = restored type=SampleEditType.EDIT_OVERRIDE, samples_ids=[str(i) for i in range(10, 13)], sample_origins=["test"] * 3 @@ -503,7 +503,7 @@ def test_05_deny_listed_operations(self): response = self.data_service.EditDataSample(request_restore, self.mock_context) self.assertTrue(response.success, f"Failed to restore samples: {response.message}") - print(f"✓ Restored samples 10-12") + print(f" Restored samples 10-12") # Verify restoration df = self.data_service._all_datasets_df @@ -550,7 +550,7 @@ def test_05_deny_listed_operations(self): self.assertTrue(value, f"Sample {sample_id} should still be discarded") - print(f"✓ Verified restoration worked correctly") + print(f" Verified restoration worked correctly") def test_06_batch_tag_operations(self): """Test batch operations on many samples at once""" @@ -583,7 +583,7 @@ def test_06_batch_tag_operations(self): response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to add batch tag: {response.message}") - print(f"✓ Added 'batch_tag' to 50 samples in one operation") + print(f" Added 'batch_tag' to 50 samples in one operation") # Verify all 50 samples have the tag df = self.data_service._all_datasets_df @@ -614,7 +614,7 @@ def test_06_batch_tag_operations(self): success_count += 1 self.assertGreaterEqual(success_count, 45, f"Expected at least 45 samples to have batch_tag, got {success_count}") - print(f"✓ Verified {success_count}/50 samples have the batch tag") + print(f" Verified {success_count}/50 samples have the batch tag") def test_07_tag_persistence(self): """Test that tags persist and can be queried""" @@ -624,14 +624,14 @@ def test_07_tag_persistence(self): # Count tag columns tag_columns = [col for col in df.columns if col.startswith(f"{SampleStatsEx.TAG.value}:")] - print(f"✓ Found {len(tag_columns)} tag columns: {tag_columns}") + print(f" Found {len(tag_columns)} tag columns: {tag_columns}") self.assertGreater(len(tag_columns), 0, "Should have at least one tag column") # Verify we can query tagged samples for tag_col in tag_columns: tagged_samples = df[df[tag_col] == True] - print(f" - {tag_col}: {len(tagged_samples)} samples") + print(f" - {tag_col}: {len(tagged_samples)} samples") self.assertGreaterEqual(len(tagged_samples), 0) @@ -661,14 +661,14 @@ def run_tests(): print("="*80 + "\n") if result.wasSuccessful(): - print("✅ ALL TESTS PASSED!") + print(" ALL TESTS PASSED!") else: - print("❌ SOME TESTS FAILED") + print(" SOME TESTS FAILED") if result.failures: print("\nFailures:") for test, traceback in result.failures: - print(f" - {test}: {traceback}") + print(f" - {test}: {traceback}") if result.errors: print("\nErrors:") for test, traceback in result.errors: - print(f" - {test}: {traceback}") + print(f" - {test}: {traceback}") diff --git a/weightslab/tests/gRPC/test_grpc_user_actions.py b/weightslab/tests/gRPC/test_grpc_user_actions.py index cd44c8dd..1b71da74 100644 --- a/weightslab/tests/gRPC/test_grpc_user_actions.py +++ b/weightslab/tests/gRPC/test_grpc_user_actions.py @@ -280,7 +280,7 @@ def _make_real_data_service(self): # the first call proceeds (mirrors DataService.__init__). ds._update_done = threading.Event() ds._update_done.set() - ds._refresh_in_flight = threading.Lock() # mirrors __init__: bg view-refresh guard + ds._refresh_in_flight = threading.Lock() # mirrors __init__: bg view-refresh guard ds._df_manager = df_manager ds._all_datasets_df = df.copy() ds._compute_natural_sort = False @@ -561,7 +561,7 @@ def _agg(graph_name, sample_ids=None, exp_hash=None): # Only sample 11 is 'hard'-tagged → mean curve over {11} = one aggregated point. self.assertEqual(len(response.points), 1) - self.assertEqual(response.points[0].sample_id, "") # aggregated mean curve + self.assertEqual(response.points[0].sample_id, "") # aggregated mean curve self.assertEqual(response.points[0].metric_name, "test/loss") self.assertAlmostEqual(response.points[0].metric_value, 0.2, places=5) diff --git a/weightslab/tests/general/test_cli.py b/weightslab/tests/general/test_cli.py index 0d34322c..c1d2a9dc 100644 --- a/weightslab/tests/general/test_cli.py +++ b/weightslab/tests/general/test_cli.py @@ -112,7 +112,7 @@ def test_empty_command(self): """Test that empty command returns ok.""" result = _handle_command('') self.assertTrue(result['ok']) - result = _handle_command(' ') + result = _handle_command(' ') self.assertTrue(result['ok']) def test_unknown_command(self): @@ -183,7 +183,7 @@ def test_plot_model_with_model(self): """Test plot_model with registered model.""" # Create a mock model with __str__ method mock_model = MagicMock() - mock_model.__str__ = MagicMock(return_value="Model(\n Layer1\n Layer2\n)") + mock_model.__str__ = MagicMock(return_value="Model(\n Layer1\n Layer2\n)") GLOBAL_LEDGER.register_model(mock_model, name='test_model') @@ -339,90 +339,90 @@ def test_cli_serve_port_binding(self): # TODO (GP): Fix CLI initialization takes too long for integration tests - need to ensure server is fully ready before client tests run, and possibly optimize server startup time for testing purposes # Not working yet - needs check first initialization and teardown of server between tests, and some tweaks to client connection logic to ensure it waits for server to be ready before connecting # class TestCLIIntegration(unittest.TestCase): -# """Integration tests for CLI server-client communication.""" - -# @classmethod -# def setUpClass(cls): -# """Start CLI server for integration tests.""" -# cls.server_info = cli_serve(cli_host='127.0.0.1', cli_port=0, spawn_client=False) -# if not cls.server_info['ok']: -# raise RuntimeError("Failed to start CLI server for integration tests") -# time.sleep(0.2) # Give server time to fully start - -# @classmethod -# def tearDownClass(cls): -# """Stop CLI server after integration tests.""" -# global _server_sock -# if _server_sock: -# try: -# _server_sock.close() -# except Exception: -# pass - -# def _send_command(self, cmd: str) -> dict: -# """Helper to send command to server and get response.""" -# sock = socket.create_connection( -# (self.server_info['host'], self.server_info['port']), -# timeout=5 -# ) -# f = sock.makefile('rwb') - -# # Send command -# f.write((cmd + '\n').encode('utf8')) -# f.flush() - -# # Read response -# response_line = f.readline() -# response = json.loads(response_line.decode('utf8')) - -# f.close() -# sock.close() - -# return response - -# def test_integration_help(self): -# """Test help command through server.""" -# response = self._send_command('help') -# self.assertTrue(response['ok']) -# self.assertIn('commands', response) - -# def test_integration_status(self): -# """Test status command through server.""" -# response = self._send_command('status') -# self.assertTrue(response['ok']) -# self.assertIn('snapshot', response) - -# def test_integration_list_models(self): -# """Test list_models through server.""" -# response = self._send_command('list_models') -# self.assertTrue(response['ok']) -# self.assertIn('models', response) - -# def test_integration_unknown_command(self): -# """Test unknown command through server.""" -# response = self._send_command('invalid_command_xyz') -# self.assertFalse(response['ok']) -# self.assertIn('error', response) - -# def test_integration_quit(self): -# """Test quit command closes connection.""" -# sock = socket.create_connection( -# (self.server_info['host'], self.server_info['port']), -# timeout=5 -# ) -# f = sock.makefile('rwb') - -# # Send quit -# f.write(b'quit\n') -# f.flush() - -# # Read goodbye -# response = json.loads(f.readline().decode('utf8')) -# self.assertTrue(response['ok']) -# self.assertTrue(response.get('bye')) - -# f.close() -# sock.close() +# """Integration tests for CLI server-client communication.""" + +# @classmethod +# def setUpClass(cls): +# """Start CLI server for integration tests.""" +# cls.server_info = cli_serve(cli_host='127.0.0.1', cli_port=0, spawn_client=False) +# if not cls.server_info['ok']: +# raise RuntimeError("Failed to start CLI server for integration tests") +# time.sleep(0.2) # Give server time to fully start + +# @classmethod +# def tearDownClass(cls): +# """Stop CLI server after integration tests.""" +# global _server_sock +# if _server_sock: +# try: +# _server_sock.close() +# except Exception: +# pass + +# def _send_command(self, cmd: str) -> dict: +# """Helper to send command to server and get response.""" +# sock = socket.create_connection( +# (self.server_info['host'], self.server_info['port']), +# timeout=5 +# ) +# f = sock.makefile('rwb') + +# # Send command +# f.write((cmd + '\n').encode('utf8')) +# f.flush() + +# # Read response +# response_line = f.readline() +# response = json.loads(response_line.decode('utf8')) + +# f.close() +# sock.close() + +# return response + +# def test_integration_help(self): +# """Test help command through server.""" +# response = self._send_command('help') +# self.assertTrue(response['ok']) +# self.assertIn('commands', response) + +# def test_integration_status(self): +# """Test status command through server.""" +# response = self._send_command('status') +# self.assertTrue(response['ok']) +# self.assertIn('snapshot', response) + +# def test_integration_list_models(self): +# """Test list_models through server.""" +# response = self._send_command('list_models') +# self.assertTrue(response['ok']) +# self.assertIn('models', response) + +# def test_integration_unknown_command(self): +# """Test unknown command through server.""" +# response = self._send_command('invalid_command_xyz') +# self.assertFalse(response['ok']) +# self.assertIn('error', response) + +# def test_integration_quit(self): +# """Test quit command closes connection.""" +# sock = socket.create_connection( +# (self.server_info['host'], self.server_info['port']), +# timeout=5 +# ) +# f = sock.makefile('rwb') + +# # Send quit +# f.write(b'quit\n') +# f.flush() + +# # Read goodbye +# response = json.loads(f.readline().decode('utf8')) +# self.assertTrue(response['ok']) +# self.assertTrue(response.get('bye')) + +# f.close() +# sock.close() def run_tests(): diff --git a/weightslab/tests/general/test_signals.py b/weightslab/tests/general/test_signals.py index 4c6e5893..0d32127d 100644 --- a/weightslab/tests/general/test_signals.py +++ b/weightslab/tests/general/test_signals.py @@ -57,7 +57,7 @@ def test_signals_with_list_batch_ids(self): mock_gm.return_value = mock_model with patch("weightslab.src.DATAFRAME_M", mock_df): - batch_ids = [20, 21, 22] # list instead of tensor + batch_ids = [20, 21, 22] # list instead of tensor signals = {"loss": 0.3} wl.save_signals( @@ -84,7 +84,7 @@ def test_signals_with_scalar_values(self): with patch("weightslab.src.DATAFRAME_M", mock_df): batch_ids = torch.tensor([30, 31]) signals = { - "loss": 0.25, # scalar float + "loss": 0.25, # scalar float "accuracy": 0.95, "f1": np.float32(0.92) } @@ -462,8 +462,8 @@ def test_save_signals_batch_processing(self, mock_gm, mock_get_dataframe): wl.save_signals( signals={"det_loss": torch.tensor(0.2)}, batch_ids=torch.tensor([6, 7]), - preds=torch.rand((2, 5, 4)), # 5 boxes, 4 coords - targets=torch.rand((2, 4, 4)), # 4 boxes, 4 coords + preds=torch.rand((2, 5, 4)), # 5 boxes, 4 coords + targets=torch.rand((2, 4, 4)), # 4 boxes, 4 coords log=True ) @@ -514,7 +514,7 @@ def test_signals_with_none_batch_ids(self): wl.save_signals( signals=signals, batch_ids=None, - log=False # Don't log without IDs + log=False # Don't log without IDs ) def test_signal_with_mixed_data_types(self): @@ -680,11 +680,11 @@ def test_detection_signals_with_variable_boxes(self, mock_gm, mock_get_dataframe # Variable number of boxes: img1 has 3, img2 has 1, img3 has 5, img4 has 2 preds = [ - torch.tensor([[10, 20, 110, 120], [50, 60, 150, 160], [200, 210, 300, 310]]), # 3 boxes - torch.tensor([[15, 25, 115, 125]]), # 1 box + torch.tensor([[10, 20, 110, 120], [50, 60, 150, 160], [200, 210, 300, 310]]), # 3 boxes + torch.tensor([[15, 25, 115, 125]]), # 1 box torch.tensor([[30, 40, 130, 140], [70, 80, 170, 180], [250, 260, 350, 360], - [100, 110, 200, 210], [180, 190, 280, 290]]), # 5 boxes - torch.tensor([[45, 55, 145, 155], [220, 230, 320, 330]]) # 2 boxes + [100, 110, 200, 210], [180, 190, 280, 290]]), # 5 boxes + torch.tensor([[45, 55, 145, 155], [220, 230, 320, 330]]) # 2 boxes ] targets = [ @@ -800,7 +800,7 @@ def test_signals_for_binary_classification(self, mock_gm, mock_get_dataframe): self.assertTrue(mock_df.enqueue_batch.called) call_kwargs = mock_df.enqueue_batch.call_args[1] losses = call_kwargs['losses'] - self.assertEqual(len(losses), 5) # All signals should be saved + self.assertEqual(len(losses), 5) # All signals should be saved if __name__ == "__main__": unittest.main() diff --git a/weightslab/tests/general/test_signals_wrapping.py b/weightslab/tests/general/test_signals_wrapping.py index 3a307b97..aeac7f48 100644 --- a/weightslab/tests/general/test_signals_wrapping.py +++ b/weightslab/tests/general/test_signals_wrapping.py @@ -58,7 +58,7 @@ def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 32, 3, padding=1) self.pool = nn.AdaptiveAvgPool2d((8, 8)) - self.fc = nn.Linear(32 * 8 * 8, 100) # 100 outputs for bbox/conf + self.fc = nn.Linear(32 * 8 * 8, 100) # 100 outputs for bbox/conf self.task_type = "detection" def forward(self, x): @@ -270,9 +270,9 @@ def test_save_detection_signals_with_variable_boxes(self, mock_gm, mock_get_df): # Variable boxes preds = [ - torch.tensor([[10, 20, 100, 150], [200, 250, 400, 450]]), # 2 boxes - torch.tensor([[15, 25, 110, 160]]), # 1 box - torch.tensor([[30, 40, 130, 140], [70, 80, 170, 180], [250, 260, 350, 360]]) # 3 boxes + torch.tensor([[10, 20, 100, 150], [200, 250, 400, 450]]), # 2 boxes + torch.tensor([[15, 25, 110, 160]]), # 1 box + torch.tensor([[30, 40, 130, 140], [70, 80, 170, 180], [250, 260, 350, 360]]) # 3 boxes ] targets = [ diff --git a/weightslab/tests/integrations/test_pytorch_lightning_integration.py b/weightslab/tests/integrations/test_pytorch_lightning_integration.py index 704eee47..6942b1e4 100644 --- a/weightslab/tests/integrations/test_pytorch_lightning_integration.py +++ b/weightslab/tests/integrations/test_pytorch_lightning_integration.py @@ -154,127 +154,127 @@ def tearDown(self): # These 3 next tests were removed as they are covering disabled feature. We can re-enable them once the feature is re-enabled. # def test_proxy_hashable_in_lightning(self): - # """Test that Proxy objects are hashable and work with Lightning's module system.""" - # model = SimpleCNN() - # print(wl.__file__) - # model_wl = wl.watch_or_edit(model, flag="model", device=self.device) + # """Test that Proxy objects are hashable and work with Lightning's module system.""" + # model = SimpleCNN() + # print(wl.__file__) + # model_wl = wl.watch_or_edit(model, flag="model", device=self.device) - # # Test that proxy can be used in sets (requires __hash__) - # proxy_set = {model_wl} - # self.assertIn(model_wl, proxy_set) + # # Test that proxy can be used in sets (requires __hash__) + # proxy_set = {model_wl} + # self.assertIn(model_wl, proxy_set) - # # Test that proxy can be used as dict key - # proxy_dict = {model_wl: "test_value"} - # self.assertEqual(proxy_dict[model_wl], "test_value") + # # Test that proxy can be used as dict key + # proxy_dict = {model_wl: "test_value"} + # self.assertEqual(proxy_dict[model_wl], "test_value") # def test_lightning_module_with_weightslab_tracking(self): - # """Test that Lightning module can be created with WeightsLab tracked objects.""" - # pause_controller.resume(force=True) # Ensure not pausedv - # # Create model and wrap with WeightsLab - # _model = SimpleCNN().to(self.device) - # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) - - # # Create tracked loss and metrics - # criterion = wl.watch_or_edit( - # nn.CrossEntropyLoss(reduction="none"), - # flag="loss", signal_name="loss-CE", log=True - # ) - - # metric = wl.watch_or_edit( - # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), - # flag="metric", signal_name="metric-ACC", log=True - # ) - - # # Create optimizer - # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) - # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") - - # # Create Lightning module with tracked objects - # lit_model = LitTestModel( - # model=model_wl, - # optimizer=optimizer_wl, - # criterion_wl=criterion, - # metric_wl=metric - # ) - - # # Verify Lightning module was created successfully - # self.assertIsInstance(lit_model, pl.LightningModule) - # self.assertIsInstance(lit_model.model, Proxy) + # """Test that Lightning module can be created with WeightsLab tracked objects.""" + # pause_controller.resume(force=True) # Ensure not pausedv + # # Create model and wrap with WeightsLab + # _model = SimpleCNN().to(self.device) + # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) + + # # Create tracked loss and metrics + # criterion = wl.watch_or_edit( + # nn.CrossEntropyLoss(reduction="none"), + # flag="loss", signal_name="loss-CE", log=True + # ) + + # metric = wl.watch_or_edit( + # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), + # flag="metric", signal_name="metric-ACC", log=True + # ) + + # # Create optimizer + # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) + # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") + + # # Create Lightning module with tracked objects + # lit_model = LitTestModel( + # model=model_wl, + # optimizer=optimizer_wl, + # criterion_wl=criterion, + # metric_wl=metric + # ) + + # # Verify Lightning module was created successfully + # self.assertIsInstance(lit_model, pl.LightningModule) + # self.assertIsInstance(lit_model.model, Proxy) # def test_lightning_training_with_weightslab_loaders(self): - # """Test full training loop with WeightsLab tracked data loaders.""" - # pause_controller.resume(force=True) # Ensure not paused - - # # Create tracked loaders - # train_loader = wl.watch_or_edit( - # self.train_dataset, - # flag="data", - # loader_name="train_loader", - # batch_size=16, - # shuffle=True, - # is_training=True, - # compute_hash=False, - # enable_h5_persistence=False - # ) - - # val_loader = wl.watch_or_edit( - # self.val_dataset, - # flag="data", - # loader_name="val_loader", - # batch_size=16, - # shuffle=False, - # is_training=False, - # compute_hash=False, - # enable_h5_persistence=False - # ) - - # # Create model with tracked components - # _model = SimpleCNN().to(self.device) - # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) - - # criterion = wl.watch_or_edit( - # nn.CrossEntropyLoss(reduction="none"), - # flag="loss", signal_name="loss-CE", log=True - # ) - - # metric = wl.watch_or_edit( - # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), - # flag="metric", signal_name="metric-ACC", log=True - # ) - - # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) - # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") - - # lit_model = LitTestModel( - # model=model_wl, - # optimizer=optimizer_wl, - # criterion_wl=criterion, - # metric_wl=metric - # ) - - # # Create Lightning trainer with minimal configuration - # trainer = pl.Trainer( - # max_epochs=2, - # accelerator=self.device if self.device in ["cpu", "cuda"] else "auto", - # devices=1, - # enable_checkpointing=False, - # logger=False, - # enable_progress_bar=False, - # ) - - # # Train the model - this should complete without errors - # try: - # trainer.fit(lit_model, train_loader, val_loader) - # training_succeeded = True - # except Exception as e: - # training_succeeded = False - # self.fail(f"Training failed with error: {e}") - - # self.assertTrue(training_succeeded, "Training should complete successfully") + # """Test full training loop with WeightsLab tracked data loaders.""" + # pause_controller.resume(force=True) # Ensure not paused + + # # Create tracked loaders + # train_loader = wl.watch_or_edit( + # self.train_dataset, + # flag="data", + # loader_name="train_loader", + # batch_size=16, + # shuffle=True, + # is_training=True, + # compute_hash=False, + # enable_h5_persistence=False + # ) + + # val_loader = wl.watch_or_edit( + # self.val_dataset, + # flag="data", + # loader_name="val_loader", + # batch_size=16, + # shuffle=False, + # is_training=False, + # compute_hash=False, + # enable_h5_persistence=False + # ) + + # # Create model with tracked components + # _model = SimpleCNN().to(self.device) + # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) + + # criterion = wl.watch_or_edit( + # nn.CrossEntropyLoss(reduction="none"), + # flag="loss", signal_name="loss-CE", log=True + # ) + + # metric = wl.watch_or_edit( + # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), + # flag="metric", signal_name="metric-ACC", log=True + # ) + + # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) + # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") + + # lit_model = LitTestModel( + # model=model_wl, + # optimizer=optimizer_wl, + # criterion_wl=criterion, + # metric_wl=metric + # ) + + # # Create Lightning trainer with minimal configuration + # trainer = pl.Trainer( + # max_epochs=2, + # accelerator=self.device if self.device in ["cpu", "cuda"] else "auto", + # devices=1, + # enable_checkpointing=False, + # logger=False, + # enable_progress_bar=False, + # ) + + # # Train the model - this should complete without errors + # try: + # trainer.fit(lit_model, train_loader, val_loader) + # training_succeeded = True + # except Exception as e: + # training_succeeded = False + # self.fail(f"Training failed with error: {e}") + + # self.assertTrue(training_succeeded, "Training should complete successfully") def test_weightslab_context_guards_in_lightning(self): """Test that WeightsLab context guards work correctly in Lightning steps.""" - pause_controller.resume(force=True) # Ensure not paused + pause_controller.resume(force=True) # Ensure not paused context_log = [] class ContextTestModule(pl.LightningModule): diff --git a/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py b/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py index 74c45d72..3c268e84 100644 --- a/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py +++ b/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py @@ -3,9 +3,9 @@ The fair baseline isn't "no logging" — anyone wanting per-sample signals must decode preds, compute per-sample loss/metrics, and store them. So we compare two modes with identical model / batch / imgsz / data: - ulmanual — ultralytics + a HAND-ROLLED minimal per-sample logger: decode + per-sample + ulmanual — ultralytics + a HAND-ROLLED minimal per-sample logger: decode + per-sample loss/IoU + append the scalars to a plain list. The "classic" way. - wl — full WL pipeline: wrapped model/loss/loader, save_signals, anchor + wl — full WL pipeline: wrapped model/loss/loader, save_signals, anchor (reconcile DOWN + flush UP), decode-for-logging. (wl - ulmanual) = WL's internal machinery (dataframe upserts + ledger/H5 + the DDP @@ -14,7 +14,7 @@ ms/step + rank-0 RSS; `wl` also prints the global dataframe RAM + H5 store sizes. WL_ABLATE=ulmanual WL_DDP_CUDA=1 python ddp_ablation.py - WL_ABLATE=wl WL_DDP_CUDA=1 python ddp_ablation.py + WL_ABLATE=wl WL_DDP_CUDA=1 python ddp_ablation.py """ import os os.environ.setdefault("WEIGHTSLAB_SKIP_SECURE_INIT", "true") @@ -54,7 +54,7 @@ def _rss_mb(): with open("/proc/self/status") as f: for ln in f: if ln.startswith("VmRSS:"): - return int(ln.split()[1]) / 1024.0 # KB -> MB + return int(ln.split()[1]) / 1024.0 # KB -> MB except Exception: pass return -1.0 @@ -134,7 +134,7 @@ def _worker(rank, world, master_port): batch_size = int(os.environ.get("WL_DDP_BATCH", "16")) num_workers = int(os.environ.get("WL_DDP_WORKERS", "0")) - is_wl = MODE == "wl" # else: ulmanual (the hand-rolled classic baseline) + is_wl = MODE == "wl" # else: ulmanual (the hand-rolled classic baseline) if is_wl: import yolo_pipeline cfg["compute_natural_sort"] = False @@ -154,7 +154,7 @@ def _worker(rank, world, master_port): else: model, loader, crit, iou, optimizer = _build_ul(cfg, device, batch_size, num_workers) from yolo_pipeline import _decode_preds_to_6col as decode - _manual_store = [] # the "classic" sink: a plain in-memory list + _manual_store = [] # the "classic" sink: a plain in-memory list # identical initial weights on every rank (flattened broadcast) with torch.no_grad(): @@ -191,7 +191,7 @@ def _inf(ld): for step in range(_WARMUP + _STEPS): timed = step >= _WARMUP if step == _WARMUP: - io0 = _proc_io() # I/O counters at the start of the timed window + io0 = _proc_io() # I/O counters at the start of the timed window t0 = time.perf_counter() inputs = next(batches) if is_wl: @@ -282,21 +282,21 @@ def _inf(ld): # Each rank prints its OWN per-rank line (no gather collective — it was flaky on # gloo+CUDA; per-process I/O reads are independent anyway). io = io_d - print(f"[mode={MODE} rank {rank}] RSS={rss:7.0f}MB anchor={t.ms('anchor(WL)'):6.1f}ms " + print(f"[mode={MODE} rank {rank}] RSS={rss:7.0f}MB anchor={t.ms('anchor(WL)'):6.1f}ms " f"IO(MB): rchar={io.get('rchar',0)/1e6:7.1f} wchar={io.get('wchar',0)/1e6:7.1f} " f"read_dsk={io.get('read_bytes',0)/1e6:6.1f} write_dsk={io.get('write_bytes',0)/1e6:6.1f}", flush=True) if rank == 0: total = sum(t.ms(k) for k in order) print("\n" + "=" * 74) - print(f"ABLATION mode={MODE} device={device} world={world} batch={batch_size} steps={_STEPS}") + print(f"ABLATION mode={MODE} device={device} world={world} batch={batch_size} steps={_STEPS}") print("=" * 74) for k in order: - print(f" {k:18s} {t.ms(k):8.1f} ms/step") - print(f" {'STEP TOTAL':18s} {total:8.1f} ms/step") - print(f" {'grad on the wire':18s} {grad_bytes/1e6:8.1f} MB/step") + print(f" {k:18s} {t.ms(k):8.1f} ms/step") + print(f" {'STEP TOTAL':18s} {total:8.1f} ms/step") + print(f" {'grad on the wire':18s} {grad_bytes/1e6:8.1f} MB/step") if is_wl: - print(f" WL df RAM {df_mb:.1f} MB | WL H5 {h5_mb:.1f} MB disk | " + print(f" WL df RAM {df_mb:.1f} MB | WL H5 {h5_mb:.1f} MB disk | " f"H5 cfg: persist={cfg.get('ledger_enable_h5_persistence')} " f"max_rows={cfg.get('ledger_flush_max_rows')} " f"interval={cfg.get('ledger_flush_interval')}s " diff --git a/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py b/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py index fb8efb42..c9608081 100644 --- a/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py +++ b/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py @@ -7,8 +7,8 @@ measure time/processes, and later extract the reusable bits into the SDK. Layout: - * _train_worker(rank, world, ...) -- the distributed SERVER (spawned, rank 0 serves). - * Client + scenarios -- run in the PARENT process, talk gRPC to rank 0. + * _train_worker(rank, world, ...) -- the distributed SERVER (spawned, rank 0 serves). + * Client + scenarios -- run in the PARENT process, talk gRPC to rank 0. First scenario (scenario_epoch_then_pause): 1. spawn `world` ranks (rank 0 serves gRPC plaintext); parent = client. @@ -19,13 +19,13 @@ 5. wait ~50% of the epoch's wall time (training is paused) and assert the last_seen map is byte-identical => pause truly froze training. -Run: python ddp_test_suite.py (WL_DDP_WORLD_SIZE=2, imgsz 96, num_workers 0) +Run: python ddp_test_suite.py (WL_DDP_WORLD_SIZE=2, imgsz 96, num_workers 0) """ import os # --- test-mode env (must be set before importing weightslab) --------------- -os.environ["WEIGHTSLAB_SKIP_SECURE_INIT"] = "true" # plaintext gRPC for the client +os.environ["WEIGHTSLAB_SKIP_SECURE_INIT"] = "true" # plaintext gRPC for the client os.environ["GRPC_TLS_ENABLED"] = "0" -os.environ.setdefault("WL_DDP_IMGSZ", "96") # small images for speed +os.environ.setdefault("WL_DDP_IMGSZ", "96") # small images for speed os.environ.setdefault("WL_DDP_COLLECTIVE_LOG", "/tmp/wl_collective_log.txt") os.environ.setdefault("WL_PRELOAD_IMAGE_OVERVIEW", "0") os.environ.setdefault("WEIGHTSLAB_LOG_LEVEL", "WARNING") @@ -44,7 +44,7 @@ # usecase modules (yolo_pipeline, utils.*) and its config/data/ddp_run resolve. sys.path.insert(0, os.path.abspath(os.path.join( os.path.dirname(__file__), "../../../../examples/PyTorch/ws-detection/src"))) -import yolo_pipeline # reuse _build_pipeline / _decode_preds_to_6col / _HERE / _LOSS_PARTS +import yolo_pipeline # reuse _build_pipeline / _decode_preds_to_6col / _HERE / _LOSS_PARTS import weightslab.proto.experiment_service_pb2 as pb2 import weightslab.proto.experiment_service_pb2_grpc as pb2_grpc @@ -54,7 +54,7 @@ # =========================================================================== -# SERVER (spawned ranks; rank 0 serves gRPC) +# SERVER (spawned ranks; rank 0 serves gRPC) # =========================================================================== def _train_worker(rank, world, master_port, grpc_port): """Spawned per rank. Delegates to main_ddp.train_worker — the clean @@ -80,11 +80,11 @@ def _train_worker(rank, world, master_port, grpc_port): # =========================================================================== -# CLIENT (parent process) +# CLIENT (parent process) # =========================================================================== class Client: def __init__(self, port): - self._port = int(port) # exposed for topology-style scenarios + self._port = int(port) # exposed for topology-style scenarios self.channel = grpc.insecure_channel( f"{_HOST}:{port}", options=[("grpc.max_receive_message_length", 256 * 1024 * 1024)], @@ -298,7 +298,7 @@ def _wait_until_paused(client, n, min_step, timeout=600.0, poll=5.0): server is paused, not just mid-step). drop_last means the per-rank epoch is floor(shard/batch), so we don't require an exact step count.""" _t0 = time.time() - _last_change = _t0 # wall-time of the most recent last_seen-max change + _last_change = _t0 # wall-time of the most recent last_seen-max change deadline = time.time() + timeout prev = None stable = 0 @@ -310,9 +310,9 @@ def _wait_until_paused(client, n, min_step, timeout=600.0, poll=5.0): if cur >= min_step and stable >= 2: if _SCN_TIMING: tot = time.time() - _t0 - active = _last_change - _t0 # last_seen advancing = training/observed work - settle = time.time() - _last_change # stable-confirm + snapshot-lag = observability - print(f"[scn_timing] wait_until_paused total={tot:6.1f}s " + active = _last_change - _t0 # last_seen advancing = training/observed work + settle = time.time() - _last_change # stable-confirm + snapshot-lag = observability + print(f"[scn_timing] wait_until_paused total={tot:6.1f}s " f"active(train)={active:6.1f}s settle(obs)={settle:6.1f}s", flush=True) return cur prev = cur @@ -354,8 +354,8 @@ def scenario_epoch_then_pause(client, world, batch): # auto-pause (pause_at_step) fire at the epoch boundary WITHOUT crossing into # epoch 2 (which would force a sampler re-iteration mid-test). epoch_steps = (n // world) // batch - epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) # fast-debug override - print(f"[client] universe N={n} world={world} batch={batch} -> epoch_steps={epoch_steps}") + epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) # fast-debug override + print(f"[client] universe N={n} world={world} batch={batch} -> epoch_steps={epoch_steps}") t0 = time.time() client.train_steps(epoch_steps) @@ -366,7 +366,7 @@ def scenario_epoch_then_pause(client, world, batch): epoch_secs = time.time() - t0 print(f"[client] epoch done: max last_seen={reached} in {epoch_secs:.1f}s") - s1 = _settled_last_seen(client, n) # wait out the DataService snapshot throttle + s1 = _settled_last_seen(client, n) # wait out the DataService snapshot throttle populated = {k: v for k, v in s1.items() if v is not None and v >= 0} # With the children->rank0 gather (fired on pause), rank 0 sees what ALL ranks # trained: ~reached*batch*world distinct samples (capped at the universe N, @@ -391,8 +391,8 @@ def scenario_epoch_then_pause(client, world, batch): print(f"[client] FROZEN CHECK FAILED, {len(diff)} changed e.g. {list(diff.items())[:5]}") ok = a1 and a1b and a2 - print(f"[1] EPOCH COVERAGE populated>0={a1} populated~=shard={a1b} -> {'PASS' if (a1 and a1b) else 'FAIL'}") - print(f"[2] PAUSE FREEZES last_seen identical after wait={a2} -> {'PASS' if a2 else 'FAIL'}") + print(f"[1] EPOCH COVERAGE populated>0={a1} populated~=shard={a1b} -> {'PASS' if (a1 and a1b) else 'FAIL'}") + print(f"[2] PAUSE FREEZES last_seen identical after wait={a2} -> {'PASS' if a2 else 'FAIL'}") return ok @@ -443,11 +443,11 @@ def scenario_discard_subset_freezes(client, world, batch, n_discard=5): most_advanced = advanced >= int(0.8 * non_discarded_pop1) ok = a0 and all_frozen and most_advanced and (m2 > m1) - print(f"[1] DISCARD REGISTERED exactly {n_discard} added={a0}") - print(f"[2] SUBSET FROZEN all {n_discard} unchanged={all_frozen} values {L} -> {frozen}") - print(f"[3] MOST ADVANCED {advanced}/{non_discarded_pop1} non-discarded advanced " - f"(>=80%)={most_advanced} (epoch max {m1}->{m2})") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f"[1] DISCARD REGISTERED exactly {n_discard} added={a0}") + print(f"[2] SUBSET FROZEN all {n_discard} unchanged={all_frozen} values {L} -> {frozen}") + print(f"[3] MOST ADVANCED {advanced}/{non_discarded_pop1} non-discarded advanced " + f"(>=80%)={most_advanced} (epoch max {m1}->{m2})") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -458,7 +458,7 @@ def scenario_break_by_slice(client, world, batch): epoch_steps = (n // world) // batch epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) origin = client.train_origin() - graph = "train/bbxs" # a per-sample loss component logged by the criterions + graph = "train/bbxs" # a per-sample loss component logged by the criterions client.train_steps(epoch_steps) _wait_until_paused(client, n, min_step=max(1, epoch_steps - batch)) @@ -472,25 +472,25 @@ def scenario_break_by_slice(client, world, batch): # trained set. Measure it directly with a 'uni' slice over all trained samples # (this is the GLOBAL loss universe on rank 0 thanks to the per-sample gather), # then check the 'even' slice returns exactly the even members that have loss. - even = set(trained[::2]) # every other trained sample — spans both ranks' shards + even = set(trained[::2]) # every other trained sample — spans both ranks' shards client.tag(trained, "uni", origin) client.tag(even, "even", origin) print(f"[client] trained={len(trained)} tagged uni + even={len(even)} (origin={origin})") uni_sids = {p[0] for p in client.break_by_slice(graph, ["uni"])} even_sids = {p[0] for p in client.break_by_slice(graph, ["even"])} - expected_even = even & uni_sids # even-tagged samples that actually have loss + expected_even = even & uni_sids # even-tagged samples that actually have loss a1 = len(even_sids) > 0 - a2 = (even_sids == expected_even) # break-by-slice slices correctly + a2 = (even_sids == expected_even) # break-by-slice slices correctly ok = a1 and a2 - print(f"[1] BREAK-BY-SLICE even returned {len(even_sids)} samples (graph={graph})={a1}") - print(f"[2] SLICE CORRECT even == even-with-loss ({len(expected_even)})={a2}") + print(f"[1] BREAK-BY-SLICE even returned {len(even_sids)} samples (graph={graph})={a1}") + print(f"[2] SLICE CORRECT even == even-with-loss ({len(expected_even)})={a2}") # Cross-rank is evidenced by the server-side [siggather] log (rank 0 receives the # children's triples). It's not cleanly black-box-assertable without a per-rank # baseline, and becomes STRUCTURAL once writes go through sync_to_rank0 on rank 0. print(f"[i] loss universe on rank 0 = {len(uni_sids)} samples (spans both ranks via the gather)") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -501,7 +501,7 @@ def scenario_lr_batch_propagate(client, world, batch): epoch_steps = (n // world) // batch epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) new_batch = batch * 2 - phase2 = 4 # short, so trained-count doesn't wrap the universe + phase2 = 4 # short, so trained-count doesn't wrap the universe # phase 1 at the original batch client.train_steps(epoch_steps) @@ -521,8 +521,8 @@ def scenario_lr_batch_propagate(client, world, batch): steps2 = a1 - a0 trained2 = sum(1 for v in s1.values() if v is not None and v > a0) rate = trained2 / steps2 if steps2 > 0 else 0.0 - expected = new_batch * world # both ranks switched - rank0_only = new_batch + batch # only rank 0 switched (the bug we fixed) + expected = new_batch * world # both ranks switched + rank0_only = new_batch + batch # only rank 0 switched (the bug we fixed) # Threshold: must be clearly ABOVE the rank0-only failure mode. We don't # require hitting the full `expected` because under drop_last=False the # DistributedSampler pads the per-rank shard with re-yields of samples @@ -531,10 +531,10 @@ def scenario_lr_batch_propagate(client, world, batch): # rank0_only + 1 cleanly distinguishes "both ranks doubled" (rate ≈ 13–16) # from "only rank-0 doubled" (rate ≈ 12). a1ok = steps2 > 0 and rate >= rank0_only + 1 - print(f"[1] BATCH PROPAGATED {trained2} samples / {steps2} steps = {rate:.1f}/step " + print(f"[1] BATCH PROPAGATED {trained2} samples / {steps2} steps = {rate:.1f}/step " f"(expect ~{expected} all-ranks vs ~{rank0_only} rank0-only)={a1ok}") print(f"[i] lr=0.05 rode the same hparam broadcast that carried batch (proven above)") - print(f" -> {'PASS' if a1ok else 'FAIL'}") + print(f" -> {'PASS' if a1ok else 'FAIL'}") return a1ok @@ -556,11 +556,11 @@ def scenario_checkpoint_data_roundtrip(client, world, batch): client.discard([A], origin) # 2) short resume -> save_pending_changes writes a FULL checkpoint (model+config+ - # data{A}) with non-null weights; then read its combined hash from the manifest. + # data{A}) with non-null weights; then read its combined hash from the manifest. client.train_steps(2) _wait_until_paused(client, n, min_step=a0 + 1) saved_hash = client.latest_full_checkpoint_hash() - time.sleep(12) # clear the DataService snapshot throttle before reading + time.sleep(12) # clear the DataService snapshot throttle before reading disc_save = client.discarded_set(n) print(f"[client] discarded A={A}; full-ckpt hash={saved_hash}; discarded@save={sorted(disc_save)}") @@ -580,13 +580,13 @@ def scenario_checkpoint_data_roundtrip(client, world, batch): f"msg={getattr(resp, 'message', '')[:70]}; discarded@post={sorted(disc_post)}") restore_ok = bool(getattr(resp, "success", False)) - a0c = (C in disc_change) # the divergent discard registered - a1 = (C not in disc_post) # restore undid it - a2 = (A in disc_post) # the saved discard survived the roundtrip + a0c = (C in disc_change) # the divergent discard registered + a1 = (C not in disc_post) # restore undid it + a2 = (A in disc_post) # the saved discard survived the roundtrip ok = restore_ok and a0c and a1 and a2 - print(f"[1] DATA ROUNDTRIP restore_ok={restore_ok} C-registered={a0c} " + print(f"[1] DATA ROUNDTRIP restore_ok={restore_ok} C-registered={a0c} " f"C-reverted={a1} A-intact={a2}") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -618,7 +618,7 @@ def scenario_signal_coverage_all_graphs(client, world, batch): if len(trained) < 10: print(f"[client] too few trained ({len(trained)})"); return False client.tag(trained, "uni", origin) - print(f"[client] trained={len(trained)} tagged 'uni' steps={epoch_steps}") + print(f"[client] trained={len(trained)} tagged 'uni' steps={epoch_steps}") graphs = ["train/bbxs", "train/clsf", "train/dfl", "miou/train"] per_sample_min = max(1, int(0.3 * len(trained))) @@ -631,10 +631,10 @@ def scenario_signal_coverage_all_graphs(client, world, batch): plot_ok = len(plot_points) >= plot_min ok = ps_ok and plot_ok all_ok &= ok - print(f"[1] {g:<18s} per-sample={len(per_sample_sids)}/{len(trained)} " - f"≥{per_sample_min}={ps_ok} plot={len(plot_points)} " - f"≥{plot_min}={plot_ok} both={ok}") - print(f" -> {'PASS' if all_ok else 'FAIL'}") + print(f"[1] {g:<18s} per-sample={len(per_sample_sids)}/{len(trained)} " + f"≥{per_sample_min}={ps_ok} plot={len(plot_points)} " + f"≥{plot_min}={plot_ok} both={ok}") + print(f" -> {'PASS' if all_ok else 'FAIL'}") return all_ok @@ -693,12 +693,12 @@ def scenario_resume_continues_curve(client, world, batch): _wait_until_paused(client, n, min_step=age_diverged + 1) post_train_plot = client.scalar_plot("train/bbxs") a3 = len(post_train_plot) > len(pre_restore_plot) - print(f"[3] PLOT GROWS pre={len(pre_restore_plot)} post={len(post_train_plot)} → {a3}") + print(f"[3] PLOT GROWS pre={len(pre_restore_plot)} post={len(post_train_plot)} → {a3}") ok = a1 and a2 and a3 - print(f"[1] RESTORE OK success={a1}") - print(f"[2] SERVER ALIVE universe={n_after}/{n} → {a2}") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f"[1] RESTORE OK success={a1}") + print(f"[2] SERVER ALIVE universe={n_after}/{n} → {a2}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -711,7 +711,7 @@ def scenario_process_topology(client, world, batch): import re import subprocess - _ = client.universe_size() # confirm we're connected; ranks are alive + _ = client.universe_size() # confirm we're connected; ranks are alive grpc_port = getattr(client, "_port", None) # Walk the descendant tree of the suite process to find spawned ranks. @@ -770,9 +770,9 @@ def listeners_of(pid): # Sanity: at least one PID does listen (otherwise the gRPC server is dead). a2 = len(listening) >= 1 ok = a1 and a2 - print(f"[1] gRPC OWNER PIDs owning port {grpc_port}: {grpc_owners} (==1) → {a1}") - print(f"[2] HAS LISTENERS {len(listening)} PID(s) with TCP sockets (≥1) → {a2}") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f"[1] gRPC OWNER PIDs owning port {grpc_port}: {grpc_owners} (==1) → {a1}") + print(f"[2] HAS LISTENERS {len(listening)} PID(s) with TCP sockets (≥1) → {a2}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -807,17 +807,17 @@ def scenario_multi_epoch_stability(client, world, batch): # Per-graph: no duplicate (sid, age) entries dedup_ok = True for g in ["train/bbxs", "train/clsf", "train/dfl", "miou/train"]: - entries = client.break_by_slice(g, ["uni"]) # [(sid, age, val), ...] + entries = client.break_by_slice(g, ["uni"]) # [(sid, age, val), ...] keys = [(sid, age) for sid, age, _ in entries] unique, total = len(set(keys)), len(keys) ok = (unique == total) dedup_ok &= ok - print(f"[1] {g:<18s} {total} entries, {unique} unique (sid,age) → {ok}") + print(f"[1] {g:<18s} {total} entries, {unique} unique (sid,age) → {ok}") age_mono = ages[0] < ages[1] < ages[2] - print(f"[2] AGE MONOTONIC ages={ages} (strictly increasing) → {age_mono}") + print(f"[2] AGE MONOTONIC ages={ages} (strictly increasing) → {age_mono}") ok = dedup_ok and age_mono - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -825,17 +825,17 @@ def scenario_curate_lifecycle(client, world, batch): """End-to-end UI curation workflow under DDP — multiple composing edits and the loss trajectory tells the story: - epoch 1 (warm up: all populated samples accumulate train/bbxs entries) + epoch 1 (warm up: all populated samples accumulate train/bbxs entries) → tag 3 samples 'suspect' → discard those 3 - epoch 2 (the 3 suspects must produce NO new train/bbxs entries — + epoch 2 (the 3 suspects must produce NO new train/bbxs entries — their slot in the loss trajectory has a gap) → un-discard the 3 - → tag them additionally 'verified' (so each carries BOTH tags) - epoch 3 (the 3 resume; new entries appear beyond the discard age) + → tag them additionally 'verified' (so each carries BOTH tags) + epoch 3 (the 3 resume; new entries appear beyond the discard age) Assertions: - [1] LIFECYCLE — for each suspect: pre-discard entries exist AND + [1] LIFECYCLE — for each suspect: pre-discard entries exist AND no entries in the (discard_age, undiscard_age] window AND post-resume entries exist. The gap is the proof that discard reached the worker fast-path (the shm @@ -895,15 +895,15 @@ def scenario_curate_lifecycle(client, world, batch): ages_by_sid.setdefault(sid, []).append(age) # Per-suspect trajectory: - # pre — every suspect must have ≥1 entry before discard (proves we're - # tracking a sample that was actually trained on); - # gap — NO suspect may have an entry in (discard, undiscard] (proves the - # discard reached the sampler/worker fast-path); - # post — AT LEAST ONE suspect must have a post-undiscard entry (proves - # un-discard reaches the sampler). The shuffled sampler in a - # short 20-step epoch won't yield every sample, so requiring ALL - # suspects to resume would be a shuffle-luck check, not a - # correctness check. + # pre — every suspect must have ≥1 entry before discard (proves we're + # tracking a sample that was actually trained on); + # gap — NO suspect may have an entry in (discard, undiscard] (proves the + # discard reached the sampler/worker fast-path); + # post — AT LEAST ONE suspect must have a post-undiscard entry (proves + # un-discard reaches the sampler). The shuffled sampler in a + # short 20-step epoch won't yield every sample, so requiring ALL + # suspects to resume would be a shuffle-luck check, not a + # correctness check. pre_ok, gap_ok = True, True any_post = False for sid in suspects: @@ -912,26 +912,26 @@ def scenario_curate_lifecycle(client, world, batch): gap = [a for a in ages if age_at_discard < a <= age_at_undiscard] post = [a for a in ages if a > age_at_undiscard] if not pre: pre_ok = False - if gap: gap_ok = False - if post: any_post = True - print(f" sid={sid}: pre={pre[-3:]} gap={gap} post={post[:3]}") + if gap: gap_ok = False + if post: any_post = True + print(f" sid={sid}: pre={pre[-3:]} gap={gap} post={post[:3]}") post_ok = any_post verified_sids = {p[0] for p in client.break_by_slice("train/bbxs", ["verified"])} tag_compose = set(suspects).issubset(verified_sids) plot = client.scalar_plot("train/bbxs") - plot_ok = len(plot) >= 3 * 1 # at least one point per epoch (loose) + plot_ok = len(plot) >= 3 * 1 # at least one point per epoch (loose) a1ok = pre_ok and gap_ok and post_ok a2ok = tag_compose a3ok = plot_ok - print(f"[1] LIFECYCLE pre={pre_ok} gap-empty={gap_ok} any-post={post_ok} → {a1ok}") - print(f"[2] TAG COMPOSE verified⊇suspects ({len(verified_sids)} verified, " + print(f"[1] LIFECYCLE pre={pre_ok} gap-empty={gap_ok} any-post={post_ok} → {a1ok}") + print(f"[2] TAG COMPOSE verified⊇suspects ({len(verified_sids)} verified, " f"{len(set(suspects) & verified_sids)}/3 suspects tagged) → {a2ok}") - print(f"[3] PLOT METRICS scalar_plot has {len(plot)} entries → {a3ok}") + print(f"[3] PLOT METRICS scalar_plot has {len(plot)} entries → {a3ok}") ok = a1ok and a2ok and a3ok - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -948,7 +948,7 @@ def scenario_collective_budget(client, world, batch): log_path = os.environ.get("WL_DDP_COLLECTIVE_LOG") if not log_path: print(f"[client] WL_DDP_COLLECTIVE_LOG not set; skipping"); return False - open(log_path, "w").close() # truncate before this scenario's window + open(log_path, "w").close() # truncate before this scenario's window n = client.universe_size() epoch_steps = (n // world) // batch @@ -966,19 +966,19 @@ def scenario_collective_budget(client, world, batch): # First few entries can include pause-spin reconciles (many per "step" while # the trainer is waiting for the resume signal). Take a slice from the tail # corresponding to clearly-in-the-body steps. - body = [c for c in counts if c <= 5] # drop the spin-inflated outliers + body = [c for c in counts if c <= 5] # drop the spin-inflated outliers spin = [c for c in counts if c > 5] avg_body = (sum(body) / len(body)) if body else float("inf") max_body = max(body) if body else 0 a1 = max_body <= 2 a2 = avg_body <= 2.0 - print(f"[1] BUDGET PER STEP body samples={len(body)}, max={max_body}, " + print(f"[1] BUDGET PER STEP body samples={len(body)}, max={max_body}, " f"avg={avg_body:.2f}, spin samples={len(spin)} (excluded) " f"max-over-budget→{a1}") - print(f"[2] AVG ≤ 2 {avg_body:.2f} → {a2}") + print(f"[2] AVG ≤ 2 {avg_body:.2f} → {a2}") ok = a1 and a2 - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -1013,13 +1013,13 @@ def scenario_seed_determinism(client, world, batch): a1 = len(pull1) > 0 and len(pull1) == len(pull2) a2 = all(p1 == p2 for p1, p2 in zip(pull1, pull2)) - print(f"[1] STABLE LEN pull1={len(pull1)} == pull2={len(pull2)} → {a1}") - print(f"[2] BIT-IDENTICAL every (sid, age, val) matches → {a2}") + print(f"[1] STABLE LEN pull1={len(pull1)} == pull2={len(pull2)} → {a1}") + print(f"[2] BIT-IDENTICAL every (sid, age, val) matches → {a2}") # Spot-check first 3 entries for i in range(min(3, len(pull1))): - print(f" p1[{i}]={pull1[i]} p2[{i}]={pull2[i]}") + print(f" p1[{i}]={pull1[i]} p2[{i}]={pull2[i]}") ok = a1 and a2 - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -1060,7 +1060,7 @@ def scenario_empty_shard_starvation(client, world, batch): else: to_discard = populated[:-keep] client.discard(to_discard, origin) - print(f"[client] discarded {len(to_discard)} (keep={keep}, force={bool(_force)})") + print(f"[client] discarded {len(to_discard)} (keep={keep}, force={bool(_force)})") # Short post-discard train. Bounded timeout: if it hangs, assertion fires. K = min(epoch_steps, 8) @@ -1072,13 +1072,13 @@ def scenario_empty_shard_starvation(client, world, batch): timeout=180.0, poll=3.0) except TimeoutError: elapsed = time.time() - t0 - print(f"[client] HUNG no model_age advance in {elapsed:.0f}s (a0={a0})") - print(f" -> FAIL") + print(f"[client] HUNG no model_age advance in {elapsed:.0f}s (a0={a0})") + print(f" -> FAIL") return False elapsed = time.time() - t0 advanced = a1 > a0 - print(f"[1] NO HANG age advanced {a0}→{a1} in {elapsed:.1f}s → {advanced}") - print(f" -> {'PASS' if advanced else 'FAIL'}") + print(f"[1] NO HANG age advanced {a0}→{a1} in {elapsed:.1f}s → {advanced}") + print(f" -> {'PASS' if advanced else 'FAIL'}") return advanced @@ -1098,7 +1098,7 @@ def scenario_progressive_resample(client, world, batch): def _epoch_steps(live): return max(1, (live // world) // batch) - def _ls(d, sid): # None-safe last_seen ( -1 == never seen ) + def _ls(d, sid): # None-safe last_seen ( -1 == never seen ) v = d.get(sid) return v if v is not None else -1 @@ -1114,8 +1114,8 @@ def _run_epochs(n_ep, steps_each, label, m_start): timeout=180.0, poll=3.0) dt = time.perf_counter() - t0 total = n_ep * steps_each - print(f"[time] {label:24s} {n_ep}ep x {steps_each:>3}st = {total:>4} steps " - f"{dt:6.1f}s ({dt/max(1,total):.2f}s/step)") + print(f"[time] {label:24s} {n_ep}ep x {steps_each:>3}st = {total:>4} steps " + f"{dt:6.1f}s ({dt/max(1,total):.2f}s/step)") return m, dt # --- warm-up: 1 full epoch (100% live) --- @@ -1124,7 +1124,7 @@ def _run_epochs(n_ep, steps_each, label, m_start): m0 = _wait_until_paused(client, n, min_step=max(1, full_epoch_steps - batch)) t_warm = time.perf_counter() - t0 print(f"[time] {'warmup (1 x 100%)':24s} 1ep x {full_epoch_steps:>3}st = {full_epoch_steps:>4} " - f"steps {t_warm:6.1f}s ({t_warm/full_epoch_steps:.2f}s/step)") + f"steps {t_warm:6.1f}s ({t_warm/full_epoch_steps:.2f}s/step)") s0 = _settled_last_seen(client, n) all_ids = sorted(s0.keys(), key=lambda k: int(k)) if sum(1 for k in all_ids if _ls(s0, k) >= 0) < 40: @@ -1146,7 +1146,7 @@ def _run_epochs(n_ep, steps_each, label, m_start): disc_frozen = sum(1 for sid in discard_ids if _ls(s1, sid) == _ls(s0, sid)) a1 = (kept_adv >= int(0.8 * len(keep)) and disc_frozen >= int(0.95 * len(discard_ids))) - print(f"[1] DISCARD SHIFT kept advanced {kept_adv}/{len(keep)} (>=80%), " + print(f"[1] DISCARD SHIFT kept advanced {kept_adv}/{len(keep)} (>=80%), " f"discarded frozen {disc_frozen}/{len(discard_ids)} (>=95%) -> {a1}") # --- grow: un-discard up to ~50% live --- @@ -1165,14 +1165,14 @@ def _run_epochs(n_ep, steps_each, label, m_start): still_frozen = sum(1 for sid in still_disc if _ls(s2, sid) == _ls(s1, sid)) a2 = ((not re_add or readd_adv >= int(0.8 * len(re_add))) and (not still_disc or still_frozen >= int(0.95 * len(still_disc)))) - print(f"[2] GROWTH HANDLED re-added advanced {readd_adv}/{len(re_add)} (>=80%), " + print(f"[2] GROWTH HANDLED re-added advanced {readd_adv}/{len(re_add)} (>=80%), " f"still-discarded frozen {still_frozen}/{len(still_disc)} (>=95%) -> {a2}") - print(f"[time] SUMMARY warmup={t_warm:.0f}s post-discard(2x10%)={t_lo:.0f}s " - f"post-readd(2x50%)={t_hi:.0f}s (warmup/post-discard ~= " + print(f"[time] SUMMARY warmup={t_warm:.0f}s post-discard(2x10%)={t_lo:.0f}s " + f"post-readd(2x50%)={t_hi:.0f}s (warmup/post-discard ~= " f"{t_warm/max(0.1,t_lo):.1f}x, expect ~5x if per-step cost is flat)") ok = a1 and a2 - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -1203,7 +1203,7 @@ def _free_port(): def _run_one(scn, batch): """Spawn a FRESH server (isolation), run one scenario, tear the server down.""" master_port, grpc_port = _free_port(), _free_port() - print(f"\n[suite] === {scn.__name__} === spawning {_WORLD} ranks, gRPC :{grpc_port}, " + print(f"\n[suite] === {scn.__name__} === spawning {_WORLD} ranks, gRPC :{grpc_port}, " f"imgsz={os.environ['WL_DDP_IMGSZ']}") ctx = mp.spawn(_train_worker, args=(_WORLD, master_port, grpc_port), nprocs=_WORLD, join=False) client = Client(grpc_port) @@ -1235,7 +1235,7 @@ def main(): _cfg_batch = yaml.safe_load(open(os.path.join(yolo_pipeline._HERE, "config.yaml")) )["data"]["train_loader"]["batch_size"] batch = int(os.environ.get("WL_DDP_BATCH", _cfg_batch)) - only = os.environ.get("WL_DDP_ONLY") # substring filter to run a single scenario + only = os.environ.get("WL_DDP_ONLY") # substring filter to run a single scenario # WL_DDP_SKIP: comma-separated substrings to EXCLUDE — lets a killed run resume # by skipping the scenarios that already passed (the suite has no checkpoint). skip = [s.strip() for s in os.environ.get("WL_DDP_SKIP", "").split(",") if s.strip()] @@ -1246,9 +1246,9 @@ def main(): print("\n" + "=" * 64) for name, ok in results.items(): - print(f" {name:42s} -> {'PASS' if ok else 'FAIL'}") + print(f" {name:42s} -> {'PASS' if ok else 'FAIL'}") allok = bool(results) and all(results.values()) - print(f" RESULT: {'ALL PASS' if allok else 'FAILURES ABOVE'}") + print(f" RESULT: {'ALL PASS' if allok else 'FAILURES ABOVE'}") print("=" * 64) raise SystemExit(0 if allok else 1) diff --git a/weightslab/tests/model/test_constraint_generation.py b/weightslab/tests/model/test_constraint_generation.py index 0b627a22..ddd00a94 100644 --- a/weightslab/tests/model/test_constraint_generation.py +++ b/weightslab/tests/model/test_constraint_generation.py @@ -42,7 +42,7 @@ def __init__(self): self.grouped_conv = nn.Conv2d(8, 16, kernel_size=3, padding=1, groups=2) self.bn = nn.BatchNorm2d(16) self.relu = nn.ReLU() - self.regular_conv = nn.Conv2d(16, 32, kernel_size=3, padding=1) # No groups + self.regular_conv = nn.Conv2d(16, 32, kernel_size=3, padding=1) # No groups def forward(self, x): x = self.grouped_conv(x) @@ -62,8 +62,8 @@ class DepthwisePointwiseModel(nn.Module): """ def __init__(self): super().__init__() - self.dw = nn.Conv2d(16, 16, kernel_size=3, padding=1, groups=16) # Depthwise - self.pw = nn.Conv2d(16, 32, kernel_size=1) # Pointwise + self.dw = nn.Conv2d(16, 16, kernel_size=3, padding=1, groups=16) # Depthwise + self.pw = nn.Conv2d(16, 32, kernel_size=1) # Pointwise def forward(self, x): x = self.dw(x) @@ -227,7 +227,7 @@ def test_constraint_no_hardcoding(self): """Constraints are detected via introspection, not hardcoding on names""" # Create a custom conv with groups but no special name conv_with_groups = nn.Conv2d(4, 8, 3, padding=1, groups=2) - conv_with_groups.__class__.__name__ = "CustomConv" # Change class name + conv_with_groups.__class__.__name__ = "CustomConv" # Change class name constraints = _detect_layer_constraints(conv_with_groups) diff --git a/weightslab/tests/model/test_dependency_patterns.py b/weightslab/tests/model/test_dependency_patterns.py index f1955282..7d62cd31 100644 --- a/weightslab/tests/model/test_dependency_patterns.py +++ b/weightslab/tests/model/test_dependency_patterns.py @@ -114,7 +114,7 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - out = out + out1 # Residual connection (REC dependency) + out = out + out1 # Residual connection (REC dependency) return out @@ -142,7 +142,7 @@ def forward(self, x): branch2 = self.conv2(x) branch2 = self.relu2(branch2) - merged = torch.cat([branch1, branch2], dim=1) # REC: both branches constrained + merged = torch.cat([branch1, branch2], dim=1) # REC: both branches constrained out = self.conv_merged(merged) return out @@ -341,10 +341,10 @@ def forward(self, x): out_id = out out = self.conv2(out) out = self.bn2(out) - out = out + out_id # inner residual + out = out + out_id # inner residual out = self.conv3(out) out = self.bn3(out) - out = out + out1 # outer residual + out = out + out1 # outer residual return out @@ -965,7 +965,7 @@ def setUp(self): self.model = MinimalConv1DChain() self.model.eval() self._model = self.model - self.dummy_input = torch.randn(1, 4, 64) # N, C, L + self.dummy_input = torch.randn(1, 4, 64) # N, C, L def test_conv1d_onnx(self): self.model = self.get_dependencies_onnx(self.model, self.dummy_input) @@ -998,7 +998,7 @@ def setUp(self): self.model = MinimalConv3DChain() self.model.eval() self._model = self.model - self.dummy_input = torch.randn(1, 2, 8, 16, 16) # N, C, D, H, W + self.dummy_input = torch.randn(1, 2, 8, 16, 16) # N, C, D, H, W def test_conv3d_onnx(self): self.model = self.get_dependencies_onnx(self.model, self.dummy_input) diff --git a/weightslab/tests/model/test_model_with_ops.py b/weightslab/tests/model/test_model_with_ops.py index 84a9a2a9..f381c608 100644 --- a/weightslab/tests/model/test_model_with_ops.py +++ b/weightslab/tests/model/test_model_with_ops.py @@ -21,7 +21,7 @@ # Set Global Default Settings -th.manual_seed(42) # Set SEED +th.manual_seed(42) # Set SEED TMP_DIR = '/tmp/utests/'; os.makedirs('/tmp/utests/', exist_ok=True) diff --git a/weightslab/tests/model/test_tracking.py b/weightslab/tests/model/test_tracking.py index 10ca164d..411730f5 100644 --- a/weightslab/tests/model/test_tracking.py +++ b/weightslab/tests/model/test_tracking.py @@ -17,7 +17,7 @@ # Set Global Default Settings DEVICE = 'cpu' if not th.cuda.is_available() else 'cuda' -th.manual_seed(42) # Set SEED +th.manual_seed(42) # Set SEED @unittest.skip("Constraint detection and propagation tests are currently skipped due to ongoing refactor and potential changes in the underlying implementation. Will be re-enabled once the new system is in place more modeling.") diff --git a/weightslab/tests/modules/test_modules_with_ops.py b/weightslab/tests/modules/test_modules_with_ops.py index c1ce3067..8ceab6a5 100644 --- a/weightslab/tests/modules/test_modules_with_ops.py +++ b/weightslab/tests/modules/test_modules_with_ops.py @@ -13,7 +13,7 @@ # Set Global Default Settings -th.manual_seed(42) # Set SEED +th.manual_seed(42) # Set SEED class LayerWiseOperationsTest(unittest.TestCase): @@ -85,7 +85,7 @@ def _test_operation_core( self._create_layers(device=device) layer_instance = self.all_layers.get(layer_key) - layer_instance.to(device) # Update tracker device + layer_instance.to(device) # Update tracker device if layer_instance == None: self.fail(f"Layer key '{layer_key}' not found in setup.") @@ -134,7 +134,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to increase parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # ADD must strictly increase the count + ) # ADD must strictly increase the count self.assertEqual( layer_instance.get_neurons(attr_name='out_neurons'), initial_nb_out_neurons + len(neuron_indices), @@ -142,7 +142,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # ADD 2 neurons must increase the count by 2 + ) # ADD 2 neurons must increase the count by 2 # --- Incoming --- if len(layer_instance.weight.shape) > 1: @@ -166,7 +166,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to increase parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # ADD must strictly increase the count + ) # ADD must strictly increase the count self.assertEqual( layer_instance.get_neurons(attr_name='in_neurons'), initial_nb_in_neurons + len(neuron_indices), @@ -174,7 +174,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # ADD 2 neurons must increase the count by 2 + ) # ADD 2 neurons must increase the count by 2 elif op == ArchitectureNeuronsOpType.PRUNE: # --- Not Incoming --- @@ -198,7 +198,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # PRUNE must strictly decrease the count + ) # PRUNE must strictly decrease the count self.assertEqual( layer_instance.get_neurons(attr_name='out_neurons'), initial_nb_out_neurons - len(neuron_indices), @@ -206,7 +206,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # PRUNE 2 neurons must decrease the count by 2 + ) # PRUNE 2 neurons must decrease the count by 2 # --- Incoming --- if len(layer_instance.weight.shape) > 1: @@ -230,7 +230,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # PRUNE must strictly decrease the count + ) # PRUNE must strictly decrease the count self.assertEqual( layer_instance.get_neurons(attr_name='in_neurons'), initial_nb_in_neurons - len(neuron_indices), @@ -238,7 +238,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # PRUNE 2 neurons must decrease the count by 2 + ) # PRUNE 2 neurons must decrease the count by 2 elif op == ArchitectureNeuronsOpType.FREEZE: # --- Not Incoming --- @@ -261,7 +261,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE must strictly decrease the count + ) # FREEZE must strictly decrease the count # for tensor_name in layer_instance.learnable_tensors_name: # reverse neuron index @@ -300,7 +300,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to unfreeze parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE must match initial count + ) # UNFREEZE must match initial count # # FREEZE & UNFREEZE every neurons # # FREEZE @@ -319,7 +319,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to freeze every params." + f"Init:{layer_instance.get_neurons(attr_name='in_neurons')}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE every out neurons + ) # FREEZE every out neurons # # # UNFREEZE layer_instance.operate( @@ -337,7 +337,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to unfreeze every params." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE every out neurons + ) # UNFREEZE every out neurons # --- Incoming --- if len(layer_instance.weight.shape) > 1: @@ -360,7 +360,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE must strictly decrease the count + ) # FREEZE must strictly decrease the count # for tensor_name in layer_instance.learnable_tensors_name: # reverse neuron index @@ -400,7 +400,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to unfreeze parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE must match initial count + ) # UNFREEZE must match initial count # # FREEZE & UNFREEZE every neurons # # FREEZE @@ -419,7 +419,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to freeze every params." + f"Init:{layer_instance.get_neurons(attr_name='in_neurons')}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE every out neurons + ) # FREEZE every out neurons # # # UNFREEZE layer_instance.operate( @@ -438,7 +438,7 @@ def _test_operation_core( "params." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE every out neurons + ) # UNFREEZE every out neurons elif op == ArchitectureNeuronsOpType.RESET: # RESET must preserve the number of parameters diff --git a/weightslab/tests/test_secure_docker.py b/weightslab/tests/test_secure_docker.py index cee59f9b..d7b9bd30 100644 --- a/weightslab/tests/test_secure_docker.py +++ b/weightslab/tests/test_secure_docker.py @@ -153,7 +153,7 @@ def test_backend_connection_timeout(self): """Test backend connection with timeout.""" result = _test_backend_connection( host='127.0.0.1', - port=59999, # Likely not listening + port=59999, # Likely not listening timeout=0.5 ) assert result is False diff --git a/weightslab/tests/trainer/services/test_agent_prompt_unit.py b/weightslab/tests/trainer/services/test_agent_prompt_unit.py index 96179db8..52941897 100644 --- a/weightslab/tests/trainer/services/test_agent_prompt_unit.py +++ b/weightslab/tests/trainer/services/test_agent_prompt_unit.py @@ -455,7 +455,7 @@ def test_tag_outliers_by_stddev(self): ctx = SimpleNamespace( _all_datasets_df=agent_mod.pd.DataFrame( { - "signals//train_loss": [0.1, 0.15, 0.12, 0.14, 1.5], # Last one is outlier + "signals//train_loss": [0.1, 0.15, 0.12, 0.14, 1.5], # Last one is outlier }, index=agent_mod.pd.MultiIndex.from_tuples( [(f"train", i) for i in range(5)], diff --git a/weightslab/tests/trainer/services/test_trainer_services_server.py b/weightslab/tests/trainer/services/test_trainer_services_server.py index 288ba402..61e38134 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_server.py +++ b/weightslab/tests/trainer/services/test_trainer_services_server.py @@ -4,7 +4,7 @@ import weightslab.trainer.trainer_services as trainer_services -# Default per-test timeout in seconds. Override with WL_TEST_TIMEOUT env var. +# Default per-test timeout in seconds. Override with WL_TEST_TIMEOUT env var. import os _TEST_TIMEOUT = int(os.getenv("WL_TEST_TIMEOUT", "30")) diff --git a/weightslab/tests/trainer/services/test_trainer_services_unit.py b/weightslab/tests/trainer/services/test_trainer_services_unit.py index 9d9676d8..1e6b0a6f 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_unit.py +++ b/weightslab/tests/trainer/services/test_trainer_services_unit.py @@ -115,7 +115,7 @@ def _agg(graph_name, sample_ids=None, exp_hash=None): # Only sample 11 is 'hard'-tagged → mean curve over {11} = one aggregated point. self.assertEqual(len(response.points), 1) - self.assertEqual(response.points[0].sample_id, "") # aggregated, not a single sample + self.assertEqual(response.points[0].sample_id, "") # aggregated, not a single sample self.assertEqual(response.points[0].metric_name, "test/loss") self.assertAlmostEqual(response.points[0].metric_value, 0.3, places=5) diff --git a/weightslab/tests/watchdog/test_lock_monitor.py b/weightslab/tests/watchdog/test_lock_monitor.py index 29a07bec..279cfcc5 100644 --- a/weightslab/tests/watchdog/test_lock_monitor.py +++ b/weightslab/tests/watchdog/test_lock_monitor.py @@ -90,7 +90,7 @@ class TestMonitoredRLockReentrant(unittest.TestCase): def test_same_thread_can_reacquire(self): lock = MonitoredRLock() lock.acquire() - lock.acquire() # reentrant — must not deadlock + lock.acquire() # reentrant — must not deadlock try: self.assertTrue(lock.is_held()) finally: diff --git a/weightslab/tests/watchdog/test_watchdog.py b/weightslab/tests/watchdog/test_watchdog.py index 1601a3cc..cb915736 100644 --- a/weightslab/tests/watchdog/test_watchdog.py +++ b/weightslab/tests/watchdog/test_watchdog.py @@ -46,7 +46,7 @@ def emit(self, record): log.addHandler(handler) log.setLevel(logging.DEBUG) try: - log.watchdog("hello %s", "world") # type: ignore[attr-defined] + log.watchdog("hello %s", "world") # type: ignore[attr-defined] finally: log.removeHandler(handler) @@ -104,11 +104,11 @@ def test_healthy_lock_not_interrupted(self): def quick_holder(): lock.acquire() - time.sleep(0.02) # well below threshold + time.sleep(0.02) # well below threshold lock.release() watchdog = WeighlabsWatchdog( - stuck_threshold_s=5.0, # high threshold — should not fire + stuck_threshold_s=5.0, # high threshold — should not fire poll_interval_s=0.05, ) watchdog.register_lock("safe_lock", lock) @@ -147,16 +147,16 @@ def test_stuck_rpc_triggers_restart_request(self): def test_healthy_rpc_does_not_trigger_restart(self): watchdog = WeighlabsWatchdog( - stuck_threshold_s=5.0, # high threshold + stuck_threshold_s=5.0, # high threshold poll_interval_s=0.05, restart_threshold=1, ) watchdog.start() rpc_id = watchdog.rpc_state.begin("/test/FastMethod") - time.sleep(0.02) # much less than threshold + time.sleep(0.02) # much less than threshold watchdog.rpc_state.end(rpc_id) - time.sleep(0.1) # let watchdog tick + time.sleep(0.1) # let watchdog tick watchdog.stop() self.assertFalse(watchdog.server_manager.should_restart()) @@ -165,14 +165,14 @@ def test_unhealthy_count_resets_on_recovery(self): watchdog = WeighlabsWatchdog( stuck_threshold_s=0.02, poll_interval_s=0.02, - restart_threshold=10, # high — won't restart + restart_threshold=10, # high — won't restart ) watchdog.start() rpc_id = watchdog.rpc_state.begin("/test/SlowThenFast") - time.sleep(0.12) # trigger unhealthy + time.sleep(0.12) # trigger unhealthy watchdog.rpc_state.end(rpc_id) - time.sleep(0.15) # let watchdog see healthy state + time.sleep(0.15) # let watchdog see healthy state watchdog.stop() self.assertEqual(watchdog._unhealthy_count, 0, "unhealthy_count must reset to 0 on recovery") @@ -194,7 +194,7 @@ class TestWatchdogConfigurability(unittest.TestCase): def test_per_lock_timeout_overrides_global_threshold(self): """A per-lock set_timeout() must take precedence over the global threshold.""" lock = MonitoredRLock() - lock.set_timeout(0.05) # this lock is allowed only 50ms, regardless of global + lock.set_timeout(0.05) # this lock is allowed only 50ms, regardless of global released = threading.Event() started = threading.Event() @@ -230,7 +230,7 @@ def test_restart_threshold_requires_n_consecutive_unhealthy(self): watchdog = WeighlabsWatchdog(stuck_threshold_s=0.01, poll_interval_s=0.03, restart_threshold=3) watchdog.start() rpc_id = watchdog.rpc_state.begin("/test/SlowMethod") - time.sleep(0.3) # many poll cycles → unhealthy count climbs past 3 + time.sleep(0.3) # many poll cycles → unhealthy count climbs past 3 watchdog.stop() watchdog.rpc_state.end(rpc_id) @@ -250,7 +250,7 @@ def holder(): started.set() try: for _ in range(40): - time.sleep(0.02) # ~0.8s, far over the threshold + time.sleep(0.02) # ~0.8s, far over the threshold except _WatchdogInterrupt: interrupted.set() finally: @@ -263,7 +263,7 @@ def holder(): watchdog = WeighlabsWatchdog(stuck_threshold_s=0.05, poll_interval_s=0.05) watchdog.register_lock("eval_lock", lock) watchdog.start() - time.sleep(0.4) # let several poll cycles run + time.sleep(0.4) # let several poll cycles run watchdog.stop() t.join(timeout=2.0) @@ -294,7 +294,7 @@ def test_dead_worker_with_running_controller_is_marked_error(self): controller = self._FakeController() dead_thread = threading.Thread(target=lambda: None) dead_thread.start() - dead_thread.join() # now not alive + dead_thread.join() # now not alive watchdog = WeighlabsWatchdog(poll_interval_s=0.05) watchdog.register_eval_monitor(lambda: controller, lambda: dead_thread) diff --git a/weightslab/trainer/experiment_context.py b/weightslab/trainer/experiment_context.py index 13a5f83e..52899f06 100644 --- a/weightslab/trainer/experiment_context.py +++ b/weightslab/trainer/experiment_context.py @@ -83,7 +83,7 @@ def ensure_components(self, force: bool = False): try: dnames = list_dataloaders() for dname in dnames: - data_loaders[dname] = get_dataloader(dname) # pre-load to catch errors early + data_loaders[dname] = get_dataloader(dname) # pre-load to catch errors early except Exception: logger.error("Error while listing/resolving dataloaders", exc_info=True) pass @@ -154,7 +154,7 @@ def ensure_components(self, force: bool = False): "checkpoint_manager": checkpoint_manager, "df_manager": df_manager } - self._components.update(data_loaders) # add all dataloaders found + self._components.update(data_loaders) # add all dataloaders found self._last_resolve_time = now # Build hyper-parameter descriptors used by the protocol. Use diff --git a/weightslab/trainer/services/agent/agent.py b/weightslab/trainer/services/agent/agent.py index 715e52f4..05a80621 100644 --- a/weightslab/trainer/services/agent/agent.py +++ b/weightslab/trainer/services/agent/agent.py @@ -194,9 +194,9 @@ def build_op(self, step: AtomicIntent, context: Intent) -> Optional[dict]: pattern = r"(\[\s*['\"])(.*?)(['\"]\s*\])" def replace_col(match): - prefix = match.group(1) # e.g. [' + prefix = match.group(1) # e.g. [' content = match.group(2) # e.g. signals//train_loss - suffix = match.group(3) # e.g. '] + suffix = match.group(3) # e.g. '] resolved = self.agent._resolve_column(content) # Try to resolve the content to a real column @@ -373,7 +373,7 @@ def _setup_schema(self): self._build_column_index() def _load_config(self): - self.preferred_provider = os.environ.get("PREFERRED_PROVIDER", "openrouter") # Default to OpenRouter if API key is provided, otherwise fallback to local Ollama. This can be overridden by config file or env variable. + self.preferred_provider = os.environ.get("PREFERRED_PROVIDER", "openrouter") # Default to OpenRouter if API key is provided, otherwise fallback to local Ollama. This can be overridden by config file or env variable. # Cloud provider settings with sensible defaults. OpenRouter is the default cloud provider if API key is provided. self.openrouter_model = os.environ.get("OPENROUTER_MODEL", "meta-llama/llama-3.3-70b-instruct") @@ -382,12 +382,12 @@ def _load_config(self): self.openrouter_request_timeout = float(os.environ.get("OPENROUTER_REQUEST_TIMEOUT", "15.0")) # Local fallback if no cloud (OpenRouter) is available or if the user prefers it. Ollama is the default local provider. - self.fallback_to_local = True # Default to allowing fallback to local Ollama if OpenRouter fails + self.fallback_to_local = True # Default to allowing fallback to local Ollama if OpenRouter fails self.ollama_host = "localhost" self.ollama_port = "11435" self.ollama_model = "llama3.2:3b" - repo_root = Path(__file__).resolve().parents[4] # weightslab/ root + repo_root = Path(__file__).resolve().parents[4] # weightslab/ root inner_pkg = Path(__file__).resolve().parents[3] env_paths = [repo_root / ".env", inner_pkg / ".env"] @@ -574,9 +574,9 @@ def initialize_with_cloud_key(self, api_key: str, provider: str, model: Optional Initialize (or reinitialize) the OpenRouter cloud provider. Args: - api_key: The API key obtained from the provider's website. + api_key: The API key obtained from the provider's website. provider: Must be ``"openrouter"``. - model: OpenRouter model identifier chosen by the user. + model: OpenRouter model identifier chosen by the user. Returns: ``(True, success_message)`` or ``(False, error_message)``. @@ -688,7 +688,7 @@ def _resolve_column(self, user_name: str) -> Optional[str]: # Normalize Input: lowercase, replace spaces AND SLASHES with underscores user_lower = user_name.strip().lower() - user_clean = re.sub(r"[ /_]+", "_", user_lower) # "signals//train_loss" -> "signals_train_loss" + user_clean = re.sub(r"[ /_]+", "_", user_lower) # "signals//train_loss" -> "signals_train_loss" # 1. Exact Match (Fast path) if user_name in self._cols: return user_name @@ -735,7 +735,7 @@ def _build_python_mask(self, conditions: List[Condition], n: Optional[int] = Non # 2. Normalize Operator op = cond.op.lower() - if op == "=" or op == "equals": op = "==" # Fix "equals" + if op == "=" or op == "equals": op = "==" # Fix "equals" val = cond.value diff --git a/weightslab/trainer/services/agent_service.py b/weightslab/trainer/services/agent_service.py index 29ad164e..e39998b0 100644 --- a/weightslab/trainer/services/agent_service.py +++ b/weightslab/trainer/services/agent_service.py @@ -4,16 +4,16 @@ gRPC surface for AI-agent lifecycle management. Responsibilities: - - CheckAgentHealth : report whether any LLM provider is ready. - - InitializeAgent : wire up a cloud provider from a user-supplied API key. + - CheckAgentHealth : report whether any LLM provider is ready. + - InitializeAgent : wire up a cloud provider from a user-supplied API key. The actual ``DataManipulationAgent`` instance lives inside ``DataService`` because it requires the live dataframe context (schema, column index, etc.) -that ``DataService`` owns. ``AgentService`` receives a reference to +that ``DataService`` owns. ``AgentService`` receives a reference to ``DataService`` at construction time and delegates to its agent. Wire-up (in ExperimentService): - data_service = DataService(ctx) + data_service = DataService(ctx) agent_service = AgentService(data_service) """ @@ -62,7 +62,7 @@ def CheckAgentHealth(self, request, context): Returns: AgentHealthResponse { available: bool, message: str } - - available=True → "Ready to help you." + - available=True → "Ready to help you." - available=False → "Agent not configured. Type /init to set up." """ available = self._is_available() diff --git a/weightslab/trainer/services/data_image_utils.py b/weightslab/trainer/services/data_image_utils.py index 872e8f3d..81358b32 100644 --- a/weightslab/trainer/services/data_image_utils.py +++ b/weightslab/trainer/services/data_image_utils.py @@ -2,7 +2,7 @@ data_image_utils — Image encoding, mask compression, and proto helpers for gRPC data serving. Extracted from data_service.py to keep image-specific logic separate from the -DataService orchestration class. All functions here are pure (stateless) and +DataService orchestration class. All functions here are pure (stateless) and safe to call from any thread. """ @@ -47,8 +47,8 @@ def rle_encode_mask(mask_flat: np.ndarray) -> bytes: ends = np.empty_like(starts) ends[:-1] = starts[1:] ends[-1] = mask_flat.size - lengths = ends - starts # numpy int array - values = mask_flat[starts] # numpy uint8 array + lengths = ends - starts # numpy int array + values = mask_flat[starts] # numpy uint8 array # Split any runs > 65535 into multiple segments out_vals: list[int] = [] diff --git a/weightslab/trainer/services/data_service.py b/weightslab/trainer/services/data_service.py index 004aea89..f30f8831 100755 --- a/weightslab/trainer/services/data_service.py +++ b/weightslab/trainer/services/data_service.py @@ -55,7 +55,7 @@ # Streamed chunk size for GetPointCloud (raw float32 bytes per gRPC message). # Larger chunks mean fewer messages but more memory per message. Override with # the WL_POINT_CLOUD_CHUNK_BYTES env variable (see docs/configuration.rst). -_DEFAULT_POINT_CLOUD_CHUNK_BYTES = 1 << 20 # 1 MiB +_DEFAULT_POINT_CLOUD_CHUNK_BYTES = 1 << 20 # 1 MiB def _point_cloud_chunk_bytes() -> int: @@ -259,12 +259,12 @@ def __init__(self, ctx): # rather than queuing to redo the same work. # # Protocol: - # 1. try_acquire(_update_lock, blocking=False) - # → won: clear _update_done, do the update, release, set _update_done - # → lost: _update_done.wait() then return (result already fresh) + # 1. try_acquire(_update_lock, blocking=False) + # → won: clear _update_done, do the update, release, set _update_done + # → lost: _update_done.wait() then return (result already fresh) self._update_lock = threading.Lock() self._update_done = threading.Event() - self._update_done.set() # "done" initially so the very first call proceeds + self._update_done.set() # "done" initially so the very first call proceeds # Guard so a non-force (reader-triggered) view refresh runs in the BACKGROUND # at most once at a time — readers never pay the rebuild cost (they read the # current snapshot; the bg thread swaps in fresh data when ready). @@ -290,7 +290,7 @@ def __init__(self, ctx): # Check hyperparameters for compute_natural_sort flag (default: False) # Users can enable it by setting compute_natural_sort=True in their hyperparameters. hp = self._ctx.components.get("hyperparams") if self._ctx and self._ctx.components else None - hp_dict = hp.get() if Proxy.is_proxy(hp) else (hp if isinstance(hp, dict) else {}) # is it already a proxy ? + hp_dict = hp.get() if Proxy.is_proxy(hp) else (hp if isinstance(hp, dict) else {}) # is it already a proxy ? self._compute_natural_sort = bool((hp_dict or {}).get("compute_natural_sort", False)) # How per-instance (per-annotation) numeric columns are folded to a single @@ -323,14 +323,14 @@ def __init__(self, ctx): max_workers=8 ) - self._is_filtered = False # Track if the current view is filtered/modified by user + self._is_filtered = False # Track if the current view is filtered/modified by user # logger.info("[DataService] Skipping expensive startup computations (aspect ratio, natural sort, signals).") # These should be triggered on-demand or run in background to avoid blocking training start. if self._compute_natural_sort: self._compute_natural_sort_stats() - self._is_filtered = False # Track if the current view is filtered/modified by user + self._is_filtered = False # Track if the current view is filtered/modified by user # ===================================================================== # Preview cache: pre-generate 64×64 or less WebP thumbnails + RLE masks for @@ -353,7 +353,7 @@ def __init__(self, ctx): daemon=True, ).start() else: - self._preview_cache_ready.set() # No preload → mark immediately ready + self._preview_cache_ready.set() # No preload → mark immediately ready logger.info("DataService initialized.") @@ -378,8 +378,8 @@ def _build_preview_cache(self) -> None: """Pre-generate 64×64 or less or less thumbnail + RLE mask for every row in the DF. Each entry is a lightweight ``DataRecord`` containing only: - • raw_data (bytes) — 64×64 or less or less WebP thumbnail - • target (rle_mask) — RLE-encoded GT mask resized to 64×64 or less or less + • raw_data (bytes) — 64×64 or less or less WebP thumbnail + • target (rle_mask) — RLE-encoded GT mask resized to 64×64 or less or less • pred_mask (rle_mask) — RLE-encoded prediction mask resized to 64×64 or less or less • origin, task_type, num_classes, class_names (metadata) Respects ``_preview_cache_max`` to cap memory usage. @@ -395,7 +395,7 @@ def _build_preview_cache(self) -> None: logger.info("[PreviewCache] Building 64×64 or less or less preview cache for %d samples …", total) t0 = time.time() - PREVIEW_SIZE = 64 # fixed low-res dimension + PREVIEW_SIZE = 64 # fixed low-res dimension built = 0 index_names = list(getattr(df.index, "names", []) or []) @@ -847,7 +847,7 @@ def _get_origin_filter(self, request): if val: # Normalize to list if isinstance(val, str): - origins = [val] if val.strip() else [] # Filter empty strings + origins = [val] if val.strip() else [] # Filter empty strings else: # Filter out empty strings from list origins = [o for o in list(val) if o and str(o).strip()] @@ -925,9 +925,9 @@ def _compute_natural_sort_stats(self): # 4. "Grouped" (Pseudo-primary key): Brightness=5.0, Entropy=1.0 (Forces clustering by light) SORT_WEIGHTS = { - "brightness": 0.7, # Primary cue: Lighting conditions - "entropy": 0.3, # Secondary cue: Texture/Scene complexity - "hue": 0.0 # Optional: Color tint + "brightness": 0.7, # Primary cue: Lighting conditions + "entropy": 0.3, # Secondary cue: Texture/Scene complexity + "hue": 0.0 # Optional: Color tint } logger.info(f"[DataService] Starting natural sort stats computation with weights: {SORT_WEIGHTS}") @@ -1208,7 +1208,7 @@ def _process_sample_row(self, args): task_type = "unknown" else: # 4. Safe Heuristic evaluation - task_type = "classification" # Default fallback + task_type = "classification" # Default fallback if label is not None: if isinstance(label, dict): if ('boxes' in label or 'bboxes' in label): @@ -1400,7 +1400,7 @@ def _process_sample_row(self, args): max_id = int(label_arr.max()) num_classes = max(1, max_id) + 1 else: - num_classes = 2 # Always at least 2 classes for segmentation (foreground/background) + num_classes = 2 # Always at least 2 classes for segmentation (foreground/background) data_stats.append( create_data_stat( @@ -1465,7 +1465,7 @@ def _process_sample_row(self, args): else: # Check if label is NaN (handle both scalars and arrays) if self._is_nan_value(label): - pass # Skip NaN labels + pass # Skip NaN labels # Handle scalar labels try: @@ -1614,7 +1614,7 @@ def _process_sample_row(self, args): else: # Classification: get prediction from row or dataset if pred is None: - pass # No prediction to process + pass # No prediction to process else: # Handle scalar predictions (int, float, or unwrapped from H5) @@ -1696,7 +1696,7 @@ def _process_sample_row(self, args): target_width = w_limit target_height = int(target_width / aspect_ratio) elif request.resize_width == 0 and request.resize_height == 0: - target_height = int(os.environ.get("WL_DEFAULT_THUMBNAIL_SIZE", 180)) # Default full resolution image is 360p on the longest side, but can be overridden by env var + target_height = int(os.environ.get("WL_DEFAULT_THUMBNAIL_SIZE", 180)) # Default full resolution image is 360p on the longest side, but can be overridden by env var target_width = int(target_height * aspect_ratio) if is_full_resolution: @@ -1874,7 +1874,7 @@ def _build_success_response( """ total_count = len(df) discarded_count = ( - len(df[df.get(SampleStatsEx.DISCARDED.value, False) == True]) # noqa: E712 + len(df[df.get(SampleStatsEx.DISCARDED.value, False) == True]) # noqa: E712 if df is not None and SampleStatsEx.DISCARDED.value in df.columns else 0 ) @@ -2219,7 +2219,7 @@ def _apply_agent_operation(self, df, func: str, params: dict) -> str: # We must split the updates by origin and upsert them to the manager if self._df_manager is not None: # Create a minimal update dataframe with just the modified column - update_payload = df[[col]] # .copy() # Remove copy because memory waste and slowdown + update_payload = df[[col]] # .copy() # Remove copy because memory waste and slowdown # Ensure origin is available for grouping if isinstance(df.index, pd.MultiIndex) and "origin" in df.index.names: @@ -2268,7 +2268,7 @@ def _apply_agent_operation(self, df, func: str, params: dict) -> str: if start < len(df): logger.debug(f"[sort_view_slice] Sorting slice {start}:{end}") # Extract and sort slice - sub_df = df.iloc[start:end] # .copy() # Remove copy because memory waste and slowdown + sub_df = df.iloc[start:end] # .copy() # Remove copy because memory waste and slowdown # Apply sort to slice # Filter params for sort_values @@ -2328,12 +2328,12 @@ def _apply_agent_operation(self, df, func: str, params: dict) -> str: # otherwise Sample ID X will point to data from Sample ID Y (corruption). try: if isinstance(df.index, pd.MultiIndex): - new_index_values = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown + new_index_values = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown new_index_values[start:end] = sub_df.index.to_numpy() df.index = pd.MultiIndex.from_tuples(new_index_values, names=df.index.names) else: idx_name = df.index.name - new_index = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown + new_index = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown new_index[start:end] = sub_df.index.to_numpy() df.index = pd.Index(new_index, name=idx_name) except Exception as e: @@ -2463,7 +2463,7 @@ def _restore_index(): # Lock watchdog helpers # ------------------------------------------------------------------ # ------------------------------------------------------------------ - # Lock watchdog helpers (build on MonitoredRLock from watchdog/) + # Lock watchdog helpers (build on MonitoredRLock from watchdog/) # ------------------------------------------------------------------ @staticmethod def _lock_caller() -> str: @@ -2500,7 +2500,7 @@ def _watched_lock(self, lock_name: str = "_lock"): self._lock.acquire() waited_ms = (time.time() - t0) * 1000 logger.debug( - "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", + "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", lock_name, thread, caller, waited_ms, ) t_held = time.time() @@ -2511,12 +2511,12 @@ def _watched_lock(self, lock_name: str = "_lock"): self._lock.release() if held_ms > 1000: logger.warning( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", lock_name, thread, held_ms, ) else: logger.debug( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", lock_name, thread, held_ms, ) @@ -2576,7 +2576,7 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> target=self._bg_view_refresh, name="WL-ViewRefresh", daemon=True ).start() except Exception: - self._refresh_in_flight.release() # never leak the guard + self._refresh_in_flight.release() # never leak the guard logger.exception("[ViewRefresh] failed to start background refresh") return @@ -2585,7 +2585,7 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> acquired = self._update_lock.acquire(blocking=False) if not acquired: - # Another worker is already updating. Wait for it to finish (bounded), + # Another worker is already updating. Wait for it to finish (bounded), # then return — the caller will read the already-refreshed view. thread = threading.current_thread().name logger.debug( @@ -2601,7 +2601,7 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> thread = threading.current_thread().name caller = self._lock_caller() logger.debug( - "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", + "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", "_update_lock[_slowUpdateInternals]", thread, caller, waited_ms, ) # Signal to latecomers that an update is now in progress. @@ -2710,12 +2710,12 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> self._update_lock.release() if held_ms > 1000: logger.warning( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", "_update_lock[_slowUpdateInternals]", threading.current_thread().name, held_ms, ) else: logger.debug( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", "_update_lock[_slowUpdateInternals]", threading.current_thread().name, held_ms, ) # Unblock all workers that were waiting on this update. @@ -2795,7 +2795,7 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, requested_cols=N # -- Vectorized pre-processing: build string matrices via pandas ------ # Separate tag columns from regular metadata columns for different - # handling. All heavy conversion is done once on the full column + # handling. All heavy conversion is done once on the full column # vectors, not per-row. tag_cols = [c for c in metadata_cols if c.startswith(tag_prefix)] meta_cols = [c for c in metadata_cols if not c.startswith(tag_prefix)] @@ -2809,10 +2809,10 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, requested_cols=N # -- Column-wise DataStat construction -------------------------------- # Build all DataStat objects for one column at a time using list # comprehensions (CPython fast-path) and inline pb2.DataStat() to - # eliminate the create_data_stat wrapper overhead. Then scatter - # them into the per-row bins. At 1M rows × 10 cols this avoids + # eliminate the create_data_stat wrapper overhead. Then scatter + # them into the per-row bins. At 1M rows × 10 cols this avoids # a 10M-iteration nested Python loop. - _DataStat = pb2.DataStat # local ref – avoids repeated attr lookup + _DataStat = pb2.DataStat # local ref – avoids repeated attr lookup for col in meta_cols: series = df_slice[col] @@ -3079,14 +3079,14 @@ def _process_get_data_samples(self, request, context): resolution (both dims ≤ ``_PREVIEW_CACHE_THRESHOLD``), serve from the cache instantly without touching the file system. 2. **Parallel batch processing** – All samples are submitted to the - thread pool at once so all 8 workers stay busy. The chunk-size + thread pool at once so all 8 workers stay busy. The chunk-size env-var ``WL_BATCH_CHUNK_SIZE`` is kept for backward compat but the default is now the full request size (all at once). """ - _PREVIEW_CACHE_THRESHOLD = 80 # max px to consider a "preview" request + _PREVIEW_CACHE_THRESHOLD = 80 # max px to consider a "preview" request # Default: process ALL rows at once in the thread pool (workers = 8). # Override with WL_BATCH_CHUNK_SIZE to throttle concurrency. - _BATCH_CHUNK_SIZE = int(os.environ.get("WL_BATCH_CHUNK_SIZE", "0")) # 0 = all at once + _BATCH_CHUNK_SIZE = int(os.environ.get("WL_BATCH_CHUNK_SIZE", "0")) # 0 = all at once try: start_time = time.time() @@ -3253,7 +3253,7 @@ def _process_get_data_samples(self, request, context): # ---- Parallel batch processing --------------------------------- # Submit ALL rows to the thread pool at once so all 8 workers - # stay busy. This avoids the old sequential-chunk bottleneck + # stay busy. This avoids the old sequential-chunk bottleneck # where each sub-batch had to finish before the next started. data_records: list = [] rows_list = list(df_slice.iterrows()) @@ -3319,7 +3319,7 @@ def _calculate_tag_column_updates(self, sample_id: int, origin: str, new_tag_nam new_tag_name = f'{SampleStatsEx.TAG.value}:{stripped_tag_name}' # Get current tags from the in-memory dataframe or df_manager - existing_tag_value = True # Default to True for new tags + existing_tag_value = True # Default to True for new tags try: if self._all_datasets_df is not None: # Read current tag columns from in-memory dataframe @@ -3335,7 +3335,7 @@ def _calculate_tag_column_updates(self, sample_id: int, origin: str, new_tag_nam if row is not None: for col in row.index: - if col == new_tag_name and row[col]: # If existing, revert the value + if col == new_tag_name and row[col]: # If existing, revert the value existing_tag_value = bool(1 - row[col]) except (KeyError, AttributeError) as e: @@ -3343,7 +3343,7 @@ def _calculate_tag_column_updates(self, sample_id: int, origin: str, new_tag_nam # Calculate target tags based on edit type if edit_type == SampleEditType.EDIT_REMOVE: - existing_tag_value = False # For removal, we set the tag to False + existing_tag_value = False # For removal, we set the tag to False target_tags_set = self._parse_tags(new_tag_name) else: # Override: replace all tags with the new value @@ -3385,8 +3385,8 @@ def ApplyDataQuery(self, request, context): Apply a query on the in-memory dataframe. Modes: - - request.query == "" -> just return counts, do not modify df - - request.query != "" -> always handled by the agent (natural language path) + - request.query == "" -> just return counts, do not modify df + - request.query != "" -> always handled by the agent (natural language path) Counts returned: - number_of_all_samples: all rows currently in the dataframe @@ -3430,10 +3430,10 @@ def ApplyDataQuery(self, request, context): is_sort_only = bool(operations) and all( op.get("function") in _SORT_FUNCS for op in operations) if not is_sort_only: - self._slowUpdateInternals(force=True) # Refresh internals before applying non-sort operations + self._slowUpdateInternals(force=True) # Refresh internals before applying non-sort operations # Work on a copy to allow concurrent readers to see a consistent state - df = self._all_datasets_df # Remove copy because memory waste and slowdown + df = self._all_datasets_df # Remove copy because memory waste and slowdown messages = [] for op in operations: @@ -3464,7 +3464,7 @@ def ApplyDataQuery(self, request, context): logger.info(f"[ApplyDataQuery] BYPASSING AGENT - Direct DataFrame operation: {request.query[:100]}...") with self._lock: - self._all_datasets_df, message = execute_df_operation(self._all_datasets_df, request.query) # in-place operation, or replace previous dataframe + self._all_datasets_df, message = execute_df_operation(self._all_datasets_df, request.query) # in-place operation, or replace previous dataframe logger.info(f"[ApplyDataQuery] Executed direct DataFrame operation. Message: {message}") if operations: @@ -3480,8 +3480,8 @@ def ApplyDataQuery(self, request, context): logger.info(f"[ApplyDataQuery] BYPASSING AGENT - Direct reset/clear operation: {request.query[:100]}...") # Force view reset with self._lock: - self._is_filtered = False # Unfreeze view first - self._slowUpdateInternals(force=True) # Force update to ensure we have the latest data + self._is_filtered = False # Unfreeze view first + self._slowUpdateInternals(force=True) # Force update to ensure we have the latest data logger.info(f"[ApplyDataQuery] Force view reset and unfrozen.") return pb2.DataQueryResponse( @@ -3531,7 +3531,7 @@ def status_cb(msg: str): if self._all_datasets_df is None: self._all_datasets_df = self._pull_into_all_data_view_df() or pd.DataFrame() - df = self._all_datasets_df # .copy() # Remove copy because memory waste and slowdown + df = self._all_datasets_df # .copy() # Remove copy because memory waste and slowdown messages = [] intent_type = pb2.INTENT_FILTER analysis_result = "" diff --git a/weightslab/trainer/services/experiment_service.py b/weightslab/trainer/services/experiment_service.py index 8f838469..1ad86ff4 100644 --- a/weightslab/trainer/services/experiment_service.py +++ b/weightslab/trainer/services/experiment_service.py @@ -177,7 +177,7 @@ def _get_latest_logger_data_impl(self, request, context): self._ctx.ensure_components() components = self._ctx.components signal_logger = components.get("signal_logger") - if signal_logger == None: + if signal_logger == None: return pb2.GetLatestLoggerDataResponse(points=[]) # Drop the request early if the client already disconnected @@ -239,7 +239,7 @@ def _get_latest_logger_data_impl(self, request, context): # WL_MAX_POINTS_PER_SAMPLE bounds points per returned curve (endpoints kept). max_points = _max_points_per_sample() for exp_hash, series in per_hash.items(): - series.sort(key=lambda sv: sv[0]) # order by step (model age) + series.sort(key=lambda sv: sv[0]) # order by step (model age) series = _downsample_uniform(series, max_points) audit = exp_hash in eval_hashes for step, mean_val in series: @@ -250,7 +250,7 @@ def _get_latest_logger_data_impl(self, request, context): metric_value=mean_val, experiment_hash=exp_hash, timestamp=now, - sample_id="", # aggregated mean curve — not a single sample + sample_id="", # aggregated mean curve — not a single sample audit_mode=audit, ) ) @@ -321,7 +321,7 @@ def _get_latest_logger_data_impl(self, request, context): metric_value=s.get("metric_value", 0.0), experiment_hash=s.get("experiment_hash", "N.A."), timestamp=int(s.get("timestamp", time.time())), - sample_id="", # No sample_id in aggregated mode + sample_id="", # No sample_id in aggregated mode is_evaluation_marker=bool(s.get("is_evaluation_marker", False)), split_name=str(s.get("split_name", "")), evaluation_tags=[str(tag) for tag in s.get("evaluation_tags", []) or []], @@ -347,7 +347,7 @@ def _get_latest_logger_data_impl(self, request, context): metric_value=s.get("metric_value", 0.0), experiment_hash=s.get("experiment_hash", "N.A."), timestamp=int(s.get("timestamp", time.time())), - sample_id="", # No sample_id in queue mode + sample_id="", # No sample_id in queue mode is_evaluation_marker=bool(s.get("is_evaluation_marker", False)), split_name=str(s.get("split_name", "")), evaluation_tags=[str(tag) for tag in s.get("evaluation_tags", []) or []], @@ -420,11 +420,11 @@ def RestoreCheckpoint(self, request, context): load_weights=True, load_config=True, load_data=True, - load_logger=False, # Don't load logger for weights-only restore to avoid overwriting signals, + load_logger=False, # Don't load logger for weights-only restore to avoid overwriting signals, target_step=target_step, ) else: - success = checkpoint_manager.load_state(experiment_hash, load_logger=False) # Don't load logger for full restore to avoid overwriting signals already in memory + success = checkpoint_manager.load_state(experiment_hash, load_logger=False) # Don't load logger for full restore to avoid overwriting signals already in memory # Reply if success: @@ -625,7 +625,7 @@ def _handle_save_checkpoint(self, save_op, components, context): } # 1) Resolve the model first. Without a registered model there is nothing - # to dump — fail early so we don't needlessly pause a running experiment. + # to dump — fail early so we don't needlessly pause a running experiment. model = components.get("model") if components else None if model is None: model = ledgers.get_model() @@ -649,9 +649,9 @@ def _handle_save_checkpoint(self, save_op, components, context): return pb2.CommandResponse(success=False, message=msg) # 3) Pause training before dumping. Acquire the global lock so any - # in-flight training step has finished; clearing is_training keeps the - # loop parked at its next pause point, so the subsequent save reads a - # consistent model state. + # in-flight training step has finished; clearing is_training keeps the + # loop parked at its next pause point, so the subsequent save reads a + # consistent model state. if not try_acquire_rlock(): logger.error( "[SaveCheckpoint] weightslab_rlock timed out after %.0fs", @@ -678,7 +678,7 @@ def _handle_save_checkpoint(self, save_op, components, context): weightslab_rlock.release() # 4) Ensure an experiment hash exists so save_model_checkpoint has a - # target directory (a brand-new experiment may not have one yet). + # target directory (a brand-new experiment may not have one yet). try: if getattr(checkpoint_manager, "current_exp_hash", None) is None: if hasattr(checkpoint_manager, "get_current_experiment_hash"): @@ -950,12 +950,12 @@ def ExperimentCommand(self, request, context): ) # TODO (GP): Disabled with modelling for now. # if hyper_parameters.HasField("learning_rate"): - # hp_changes["learning_rate"] = hyper_parameters.learning_rate - # set_hyperparam( - # name=hp_name, - # key_path="optimizer.lr", - # value=hyper_parameters.learning_rate - # ) + # hp_changes["learning_rate"] = hyper_parameters.learning_rate + # set_hyperparam( + # name=hp_name, + # key_path="optimizer.lr", + # value=hyper_parameters.learning_rate + # ) if hyper_parameters.HasField("batch_size"): hp_changes["batch_size"] = hyper_parameters.batch_size set_hyperparam( @@ -1117,7 +1117,7 @@ def ExperimentCommand(self, request, context): # Set number of steps desired to run before next pause if provided, based on current model age + requested nb_steps if hyper_parameters.HasField("nb_steps"): - m = components.get("model") # Get model + m = components.get("model") # Get model m_age = m.get_age() logger.info(f"\n[WeightsLab] UI Command: Define number of steps at {hyper_parameters.nb_steps}") if hyper_parameters.nb_steps > 0: diff --git a/weightslab/trainer/services/instance_merger.py b/weightslab/trainer/services/instance_merger.py index 66f7efa1..bc582936 100644 --- a/weightslab/trainer/services/instance_merger.py +++ b/weightslab/trainer/services/instance_merger.py @@ -102,7 +102,7 @@ def merge_segmentation_instances(instance_values: List[Any], task_type: str = "s mask1: [[0, 1], [0, 1]] mask2: [[1, 0], [0, 1]] Output: np.max([mask0, mask1, mask2], axis=0) - = [[1, 1], [1, 1]] [MAX aggregated!] + = [[1, 1], [1, 1]] [MAX aggregated!] - Input: [mask0] (single mask) Output: mask0 as-is (512, 512) @@ -133,7 +133,7 @@ def merge_segmentation_instances(instance_values: List[Any], task_type: str = "s # Multiple masks: aggregate using max at each pixel # Stack temporarily for max operation, then return result stacked = np.stack(masks_np, axis=0) - return np.max(stacked, axis=0) # Take max across instances → (H, W) + return np.max(stacked, axis=0) # Take max across instances → (H, W) def merge_classification_instances(instance_values: List[Any], task_type: str = "classification") -> Union[list, None]: @@ -152,7 +152,7 @@ def merge_classification_instances(instance_values: List[Any], task_type: str = - All None: Return None Example: - - Input: ['cat', None, None] → Output: ['cat'] [LIST!] + - Input: ['cat', None, None] → Output: ['cat'] [LIST!] - Input: ['cat', 'dog', 'animal'] → Output: ['cat', 'dog', 'animal'] """ labels = [] @@ -212,9 +212,9 @@ def group_instances_by_sample(df_slice, target_column: str, task_type: str): Returns: Dict mapping sample_id to merged value Example: { - 'sample_0': [bbox0, bbox1, bbox2], # Detection: list of bboxes - 'sample_1': [mask0, mask1], # Segmentation: list of masks - 'sample_2': 'cat', # Classification: single label + 'sample_0': [bbox0, bbox1, bbox2], # Detection: list of bboxes + 'sample_1': [mask0, mask1], # Segmentation: list of masks + 'sample_2': 'cat', # Classification: single label } """ if df_slice.empty or target_column not in df_slice.columns: diff --git a/weightslab/trainer/services/utils/tools.py b/weightslab/trainer/services/utils/tools.py index 3654fabd..11148009 100644 --- a/weightslab/trainer/services/utils/tools.py +++ b/weightslab/trainer/services/utils/tools.py @@ -3,7 +3,7 @@ ======================= Shared utility helpers for the trainer service layer. -Keep this file free of heavy domain logic. It is the right place for: +Keep this file free of heavy domain logic. It is the right place for: - Small, stateless helper functions used by two or more services. - Shared constants / lookup tables (e.g. provider maps). - Thin wrappers that reduce boilerplate inside service methods. diff --git a/weightslab/trainer/trainer_services.py b/weightslab/trainer/trainer_services.py index 76e76f43..0400c434 100644 --- a/weightslab/trainer/trainer_services.py +++ b/weightslab/trainer/trainer_services.py @@ -282,7 +282,7 @@ def intercept_service(self, continuation, handler_call_details): # --------------------------------------------------------------------------- # Backward-compat note: RpcWatchdogState, RpcTimingAndWatchdogInterceptor and # GrpcServerManager are now defined in weightslab.watchdog.grpc_watchdog and -# re-exported above. External code that imported them from trainer_services +# re-exported above. External code that imported them from trainer_services # continues to work unchanged. # --------------------------------------------------------------------------- @@ -468,12 +468,12 @@ def grpc_serve( grpc_host = os.getenv("GRPC_BACKEND_HOST", "0.0.0.0") if not force_parameters or grpc_host is None else grpc_host grpc_port = int(os.getenv("GRPC_BACKEND_PORT", 50051)) if not force_parameters or grpc_port is None else grpc_port - watchdog_threshold_s = float(os.getenv("GRPC_WATCHDOG_STUCK_SECONDS", "180")) # 3 minutes default stuck threshold + watchdog_threshold_s = float(os.getenv("GRPC_WATCHDOG_STUCK_SECONDS", "180")) # 3 minutes default stuck threshold watchdog_interval_s = float(os.getenv("GRPC_WATCHDOG_INTERVAL_SECONDS", "5")) watchdog_exit_on_stuck = str(os.getenv("GRPC_WATCHDOG_EXIT_ON_STUCK", "0")).strip().lower() in {"1", "true", "yes", "on"} - watchdog_restart_threshold = int(os.getenv("GRPC_WATCHDOG_RESTART_THRESHOLD", "3")) # Restart after 3 unhealthy checks + watchdog_restart_threshold = int(os.getenv("GRPC_WATCHDOG_RESTART_THRESHOLD", "3")) # Restart after 3 unhealthy checks watchdog_details_limit = int(os.getenv("GRPC_WATCHDOG_INFLIGHT_DETAILS_LIMIT", "10")) - watchdog_disabled = str(os.getenv("WEIGHTSLAB_DISABLE_WATCHDOGS", "1")).strip().lower() in {"1", "true", "yes", "on"} # Default state: disabled + watchdog_disabled = str(os.getenv("WEIGHTSLAB_DISABLE_WATCHDOGS", "1")).strip().lower() in {"1", "true", "yes", "on"} # Default state: disabled config = get_hyperparams() grpc_tls_enabled = _resolve_bool_setting(config, "grpc_tls_enabled", "GRPC_TLS_ENABLED", "0") grpc_tls_key_file = _resolve_grpc_tls_path( @@ -528,7 +528,7 @@ def grpc_serve( ) watchdog.register_lock("weightslab_rlock", weightslab_rlock) - # Eval thread monitor — no timeout, just liveness. Lazy imports avoid + # Eval thread monitor — no timeout, just liveness. Lazy imports avoid # circular dependencies since weightslab.src imports trainer code. def _get_eval_controller(): from weightslab.components.evaluation_controller import eval_controller as _ec @@ -542,8 +542,8 @@ def _get_eval_thread(): get_controller=_get_eval_controller, get_thread=_get_eval_thread, ) - watchdog_state = watchdog.rpc_state # shared with RpcTimingAndWatchdogInterceptor - server_manager = watchdog.server_manager # shared with serving_thread_callback + watchdog_state = watchdog.rpc_state # shared with RpcTimingAndWatchdogInterceptor + server_manager = watchdog.server_manager # shared with serving_thread_callback logger.debug( f"grpc_serve called with parameters: n_workers_grpc={n_workers_grpc}, grpc_host={grpc_host}, grpc_port={grpc_port}, " f"watchdog_threshold_s={watchdog_threshold_s}, watchdog_interval_s={watchdog_interval_s}, watchdog_exit_on_stuck={watchdog_exit_on_stuck}, watchdog_restart_threshold={watchdog_restart_threshold}, " @@ -553,13 +553,13 @@ def _get_eval_thread(): def serving_thread_callback(): logger.info("[gRPC] Thread callback started") try: - while True: # Loop to allow restarts + while True: # Loop to allow restarts _effective_workers = n_workers_grpc or min(32, (os.cpu_count() or 1) + 4) logger.info( "[gRPC] Creating ThreadPoolExecutor with %d worker threads (n_workers_grpc=%s, max_concurrent_rpcs=%s)", _effective_workers, n_workers_grpc, max_concurrent_rpcs, ) - _max_msg = int(os.getenv("GRPC_MAX_MESSAGE_BYTES", 256 * 1024 * 1024)) # 256 MB + _max_msg = int(os.getenv("GRPC_MAX_MESSAGE_BYTES", 256 * 1024 * 1024)) # 256 MB server = grpc.server( futures.ThreadPoolExecutor( thread_name_prefix="WL-gRPC-Worker", @@ -636,16 +636,16 @@ def serving_thread_callback(): while not server_manager.should_restart(): time.sleep(0.5) - logger.watchdog("[gRPC] Restart requested. Gracefully shutting down (5s grace)...") # type: ignore[attr-defined] + logger.watchdog("[gRPC] Restart requested. Gracefully shutting down (5s grace)...") # type: ignore[attr-defined] stop_event = server.stop(grace=5) stopped = stop_event.wait(timeout=6.0) if not stopped: - logger.watchdog("[gRPC] Graceful stop timed out; forcing immediate stop.") # type: ignore[attr-defined] + logger.watchdog("[gRPC] Graceful stop timed out; forcing immediate stop.") # type: ignore[attr-defined] server.stop(grace=0).wait(timeout=1.0) cleared = watchdog_state.clear_for_restart() if cleared: - logger.watchdog("[gRPC] Cleared %d stale in-flight RPC records after restart.", cleared) # type: ignore[attr-defined] + logger.watchdog("[gRPC] Cleared %d stale in-flight RPC records after restart.", cleared) # type: ignore[attr-defined] server_manager.clear_restart_request() logger.info("[gRPC] Server stopped. Restarting in 2s...") time.sleep(2) diff --git a/weightslab/trainer/trainer_tools.py b/weightslab/trainer/trainer_tools.py index e017a279..f81e7f9b 100644 --- a/weightslab/trainer/trainer_tools.py +++ b/weightslab/trainer/trainer_tools.py @@ -59,7 +59,7 @@ def execute_df_operation(df, operation_str): except Exception as e: error_msg = f"Error executing DataFrame operation '{operation_str}': {e}" logger.error(error_msg, exc_info=True) - return df, error_msg # Return original df on error + return df, error_msg # Return original df on error def get_hyper_parameters_pb( @@ -82,7 +82,7 @@ def get_hyper_parameters_pb( # For numerical values, ensure we pass a float to gRPC to avoid "must be real number" errors try: if hasattr(value, "get"): - value = value.get() # unwrap if it's a wrapper object + value = value.get() # unwrap if it's a wrapper object if value is None or value == None: num_val = 'null' type_ = 'string' @@ -241,7 +241,7 @@ def _labels_from_mask_path_histogram(path, num_classes=None, ignore_index=255): with Image.open(path) as im: if im.mode not in ("P", "L"): im = im.convert("L") - hist = im.histogram() # length 256 + hist = im.histogram() # length 256 ub = 256 if num_classes is None else int(num_classes) ids = [i for i, cnt in enumerate(hist[:ub]) if cnt > 0] if ignore_index is not None: @@ -313,7 +313,7 @@ def _safe_dataset_length(ds): sample_stats.task_type = task_type ignore_index = getattr(dataset, "ignore_index", 255) - num_classes = getattr(dataset, "num_classes", getattr(experiment, "num_classes", None)) + num_classes = getattr(dataset, "num_classes", getattr(experiment, "num_classes", None)) # Safely iterate dataset records; if as_records isn't available or dataset is a placeholder # fall back to an empty iterator. @@ -343,9 +343,9 @@ def _safe_dataset_length(ds): pred_list = _class_ids(row.get("prediction_raw"), num_classes, ignore_index) else: target = row.get("label", row.get("target", -1)) - pred = row.get("prediction_raw", -1) + pred = row.get("prediction_raw", -1) target_list = [int(target)] if not isinstance(target, (list, np.ndarray)) else [int(np.array(target).item())] - pred_list = [int(pred)] if not isinstance(pred, (list, np.ndarray)) else [int(np.array(pred).item())] + pred_list = [int(pred)] if not isinstance(pred, (list, np.ndarray)) else [int(np.array(pred).item())] record.sample_label.extend(target_list) record.sample_prediction.extend(pred_list) @@ -526,18 +526,18 @@ def encode_image_to_raw_bytes( - All other cases (thumbnails, 2D): PIL image compressed to WebP (JPEG fallback). Args: - np_img: Numpy array of the image (required for the volumetric path). - middle_pil: PIL Image (required for the 2D / thumbnail path). - original_shape: Original tensor shape, used to derive [Z, H, W, C] for volumetric. - is_volumetric: True when the image has a depth (Z) dimension. + np_img: Numpy array of the image (required for the volumetric path). + middle_pil: PIL Image (required for the 2D / thumbnail path). + original_shape: Original tensor shape, used to derive [Z, H, W, C] for volumetric. + is_volumetric: True when the image has a depth (Z) dimension. is_full_resolution: True when sending the full modal view, False for grid thumbnails. - target_width: Width of the (possibly resized) output image. - target_height: Height of the (possibly resized) output image. + target_width: Width of the (possibly resized) output image. + target_height: Height of the (possibly resized) output image. Returns: raw_data_bytes: Encoded bytes ready for gRPC transfer. - raw_shape: [Z, H, W, C] or [H, W, C] shape of the encoded data. - encode_time_s: Seconds spent encoding (0.0 for the raw float32 path). + raw_shape: [Z, H, W, C] or [H, W, C] shape of the encoded data. + encode_time_s: Seconds spent encoding (0.0 for the raw float32 path). """ raw_data_bytes: bytes = b"" raw_shape: list = [] @@ -549,16 +549,16 @@ def encode_image_to_raw_bytes( if not np_img_f32.flags['C_CONTIGUOUS']: np_img_f32 = np.ascontiguousarray(np_img_f32) raw_data_bytes = np_img_f32.tobytes() - del np_img_f32 # release float32 copy immediately + del np_img_f32 # release float32 copy immediately # Normalise shape to [Z, H, W, C] from the original 4-D tensor. if len(original_shape) == 4: if original_shape[1] > original_shape[-1]: - raw_shape = list(original_shape) # already [Z, H, W, C] + raw_shape = list(original_shape) # already [Z, H, W, C] elif original_shape[1] < original_shape[-1]: - raw_shape = [original_shape[0], original_shape[2], original_shape[3], original_shape[1]] # [Z, C, H, W] -> [Z, H, W, C] + raw_shape = [original_shape[0], original_shape[2], original_shape[3], original_shape[1]] # [Z, C, H, W] -> [Z, H, W, C] else: - raw_shape = [original_shape[0], original_shape[1], original_shape[2], 1] # ambiguous: assume single channel + raw_shape = [original_shape[0], original_shape[1], original_shape[2], 1] # ambiguous: assume single channel logger.info( "[Volumetric] Sending full res: np_img.shape=%s, original_shape=%s, raw_shape=%s, bytes=%d", np_img.shape, original_shape, raw_shape, len(raw_data_bytes), @@ -567,7 +567,7 @@ def encode_image_to_raw_bytes( # Thumbnail (grid) or non-volumetric: compress with WebP, fall back to JPEG. # WebP is ~40-50 % smaller than JPEG at equivalent visual quality. _quality = 80 if is_full_resolution else 65 - _webp_method = 4 if is_full_resolution else 2 # 0 = fastest … 6 = smallest + _webp_method = 4 if is_full_resolution else 2 # 0 = fastest … 6 = smallest raw_buf = io.BytesIO() t0_enc = time.time() try: diff --git a/weightslab/ui_docker_bridge.py b/weightslab/ui_docker_bridge.py index 343b55c3..229fdb13 100644 --- a/weightslab/ui_docker_bridge.py +++ b/weightslab/ui_docker_bridge.py @@ -56,7 +56,7 @@ def _persist_certs_dir(certs_dir_str: str) -> None: """Persist WEIGHTSLAB_CERTS_DIR so future terminals and the training backend find it. - Windows — runs `setx` (permanent user env) and prints the PS one-liner for + Windows — runs `setx` (permanent user env) and prints the PS one-liner for the current session. Linux/macOS — appends an export line to ~/.bashrc (idempotent) and prints the source command for the current session. @@ -68,10 +68,10 @@ def _persist_certs_dir(certs_dir_str: str) -> None: stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) if result.returncode == 0: - logger.info("✓ WEIGHTSLAB_CERTS_DIR saved permanently via setx (new terminals will have it)") + logger.info(" WEIGHTSLAB_CERTS_DIR saved permanently via setx (new terminals will have it)") else: logger.warning(f"setx failed — set it manually: setx WEIGHTSLAB_CERTS_DIR \"{certs_dir_str}\"") - logger.info(f" Current terminal (PowerShell): $env:WEIGHTSLAB_CERTS_DIR = \"{certs_dir_str}\"") + logger.info(f" Current terminal (PowerShell): $env:WEIGHTSLAB_CERTS_DIR = \"{certs_dir_str}\"") else: bashrc = Path.home() / ".bashrc" try: @@ -79,13 +79,13 @@ def _persist_certs_dir(certs_dir_str: str) -> None: if export_line not in existing: with open(bashrc, "a", encoding="utf-8") as f: f.write(f"\n# Added by weightslab\n{export_line}\n") - logger.info(f"✓ WEIGHTSLAB_CERTS_DIR appended to {bashrc} (new terminals will have it)") + logger.info(f" WEIGHTSLAB_CERTS_DIR appended to {bashrc} (new terminals will have it)") else: - logger.info(f"✓ WEIGHTSLAB_CERTS_DIR already in {bashrc}") + logger.info(f" WEIGHTSLAB_CERTS_DIR already in {bashrc}") except OSError as e: logger.warning(f"Could not write to {bashrc}: {e}") - logger.info(f" Add manually: {export_line}") - logger.info(f" Current terminal: source ~/.bashrc (or open a new terminal)") + logger.info(f" Add manually: {export_line}") + logger.info(f" Current terminal: source ~/.bashrc (or open a new terminal)") def _strip_derived_deploy_env() -> None: @@ -113,40 +113,40 @@ def _banner() -> str: _EPILOG = """\ commands: - se Set up the secure environment: generate TLS + se Set up the secure environment: generate TLS certificates + a gRPC auth token in ~/.weightslab-certs. Then set WEIGHTSLAB_CERTS_DIR (the single source of truth) so the backend + new shells find them. - --force-certs regenerate even if certs exist + --force-certs regenerate even if certs exist - ui launch Purge stale weightslab/weights_studio Docker + ui launch Purge stale weightslab/weights_studio Docker resources, then build & start the UI stack. UNSECURED (HTTP) by default — no certs generated. - --certs generate (if missing) + use TLS + --certs generate (if missing) + use TLS certs + gRPC auth (HTTPS) - start example Run a bundled PyTorch example (foreground; stop with + start example Run a bundled PyTorch example (foreground; stop with Ctrl+C). Installs the example's requirements first, without prompting. Defaults to classification: - --cls classification example (default) - --seg segmentation example - --det detection example - --clus clustering example - --gen generation example - --3d_det 3D LiDAR point-cloud detection example - --2d_det 2D LiDAR point-cloud detection example + --cls classification example (default) + --seg segmentation example + --det detection example + --clus clustering example + --gen generation example + --3d_det 3D LiDAR point-cloud detection example + --2d_det 2D LiDAR point-cloud detection example examples: - weightslab se # one-time secure setup (then export WEIGHTSLAB_CERTS_DIR) - weightslab se --force-certs # regenerate the certs - weightslab ui launch # clean + launch (unsecured HTTP, default) - weightslab ui launch --certs # secured launch (HTTPS + gRPC auth) - weightslab start example # run the classification demo (default) - weightslab start example --seg # run the segmentation demo - weightslab start example --det # run the detection demo - weightslab start example --3d_det # run the 3D LiDAR detection demo - weightslab start example --2d_det # run the 2D LiDAR detection demo + weightslab se # one-time secure setup (then export WEIGHTSLAB_CERTS_DIR) + weightslab se --force-certs # regenerate the certs + weightslab ui launch # clean + launch (unsecured HTTP, default) + weightslab ui launch --certs # secured launch (HTTPS + gRPC auth) + weightslab start example # run the classification demo (default) + weightslab start example --seg # run the segmentation demo + weightslab start example --det # run the detection demo + weightslab start example --3d_det # run the 3D LiDAR detection demo + weightslab start example --2d_det # run the 2D LiDAR detection demo """ @@ -329,7 +329,7 @@ def _compose_cmd(compose_file, envoy_config, action): # locally), then `up` without the flag. v2 supports it inline, so leave it. if base == ["docker-compose"] and action and action[0] == "up" and "--pull" in action: i = action.index("--pull") - del action[i:i + 2] # drop '--pull' and its policy value (e.g. 'always') + del action[i:i + 2] # drop '--pull' and its policy value (e.g. 'always') logger.info("Docker Compose v1 detected — pulling images before 'up'...") pull_result = subprocess.run( base + ["-f", str(compose_file), "pull"], @@ -452,7 +452,7 @@ def _run_shell_script(script_path: str, args: list = None, env_vars: dict = None # Build bash command - pass Windows path directly, script will handle conversion # # Process path to ensure it's compatible with bash, especially on Windows if _is_windows() and '\\' in script_path: - script_path = script_path.replace("\\", "/") # Ensure path is Unix-style for bash + script_path = script_path.replace("\\", "/") # Ensure path is Unix-style for bash script_path = _convert_to_git_bash_path(script_path) logger.info(f"Converted script path for bash: {script_path}") logger.info(f"Running shell script: {script_path} with args: {args} and env_vars: {env_vars}") @@ -550,8 +550,8 @@ def _install_ca_trust(ca_file: Path) -> None: Idempotent and safe to call on every launch. Platform behavior: * Windows — adds to the CurrentUser\\Root store via the .NET X509Store API (silent, no prompt). - * macOS — adds to the login keychain (may show a one-time auth prompt). - * Linux — installs into the system trust store via sudo (one-time prompt) + * macOS — adds to the login keychain (may show a one-time auth prompt). + * Linux — installs into the system trust store via sudo (one-time prompt) and, best-effort, the user's NSS DB so Chrome/Firefox trust it too. A failure here is non-fatal: TLS still works, the browser just shows a @@ -584,7 +584,7 @@ def _install_ca_trust(ca_file: Path) -> None: stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if result.returncode == 0: - logger.info("✓ Dev CA trusted in Windows CurrentUser\\Root store (restart browser to apply)") + logger.info(" Dev CA trusted in Windows CurrentUser\\Root store (restart browser to apply)") else: logger.warning(f"Could not auto-trust dev CA: {result.stderr.strip()}") return @@ -595,7 +595,7 @@ def _install_ca_trust(ca_file: Path) -> None: stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) if check.returncode == 0: - logger.info("✓ Dev CA already trusted (macOS keychain)") + logger.info(" Dev CA already trusted (macOS keychain)") return logger.info("Installing dev CA into macOS login keychain (may prompt)...") subprocess.run( @@ -615,7 +615,7 @@ def _install_ca_trust(ca_file: Path) -> None: subprocess.run(["sudo", "cp", str(ca_file), str(system_ca)]) subprocess.run(["sudo", "update-ca-certificates"]) else: - logger.info("✓ Dev CA already in Linux system trust store") + logger.info(" Dev CA already in Linux system trust store") # Browsers use their own NSS DB; add it there too if certutil is available. if shutil.which("certutil"): @@ -636,7 +636,7 @@ def _ensure_certificates(manager: CertAuthManager, force_certs: bool = False) -> truth). Returns True if certs are present afterwards, False otherwise. """ if manager.has_any_credentials() and not force_certs: - logger.info(f"✓ Using existing credentials in {manager.certs_dir}") + logger.info(f" Using existing credentials in {manager.certs_dir}") manager.get_or_create_auth_token() # Ensure the CA is trusted even when reusing certs from a prior run that # was generated via bash (which does not install OS trust). @@ -656,7 +656,7 @@ def _ensure_certificates(manager: CertAuthManager, force_certs: bool = False) -> manager.get_or_create_auth_token() _install_ca_trust(manager.ca_file) - logger.info(f"✓ Certificates ready in {manager.certs_dir}") + logger.info(f" Certificates ready in {manager.certs_dir}") return manager.has_valid_certs() @@ -724,10 +724,10 @@ def ui_launch(args): (file presence is the single source of truth) and are never deleted here. Flags (all optional, read defensively so legacy callers still work): - --certs generate (if missing) and use TLS certs + gRPC auth (HTTPS) - --force-certs with --certs, regenerate certificates even if they exist - --no-clean skip the stale Docker resource cleanup step - --dev use the dev compose overlay + --certs generate (if missing) and use TLS certs + gRPC auth (HTTPS) + --force-certs with --certs, regenerate certificates even if they exist + --no-clean skip the stale Docker resource cleanup step + --dev use the dev compose overlay """ _check_docker() # pip installs the bundled .sh scripts without the execute bit; make them @@ -869,10 +869,10 @@ def ui_launch(args): # The backend and any new shell must point at the same certs dir, or # they'll mismatch the UI's TLS/auth. Keep this the last thing printed. logger.warning("") - logger.warning("⚠ ACTION REQUIRED — TLS is ON. Set WEIGHTSLAB_CERTS_DIR so the " + logger.warning(" ACTION REQUIRED — TLS is ON. Set WEIGHTSLAB_CERTS_DIR so the " "training backend and new terminals use the same certificates:") - logger.warning(f" (bash) export WEIGHTSLAB_CERTS_DIR=\"{certs_dir_str}\"") - logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{certs_dir_str}\"") + logger.warning(f" (bash) export WEIGHTSLAB_CERTS_DIR=\"{certs_dir_str}\"") + logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{certs_dir_str}\"") else: logger.info("UI is running UNSECURED (HTTP, no gRPC auth). " "Re-run with `weightslab ui launch --certs` for TLS.") @@ -913,17 +913,17 @@ def ui_secure_environment(args): # Export ONLY the single source of truth for this process. os.environ["WEIGHTSLAB_CERTS_DIR"] = str(manager.certs_dir) - logger.info("✓ Certificates generated successfully") - logger.info("✓ gRPC auth token created") - logger.info(f"✓ Certs and token stored in: {manager.certs_dir}") - logger.info(f"✓ WEIGHTSLAB_CERTS_DIR exported for this process: {manager.certs_dir}") + logger.info(" Certificates generated successfully") + logger.info(" gRPC auth token created") + logger.info(f" Certs and token stored in: {manager.certs_dir}") + logger.info(f" WEIGHTSLAB_CERTS_DIR exported for this process: {manager.certs_dir}") logger.info("Then launch the secured UI with: weightslab ui launch --certs") # Keep this the FINAL output so the user can't miss the action they must take. logger.warning("") - logger.warning("⚠ ACTION REQUIRED — set WEIGHTSLAB_CERTS_DIR globally so new shells " + logger.warning(" ACTION REQUIRED — set WEIGHTSLAB_CERTS_DIR globally so new shells " "and the training backend find these certs (single source of truth):") - logger.warning(f" (bash) echo 'export WEIGHTSLAB_CERTS_DIR=\"{manager.certs_dir}\"' >> ~/.bashrc && source ~/.bashrc") - logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{manager.certs_dir}\"") + logger.warning(f" (bash) echo 'export WEIGHTSLAB_CERTS_DIR=\"{manager.certs_dir}\"' >> ~/.bashrc && source ~/.bashrc") + logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{manager.certs_dir}\"") # Bundled PyTorch examples, keyed by the CLI flag (e.g. --cls -> ws-classification). @@ -970,7 +970,7 @@ def _install_example_requirements(example_dir: Path) -> None: f"Failed to install requirements ({req}): {exc}. " "Continuing — the example may still run if deps are already installed." ) - return # only the first matching requirements file is used + return # only the first matching requirements file is used def example_start(args): @@ -994,7 +994,7 @@ def example_start(args): _install_example_requirements(example_dir) logger.info(f"Starting the WeightsLab {label} ({kind}) example...") - logger.info(f" {main_py}") + logger.info(f" {main_py}") logger.info("In another terminal, launch the UI with: weightslab ui launch") logger.info(f"Then open http://localhost:5173 — stop the example with Ctrl+C.") if not _CERTS_DIR_IN_ORIGINAL_ENV: diff --git a/weightslab/utils/computational_graph.py b/weightslab/utils/computational_graph.py index f6754658..34d1699f 100644 --- a/weightslab/utils/computational_graph.py +++ b/weightslab/utils/computational_graph.py @@ -144,11 +144,11 @@ def _generate_mappings( # Case 2: Many-to-one (src > dst) # A "batch" of source neurons maps to a single dstination neuron. # if len(src_channels) % len(dst_channels) != 0: - # raise ValueError( - # f"Source channels ({src_channels}) must be perfectly \ - # divisible by dstination channels ({dst_channels}) \ - # for many-to-one mapping." - # ) + # raise ValueError( + # f"Source channels ({src_channels}) must be perfectly \ + # divisible by dstination channels ({dst_channels}) \ + # for many-to-one mapping." + # ) # 1. Calculate the block size. # This determines how many linear layer neurons map to one convolution channel. @@ -184,7 +184,7 @@ def _generate_mappings( # We map the individual code back to the original index dst_to_src_mapping[code] = [index] - else: # src_channels < dst_channels + else: # src_channels < dst_channels # 1. Calculate the block size. # This determines how many linear layer neurons map to one convolution channel. # We use integer division to ensure a clean split. @@ -359,7 +359,7 @@ def _propagate_constraints_through_dependencies( propagated_constraints[module_id].get('outgoing', {})[cname] = cval # BFS to propagate OUTGOING constraints downstream - queue = [(module_id, native_constraints.copy(), {module_id})] # (current_id, outgoing_constraints, visited) + queue = [(module_id, native_constraints.copy(), {module_id})] # (current_id, outgoing_constraints, visited) while queue: current_id, current_constraints, visited_set = queue.pop(0) @@ -531,7 +531,7 @@ def _alias_from_tensor_name(tensor_name: str) -> Optional[str]: next_part = module_parts[i + 1] # If next part starts with current part + '.', it's redundant if next_part.startswith(part + '.'): - continue # Skip this redundant part + continue # Skip this redundant part deduplicated.append(part) return '.'.join(deduplicated) @@ -712,23 +712,23 @@ def generate_graph_dependencies_from_torchfx( # # SEED NEURONS: Use FX metadata to seed neurons if possible # for mod in make_safelist(current_module): - # if 'tensor_meta' in node.meta: - # meta = node.meta['tensor_meta'] - # if hasattr(meta, 'shape') and len(meta.shape) >= 2: - # out_ch = meta.shape[1] - # if out_ch is not None and out_ch > 0: - # mod.set_neurons('out_neurons', out_ch) - # if getattr(mod, 'wl_same_flag', False): - # mod.set_neurons('in_neurons', out_ch) - - # # Also check inputs to seed in_neurons - # for arg in node.args: - # if isinstance(arg, th.fx.Node) and 'tensor_meta' in arg.meta: - # meta_in = arg.meta['tensor_meta'] - # if hasattr(meta_in, 'shape') and len(meta_in.shape) >= 2: - # in_ch = meta_in.shape[1] - # if in_ch is not None and in_ch > 0: - # mod.set_neurons('in_neurons', in_ch) + # if 'tensor_meta' in node.meta: + # meta = node.meta['tensor_meta'] + # if hasattr(meta, 'shape') and len(meta.shape) >= 2: + # out_ch = meta.shape[1] + # if out_ch is not None and out_ch > 0: + # mod.set_neurons('out_neurons', out_ch) + # if getattr(mod, 'wl_same_flag', False): + # mod.set_neurons('in_neurons', out_ch) + + # # Also check inputs to seed in_neurons + # for arg in node.args: + # if isinstance(arg, th.fx.Node) and 'tensor_meta' in arg.meta: + # meta_in = arg.meta['tensor_meta'] + # if hasattr(meta_in, 'shape') and len(meta_in.shape) >= 2: + # in_ch = meta_in.shape[1] + # if in_ch is not None and in_ch > 0: + # mod.set_neurons('in_neurons', in_ch) # --- Handle General Merge Operations (Any call_function with multiple # module inputs) --- @@ -744,8 +744,8 @@ def generate_graph_dependencies_from_torchfx( # TODO (GP): cat of cat of cat, should be nested list also ? # TODO (GP): e.g., cat([conv1, conv2, cat([conv3, cat([conv4, # TODO (GP): conv5])])])]) - source_modules_ = [] # Collect modules to check for single input - source_nodes = [] # Collect nodes to check for single input + source_modules_ = [] # Collect modules to check for single input + source_nodes = [] # Collect nodes to check for single input for arg in node.args: if not isinstance(arg, list): arg = make_safelist(arg) @@ -795,7 +795,7 @@ def generate_graph_dependencies_from_torchfx( # dependent on the first module in the merge node_to_module[node] = distinct_source_modules else: - node_to_module[node] = None # Placeholder or constant input + node_to_module[node] = None # Placeholder or constant input # Clean dependencies (remove duplicates and self-loops) dependencies = _clean_dependencies(dependencies) @@ -835,9 +835,9 @@ def generate_layer_dependencies_from_onnx( Returns: [ - ('conv1', 'bn1', DepType.SAME), # Conv output = BN input/output - ('bn1', 'relu', DepType.SAME), # BN output = ReLU input/output - ('relu', 'conv2', DepType.INCOMING), # ReLU output = Conv2 input only + ('conv1', 'bn1', DepType.SAME), # Conv output = BN input/output + ('bn1', 'relu', DepType.SAME), # BN output = ReLU input/output + ('relu', 'conv2', DepType.INCOMING), # ReLU output = Conv2 input only ] Note: @@ -916,7 +916,7 @@ def get_channel_count(tensor_name: str) -> Optional[int]: if onnx_shapes_map: shape = onnx_shapes_map.get(tensor_name) if shape and len(shape) >= 2: - return shape[1] # NCHW format, C is dimension 1 + return shape[1] # NCHW format, C is dimension 1 # Fallback to node attributes producer = producer_for_tensor.get(tensor_name) @@ -925,19 +925,19 @@ def get_channel_count(tensor_name: str) -> Optional[int]: weight_name = producer.input[1] for init in graph.initializer: if init.name == weight_name and len(init.dims) >= 1: - return init.dims[0] # out_channels + return init.dims[0] # out_channels elif producer.op_type == 'Gemm' and len(producer.input) >= 2: weight_name = producer.input[1] for init in graph.initializer: if init.name == weight_name and len(init.dims) >= 1: - return init.dims[0] # out_features + return init.dims[0] # out_features elif producer.op_type == 'BatchNormalization' and len(producer.input) >= 2: weight_name = producer.input[1] for init in graph.initializer: if init.name == weight_name and len(init.dims) >= 1: - return init.dims[0] # num_features + return init.dims[0] # num_features return None @@ -1018,8 +1018,8 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: for node in graph.node: logger.debug(f"\nNode: {node.op_type} | name: {node.name}") - logger.debug(f" Inputs: {node.input[:3]}") # Show first 3 inputs - logger.debug(f" Outputs: {node.output}") + logger.debug(f" Inputs: {node.input[:3]}") # Show first 3 inputs + logger.debug(f" Outputs: {node.output}") # Find source modules from inputs src_modules = [] @@ -1028,7 +1028,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: for inp in node.input: # Skip constant/initializer inputs (weights, biases, etc.) if any(param in inp for param in ['.weight', '.bias', '.running_mean', '.running_var', '.num_batches_tracked']): - logger.debug(f" Skipping parameter input: {inp[:50]}") + logger.debug(f" Skipping parameter input: {inp[:50]}") continue src_mods = module_for_tensor(inp) @@ -1037,12 +1037,12 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: src_modules.append(src_mod) src_tensors.append(inp) src_name = module_to_name.get(src_mod, "") - logger.debug(f" Found source: {src_name} (from tensor: {inp[:50]})") + logger.debug(f" Found source: {src_name} (from tensor: {inp[:50]})") else: - logger.debug(f" Could not find source module for input: {inp[:50]}") + logger.debug(f" Could not find source module for input: {inp[:50]}") if not src_modules: - logger.debug(f" -> No source modules found, skipping") + logger.debug(f" -> No source modules found, skipping") continue # Handle merge operations (Add, Sub, Sum, Concat, Mul, Div) - create REC dependencies between branches @@ -1050,8 +1050,8 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: is_concat = "Concat" in node.op_type.capitalize() or "Cat" in node.op_type.capitalize() if is_merge: - logger.debug(f" Detected merge operation: {node.op_type} with {len(src_modules)} source modules") - logger.debug(f" Source modules: {[module_to_name.get(m, '?') for m in src_modules]}") + logger.debug(f" Detected merge operation: {node.op_type} with {len(src_modules)} source modules") + logger.debug(f" Source modules: {[module_to_name.get(m, '?') for m in src_modules]}") if is_merge and len(src_modules) >= 2: # Create REC dependencies between all pairs of source modules @@ -1073,15 +1073,15 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: channels_a = get_channel_count(tensor_a) channels_b = get_channel_count(tensor_b) - logger.debug(f" Checking REC: {name_a} (ch={channels_a}) <-> {name_b} (ch={channels_b})") + logger.debug(f" Checking REC: {name_a} (ch={channels_a}) <-> {name_b} (ch={channels_b})") # For Add/Sub/Mul/Div, channels must match # For Concat, channels can differ (concatenated along channel dim) create_rec = False if is_concat: - create_rec = True # Always create REC for concat + create_rec = True # Always create REC for concat elif channels_a is not None and channels_b is not None and channels_a == channels_b: - create_rec = True # Channels match for Add/Sub/etc + create_rec = True # Channels match for Add/Sub/etc elif channels_a is None or channels_b is None: # Can't verify channels, but merge operation requires compatibility create_rec = True @@ -1092,16 +1092,16 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: edge_key_ba = (name_b, name_a) if edge_key_ab not in seen_edges: - logger.debug(f" ✓ Adding REC dependency: {name_a} <-> {name_b}") + logger.debug(f" Adding REC dependency: {name_a} <-> {name_b}") dependencies.append((mod_a, mod_b, DepType.REC)) seen_edges.add(edge_key_ab) if edge_key_ba not in seen_edges: - logger.debug(f" ✓ Adding REC dependency: {name_b} <-> {name_a}") + logger.debug(f" Adding REC dependency: {name_b} <-> {name_a}") dependencies.append((mod_b, mod_a, DepType.REC)) seen_edges.add(edge_key_ba) else: - logger.debug(f" ✗ Skipping REC dependency (channel mismatch: {channels_a} vs {channels_b})") + logger.debug(f" Skipping REC dependency (channel mismatch: {channels_a} vs {channels_b})") # Find destination module from outputs # First try: direct alias from output tensor name @@ -1113,7 +1113,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: if alias and alias in name_to_module: dst_mod = name_to_module[alias] dst_tensor = out_name - logger.debug(f" Found dest (method 1 - alias): {alias}") + logger.debug(f" Found dest (method 1 - alias): {alias}") break # Second try: For ops that correspond to nn.Module, extract module from node itself @@ -1125,15 +1125,15 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: if len(parts) >= 2: # The module name is typically everything except the last part (op type) potential_names = [ - '.'.join(parts[:-1]), # e.g., 'model.bn1' from '/model/bn1/BatchNormalization' - parts[-2] if len(parts) >= 2 else parts[0], # e.g., 'bn1' from above + '.'.join(parts[:-1]), # e.g., 'model.bn1' from '/model/bn1/BatchNormalization' + parts[-2] if len(parts) >= 2 else parts[0], # e.g., 'bn1' from above ] for pname in potential_names: if pname in name_to_module: dst_mod = name_to_module[pname] if node.output: dst_tensor = node.output[0] - logger.debug(f" Found dest (method 2 - node name): {pname}") + logger.debug(f" Found dest (method 2 - node name): {pname}") break # Third try: For weight-based ops (Conv, Gemm, BatchNorm), check input parameters @@ -1148,7 +1148,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: dst_mod = name_to_module[potential_name] if node.output: dst_tensor = node.output[0] - logger.debug(f" Found dest (method 3 - params): {potential_name} (from {inp})") + logger.debug(f" Found dest (method 3 - params): {potential_name} (from {inp})") break if dst_mod: break @@ -1159,17 +1159,17 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: # if not hasattr(src_mod, 'bypass'): # src_mod.bypass = 0 bypassed.extend(make_safelist(list(node.output))) - logger.debug(f" Setting bypass=0 for module after Concat: {module_to_name.get(src_mod, '?')}") + logger.debug(f" Setting bypass=0 for module after Concat: {module_to_name.get(src_mod, '?')}") for k in make_safelist(list(node.input)): if k in bypassed: if dst_mod is not None: dst_mod.bypass = 0 - logger.debug(f" Setting bypass=0 for destination module: {module_to_name.get(dst_mod, '?')}") + logger.debug(f" Setting bypass=0 for destination module: {module_to_name.get(dst_mod, '?')}") break # If no destination module found, skip as it s the end if dst_mod is None: - logger.debug(f" -> No destination module found, skipping") + logger.debug(f" -> No destination module found, skipping") continue # Determine dependency type for each source -> destination connection @@ -1196,17 +1196,17 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: # # SEED NEURONS: Use ONNX metadata to seed neurons if possible # if src_channels is not None and src_channels > 0: - # src_mod.set_neurons('out_neurons', src_channels) - # if getattr(src_mod, 'wl_same_flag', False): - # src_mod.set_neurons('in_neurons', src_channels) + # src_mod.set_neurons('out_neurons', src_channels) + # if getattr(src_mod, 'wl_same_flag', False): + # src_mod.set_neurons('in_neurons', src_channels) # if dst_channels is not None and dst_channels > 0: - # dst_mod.set_neurons('in_neurons', dst_channels) - # if getattr(dst_mod, 'wl_same_flag', False): - # dst_mod.set_neurons('out_neurons', dst_channels) + # dst_mod.set_neurons('in_neurons', dst_channels) + # if getattr(dst_mod, 'wl_same_flag', False): + # dst_mod.set_neurons('out_neurons', dst_channels) logger.debug(f"Analyzing dependency {src_name} -> {dst_name}") - logger.debug(f" Source channels: {src_channels}, Destination channels: {dst_channels}") + logger.debug(f" Source channels: {src_channels}, Destination channels: {dst_channels}") # Use helper function to infer dependency type dep_type = _infer_dependency_type(dst_mod) @@ -1215,7 +1215,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: if dep_type == DepType.SAME: dst_mod.wl_same_flag = True - logger.debug(f" ✓ Adding dependency: {src_name} -> {dst_name} [{dep_type.name}]") + logger.debug(f" Adding dependency: {src_name} -> {dst_name} [{dep_type.name}]") dependencies.append((src_mod, dst_mod, dep_type)) seen_edges.add(edge_key) @@ -1247,7 +1247,7 @@ def generate_index_maps( for edge in dependencies: # Get src and dst modules and type src_mod, dst_mod, edge_label = edge[0], edge[1], edge[2] - recursive_dep = edge_label == DepType.REC # A recursive dependency ? + recursive_dep = edge_label == DepType.REC # A recursive dependency ? # 1.1. Determine the number of neurons in each direction # # Src - First will always be is not None and int @@ -1264,7 +1264,7 @@ def generate_index_maps( dst_mod.set_neurons( 'in_neurons' if not recursive_dep and not hasattr(dst_mod, 'wl_transposed') else 'out_neurons', dst_nb_neurons - ) # So next will have neurons + ) # So next will have neurons dst_mod_out_neurons = dst_mod.get_neurons( 'in_neurons' if not (not recursive_dep and not hasattr(dst_mod, 'wl_transposed')) else 'out_neurons' ) @@ -1279,7 +1279,7 @@ def generate_index_maps( src_mod.set_neurons( 'out_neurons' if not hasattr(src_mod, 'wl_transposed') else 'in_neurons', src_nb_neurons - ) # So next will have neurons + ) # So next will have neurons src_mod_out_neurons = src_mod.get_neurons( 'out_neurons' if not hasattr(src_mod, 'wl_transposed') else 'in_neurons' ) @@ -1360,7 +1360,7 @@ def extract_group_size(mod: nn.Module, incoming: bool) -> Optional[int]: dst_mod.get_name_wi_id(): deepcopy(src_to_dst_mapping_tnsr) } if not hasattr(dst_mod, 'bypass') else {} - ) # Child equivalent here + ) # Child equivalent here dst_mod.src_to_dst_mapping_tnsrs = normalize_dicts(dst_mod.src_to_dst_mapping_tnsrs) dst_mod.related_dst_to_src_mapping_tnsrs = normalize_dicts(dst_mod.related_dst_to_src_mapping_tnsrs) @@ -1392,7 +1392,7 @@ def extract_group_size(mod: nn.Module, incoming: bool) -> Optional[int]: """ # Enable debug logging logging.basicConfig( - level=logging.DEBUG, # Set to DEBUG to see detailed merge operation detection + level=logging.DEBUG, # Set to DEBUG to see detailed merge operation detection format='%(levelname)s - %(message)s' ) logger.setLevel(logging.DEBUG) @@ -1523,14 +1523,14 @@ def __init__(self, in_channels=3, num_classes=10): self.enc_relu1 = nn.ReLU() self.enc_residual1 = ResidualBlockWithUpsampling(64, 64) - self.enc_pool1 = nn.MaxPool2d(2, 2) # 32x32 -> 16x16 + self.enc_pool1 = nn.MaxPool2d(2, 2) # 32x32 -> 16x16 self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.enc_bn2 = nn.BatchNorm2d(128) self.enc_relu2 = nn.ReLU() self.enc_residual2 = ResidualBlockWithUpsampling(128, 128) - self.enc_pool2 = nn.MaxPool2d(2, 2) # 16x16 -> 8x8 + self.enc_pool2 = nn.MaxPool2d(2, 2) # 16x16 -> 8x8 # Bottleneck self.bottleneck_conv = nn.Conv2d(128, 256, kernel_size=3, padding=1) @@ -1539,8 +1539,8 @@ def __init__(self, in_channels=3, num_classes=10): self.bottleneck_residual = ResidualBlockWithUpsampling(256, 256) # Decoder: Upsampling path - self.dec_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 8x8 -> 16x16 - self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features + self.dec_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 8x8 -> 16x16 + self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features self.dec_bn1 = nn.BatchNorm2d(128) self.dec_residual1 = ResidualBlockWithUpsampling(128, 128) @@ -1569,7 +1569,7 @@ def forward(self, x): # Decoder with skip connections dec1 = self.dec_upsample1(bottleneck) - dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation + dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation dec1 = self.dec_conv1(dec1) dec1 = self.dec_bn1(dec1) dec1 = self.dec_residual1(dec1) @@ -1595,14 +1595,14 @@ def __init__(self, in_channels=3, num_classes=10): self.enc_relu1 = nn.ReLU() self.enc_residual1 = ResidualBlockWithUpsampling(64, 64) - self.enc_pool1 = nn.MaxPool2d(3, 3) # 27x27 -> 9x9 + self.enc_pool1 = nn.MaxPool2d(3, 3) # 27x27 -> 9x9 self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.enc_bn2 = nn.BatchNorm2d(128) self.enc_relu2 = nn.ReLU() self.enc_residual2 = ResidualBlockWithUpsampling(128, 128) - self.enc_pool2 = nn.MaxPool2d(3, 3) # 9x9 -> 3x3 + self.enc_pool2 = nn.MaxPool2d(3, 3) # 9x9 -> 3x3 # Bottleneck self.bottleneck_conv = nn.Conv2d(128, 256, kernel_size=3, padding=1) @@ -1611,13 +1611,13 @@ def __init__(self, in_channels=3, num_classes=10): self.bottleneck_residual = ResidualBlockWithUpsampling(256, 256) # Decoder: Mixed upsampling (3x and 2x) - self.dec_upsample1 = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False) # 3x3 -> 9x9 - self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features + self.dec_upsample1 = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False) # 3x3 -> 9x9 + self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features self.dec_bn1 = nn.BatchNorm2d(128) self.dec_residual1 = ResidualBlockWithUpsampling(128, 128) - self.dec_upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 9x9 -> 18x18 - self.dec_conv2 = nn.Conv2d(128 + 64, 64, kernel_size=3, padding=1) # Concatenate with encoder features + self.dec_upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 9x9 -> 18x18 + self.dec_conv2 = nn.Conv2d(128 + 64, 64, kernel_size=3, padding=1) # Concatenate with encoder features self.dec_bn2 = nn.BatchNorm2d(64) self.dec_residual2 = ResidualBlockWithUpsampling(64, 64) @@ -1646,13 +1646,13 @@ def forward(self, x): # Decoder with skip connections dec1 = self.dec_upsample1(bottleneck) - dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation + dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation dec1 = self.dec_conv1(dec1) dec1 = self.dec_bn1(dec1) dec1 = self.dec_residual1(dec1) dec2 = self.dec_upsample2(dec1) - dec2 = th.cat([dec2, enc1], dim=1) # Skip connection via concatenation + dec2 = th.cat([dec2, enc1], dim=1) # Skip connection via concatenation dec2 = self.dec_conv2(dec2) dec2 = self.dec_bn2(dec2) dec2 = self.dec_residual2(dec2) @@ -1694,23 +1694,23 @@ def forward(self, x): incoming_deps = [(s, d) for s, d, t in dependencies3 if t == DepType.INCOMING] rec_deps = [(s, d) for s, d, t in dependencies3 if t == DepType.REC] - print(f"\n SAME Dependencies ({len(same_deps)}):") - for src, dst in same_deps: # Show first 5 + print(f"\n SAME Dependencies ({len(same_deps)}):") + for src, dst in same_deps: # Show first 5 src_name = next((name for name, mod in model.named_modules() if mod is src), "") dst_name = next((name for name, mod in model.named_modules() if mod is dst), "") - print(f" [{src_name:30s}] --SAME-----> [{dst_name:30s}]") + print(f" [{src_name:30s}] --SAME-----> [{dst_name:30s}]") - print(f"\n INCOMING Dependencies ({len(incoming_deps)}):") - for src, dst in incoming_deps: # Show first 5 + print(f"\n INCOMING Dependencies ({len(incoming_deps)}):") + for src, dst in incoming_deps: # Show first 5 src_name = next((name for name, mod in model.named_modules() if mod is src), "") dst_name = next((name for name, mod in model.named_modules() if mod is dst), "") - print(f" [{src_name:30s}] --INCOMING-> [{dst_name:30s}]") + print(f" [{src_name:30s}] --INCOMING-> [{dst_name:30s}]") - print(f"\n REC Dependencies ({len(rec_deps)}):") - for src, dst in rec_deps: # Show first 5 + print(f"\n REC Dependencies ({len(rec_deps)}):") + for src, dst in rec_deps: # Show first 5 src_name = next((name for name, mod in model.named_modules() if mod is src), "") dst_name = next((name for name, mod in model.named_modules() if mod is dst), "") - print(f" [{src_name:30s}] <--REC----> [{dst_name:30s}]") + print(f" [{src_name:30s}] <--REC----> [{dst_name:30s}]") print(""" Model Architecture Notes: diff --git a/weightslab/utils/logs.py b/weightslab/utils/logs.py index c329710e..dd75872f 100644 --- a/weightslab/utils/logs.py +++ b/weightslab/utils/logs.py @@ -140,7 +140,7 @@ def setup_logging(level, log_to_file=True): _LOG_FILE_PATH = os.path.join(log_dir, f'weightslab_{timestamp}.log') _FILE_HANDLER = logging.FileHandler(_LOG_FILE_PATH, mode='w', encoding='utf-8') - _FILE_HANDLER.setLevel(logging.DEBUG) # Always log DEBUG+ to file + _FILE_HANDLER.setLevel(logging.DEBUG) # Always log DEBUG+ to file _FILE_HANDLER.setFormatter(formatter) root_logger.addHandler(_FILE_HANDLER) diff --git a/weightslab/utils/tools.py b/weightslab/utils/tools.py index 4e8fa025..8ae77b8f 100644 --- a/weightslab/utils/tools.py +++ b/weightslab/utils/tools.py @@ -29,7 +29,7 @@ def safe_reset_index(df: "pd.DataFrame") -> "pd.DataFrame": Plain ``df.reset_index()`` raises ``ValueError: cannot insert X, already exists`` when a MultiIndex level name (e.g. ``sample_id`` or ``annotation_id``) has already been materialised as a column — which - happens after ``_normalize_for_read`` in the H5 store. This helper only + happens after ``_normalize_for_read`` in the H5 store. This helper only promotes the levels that are actually missing from the column namespace. """ import pandas as _pd @@ -89,7 +89,7 @@ def normalize_config(obj: Any) -> Any: elif isinstance(obj, list): return [normalize_config(v) for v in obj] elif isinstance(obj, torch.device): - return str(obj) # e.g. "cuda" or "cuda:0" + return str(obj) # e.g. "cuda" or "cuda:0" elif isinstance(obj, pathlib.Path): return obj.as_posix() elif isinstance(obj, (bool, int, float, str)) or obj is None: @@ -176,7 +176,7 @@ def restore_rng_state(rng_state): # Restore Python random state if 'python_random' in rng_state: try: - random.setstate(tuple(tuple(i) if i is not None and not isinstance(i, (int, float)) else i for i in rng_state['python_random'])) # Conver to tuple of tuples + random.setstate(tuple(tuple(i) if i is not None and not isinstance(i, (int, float)) else i for i in rng_state['python_random'])) # Conver to tuple of tuples logger.debug("Restored Python random state") except Exception as e: logger.warning(f"Failed to restore Python random state: {e}") @@ -402,7 +402,7 @@ def model_op_neurons(model, layer_id=None, dummy_input=None, op=None, rand=False Test function to iteratively update neurons for each layer, then test inference. Everything match ? """ - seed_everything(42) if rand else None # Set seed for reproducibility + seed_everything(42) if rand else None # Set seed for reproducibility n_layers = len(model.layers) for n in range(n_layers-1, 0, -1): if rand and th.rand(1) > 0.5 and layer_id is None and dummy_input is None: @@ -412,7 +412,7 @@ def model_op_neurons(model, layer_id=None, dummy_input=None, op=None, rand=False if n != layer_id: continue else: - if n != n_layers + layer_id: # - -layer_id != + -layer_id + if n != n_layers + layer_id: # - -layer_id != + -layer_id continue logger.debug(f'\nOperate on neurons at layer {n}') if op is None: @@ -631,7 +631,7 @@ def array_id_2bytes( h = xxhash.xxh64() h.update(data) - digest8 = h.digest() # 8 bytes + digest8 = h.digest() # 8 bytes if return_hex: hexs = digest8.hex() @@ -648,10 +648,10 @@ def detach_to_cpu(obj: Any) -> Any: """Recursively detach tensors from the compute graph and move them to CPU. Handles: - - ``torch.Tensor`` → ``.detach().cpu()`` - - ``dict`` → recurse into values, preserve keys - - ``list`` / ``tuple`` → recurse element-wise, preserve type - - anything else → returned as-is + - ``torch.Tensor`` → ``.detach().cpu()`` + - ``dict`` → recurse into values, preserve keys + - ``list`` / ``tuple`` → recurse element-wise, preserve type + - anything else → returned as-is """ if isinstance(obj, th.Tensor): return obj.detach().cpu() @@ -681,7 +681,7 @@ def filter_kwargs_for_callable(func, kwargs): Examples: >>> def my_func(a, b, c=10): - ... return a + b + c + ... return a + b + c >>> all_kwargs = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5} >>> filtered = filter_kwargs_for_callable(my_func, all_kwargs) >>> filtered @@ -743,7 +743,7 @@ def safe_call_with_kwargs(func, *args, **kwargs): Examples: >>> def my_func(a, b, c=10): - ... return a + b + c + ... return a + b + c >>> safe_call_with_kwargs(my_func, 1, 2, c=3, d=4, e=5) 6 """ diff --git a/weightslab/watchdog/__init__.py b/weightslab/watchdog/__init__.py index 50c9e849..56e5e5a2 100644 --- a/weightslab/watchdog/__init__.py +++ b/weightslab/watchdog/__init__.py @@ -2,29 +2,29 @@ Public API ---------- -WATCHDOG : int — custom log level (35, between WARNING and ERROR) -MonitoredRLock : class — RLock with holder-thread tracking -raise_in_thread : func — deliver _WatchdogInterrupt to a thread by id -_WatchdogInterrupt : class — BaseException raised in stuck threads -RpcWatchdogState : class — tracks in-flight gRPC RPCs +WATCHDOG : int — custom log level (35, between WARNING and ERROR) +MonitoredRLock : class — RLock with holder-thread tracking +raise_in_thread : func — deliver _WatchdogInterrupt to a thread by id +_WatchdogInterrupt : class — BaseException raised in stuck threads +RpcWatchdogState : class — tracks in-flight gRPC RPCs RpcTimingAndWatchdogInterceptor : class — gRPC ServerInterceptor -GrpcServerManager : class — controls gRPC server lifecycle / restarts -WeighlabsWatchdog : class — unified watchdog (locks + gRPC) +GrpcServerManager : class — controls gRPC server lifecycle / restarts +WeighlabsWatchdog : class — unified watchdog (locks + gRPC) """ # Register WATCHDOG log level and logger.watchdog() method -from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 +from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 -from weightslab.watchdog.lock_monitor import ( # noqa: F401 +from weightslab.watchdog.lock_monitor import ( # noqa: F401 MonitoredRLock, _WatchdogInterrupt, raise_in_thread, ) -from weightslab.watchdog.grpc_watchdog import ( # noqa: F401 +from weightslab.watchdog.grpc_watchdog import ( # noqa: F401 RpcWatchdogState, RpcTimingAndWatchdogInterceptor, GrpcServerManager, ) -from weightslab.watchdog.watchdog import WeighlabsWatchdog # noqa: F401 +from weightslab.watchdog.watchdog import WeighlabsWatchdog # noqa: F401 diff --git a/weightslab/watchdog/grpc_watchdog.py b/weightslab/watchdog/grpc_watchdog.py index 94713d6c..979b90ba 100644 --- a/weightslab/watchdog/grpc_watchdog.py +++ b/weightslab/watchdog/grpc_watchdog.py @@ -12,7 +12,7 @@ from threading import Lock, Event -from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level +from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level logger = logging.getLogger(__name__) @@ -149,7 +149,7 @@ def set_server(self, server) -> None: def stop(self, grace: float = 5.0) -> None: with self._lock: if self._server: - logger.watchdog("[gRPC] Requesting graceful shutdown with %.1fs grace", grace) # type: ignore[attr-defined] + logger.watchdog("[gRPC] Requesting graceful shutdown with %.1fs grace", grace) # type: ignore[attr-defined] self._server.stop(grace=grace) self._server = None diff --git a/weightslab/watchdog/lock_monitor.py b/weightslab/watchdog/lock_monitor.py index 299ecc9b..3b00693a 100644 --- a/weightslab/watchdog/lock_monitor.py +++ b/weightslab/watchdog/lock_monitor.py @@ -1,11 +1,11 @@ """Lock monitoring for weightslab watchdog. Provides: - - MonitoredRLock : drop-in RLock replacement that tracks the holder thread + - MonitoredRLock : drop-in RLock replacement that tracks the holder thread and how long it has been held, so the watchdog can detect and recover from stuck locks. - _WatchdogInterrupt : BaseException raised asynchronously in stuck threads. - - raise_in_thread : deliver _WatchdogInterrupt to any thread by id. + - raise_in_thread : deliver _WatchdogInterrupt to any thread by id. When the watchdog raises _WatchdogInterrupt in a thread that holds a MonitoredRLock via ``with`` or a ``try/finally: release()``, Python's @@ -36,7 +36,7 @@ def raise_in_thread(tid: int, exc_type: type = _WatchdogInterrupt) -> bool: """Raise *exc_type* asynchronously in the thread identified by *tid*. Uses ``ctypes.pythonapi.PyThreadState_SetAsyncExc`` which delivers the - exception at the next Python bytecode boundary. Any active ``finally:`` + exception at the next Python bytecode boundary. Any active ``finally:`` or ``with`` block in the target thread will execute before the exception propagates, so held locks are released cleanly. @@ -48,7 +48,7 @@ def raise_in_thread(tid: int, exc_type: type = _WatchdogInterrupt) -> bool: ctypes.py_object(exc_type), ) if res == 0: - return False # thread not found + return False # thread not found if res > 1: # More than one state was modified — undo to be safe (shouldn't happen) ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None) @@ -72,17 +72,17 @@ class MonitoredRLock: whether to kill the holder via ``raise_in_thread``. Re-entrancy is fully supported: the same thread can acquire multiple - times. ``_acquired_at`` records the time of the *first* acquisition and + times. ``_acquired_at`` records the time of the *first* acquisition and is cleared only when the lock becomes fully free (count reaches 0). """ def __init__(self) -> None: self._lock = threading.RLock() - self._meta = threading.Lock() # guards the three fields below + self._meta = threading.Lock() # guards the three fields below self._holder_tid: Optional[int] = None self._acquired_at: Optional[float] = None self._count: int = 0 - self._timeout: Optional[float] = None # Optional per-lock timeout for watchdog (None to use global default) + self._timeout: Optional[float] = None # Optional per-lock timeout for watchdog (None to use global default) # ------------------------------------------------------------------ # Core acquire / release diff --git a/weightslab/watchdog/log_level.py b/weightslab/watchdog/log_level.py index 540e252f..48f4862f 100644 --- a/weightslab/watchdog/log_level.py +++ b/weightslab/watchdog/log_level.py @@ -22,4 +22,4 @@ def _watchdog(self: logging.Logger, message: str, *args, **kwargs) -> None: # Patch Logger class once so every logger instance gets the method -logging.Logger.watchdog = _watchdog # type: ignore[attr-defined] +logging.Logger.watchdog = _watchdog # type: ignore[attr-defined] diff --git a/weightslab/watchdog/watchdog.py b/weightslab/watchdog/watchdog.py index 136785af..8832471c 100644 --- a/weightslab/watchdog/watchdog.py +++ b/weightslab/watchdog/watchdog.py @@ -1,14 +1,14 @@ """WeighlabsWatchdog — unified watchdog for locks and gRPC threads. Combines: - 1. Lock monitoring — polls MonitoredRLock instances, raises _WatchdogInterrupt + 1. Lock monitoring — polls MonitoredRLock instances, raises _WatchdogInterrupt in the holder thread when the lock is held too long. - 2. gRPC monitoring — detects stuck in-flight RPCs via RpcWatchdogState and + 2. gRPC monitoring — detects stuck in-flight RPCs via RpcWatchdogState and requests a server restart when the threshold is exceeded. 3. Eval thread monitoring — checks that the evaluation worker thread is still alive whenever eval_controller reports is_running() or - is_pending(). If the thread is dead the controller is - transitioned to error state automatically. No timeout is + is_pending(). If the thread is dead the controller is + transitioned to error state automatically. No timeout is applied — evaluation may run for an arbitrarily long time. Typical usage (inside grpc_serve): @@ -25,7 +25,7 @@ get_thread=lambda: _EVAL_WORKER_THREAD, ) watchdog.start() - # watchdog.rpc_state → pass to RpcTimingAndWatchdogInterceptor + # watchdog.rpc_state → pass to RpcTimingAndWatchdogInterceptor # watchdog.server_manager → used by serving_thread_callback """ @@ -34,7 +34,7 @@ import threading from typing import Callable, Dict, List, Optional, Tuple -from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level +from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level from weightslab.watchdog.lock_monitor import MonitoredRLock, raise_in_thread from weightslab.watchdog.grpc_watchdog import RpcWatchdogState, GrpcServerManager @@ -91,12 +91,12 @@ def register_eval_monitor( The watchdog will call ``mark_error()`` on the controller when it reports ``is_running()`` or ``is_pending()`` but the worker thread is no longer - alive. **No timeout is applied** — evaluation is allowed to run for as + alive. **No timeout is applied** — evaluation is allowed to run for as long as needed. Args: get_controller: Zero-arg callable that returns the EvaluationController. - get_thread: Zero-arg callable that returns the current worker + get_thread: Zero-arg callable that returns the current worker ``threading.Thread`` (or ``None`` if not started yet). """ self._eval_monitors.append((get_controller, get_thread)) @@ -114,7 +114,7 @@ def start(self) -> None: daemon=True, ) self._thread.start() - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Started (threshold=%.1fs poll=%.1fs restart_after=%d exit_on_stuck=%s locks=%s)", self._stuck_threshold_s, self._poll_interval_s, @@ -152,7 +152,7 @@ def _loop(self) -> None: # Lock monitoring # ------------------------------------------------------------------ - def _check_locks(self) -> None: # noqa: C901 + def _check_locks(self) -> None: # noqa: C901 # Import here to avoid circular imports try: from weightslab.components.global_monitoring import is_in_evaluation @@ -185,19 +185,19 @@ def _check_locks(self) -> None: # noqa: C901 continue if duration >= effective_threshold: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Lock '%s' held for %.1fs by tid=%s — sending interrupt", name, duration, tid, ) if tid is not None: killed = raise_in_thread(tid) if killed: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Interrupt delivered to tid=%s (lock '%s' will be released by finally/with)", tid, name, ) else: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Could not deliver interrupt to tid=%s — thread may have already exited", tid, ) @@ -217,14 +217,14 @@ def _check_eval_threads(self) -> None: continue if not (controller.is_running() or controller.is_pending()): - continue # nothing active — nothing to check + continue # nothing active — nothing to check if thread is not None and thread.is_alive(): - continue # worker is alive — all good + continue # worker is alive — all good # Controller believes eval is active but the thread is dead or missing. status = controller.get_status() if hasattr(controller, "get_status") else "unknown" - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Eval controller is '%s' but worker thread is dead — marking error", status, ) @@ -253,7 +253,7 @@ def _check_grpc(self) -> None: self._unhealthy_count += 1 self.rpc_state.record_unhealthy() - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] gRPC unhealthy #%d: in_flight=%d oldest=%.1fs method=%s threshold=%.1fs | %s", self._unhealthy_count, snap["in_flight"], @@ -264,13 +264,13 @@ def _check_grpc(self) -> None: ) if self._exit_on_stuck: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] GRPC_WATCHDOG_EXIT_ON_STUCK=1 — calling os._exit(1)" ) os._exit(1) if self._unhealthy_count >= self._restart_threshold: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Restart threshold reached (%d/%d) — requesting server restart", self._unhealthy_count, self._restart_threshold, ) @@ -278,7 +278,7 @@ def _check_grpc(self) -> None: else: if self._unhealthy_count > 0: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] gRPC recovered after %d unhealthy checks", self._unhealthy_count ) self._unhealthy_count = 0 From 43daa75454d4a9d7060ad77149cfb20f9af6659f Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 19:13:37 +0200 Subject: [PATCH 13/16] Add weightslab logdir CLI command for offline explore mode weightslab logdir [--no-ui] [--certs] [--grpc-port PORT] Wires the existing load_experiment_for_explore / serve / keep_serving infrastructure into the CLI so a downloaded log dir can be explored without running the original training script. By default also launches the Weights Studio Docker UI stack; pass --no-ui to skip (useful when the UI is already running from a prior 'weightslab ui launch'). Co-Authored-By: Claude Sonnet 4.6 --- weightslab/ui_docker_bridge.py | 111 ++++++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/weightslab/ui_docker_bridge.py b/weightslab/ui_docker_bridge.py index 229fdb13..028a1d2c 100644 --- a/weightslab/ui_docker_bridge.py +++ b/weightslab/ui_docker_bridge.py @@ -1017,6 +1017,71 @@ def example_start(args): sys.exit(result.returncode) +def logdir_explore(args): + """`weightslab logdir [--no-ui]`: offline explore mode for a downloaded log dir. + + Loads the experiment from disk into a read-only ledger (no training script, + GPU, or original dataset required), starts the gRPC backend server, and — + unless ``--no-ui`` is given — also brings up the Weights Studio Docker UI + stack so the experiment can be browsed immediately. + + Intended workflow:: + + # On your dev machine, after rsync-ing the cluster run: + weightslab logdir ./root_log_dir + + Once running, open http://localhost:5173 (or the URL printed on startup). + Press Ctrl+C to stop the gRPC server (Docker UI keeps running in the + background; stop it separately with ``weightslab ui launch`` or + ``docker compose down`` if needed). + """ + root_log_dir = args.root_log_dir + + if not getattr(args, "no_ui", False): + # Launch the Weights Studio UI stack (Docker + Envoy) the same way + # `weightslab ui launch` does. Build an args namespace that is + # compatible with ui_launch, inheriting relevant flags. + ui_args = argparse.Namespace( + certs=getattr(args, "certs", False), + force_certs=False, + no_clean=False, + no_auth=False, + dev=False, + certs_dir=getattr(args, "certs_dir", None), + ) + ui_launch(ui_args) + + logger.info("Loading experiment from disk: %s", root_log_dir) + # Lazy import: pulls in torch and the full weightslab stack only when this + # command is actually invoked, keeping other commands fast. + try: + from weightslab.src import load_experiment_for_explore, serve, keep_serving + except ImportError as exc: + logger.error("Failed to import weightslab core: %s", exc) + sys.exit(1) + + try: + summary = load_experiment_for_explore( + root_log_dir, + exp_hash=getattr(args, "exp_hash", None), + ) + except FileNotFoundError as exc: + logger.error(str(exc)) + sys.exit(1) + + logger.info("Experiment loaded: hash=%s, origins=%s", + summary.get("experiment_hash"), summary.get("origins")) + + grpc_port = getattr(args, "grpc_port", None) or int(os.getenv("GRPC_BACKEND_PORT", 50051)) + os.environ["GRPC_BACKEND_PORT"] = str(grpc_port) + logger.info("Starting WeightsLab gRPC server on port %d (read-only explore mode)...", grpc_port) + serve(serving_grpc=True) + + logger.info("WeightsLab is running in read-only explore mode.") + logger.info("Open the UI, then press Ctrl+C to stop.") + keep_serving(release_gpu=False) + + def _add_example_kind_flags(p: argparse.ArgumentParser) -> None: """Attach the mutually-exclusive example-kind flags (default: classification).""" group = p.add_mutually_exclusive_group() @@ -1045,6 +1110,7 @@ def _build_parser() -> argparse.ArgumentParser: weightslab se [--force-certs] weightslab ui launch [--certs] weightslab start example [--cls|--seg|--det|--clus|--gen|--3d_det|--2d_det] + weightslab logdir [--no-ui] [--certs] [--grpc-port PORT] """ parser = argparse.ArgumentParser( prog="weightslab", @@ -1054,7 +1120,7 @@ def _build_parser() -> argparse.ArgumentParser: ) # metavar lists only the documented commands; the `example` alias is accepted # but intentionally omitted here (and help=SUPPRESS'd below) so it stays hidden. - sub = parser.add_subparsers(dest="command", metavar="{se,ui,start,help}") + sub = parser.add_subparsers(dest="command", metavar="{se,ui,start,logdir,help}") # weightslab se [--force-certs] [certs_dir] se_parser = sub.add_parser("se", help="Set up the secure environment (TLS certs + gRPC auth token)") @@ -1090,6 +1156,47 @@ def _build_parser() -> argparse.ArgumentParser: sub.add_parser("help", help="Show this help message") + # weightslab logdir [--no-ui] [--certs] [--grpc-port PORT] + logdir_parser = sub.add_parser( + "logdir", + help="Open a finished experiment from disk in read-only explore mode, " + "then serve it through Weights Studio", + ) + logdir_parser.add_argument( + "root_log_dir", + help="Path to the root_log_dir produced by a previous training run", + ) + logdir_parser.add_argument( + "--no-ui", + action="store_true", + help="Skip launching the Weights Studio Docker UI stack " + "(useful when the UI is already running)", + ) + logdir_parser.add_argument( + "--certs", + action="store_true", + help="Generate TLS certs + gRPC auth token if missing, then launch UI secured", + ) + logdir_parser.add_argument( + "--grpc-port", + type=int, + default=None, + metavar="PORT", + help=f"gRPC backend port (default: $GRPC_BACKEND_PORT or 50051)", + ) + logdir_parser.add_argument( + "--exp-hash", + default=None, + metavar="HASH", + help="Specific experiment hash to open (default: latest)", + ) + logdir_parser.add_argument( + "certs_dir", + nargs="?", + default=None, + help="Custom directory for certs/token (default: $WEIGHTSLAB_CERTS_DIR or ~/.weightslab-certs)", + ) + return parser, ui_parser, start_parser @@ -1115,6 +1222,8 @@ def main(): # Alias for `start example` — tolerate the swapped subcommand order # (`weightslab example start [flags]`) and the bare `weightslab example`. example_start(args) + elif args.command == "logdir": + logdir_explore(args) else: parser.print_help() From 30518b356994c8fca7a74c71475617722a0205ae Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 19:15:30 +0200 Subject: [PATCH 14/16] Add wl.write_history/write_dataframe to PyTorch examples (ratio=100) Both classification and segmentation examples now: - Export signal history + data grid to root_log_dir every 100 steps (configurable via write_export_ratio in config.yaml) - Also export once at end of training for a final snapshot This lets an offline explore session (weightslab logdir) immediately pick up the latest signal/dataframe state without waiting for a checkpoint. Co-Authored-By: Claude Sonnet 4.6 --- .../examples/PyTorch/ws-classification/config.yaml | 1 + weightslab/examples/PyTorch/ws-classification/main.py | 10 ++++++++++ weightslab/examples/PyTorch/ws-segmentation/main.py | 10 ++++++++++ 3 files changed, 21 insertions(+) diff --git a/weightslab/examples/PyTorch/ws-classification/config.yaml b/weightslab/examples/PyTorch/ws-classification/config.yaml index 7efe94e0..3a93f65e 100644 --- a/weightslab/examples/PyTorch/ws-classification/config.yaml +++ b/weightslab/examples/PyTorch/ws-classification/config.yaml @@ -10,6 +10,7 @@ compute_natural_sort: false # Experiment parameters eval_full_to_train_steps_ratio: 500 # was 100 — full 10k eval was the dominant wall-clock cost experiment_dump_to_train_steps_ratio: 250 # was 25 — frequent checkpoint dumps stalled training +write_export_ratio: 100 # Export signal history + data grid to JSON/CSV every N steps skip_checkpoint_load: false # If true restart the experiment from last state tqdm_display: true # Whether to use tqdm progress bars during training/evaluation is_training: false # Start training immediately or not diff --git a/weightslab/examples/PyTorch/ws-classification/main.py b/weightslab/examples/PyTorch/ws-classification/main.py index 056b9736..8aea5115 100644 --- a/weightslab/examples/PyTorch/ws-classification/main.py +++ b/weightslab/examples/PyTorch/ws-classification/main.py @@ -261,6 +261,7 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): log_dir = parameters["root_log_dir"] tqdm_display = parameters.get("tqdm_display", True) eval_full_to_train_steps_ratio = parameters.get("eval_full_to_train_steps_ratio", 50) + write_export_ratio = parameters.get("write_export_ratio", 100) enable_h5_persistence = parameters.get("enable_h5_persistence", True) training_steps_to_do = parameters.get("training_steps_to_do", 1000) @@ -406,6 +407,11 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): test_loader_len ) + # Periodic history + dataframe export (JSON/CSV snapshots to root_log_dir) + if age > 0 and age % write_export_ratio == 0: + wl.write_history() + wl.write_dataframe() + # Verbose if verbose and not tqdm_display: import sys @@ -432,5 +438,9 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): print(f" Logs saved to: {log_dir}") print("=" * 60) + # Final export of signal history and data grid to root_log_dir + wl.write_history() + wl.write_dataframe() + # Keep the main thread alive to allow background serving threads to run wl.keep_serving() diff --git a/weightslab/examples/PyTorch/ws-segmentation/main.py b/weightslab/examples/PyTorch/ws-segmentation/main.py index 3f3f7059..21eb0e0c 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/main.py +++ b/weightslab/examples/PyTorch/ws-segmentation/main.py @@ -194,6 +194,7 @@ def test(loader, model, sig, device, test_loader_len): log_dir = parameters["root_log_dir"] max_steps = parameters["training_steps_to_do"] eval_full_to_train_steps_ratio = parameters["eval_full_to_train_steps_ratio"] + write_export_ratio = parameters.get("write_export_ratio", 100) verbose = parameters.get("verbose", True) tqdm_display = parameters.get("tqdm_display", True) @@ -356,6 +357,11 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating") if tqdm_display else test_loader test_loss, test_metric = test(test_loader_it, model, test_sig, device, test_loader_len) + # Periodic history + dataframe export (JSON/CSV snapshots to root_log_dir) + if age > 0 and age % write_export_ratio == 0: + wl.write_history() + wl.write_dataframe() + # Verbose if verbose and not tqdm_display: print( @@ -378,5 +384,9 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 print(f" Logs saved to: {log_dir}") print("=" * 60) + # Final export of signal history and data grid to root_log_dir + wl.write_history() + wl.write_dataframe() + # Keep the main thread alive to allow background serving threads to run wl.keep_serving() From b562fff384f740b08a4d031dd0f5add64f8b96a1 Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 19:30:13 +0200 Subject: [PATCH 15/16] Add ws-multitask PyTorch example + fix GuardContext unresolved Proxy crash Adds a multi-task MNIST example (ws-multitask) demonstrating: - Shared CNN backbone with classification and localization heads - Two separately-tracked WeightsLab losses (cross-entropy + smooth-L1) - Detection-format targets (tight bounding box per digit) for UI bbox overlay - Per-sample accuracy signal for data grid inspection Fixes GuardContext.__enter__ crashing with "Proxy target not set" when no model is registered: get_model() returns a Proxy(None) placeholder; now checks the Proxy _obj target before assigning self.model so unregistered contexts (e.g., Lightning-only tests) work gracefully. Co-Authored-By: Claude Sonnet 4.6 --- weightslab/components/global_monitoring.py | 12 +- .../examples/PyTorch/ws-multitask/config.yaml | 38 ++ .../examples/PyTorch/ws-multitask/main.py | 338 ++++++++++++++++++ .../PyTorch/ws-multitask/utils/__init__.py | 0 .../PyTorch/ws-multitask/utils/criterions.py | 34 ++ .../PyTorch/ws-multitask/utils/data.py | 63 ++++ .../PyTorch/ws-multitask/utils/model.py | 38 ++ 7 files changed, 521 insertions(+), 2 deletions(-) create mode 100644 weightslab/examples/PyTorch/ws-multitask/config.yaml create mode 100644 weightslab/examples/PyTorch/ws-multitask/main.py create mode 100644 weightslab/examples/PyTorch/ws-multitask/utils/__init__.py create mode 100644 weightslab/examples/PyTorch/ws-multitask/utils/criterions.py create mode 100644 weightslab/examples/PyTorch/ws-multitask/utils/data.py create mode 100644 weightslab/examples/PyTorch/ws-multitask/utils/model.py diff --git a/weightslab/components/global_monitoring.py b/weightslab/components/global_monitoring.py index 20099775..423004a8 100644 --- a/weightslab/components/global_monitoring.py +++ b/weightslab/components/global_monitoring.py @@ -229,9 +229,17 @@ def __enter__(self, f: bool = False): context = Context.TRAINING if self.for_training else Context.TESTING self._context_token = set_current_context(context) - # Update model + # Update model — get_model() returns a Proxy(None) placeholder when nothing + # is registered; only use it if the proxy target is actually resolved. _model = get_model() - self.model = _model if _model != None else self.model + if _model is not None: + try: + _target = object.__getattribute__(_model, '_obj') + if _target is not None: + self.model = _model + except AttributeError: + # _model is a plain (non-Proxy) object; use it directly + self.model = _model # The exact logic requested by the user: if self.model is not None: diff --git a/weightslab/examples/PyTorch/ws-multitask/config.yaml b/weightslab/examples/PyTorch/ws-multitask/config.yaml new file mode 100644 index 00000000..8511a026 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/config.yaml @@ -0,0 +1,38 @@ +# Multi-task MNIST: digit classification + bounding-box localization +experiment_name: mnist_multitask +device: auto +training_steps_to_do: null # null = run until stopped +# root_log_dir: ./logs/mnist_multitask + +# Task loss weights — increase loc_loss_weight if localization is underfitting +cls_loss_weight: 1.0 +loc_loss_weight: 5.0 + +num_classes: 10 +eval_full_to_train_steps_ratio: 500 +write_export_ratio: 100 + +tqdm_display: true +compute_natural_sort: false +skip_checkpoint_load: false + +# H5 / dataframe persistence +ledger_enable_flushing_threads: true +ledger_enable_h5_persistence: true +ledger_flush_max_rows: 15000 +ledger_flush_interval: 30.0 + +# Services +serving_grpc: true +serving_cli: false + +optimizer: + lr: 0.001 + +data: + train_loader: + batch_size: 32 + shuffle: true + test_loader: + batch_size: 64 + shuffle: false diff --git a/weightslab/examples/PyTorch/ws-multitask/main.py b/weightslab/examples/PyTorch/ws-multitask/main.py new file mode 100644 index 00000000..95bfad40 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/main.py @@ -0,0 +1,338 @@ +""" +Multi-task learning with WeightsLab — MNIST digit classification + localization. + +This example demonstrates how to track a multi-head model with WeightsLab: + - A shared CNN backbone feeds two heads. + - Classification head: cross-entropy over 10 digit classes. + - Localization head: Smooth-L1 regression of the digit's tight bounding box. + +Both losses are tracked separately in WeightsLab so you can: + - Compare classification vs. localization learning curves in the plots board. + - Inspect per-sample loss breakdown (hardest-to-classify vs. hardest-to-locate). + - See predicted bounding boxes overlaid on each MNIST sample in the data grid. + +WeightsLab task_type="detection" enables bbox visualization in the UI grid. +""" + +import itertools +import os +import ssl +import time +import logging +import tempfile + +try: + ssl.create_default_context() +except ssl.SSLError: + ssl._create_default_https_context = ssl._create_unverified_context + +import yaml +import tqdm +import torch +import torch.optim as optim +from torchvision import transforms + +import weightslab as wl +from weightslab.components.global_monitoring import ( + guard_training_context, + guard_testing_context, +) + +from utils.data import MNISTMultiTaskDataset, multitask_collate +from utils.model import MNISTMultiTaskModel +from utils.criterions import PerSampleClassificationLoss, PerSampleLocalizationLoss + +logging.basicConfig(level=logging.ERROR) + + +# ============================================================================= +# Helpers +# ============================================================================= + +def _build_preds(cls_logits, bbox_pred): + """ + Build detection-format predictions for WeightsLab UI overlay. + + Returns a list of [1, 6] tensors — one per sample — with columns: + [x1, y1, x2, y2, predicted_class, confidence] + """ + classes = cls_logits.argmax(dim=1).float() + confs = cls_logits.softmax(dim=1).max(dim=1).values + return [ + torch.stack([ + bbox_pred[i, 0], bbox_pred[i, 1], + bbox_pred[i, 2], bbox_pred[i, 3], + classes[i], confs[i], + ]).unsqueeze(0) + for i in range(len(classes)) + ] + + +# ============================================================================= +# Train / Test loops +# ============================================================================= + +def train(loader, model, optimizer, sig, device, cls_weight, loc_weight): + """Single multi-task training step.""" + with guard_training_context: + inputs, ids, targets, _ = next(loader) + inputs = inputs.to(device) + targets = [t.to(device) for t in targets] + + optimizer.zero_grad() + cls_logits, bbox_pred = model(inputs) + + preds = _build_preds(cls_logits.detach(), bbox_pred.detach()) + + cls_loss_per_sample = sig["cls_loss"](cls_logits, targets, batch_ids=ids, preds=preds) + loc_loss_per_sample = sig["loc_loss"](bbox_pred, targets, batch_ids=ids, preds=preds) + + loss = (cls_weight * cls_loss_per_sample + loc_weight * loc_loss_per_sample).mean() + loss.backward() + optimizer.step() + + # Per-sample classification accuracy for inspection in the data grid. + labels = torch.stack([t[0, 4].long() for t in targets]).to(device) + preds_cls = cls_logits.argmax(dim=1) + wl.save_signals( + {"cls_correct_per_sample": (preds_cls == labels).float()}, + ids, + ) + + return float(loss.detach().cpu()) + + +def test(loader, model, sig, device, cls_weight, loc_weight, loader_len): + """Full evaluation pass.""" + total_cls_loss = 0.0 + total_loc_loss = 0.0 + correct = 0 + total = 0 + + with guard_testing_context, torch.no_grad(): + for inputs, ids, targets, _ in loader: + inputs = inputs.to(device) + targets = [t.to(device) for t in targets] + + cls_logits, bbox_pred = model(inputs) + preds = _build_preds(cls_logits, bbox_pred) + + cls_loss_per_sample = sig["cls_loss"](cls_logits, targets, batch_ids=ids, preds=preds) + loc_loss_per_sample = sig["loc_loss"](bbox_pred, targets, batch_ids=ids, preds=preds) + + total_cls_loss += cls_loss_per_sample.mean().item() + total_loc_loss += loc_loss_per_sample.mean().item() + + labels = torch.stack([t[0, 4].long() for t in targets]).to(device) + preds_cls = cls_logits.argmax(dim=1) + correct += (preds_cls == labels).sum().item() + total += len(labels) + + wl.save_signals( + {"cls_correct_per_sample": (preds_cls == labels).float()}, + ids, + ) + + cls_loss = total_cls_loss / loader_len + loc_loss = total_loc_loss / loader_len + accuracy = 100.0 * correct / total if total > 0 else 0.0 + return cls_loss, loc_loss, accuracy + + +# ============================================================================= +# Main +# ============================================================================= +if __name__ == "__main__": + start_time = time.time() + + config_path = os.path.join(os.path.dirname(__file__), "config.yaml") + if os.path.exists(config_path): + with open(config_path, "r") as fh: + parameters = yaml.safe_load(fh) or {} + else: + parameters = {} + + parameters.setdefault("experiment_name", "mnist_multitask") + parameters.setdefault("device", "auto") + parameters.setdefault("training_steps_to_do", None) + parameters.setdefault("eval_full_to_train_steps_ratio", 500) + parameters.setdefault("write_export_ratio", 100) + parameters.setdefault("num_classes", 10) + parameters.setdefault("cls_loss_weight", 1.0) + parameters.setdefault("loc_loss_weight", 5.0) + + wl.watch_or_edit( + parameters, + flag="hyperparameters", + name=parameters["experiment_name"], + defaults=parameters, + poll_interval=1.0, + ) + + if parameters.get("device", "auto") == "auto": + parameters["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = parameters["device"] + + if not parameters.get("root_log_dir"): + tmp_dir = tempfile.mkdtemp() + parameters["root_log_dir"] = tmp_dir + print(f"No root_log_dir specified, using temporary directory: {tmp_dir}") + os.makedirs(parameters["root_log_dir"], exist_ok=True) + + log_dir = parameters["root_log_dir"] + eval_ratio = parameters["eval_full_to_train_steps_ratio"] + write_export_ratio = parameters["write_export_ratio"] + training_steps_to_do = parameters.get("training_steps_to_do") + tqdm_display = parameters.get("tqdm_display", True) + verbose = parameters.get("verbose", True) + cls_weight = float(parameters["cls_loss_weight"]) + loc_weight = float(parameters["loc_loss_weight"]) + num_classes = int(parameters["num_classes"]) + enable_h5 = parameters.get("ledger_enable_h5_persistence", True) + + # -- Data ------------------------------------------------------------------- + if parameters.get("data_root"): + data_root = parameters["data_root"] + should_download = not os.path.exists(data_root) + else: + data_root = os.path.join(log_dir, "data") + should_download = True + os.makedirs(data_root, exist_ok=True) + + train_cfg = parameters.get("data", {}).get("train_loader", {}) + test_cfg = parameters.get("data", {}).get("test_loader", {}) + + tf = transforms.Compose([transforms.ToTensor()]) + + _train_dataset = MNISTMultiTaskDataset( + root=data_root, train=True, download=should_download, transform=tf, + max_samples=train_cfg.get("max_samples"), + ) + _test_dataset = MNISTMultiTaskDataset( + root=data_root, train=False, download=should_download, transform=tf, + max_samples=test_cfg.get("max_samples"), + ) + + # task_type="detection" tells the UI to render bbox overlays on each sample. + train_loader = wl.watch_or_edit( + _train_dataset, + flag="data", + loader_name="train_loader", + task_type="detection", + batch_size=train_cfg.get("batch_size", 32), + shuffle=train_cfg.get("shuffle", True), + is_training=True, + compute_hash=False, + preload_labels=False, + enable_h5_persistence=enable_h5, + collate_fn=multitask_collate, + ) + test_loader = wl.watch_or_edit( + _test_dataset, + flag="data", + loader_name="test_loader", + task_type="detection", + batch_size=test_cfg.get("batch_size", 64), + shuffle=False, + is_training=False, + compute_hash=False, + preload_labels=True, + enable_h5_persistence=enable_h5, + collate_fn=multitask_collate, + ) + + # -- Model ------------------------------------------------------------------ + _model = MNISTMultiTaskModel(num_classes=num_classes).to(device) + model = wl.watch_or_edit(_model, flag="model", device=device) + + lr = parameters.get("optimizer", {}).get("lr", 1e-3) + _optimizer = optim.Adam(model.parameters(), lr=lr) + optimizer = wl.watch_or_edit(_optimizer, flag="optimizer") + + # -- Losses (two separate tracked signals) ---------------------------------- + # Tracking each loss independently lets you inspect which task is harder, + # set per-task learning rate schedules, or diagnose multi-task trade-offs. + def _make_signals(split): + return { + "cls_loss": wl.watch_or_edit( + PerSampleClassificationLoss(), + flag="loss", + name=f"{split}_cls_loss", per_sample=True, log=True, + ), + "loc_loss": wl.watch_or_edit( + PerSampleLocalizationLoss(), + flag="loss", + name=f"{split}_loc_loss", per_sample=True, log=True, + ), + } + + train_sig = _make_signals("train") + test_sig = _make_signals("test") + + # -- Serving ---------------------------------------------------------------- + wl.serve( + serving_grpc=parameters.get("serving_grpc", True), + serving_cli=parameters.get("serving_cli", False), + ) + + print("=" * 60) + print(" STARTING MNIST MULTI-TASK TRAINING") + print(f" Tasks: classification (x{cls_weight}) + localization (x{loc_weight})") + print(f" Eval every {eval_ratio} steps | Export every {write_export_ratio} steps") + print(f" Train: {len(_train_dataset)} samples Test: {len(_test_dataset)} samples") + print(f" Logs: {log_dir}") + print("=" * 60 + "\n") + + wl.start_training(timeout=3) + + if tqdm_display: + train_range = tqdm.tqdm( + range(training_steps_to_do) if training_steps_to_do else itertools.count(), + desc="Training", ncols=140, + ) + else: + train_range = ( + range(training_steps_to_do) if training_steps_to_do else itertools.count() + ) + + test_cls_loss, test_loc_loss, test_acc = None, None, None + test_loader_len = len(test_loader) + + for train_step in train_range: + age = model.get_age() if hasattr(model, "get_age") else train_step + + train_loss = train(train_loader, model, optimizer, train_sig, device, cls_weight, loc_weight) + + if age == 0 or age % eval_ratio == 0: + test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating", leave=False) if tqdm_display else test_loader + test_cls_loss, test_loc_loss, test_acc = test( + test_loader_it, model, test_sig, device, cls_weight, loc_weight, test_loader_len + ) + + if age > 0 and age % write_export_ratio == 0: + wl.write_history() + wl.write_dataframe() + + if tqdm_display: + postfix = [f"train={train_loss:.4f}"] + if test_cls_loss is not None: + postfix.append(f"cls={test_cls_loss:.4f}") + if test_loc_loss is not None: + postfix.append(f"loc={test_loc_loss:.4f}") + if test_acc is not None: + postfix.append(f"acc={test_acc:.1f}%") + train_range.set_postfix_str(" | ".join(postfix)) + elif verbose: + msg = f"Step {train_step} (Age {age}): train={train_loss:.4f}" + if test_cls_loss is not None: + msg += f" | cls={test_cls_loss:.4f} loc={test_loc_loss:.4f} acc={test_acc:.1f}%" + print(f"\r{msg:<120}", end="", flush=True) + + print("\n" + "=" * 60) + print(f" Training completed in {time.time() - start_time:.2f}s") + print(f" Logs: {log_dir}") + print("=" * 60) + + wl.write_history() + wl.write_dataframe() + wl.keep_serving() diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/__init__.py b/weightslab/examples/PyTorch/ws-multitask/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/criterions.py b/weightslab/examples/PyTorch/ws-multitask/utils/criterions.py new file mode 100644 index 00000000..dbbff6c1 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/utils/criterions.py @@ -0,0 +1,34 @@ +""" +Multi-task loss functions for WeightsLab tracking. + +Both losses accept the standard WeightsLab call signature: + loss(preds_raw, targets, batch_ids=ids, preds=preds) + +where: + - preds_raw : raw model output for this head + - targets : list of [N, 6] detection tensors ([x1,y1,x2,y2,class_id,conf]) + - batch_ids : sample ids for per-sample logging + - preds : predicted boxes (list of [N,6] tensors) for UI overlay + +Both return a [B] per-sample loss tensor so WeightsLab records one value per sample. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PerSampleClassificationLoss(nn.Module): + """Cross-entropy loss per sample; class labels are extracted from targets.""" + + def forward(self, preds_raw, targets, batch_ids=None, preds=None): + labels = torch.stack([t[0, 4].long() for t in targets]).to(preds_raw.device) + return F.cross_entropy(preds_raw, labels, reduction="none") + + +class PerSampleLocalizationLoss(nn.Module): + """Smooth-L1 (Huber) bbox regression loss per sample; gt boxes from targets.""" + + def forward(self, preds_raw, targets, batch_ids=None, preds=None): + gt_boxes = torch.stack([t[0, :4] for t in targets]).to(preds_raw.device) + return F.smooth_l1_loss(preds_raw, gt_boxes, reduction="none").mean(dim=1) diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/data.py b/weightslab/examples/PyTorch/ws-multitask/utils/data.py new file mode 100644 index 00000000..b185269c --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/utils/data.py @@ -0,0 +1,63 @@ +""" +MNIST multi-task dataset: each sample has a digit class label (classification) +and a tight bounding box of the non-zero pixels (localization). + +Target format follows the WeightsLab detection convention: + tensor of shape [N, 6] with columns [x1, y1, x2, y2, class_id, confidence] + all coordinates normalized to [0, 1]. + +This lets the WeightsLab UI render ground-truth bboxes over each sample. +""" + +import torch +from torch.utils.data import Dataset +from torchvision import datasets, transforms + + +class MNISTMultiTaskDataset(Dataset): + """MNIST with per-sample tight bounding boxes synthesized from pixel intensity.""" + + def __init__(self, root, train=True, download=True, transform=None, max_samples=None): + try: + self._mnist = datasets.MNIST(root=root, train=train, download=download, transform=None) + except RuntimeError: + self._mnist = datasets.MNIST(root=root, train=train, download=True, transform=None) + + self.transform = transform or transforms.ToTensor() + self.max_samples = max_samples + self._length = min(len(self._mnist), max_samples) if max_samples else len(self._mnist) + + def __len__(self): + return self._length + + def _compute_bbox(self, img_tensor): + """Return (x1, y1, x2, y2) normalized to [0,1] for the digit's tight bbox.""" + mask = img_tensor.squeeze(0) > 0.1 + if not mask.any(): + return 0.0, 0.0, 1.0, 1.0 + + rows_with_signal = mask.any(dim=1).nonzero(as_tuple=True)[0] + cols_with_signal = mask.any(dim=0).nonzero(as_tuple=True)[0] + + H, W = img_tensor.shape[-2], img_tensor.shape[-1] + y1 = float(rows_with_signal.min()) / H + y2 = float(rows_with_signal.max()) / H + x1 = float(cols_with_signal.min()) / W + x2 = float(cols_with_signal.max()) / W + return x1, y1, x2, y2 + + def __getitem__(self, idx): + """Returns (image, idx, target) where target is a [1, 6] detection tensor.""" + image, label = self._mnist[idx] + image = self.transform(image) + x1, y1, x2, y2 = self._compute_bbox(image) + target = torch.tensor( + [[x1, y1, x2, y2, float(label), 1.0]], dtype=torch.float32 + ) + return image, idx, target + + +def multitask_collate(batch): + """Collate for detection-format targets: targets remains a list of [N,6] tensors.""" + images, ids, targets = zip(*batch) + return torch.stack(images), torch.tensor(ids, dtype=torch.long), list(targets), {} diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/model.py b/weightslab/examples/PyTorch/ws-multitask/utils/model.py new file mode 100644 index 00000000..601270a3 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/utils/model.py @@ -0,0 +1,38 @@ +""" +Multi-task CNN for MNIST: shared backbone + classification head + localization head. +""" + +import torch.nn as nn + + +class MNISTMultiTaskModel(nn.Module): + """ + Shared CNN backbone with two heads: + - cls_head: digit classification (10 classes) + - loc_head: tight bounding-box regression (normalized x1,y1,x2,y2) + """ + + def __init__(self, num_classes=10): + super().__init__() + self.backbone = nn.Sequential( + nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), + nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), + nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), + nn.AdaptiveAvgPool2d(4), + ) + feat_dim = 128 * 4 * 4 + + self.cls_head = nn.Sequential( + nn.Flatten(), + nn.Linear(feat_dim, 256), nn.ReLU(), nn.Dropout(0.3), + nn.Linear(256, num_classes), + ) + self.loc_head = nn.Sequential( + nn.Flatten(), + nn.Linear(feat_dim, 128), nn.ReLU(), + nn.Linear(128, 4), nn.Sigmoid(), + ) + + def forward(self, x): + features = self.backbone(x) + return self.cls_head(features), self.loc_head(features) From d1201ae6e757de1ab2eb39e26605e113752c47ca Mon Sep 17 00:00:00 2001 From: GuillaumePELLUET Date: Mon, 22 Jun 2026 23:47:54 +0200 Subject: [PATCH 16/16] Fix GuardContext stale model ref across tests (always resolve from ledger) Previous fix only prevented assigning an unresolved Proxy but left self.model pointing to a stale proxy from a prior call. Now always derives self.model from the current ledger state on each __enter__: if get_model() returns a Proxy with no target (no model registered), self.model is set to None rather than falling back to a potentially-stale earlier value. Co-Authored-By: Claude Sonnet 4.6 --- weightslab/components/global_monitoring.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/weightslab/components/global_monitoring.py b/weightslab/components/global_monitoring.py index 423004a8..4dcfa184 100644 --- a/weightslab/components/global_monitoring.py +++ b/weightslab/components/global_monitoring.py @@ -229,17 +229,16 @@ def __enter__(self, f: bool = False): context = Context.TRAINING if self.for_training else Context.TESTING self._context_token = set_current_context(context) - # Update model — get_model() returns a Proxy(None) placeholder when nothing - # is registered; only use it if the proxy target is actually resolved. + # Update model — always resolve from the current ledger so stale references + # from previous calls never bleed through. get_model() returns a Proxy(None) + # placeholder when nothing is registered; treat that as "no model". _model = get_model() - if _model is not None: - try: - _target = object.__getattribute__(_model, '_obj') - if _target is not None: - self.model = _model - except AttributeError: - # _model is a plain (non-Proxy) object; use it directly - self.model = _model + try: + _target = object.__getattribute__(_model, '_obj') + self.model = _model if _target is not None else None + except AttributeError: + # _model is a plain (non-Proxy) object; use it directly + self.model = _model # The exact logic requested by the user: if self.model is not None: