diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00c53b8f..0f6ce350 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -223,25 +223,6 @@ jobs: # A per-test timeout guards against any regression that hangs a test. python -m pytest ./weightslab/tests -v --timeout=300 - # TODO (GP): WL CI do not find WS CI for now; token or visibility problem ?? - # - name: Trigger WeightsStudio CI - # env: - # WS_TOKEN: ${{ secrets.WEIGHTS_STUDIO_API_TOKEN }} - # run: | - # if [ -z "${WS_TOKEN}" ]; then - # echo "WEIGHTS_STUDIO_API_TOKEN not set; skipping WeightsStudio trigger." - # exit 0 - # fi - - # # Trigger the ws-ci workflow in the weights_studio repository on main. - # curl -fSs -X POST "https://api.github.com/repos/GrayboxTech/weights_studio/actions/workflows/ws-ci.yml/dispatches" \ - # -H "Authorization: Bearer ${WS_TOKEN}" \ - # -H "Accept: application/vnd.github+json" \ - # -H "Content-Type: application/json" \ - # -d '{"ref":"main"}' - - # echo "WeightsStudio workflow dispatch requested successfully." - build-and-publish-dev: # Only publish to TestPyPI when pushing to main (not on PRs or dev branch pushes). if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d4edd332..ae6155af 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,7 +47,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install . --extra-index-url https://download.pytorch.org/whl/cpu + +# Install the test extra so pytest, graphviz, torchmetrics, + # pytorch-lightning and tensorboard are available (several test modules + # import pytest / use pytest fixtures and cannot run under bare unittest). + python -m pip install .[utest] --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install pytest-timeout - name: Run tests run: | @@ -226,7 +231,7 @@ jobs: export PRS_JSON PRS_JSON=$(gh pr list --state merged --base "dev" --limit 100 \ - --json number,title,mergedAt,url,author,body \ + --json number,title,mergedAt,url,author,body,commits \ 2>/dev/null || echo "[]") python3 << 'PYEOF' @@ -249,6 +254,29 @@ jobs: line += f"\n\n > {desc}" return line filtered = [pr for pr in prs_data if pr.get("mergedAt", "") > prev_date] + # "What's Changed precisely": list each merged PR's own developer commits, + # NOT the squashed/merge/chore commits that land on the release branch + # (those are redundant with the PR list above and not useful). + _seen_commits = set() + _commit_blocks = [] + for pr in filtered: + _commit_lines = [] + for c in (pr.get("commits") or []): + oid = (c.get("oid") or "")[:7] + head = (c.get("messageHeadline") or "").strip() + if not head or head.startswith("Merge "): + continue + key = (oid, head) + if key in _seen_commits: + continue + _seen_commits.add(key) + _commit_lines.append(f" - `{oid}` {head}") + if _commit_lines: + _commit_blocks.append( + f"**[#{pr['number']}]({pr['url']}) {pr['title']}**\n" + "\n".join(_commit_lines) + ) + # Fall back to the raw git log only when no PR commits are available. + commits_section = "\n\n".join(_commit_blocks) if _commit_blocks else commits if filtered: prs_lines = "\n".join(_pr_entry(pr) for pr in filtered) seen, clabels = set(), [] @@ -276,7 +304,7 @@ jobs: "Happy Training!\n\n" "---\n\n" "### What's Changed precisely:\n\n" - f"{commits}\n\n" + f"{commits_section}\n\n" "---\n\n" "### Thank you!\n\n" f"{contributors}\n" @@ -303,7 +331,8 @@ jobs: build-and-publish-main: name: Build & Publish Main (PyPI) - needs: [detect-target] + needs: [detect-target,test] + # needs: [detect-target] runs-on: ubuntu-latest if: ${{ needs.detect-target.outputs.is_main == 'true' }} permissions: @@ -458,7 +487,7 @@ jobs: export PRS_JSON PRS_JSON=$(gh pr list --state merged --base "main" --limit 100 \ - --json number,title,mergedAt,url,author,body \ + --json number,title,mergedAt,url,author,body,commits \ 2>/dev/null || echo "[]") python3 << 'PYEOF' @@ -481,6 +510,29 @@ jobs: line += f"\n\n > {desc}" return line filtered = [pr for pr in prs_data if pr.get("mergedAt", "") > prev_date] + # "What's Changed precisely": list each merged PR's own developer commits, + # NOT the squashed/merge/chore commits that land on the release branch + # (those are redundant with the PR list above and not useful). + _seen_commits = set() + _commit_blocks = [] + for pr in filtered: + _commit_lines = [] + for c in (pr.get("commits") or []): + oid = (c.get("oid") or "")[:7] + head = (c.get("messageHeadline") or "").strip() + if not head or head.startswith("Merge "): + continue + key = (oid, head) + if key in _seen_commits: + continue + _seen_commits.add(key) + _commit_lines.append(f" - `{oid}` {head}") + if _commit_lines: + _commit_blocks.append( + f"**[#{pr['number']}]({pr['url']}) {pr['title']}**\n" + "\n".join(_commit_lines) + ) + # Fall back to the raw git log only when no PR commits are available. + commits_section = "\n\n".join(_commit_blocks) if _commit_blocks else commits if filtered: prs_lines = "\n".join(_pr_entry(pr) for pr in filtered) seen, clabels = set(), [] @@ -508,7 +560,7 @@ jobs: "Happy Training!\n\n" "---\n\n" "### What's Changed precisely:\n\n" - f"{commits}\n\n" + f"{commits_section}\n\n" "---\n\n" "### Thank you!\n\n" f"{contributors}\n" diff --git a/AGENTS.md b/AGENTS.md index 50773f7c..8b8eeda5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -155,11 +155,18 @@ ones when debugging: | `WS_HISTOGRAM_MAX_BINS` | `512` | Cap on metadata histogram bars. | | `BB_THUMB_RENDER` | `10` | Max bounding boxes drawn per **thumbnail**, per overlay (GT and PRED capped independently). | | `BB_MODAL_RENDER` | `100` | Max bounding boxes drawn per **modal** image, per overlay. A `?` button in the modal shows the active limit. | - -> **VITE_ vs WS_/BB_:** `VITE_*` variables are baked at **build time** (changing -> them needs a rebuild). `WS_*` / `BB_*` are injected at **container start** into -> `config.js` and read as `window.*` globals — changing them needs only a -> container restart + browser reload (see the caching note in §5). +| `ENABLE_PLOTS` | `1` | `0`/`false` removes the plots board + Signals card and stops plot auto-refresh. | +| `ENABLE_DATA_EXPLORATION` | `1` | `0`/`false` removes the data grid + metadata/details panel and stops the data/metadata auto-refresh. | +| `ENABLE_HYPERPARAMETERS_OPTIMIZATION` | `1` | `0`/`false` removes the Hyperparameters section, makes HP inputs read-only, and stops the HP poll. | +| `ENABLE_AGENT` | `1` | `0`/`false` removes the agent chat bar + history panel and stops the agent health poll. | + +> **VITE_ vs WS_/BB_/ENABLE_:** `VITE_*` variables are baked at **build time** +> (changing them needs a rebuild). `WS_*` / `BB_*` / `ENABLE_*` are injected at +> **container start** into `config.js` and read as `window.*` globals (the +> toggles as `window.WS_ENABLE_*`) — changing them needs only a container restart +> + browser reload (see the caching note in §5). Each `ENABLE_*` defaults to on; +> set it to `0`/`false`/`no`/`off` to disable. Full reference: +> `weightslab/docs/configuration.rst` (“Feature toggles”). --- diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..9a1a56ba --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,144 @@ +# WeightsLab Workspace — Project Knowledge + +Self-contained knowledge file for the WeightsLab workspace. One file, organized by topic so you can jump to the section you need. **Jump to:** + +| If you need… | Go to section | +|---|---| +| The 3 repos and how they relate | [1. Workspace layout](#1-workspace-layout) | +| How backend & frontend talk at runtime | [2. Runtime integration](#2-runtime-integration) | +| Where backend Python code lives | [3. weightslab backend module map](#3-weightslab-backend-module-map) | +| Where frontend TS code / tests live | [4. weights_studio frontend module map](#4-weights_studio-frontend-module-map) | +| How a user training script plugs in | [5. Integration API (the usecase pattern)](#5-integration-api-the-usecase-pattern) | +| Testing rules, data/H5/tags features | [6. Topic notes](#6-topic-notes) | + +> Paths/line claims are point-in-time — verify against current code before asserting as fact. + +--- + +## 1. Workspace layout + +Three sibling repos under `c:\Users\GuillaumePELLUET\Documents\Codes\`: + +- **weightslab** (`weightslab/`) — Python backend/core. ML training, data processing, gRPC API, the published `pip install weightslab` package. Python pkg root: `weightslab/weightslab/`. +- **weights_studio** (`weights_studio/`) — TypeScript/Vite web UI. Consumes the backend over grpc-web. **All Playwright/E2E user-simulation tests live here, not in weightslab.** +- **weightslab_kitchen** (`weightslab_kitchen/`) — private examples/reference, minimal docs. + +They must be checked out **side-by-side**: weights_studio's proto codegen reads into the weightslab directory (see §2). + +--- + +## 2. Runtime integration + +### Shared contract — one proto, two sides +- Source of truth: `weightslab/weightslab/proto/experiment_service.proto` defines `service ExperimentService` (~20 RPCs). +- **Backend** implements it in `weightslab/weightslab/trainer/services/experiment_service.py` (the gRPC servicer), delegating to `model_service.py`, `data_service.py`, `agent_service.py`. +- **Frontend** consumes it via generated client `weights_studio/src/experiment_service.client.ts` (+ `experiment_service.ts`), produced by `npm run generate-proto:data` → + `protoc --ts_out src/ --proto_path ../weightslab/weightslab/proto experiment_service.proto`. + +### Wire path (browser → training process) +``` +weights_studio (browser, GrpcWebFetchTransport, src/main.ts) + → http(s) :8080 Envoy proxy (grpc-web ↔ grpc transcoding) + → cluster grpc-backend :__GRPC_BACKEND_PORT__ (Python gRPC servicer) + → in-process training loop (watched model/optimizer/data/loss) +``` +Browsers can't speak raw gRPC, so Envoy translates. Frontend default server port **8080** (Envoy listener); admin **9901**. Backend gRPC port is templated (`__GRPC_BACKEND_PORT__`, substituted at deploy). `main.ts` supports path-based deploys (`//api`, `/demo//api`) and loopback/TLS host normalization. + +### RPC groups +- **Training control:** `ExperimentCommand` (pause/resume/…), `GetLatestLoggerData` (metric streaming). +- **Weights/arch:** `ManipulateWeights`, `GetWeights`, `GetActivations`. +- **Data:** `GetSamples`, `ApplyDataQuery`, `GetDataSamples`, `EditDataSample`, `GetDataSplits`. +- **Agent (LLM):** `CheckAgentHealth`, `InitializeAgent`, `ChangeAgentModel`, `GetAgentModels`, `ResetAgent`. +- **Checkpoint/eval:** `RestoreCheckpoint`, `TriggerEvaluation`, `GetEvaluationStatus`, `CancelEvaluation`. + +### Deployment +- `weightslab ui launch` → `weightslab/weightslab/ui_docker_bridge.py` brings up the bundled Docker stack (`weightslab/weightslab/ui/docker/docker-compose.yml` + `envoy.yaml`) with TLS via `weightslab/security/CertAuthManager`. This is how the published package serves the studio UI. +- weights_studio also ships its own dev/prod Docker + Envoy under `weights_studio/docker/` and `weights_studio/envoy/`. `npm run dev` = Vite on :5173. + +### Proto-change checklist (keep all three in sync) +1. Edit `.proto`. 2. Regenerate backend `*_pb2*.py`. 3. Run `npm run generate-proto:data` in weights_studio. + +--- + +## 3. weightslab backend module map + +Package root `weightslab/weightslab/`. Public API re-exported from `__init__.py` (← `src.py`). Used as `import weightslab as wl`. + +Layers (top depends on lower): +- **`src.py`** — facade implementing public verbs: `watch_or_edit`, `serve`, `keep_serving`, `save_signals`, `save_instance_signals`, `tag_samples`, `register_categorical_tag`, `discard_samples`, `query_signal_history` / `query_sample_history` / `query_instance_history`, `get_current_experiment_hash`, etc. +- **`trainer/`** — orchestration. `trainer_services.py`, `trainer_tools.py`, `experiment_context.py`. + - `services/experiment_service.py` — the gRPC servicer implementing `ExperimentService`. + - `services/{model_service,data_service,agent_service}.py` — per-domain delegates. + - `services/agent/` — LLM agent (configured by repo-root `agent_config.yaml`: `ollama` local / `openrouter` remote). + - `services/instance_merger.py` — multi-instance (detection/seg) handling. +- **`components/`** — cross-cutting runtime machinery. + - `global_monitoring.py` — `guard_training_context` / `guard_testing_context`, pause controller, the global rlock used by the servicer (training + serving run in one process, different threads). + - `evaluation_controller.py` (`eval_controller`), `checkpoint_manager.py`, `tracking.py`, `experiment_hash.py`, `parallel_primitives.py`. +- **`models/`** — `model_with_ops.py` (watched/op-able model wrapper), `monkey_patcher.py`. +- **`data/`** — dataframe + storage backbone. `dataframe_manager.py`, `data_samples_with_ops.py`, `sample_stats.py` (`SampleStatsEx`); storage `h5_dataframe_store.py`, `h5_array_store.py`, `array_proxy.py`. +- **`backend/`** — primitives. `ledgers.py` (`GLOBAL_LEDGER`, hyperparameter registry: `get_hyperparams`/`set_hyperparam`/`Proxy`), `model_interface.py`, `optimizer_interface.py`, `dataloader_interface.py`, `audit_logger.py`, `logger.py`, `cli.py` (optional localhost TCP REPL). +- **`proto/`** — `.proto` + generated `*_pb2*.py` (shared with weights_studio). +- **`baseline_models/`** — ready nets (e.g. `baseline_models.pytorch.models.FashionCNN`). +- **`ui/`** — bundled Docker/Envoy/nginx assets. **`security/`** — `CertAuthManager`. **`examples/`** — see §5. + +**Key fan-in points:** `ledgers.GLOBAL_LEDGER` is the hub (`watch_or_edit` registers objects there; the servicer reads/mutates through it). `components/global_monitoring` locks coordinate the training thread with gRPC calls. + +--- + +## 4. weights_studio frontend module map + +Vite + TypeScript. Entry `index.html` → `src/main.ts`. + +- **`main.ts`** — bootstrap: infers server host/port (default :8080), builds `GrpcWebFetchTransport`, wires panels, handles path-based deploy + TLS host normalization. +- **`experiment_service.client.ts` + `experiment_service.ts`** — generated gRPC-web client/types (also under `src/proto/`). **Do not hand-edit;** regenerate via `generate-proto:data`. +- **`left_panel/`** (`leftPanel.ts`, `panelResizer.ts` — controls, class/tag prefs), **`main_area/`** (board resizers), **`plots/`** (Chart.js + zoom), **`grid_data/`** (sample grid/table), **`agent/agentPanel.ts`** (LLM agent UI), **`ui/`/`utils/`/`helpers.ts`/`ContextMenu.ts`/`darkMode.ts`/`resilience.ts`** (shared UI + reconnection), **`test/`** (vitest). + +### Build / proto scripts (package.json) +- `generate-proto:data` reads the sibling weightslab repo (must be side-by-side). +- `npm run dev` (Vite, `VITE_HOST` 0.0.0.0 / `VITE_PORT` 5173), `build`, `preview`. + +### Tests (see §6 for placement rule) +- Unit: `npm run test` (vitest). +- Managed realtime (spins backend, via `scripts/run-managed-playwright.mjs`): `test:realtime:cls`, `test:realtime:seg`. +- Real-usecase E2E: `test:e2e:detection` (`tests/playwright/real_usecases/user_detection_yolo.spec.ts`), `test:e2e:segmentation` (`...user_segmentation_bdd.spec.ts`). +- `test:all` = unit + realtime cls/seg + e2e. + +--- + +## 5. Integration API (the usecase pattern) + +How a user's own PyTorch script plugs in so weights_studio can inspect/edit it live. Examples: `weightslab/weightslab/examples/{PyTorch,PyTorch_Lightning}//` — each is `main.py` + `config.yaml`. Usecases: `ws-classification`, `ws-segmentation`, `ws-face_recognition-triplet_loss`, `ws-vad` (+ Lightning classification). + +### Pattern — `import weightslab as wl` +Wrap each training object with `wl.watch_or_edit(obj, flag=...)`; the returned tracked proxy is registered in the ledger so the gRPC service can read stats / apply edits at runtime: +- `flag="hyperparameters"` — HP dict (required flag for trainer-services/UI visibility). +- `flag="model"` — wraps `nn.Module` (`device=…`); enables weight inspection + arch ops + `.get_age()`. +- `flag="optimizer"`. +- `flag="data"` — wraps a `Dataset` into a tracked loader: `loader_name`, `batch_size`, `shuffle`, `is_training`, `preload_labels`, `enable_h5_persistence`, … +- `flag="loss"` — wraps a `reduction="none"` criterion (`signal_name`, `log=True`); called `(preds_raw, targets, batch_ids=ids, preds=preds)` so per-sample loss maps to sample ids. +- `flag="metric"` — wraps a torchmetrics metric. + +### Dataset contract +`Dataset.__getitem__` returns **`(image, idx, label)`** — the sample id is threaded through training so per-sample signals attribute back to the sample. + +### Loop conventions +- `with guard_training_context:` (train step) / `with guard_testing_context:` (eval) — drives pause/resume + train/test stat separation. +- `wl.save_signals(preds_raw=, targets=, batch_ids=ids, signals={...}, preds=)` for extra per-sample signals. +- Use `model.get_age()` (steps actually trained, survives checkpoint reloads), not the raw loop counter. + +### Serving lifecycle +- `wl.serve(serving_grpc=…, serving_cli=…)` starts background serving threads **in the same process** as training. +- End the script with `wl.keep_serving()` to keep serving threads alive after the loop. +- Config from sibling `config.yaml`: `root_log_dir`, `device`, `training_steps_to_do`, `eval_full_to_train_steps_ratio`, `data.*_loader`, `optimizer.lr`, `enable_h5_persistence`, `serving_grpc`, … + +--- + +## 6. Topic notes + +- **Playwright test placement:** E2E/user-simulation tests belong in **weights_studio** (UI simulation), not weightslab. The Python backend now starts in **parallel** with Docker deployment (not sequentially) in the managed test runner. +- **Multi-instance dataframe:** MultiIndex `(sample_id, annotation_id)` supports per-instance data for detection/segmentation. +- **H5 storage:** `H5DataFrameStore` preserves the `(sample_id, annotation_id)` multi-index through write/read. `tag:xxx` columns are auto-optimized to categorical dtype (~90% memory savings). +- **Categorical tags:** planned support for multi-value tags with predefined categories; boolean tags unchanged. +- **Detection class colors:** class preferences from the left panel apply to detection bbox rendering. +- **Audit logger:** json/csv output configurable via `AUDIT_LOG_FORMAT` env var. +- **Docker-in-Docker:** envoy template mounting / file access fixed for GitHub Actions runner DinD environments. diff --git a/README.md b/README.md index 3481c80e..c9c0890a 100644 --- a/README.md +++ b/README.md @@ -266,6 +266,29 @@ def main(): total_loss += loss.item() + # Write the history of these samples every x steps + if model.get_age() % 100 == 0: + print(f'Dump signals history and dataframe at age {model.get_age()}') + wl.write_history( + # path=None, # Use root_log_dir by default, filename generated from parameters md5 hash + type_of_history="all", + graph_name=[ + 'train/clsf_instance', + 'val/clsf_instance' + ], + # experiment_hash=None, Default is 'last', i.e., current experiment hash + sample_id=['11', '29', '28', '27', '22'], + instance_id=[1, 2, 3] + ) + + # Dump the sample dataframe: all signals plus the loss_shape categorical tag, + wl.write_dataframe( + columns=["signals", "tag:loss_shape"], + format='csv' + # sample_id=['0', '28'] + # instance_id=[1, 2], + ) + avg_loss = total_loss / len(dataloader) print(f"Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}") diff --git a/docs/configuration.rst b/docs/configuration.rst index fd23ee41..7d6a8fe5 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -245,6 +245,14 @@ Data and Cache uniformly downsampled — keeping the first and last point and an evenly-spaced subset in between (no values are interpolated/invented). Set to ``0`` to disable the cap and return every step of the mean curve. + * - ``WL_POINT_CLOUD_CHUNK_BYTES`` + - ``1048576`` + - Size, in bytes, of each chunk streamed by the ``GetPointCloud`` RPC + (raw ``float32`` point-cloud data is sent as a sequence of binary + messages). Defaults to ``1048576`` (1 MiB). Larger chunks mean fewer + gRPC messages but more memory held per message; smaller chunks lower + peak memory at the cost of more round-trips. Must be a positive integer + — non-positive or non-numeric values fall back to the 1 MiB default. Evaluation Mode @@ -547,3 +555,101 @@ These variables are injected into the browser bundle at build / dev time. * - ``VITE_WS_MODAL_CACHE_MAX_MB`` - ``64`` - Maximum memory (MB) for the full-resolution modal image cache. + + +Bounding-box render limits +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Detection samples can carry many bounding boxes per image (dense scenes, +high-recall predictions). Drawing them all slows rendering and turns the +overlay into noise, so the number of boxes drawn per image is capped. The cap +is applied **separately** to ground-truth (GT) and predictions (PRED) — a value +of ``10`` allows up to 10 GT boxes *and* 10 PRED boxes per image. Boxes beyond +the cap are simply not drawn (predictions are typically score-ordered, so the +most confident ones are kept). + +These are set on the Weights Studio frontend container (for example in +``../weights_studio/docker/docker-compose.yml``) and injected into the page at +startup by the nginx entrypoint — changing them needs no rebuild, just a +container restart. For a local ``vite`` dev server, use the ``VITE_`` fallbacks +shown below. Values are clamped to a hard ceiling of ``10000``. + +.. list-table:: + :header-rows: 1 + :widths: 30 12 58 + + * - Variable + - Default + - Description + * - ``BB_THUMB_RENDER`` + - ``10`` + - Maximum bounding boxes drawn per image in the grid **thumbnails**, per + overlay (up to N ground-truth and N predictions). Dev-server fallback: + ``VITE_BB_THUMB_RENDER``. + * - ``BB_MODAL_RENDER`` + - ``100`` + - Maximum bounding boxes drawn per image in the **modal** detail view, per + overlay (up to N ground-truth and N predictions). A ``?`` button in the + top-right of the modal image surfaces the active limit on hover. + Dev-server fallback: ``VITE_BB_MODAL_RENDER``. + +.. note:: + + These caps only affect *rendering* — no sample data is dropped. They apply to + detection bounding-box overlays; segmentation masks are unaffected. + + +Feature toggles +~~~~~~~~~~~~~~~ + +Whole areas of the Studio UI can be turned off for a given deployment — for +example a read-only demo that only shows plots, or a labelling-only view with no +agent. Each toggle **removes the area from the UI** (the elements are hidden) +**and stops its background work** (auto-refresh timers and gRPC polls are never +started), so a disabled area costs nothing at runtime. + +Like the bounding-box render limits, these are set on the Weights Studio frontend +container (for example in ``../weights_studio/docker/docker-compose.yml``) and +injected into the page at startup by the nginx entrypoint — changing them needs +no rebuild, just a container restart + browser reload. For a local ``vite`` dev +server, use the ``VITE_`` fallbacks shown below. Every toggle **defaults to +enabled**; set it to ``0`` / ``false`` / ``no`` / ``off`` (any case) to disable. + +.. list-table:: + :header-rows: 1 + :widths: 38 10 52 + + * - Variable + - Default + - Description + * - ``ENABLE_PLOTS`` + - ``1`` + - When disabled, removes the plots board and the left-panel Signals/metrics + card, and stops the plot-data auto-refresh (the ``GetLatestLoggerData`` + poll and the chart redraw loop). Dev-server fallback: + ``VITE_ENABLE_PLOTS``. + * - ``ENABLE_DATA_EXPLORATION`` + - ``1`` + - When disabled, removes the data sample grid and the metadata / details + left panel, and stops the data auto-refresh (the ``GetDataSamples`` / + ``GetMetaData`` timers and the slider-histogram poll). Dev-server + fallback: ``VITE_ENABLE_DATA_EXPLORATION``. + * - ``ENABLE_HYPERPARAMETERS_OPTIMIZATION`` + - ``1`` + - When disabled, removes the Hyperparameters section from the left panel, + makes the hyperparameter inputs read-only (no user edits are sent to the + backend), and stops the hyperparameter sync poll. Dev-server fallback: + ``VITE_ENABLE_HYPERPARAMETERS_OPTIMIZATION``. + * - ``ENABLE_AGENT`` + - ``1`` + - When disabled, removes the agent chat input bar (and its send button) and + the chat-history panel, and stops the agent health-check poll. Dev-server + fallback: ``VITE_ENABLE_AGENT``. + +.. note:: + + Each variable maps to a ``window.WS_ENABLE_*`` global injected into + ``config.js`` at container start (the same mechanism as the bounding-box + limits), with a build-time ``VITE_ENABLE_*`` fallback for the dev server. + Because ``config.js`` is served ``no-store``, a container restart + normal + reload is enough to pick up a change. diff --git a/docs/weights_studio.rst b/docs/weights_studio.rst index 79cba1e0..1d6036dc 100644 --- a/docs/weights_studio.rst +++ b/docs/weights_studio.rst @@ -71,6 +71,8 @@ Default values in ``../weights_studio/docker/.env``: - ``VITE_PORT=5173`` - ``VITE_HISTOGRAM_MAX_BINS=512`` +- ``BB_THUMB_RENDER=10`` (max bounding boxes drawn per thumbnail image, per overlay) +- ``BB_MODAL_RENDER=100`` (max bounding boxes drawn per modal image, per overlay) - ``WS_SERVER_HOST=localhost`` - ``WS_SERVER_PORT=8080`` - ``WS_SERVER_PROTOCOL=https`` @@ -313,6 +315,8 @@ consistently with your deployed endpoints: WS_SERVER_HOST=studio.your-domain.com WS_SERVER_PORT=443 VITE_HISTOGRAM_MAX_BINS=512 + BB_THUMB_RENDER=10 + BB_MODAL_RENDER=100 # envoy / backend internal wiring ENVOY_PORT=8080 @@ -369,7 +373,9 @@ Use this pattern for a simple single-VM production-like deployment. WS_SERVER_PROTOCOL=https WS_SERVER_HOST=studio.your-domain.com WS_SERVER_PORT=443 - VITE_HISTOGRAM_MAX_BINS=512 + VITE_HISTOGRAM_MAX_BINS=512 + BB_THUMB_RENDER=10 + BB_MODAL_RENDER=100 ENVOY_PORT=8080 ENVOY_ADMIN_PORT=9901 diff --git a/pyproject.toml b/pyproject.toml index 962e5b7f..71aef233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "numpy>=1.25.2,<2.0; python_version < '3.13'", "numpy>=2.1,<3; python_version >= '3.13'", "pandas>=2.2.3,<3", + "duckdb>=1.1,<2", # signal/sample/instance history store "PyYAML>=6.0.3,<7", "dill>=0.3.8,<0.5", "zstandard>=0.22,<1", @@ -52,7 +53,8 @@ dependencies = [ # Imaging "Pillow>=10,<12", - "opencv-python>=4.8,<5", + + # Modeling "onnx>=1.15,<=1.20", # Utility used in examples and progress reporting diff --git a/weightslab/__init__.py b/weightslab/__init__.py index 07262ba3..df7a8154 100644 --- a/weightslab/__init__.py +++ b/weightslab/__init__.py @@ -11,7 +11,7 @@ import logging import threading -from .src import watch_or_edit, start_training, serve, keep_serving, save_signals, save_instance_signals, save_group_signals, tag_samples, register_categorical_tag, set_categorical_tag, discard_samples, get_samples_by_tag, get_discarded_samples, signal, eval_fn, compute_signals, SignalContext, clear_all, run_pending_evaluation, trigger_pending_evaluation_async, query_signal_history, query_sample_history, query_instance_history, write_history, write_dataframe, get_current_experiment_hash, pointcloud_thumbnail, pointcloud_boxes +from .src import watch_or_edit, start_training, serve, keep_serving, load_experiment_for_explore, save_signals, save_instance_signals, save_group_signals, tag_samples, register_categorical_tag, set_categorical_tag, discard_samples, get_samples_by_tag, get_discarded_samples, signal, eval_fn, compute_signals, SignalContext, clear_all, run_pending_evaluation, trigger_pending_evaluation_async, query_signal_history, query_sample_history, query_instance_history, write_history, write_dataframe, get_current_experiment_hash, pointcloud_thumbnail, pointcloud_boxes from .backend.ledgers import GLOBAL_LEDGER as ledger from .art import _BANNER from .utils.logs import setup_logging, set_log_directory, is_main_process @@ -63,7 +63,7 @@ # Get Package Metadata try: # setuptools_scm will write weightslab/_version.py during build - from ._version import __version__ # type: ignore + from ._version import __version__ # type: ignore except Exception: # Fallback when developing locally or before build; keeps behavior stable. from datetime import datetime @@ -76,6 +76,7 @@ "watch_or_edit", "serve", "keep_serving", + "load_experiment_for_explore", "save_signals", "save_instance_signals", "save_group_signals", @@ -102,8 +103,10 @@ "query_signal_history", "query_sample_history", "query_instance_history", + "write_history", "write_dataframe", + "pointcloud_thumbnail", "pointcloud_boxes", diff --git a/weightslab/art.py b/weightslab/art.py index cf34b072..4436ee2f 100644 --- a/weightslab/art.py +++ b/weightslab/art.py @@ -10,11 +10,11 @@ def get_git_info(): git_root = current_dir # Traverse up to find .git directory - for _ in range(10): # Limit search depth + for _ in range(10): # Limit search depth if os.path.isdir(os.path.join(git_root, '.git')): break parent = os.path.dirname(git_root) - if parent == git_root: # Reached filesystem root + if parent == git_root: # Reached filesystem root git_root = None break git_root = parent @@ -40,19 +40,19 @@ def get_git_info(): branch, version, commit_hash = get_git_info() _BANNER = f""" -\x1b[31m /WW /WW\x1b[0m /$$ /$$ /$$ \x1b[32m/$$\x1b[0m /$$ -\x1b[31m| WW /W | WW\x1b[0m |__/ | $$ | $$ \x1b[32m| $$\x1b[0m | $$ -\x1b[31m| WW /WWW| WW\x1b[0m /$$$$$$ /$$ /$$$$$$ | $$$$$$$ /$$$$$$ /$$$$$$$\x1b[32m| $$\x1b[0m /$$$$$$ | $$$$$$$ -\x1b[31m| WW/WW WW WW\x1b[0m /$$__ $$| $$ /$$__ $$| $$__ $$|_ $$_/ /$$_____/\x1b[32m| $$\x1b[0m |____ $$| $$__ $$ -\x1b[31m| WWWW_ WWWW\x1b[0m| $$$$$$$$| $$| $$ \ $$| $$ \ $$ | $$ | $$$$$$ \x1b[32m| $$\x1b[0m /$$$$$$$| $$ \ $$ -\x1b[31m| WWW/ \ WWW\x1b[0m| $$_____/| $$| $$ | $$| $$ | $$ | $$ /$$ \____ $$\x1b[32m| $$\x1b[0m /$$__ $$| $$ | $$ -\x1b[31m| WW/ \ WW\x1b[0m| $$$$$$$| $$| $$$$$$$| $$ | $$ | $$$$/ /$$$$$$$/\x1b[32m| $$$$$$$$\x1b[0m $$$$$$$| $$$$$$$/ -\x1b[31m|__/ \__/\x1b[0m \_______/|__/ \____ $$|__/ |__/ \___/ |_______/ \x1b[32m|________/\x1b[0m \_______/|_______/ - /$$ \ $$ - | $$$$$$/ +\x1b[31m /WW /WW\x1b[0m /$$ /$$ /$$ \x1b[32m/$$\x1b[0m /$$ +\x1b[31m| WW /W | WW\x1b[0m |__/ | $$ | $$ \x1b[32m| $$\x1b[0m | $$ +\x1b[31m| WW /WWW| WW\x1b[0m /$$$$$$ /$$ /$$$$$$ | $$$$$$$ /$$$$$$ /$$$$$$$\x1b[32m| $$\x1b[0m /$$$$$$ | $$$$$$$ +\x1b[31m| WW/WW WW WW\x1b[0m /$$__ $$| $$ /$$__ $$| $$__ $$|_ $$_/ /$$_____/\x1b[32m| $$\x1b[0m |____ $$| $$__ $$ +\x1b[31m| WWWW_ WWWW\x1b[0m| $$$$$$$$| $$| $$ \ $$| $$ \ $$ | $$ | $$$$$$ \x1b[32m| $$\x1b[0m /$$$$$$$| $$ \ $$ +\x1b[31m| WWW/ \ WWW\x1b[0m| $$_____/| $$| $$ | $$| $$ | $$ | $$ /$$ \____ $$\x1b[32m| $$\x1b[0m /$$__ $$| $$ | $$ +\x1b[31m| WW/ \ WW\x1b[0m| $$$$$$$| $$| $$$$$$$| $$ | $$ | $$$$/ /$$$$$$$/\x1b[32m| $$$$$$$$\x1b[0m $$$$$$$| $$$$$$$/ +\x1b[31m|__/ \__/\x1b[0m \_______/|__/ \____ $$|__/ |__/ \___/ |_______/ \x1b[32m|________/\x1b[0m \_______/|_______/ + /$$ \ $$ + | $$$$$$/ \______/ By GrayBx """ if branch is not None and version is not None and commit_hash is not None: _BANNER += f"\nBranch: {branch} | Version: {version} | Commit: {commit_hash}\n" -_BANNER__ = _BANNER # Expose banner with a different name for external use and legacy +_BANNER__ = _BANNER # Expose banner with a different name for external use and legacy diff --git a/weightslab/backend/audit_logger.py b/weightslab/backend/audit_logger.py index 9930e927..fdcde841 100644 --- a/weightslab/backend/audit_logger.py +++ b/weightslab/backend/audit_logger.py @@ -14,9 +14,9 @@ @dataclass class AuditEvent: """Immutable audit event structure.""" - timestamp: str # ISO format string + timestamp: str # ISO format string action_type: str - status: str # "success" or "failed" + status: str # "success" or "failed" details: Optional[Dict[str, Any]] = None error: Optional[str] = None diff --git a/weightslab/backend/cli.py b/weightslab/backend/cli.py index acaf911e..1734a15c 100644 --- a/weightslab/backend/cli.py +++ b/weightslab/backend/cli.py @@ -172,7 +172,7 @@ def _handle_command(cmd: str) -> Any: 'hyperparams_examples': { 'list': 'hp', 'show': 'hp fashion_mnist', - 'set': "set_hp # e.g. set_hp fashion_mnist data.train_loader.batch_size 32", + 'set': "set_hp # e.g. set_hp fashion_mnist data.train_loader.batch_size 32", }, 'evaluate_examples': { 'eval on default split': 'evaluate', @@ -451,7 +451,7 @@ def _handle_command(cmd: str) -> Any: for name in snap.get(k, []): try: getter = { - # 'models': GLOBAL_LEDGER.get_model, # don't print the model out + # 'models': GLOBAL_LEDGER.get_model, # don't print the model out 'dataloaders': GLOBAL_LEDGER.get_dataloader, 'optimizers': GLOBAL_LEDGER.get_optimizer, }[k] @@ -823,7 +823,7 @@ def _handle_command(cmd: str) -> Any: names = GLOBAL_LEDGER.list_hyperparams() if hasattr(GLOBAL_LEDGER, 'list_hyperparams') else [] if len(parts) == 1: return {'ok': True, 'hyperparams': names} - # support: hp list -> same as hp + # support: hp list -> same as hp name = parts[1] if name.lower() in ('list', 'ls', 'all'): return {'ok': True, 'hyperparams': names} @@ -998,7 +998,7 @@ def _handle_command(cmd: str) -> Any: elif toggle in ('off', 'false', '0', 'disable', 'disabled'): value = False else: - return {'ok': False, 'error': f'Unknown audit toggle "{toggle}". Use: audit on or audit off'} + return {'ok': False, 'error': f'Unknown audit toggle "{toggle}". Use: audit on or audit off'} set_hyperparam(name=name, value=value, key_path='auditor_mode') label = 'enabled' if value else 'disabled' @@ -1123,7 +1123,7 @@ def cli_serve(cli_host: str = 'localhost', cli_port: int = 0, *, spawn_client: b pass srv = None if attempt < max_attempts - 1: - continue # Try next port + continue # Try next port else: # All attempts failed logger.exception("cli_bind_failed_all_attempts") diff --git a/weightslab/backend/dataloader_interface.py b/weightslab/backend/dataloader_interface.py index 2d066666..d9b16e94 100644 --- a/weightslab/backend/dataloader_interface.py +++ b/weightslab/backend/dataloader_interface.py @@ -101,7 +101,7 @@ class WeightsLabDataSampler(Sampler): loader = DataLoader(dataset, batch_sampler=sampler) # Toggle shuffle at runtime - sampler.shuffle = False # Switch to sequential + sampler.shuffle = False # Switch to sequential """ def __init__( @@ -132,7 +132,7 @@ def __init__( self._deny_listed_uids_cache: set[str] = set() self._deny_list_revision: Optional[tuple[str, int]] = None # Evaluation-mode allow-list: when set, only samples whose uid is in - # this set are yielded. None = no filter (normal behaviour). + # this set are yielded. None = no filter (normal behaviour). self._eval_allow_list: Optional[set] = None def _get_deny_listed_uids(self, origin: str = None) -> set: @@ -143,7 +143,7 @@ def _get_deny_listed_uids(self, origin: str = None) -> set: if origin is not None: df_view = self.tracked_dataset._get_df_view(column='origin', value=origin) else: - df_view = self.tracked_dataset._get_df_view() # get all by default + df_view = self.tracked_dataset._get_df_view() # get all by default if not df_view.empty and SampleStatsEx.DISCARDED.value in df_view.columns: discarded_rows = df_view[df_view[SampleStatsEx.DISCARDED.value] == True] @@ -360,7 +360,7 @@ class DataLoaderInterface: from where manual next() left off, and after epoch exhaustion, both patterns automatically reset on the next iteration. - ✅ CORRECT usage - for-loops continue from manual next() position: + CORRECT usage - for-loops continue from manual next() position: loader = DataLoaderInterface(dataset, batch_size=32) @@ -779,12 +779,12 @@ def __iter__(self) -> Iterator: reset if we're already mid-epoch, which allows for-loops to continue from where manual next() calls left off. - ✅ CORRECT usage: + CORRECT usage: while step < max_steps: - data = next(loader) # Manual next (mid-epoch) + data = next(loader) # Manual next (mid-epoch) if step % 5 == 0: - for batch in loader: # For-loop continues from batch position - process(batch) # Gets remaining batches until epoch end + for batch in loader: # For-loop continues from batch position + process(batch) # Gets remaining batches until epoch end step += 1 How it works: @@ -828,7 +828,7 @@ def __next__(self) -> Any: try: data = next(loader) except StopIteration: - data = next(loader) # Auto-resets, returns first batch of next epoch + data = next(loader) # Auto-resets, returns first batch of next epoch 2. For-loop with proper termination: for batch in loader: @@ -837,9 +837,9 @@ def __next__(self) -> Any: 3. Mixed usage: while training: - data = next(loader) # Auto-resets as needed + data = next(loader) # Auto-resets as needed if should_eval(): - for batch in loader: # Gets remaining epoch, exits on StopIteration + for batch in loader: # Gets remaining epoch, exits on StopIteration eval(batch) """ self._sync_batch_size_from_ledger() @@ -921,7 +921,7 @@ def _next_batch(self) -> Any: batch = next(self._iterator) except StopIteration: if hasattr(self, 'is_a_loop') and self.is_a_loop: - raise # Re-raise so __next__() can handle epoch exhaustion + raise # Re-raise so __next__() can handle epoch exhaustion else: self._reset_iterator() batch = next(self._iterator) @@ -931,30 +931,30 @@ def _next_batch(self) -> Any: return batch # def _execute_offset(self) -> None: - # """ - # Execute sample offset if set, skipping samples as needed. - # This is a fallback mechanism for user-supplied dataloaders where - # we cannot use an OffsetSampler. - - # TODO (GP): - # We can reproduce the random generation of samples by restoring RNG state, if during the previous checkpoints, batchsize changed dynamically and shuffle is True. - # """ - # if self._sample_offset > 0: - # current_bs = self.get_batch_size() - # # Fast-forward the iterator by the offset amount - # while len(self._skipped) < self._sample_offset: - # try: - # bs = 4 if self._sample_offset - len(self._skipped) >= 4 else self._sample_offset - len(self._skipped) # Autoscale bs to sample offset - # self.set_batch_size(bs) - # self._skipped.extend(next(self._iterator)[1].detach().cpu().tolist()) - # logger.debug(f"Offset sampler: skipped {len(self._skipped)}/{self._sample_offset}") - # except StopIteration as e: - # logger.debug(f"Offset sampler: reached end of iterator while skipping: {e}") - # self._reset_iterator() # Reset iterator and try again - - # self.set_batch_size(current_bs) - # self._skipped = [] - # self._sample_offset = 0 + # """ + # Execute sample offset if set, skipping samples as needed. + # This is a fallback mechanism for user-supplied dataloaders where + # we cannot use an OffsetSampler. + + # TODO (GP): + # We can reproduce the random generation of samples by restoring RNG state, if during the previous checkpoints, batchsize changed dynamically and shuffle is True. + # """ + # if self._sample_offset > 0: + # current_bs = self.get_batch_size() + # # Fast-forward the iterator by the offset amount + # while len(self._skipped) < self._sample_offset: + # try: + # bs = 4 if self._sample_offset - len(self._skipped) >= 4 else self._sample_offset - len(self._skipped) # Autoscale bs to sample offset + # self.set_batch_size(bs) + # self._skipped.extend(next(self._iterator)[1].detach().cpu().tolist()) + # logger.debug(f"Offset sampler: skipped {len(self._skipped)}/{self._sample_offset}") + # except StopIteration as e: + # logger.debug(f"Offset sampler: reached end of iterator while skipping: {e}") + # self._reset_iterator() # Reset iterator and try again + + # self.set_batch_size(current_bs) + # self._skipped = [] + # self._sample_offset = 0 def _should_persist_workers(self, num_workers: int) -> bool: """Whether to keep DataLoader workers alive across iterator resets. @@ -1025,7 +1025,7 @@ def _reset_iterator(self) -> None: # Only relevant when workers are actually respawning; persistent workers # stay alive, so no settle-delay is needed. if respawning: - time.sleep(0.01) # 10ms delay for worker cleanup + time.sleep(0.01) # 10ms delay for worker cleanup # Create new iterator self._iterator = iter(self.dataloader) @@ -1040,8 +1040,8 @@ def reset_iterator(self) -> None: rng_state = capture_rng_state() batch1 = next(dataloader_interface) restore_rng_state(rng_state) - dataloader_interface.reset_iterator() # Create new iterator with restored RNG - batch1_repeat = next(dataloader_interface) # Same batches! + dataloader_interface.reset_iterator() # Create new iterator with restored RNG + batch1_repeat = next(dataloader_interface) # Same batches! """ self._reset_iterator() diff --git a/weightslab/backend/explore_mode.py b/weightslab/backend/explore_mode.py new file mode 100644 index 00000000..4effe651 --- /dev/null +++ b/weightslab/backend/explore_mode.py @@ -0,0 +1,43 @@ +"""Process-wide read-only "explore" mode. + +When the backend is launched to browse a finished experiment loaded from disk +(``weightslab --logdir ``), it runs in *explore mode*: there is no +training loop, and the experiment is reconstructed from the checkpoints/logs on +disk so a user can inspect it in the UI while training continues elsewhere +(e.g. on a cluster). + +In this mode the backend refuses the actions that would mutate the model or the +training run — starting/resuming training, changing hyperparameters, and +loading/restoring/saving weights or checkpoints. Local **data management** +(tagging, discarding, queries, plot notes) and all **reads** stay available, +since the whole point is to manage and explore the data locally. + +This is a simple process-wide flag: a given backend process is either a live +training server or a read-only explorer for its whole lifetime. +""" + +import logging + +logger = logging.getLogger(__name__) + +_EXPLORE_MODE = False + +# Returned by guarded RPC handlers when a forbidden (mutating) action is attempted. +EXPLORE_BLOCKED_MESSAGE = ( + "This experiment is open in read-only explore mode (loaded from --logdir). " + "Training, hyperparameter changes, and weight/checkpoint loading are disabled." +) + + +def set_explore_mode(enabled: bool) -> None: + """Enable/disable the process-wide read-only explore mode.""" + global _EXPLORE_MODE + _EXPLORE_MODE = bool(enabled) + logger.info( + "Explore (read-only) mode %s", "ENABLED" if _EXPLORE_MODE else "disabled" + ) + + +def is_explore_mode() -> bool: + """True when the backend is serving a read-only experiment from disk.""" + return _EXPLORE_MODE diff --git a/weightslab/backend/ledgers.py b/weightslab/backend/ledgers.py index eb188dde..45a64f42 100644 --- a/weightslab/backend/ledgers.py +++ b/weightslab/backend/ledgers.py @@ -231,15 +231,28 @@ def __contains__(self, item: Any) -> bool: except TypeError: return False + def __getitem__(self, key: Any) -> Any: + """Support subscript access: ``proxy[key]``. + + Equivalent to ``.get(key)`` — reaches into the resolved value with + ``[]`` and wraps nested dicts in a fresh _ValueProxy so chaining + (e.g. ``proxy['dataset']['batch_size']``) keeps resolving live. + Missing keys raise ``KeyError``, matching standard subscript access. + """ + key = self._unwrap(key) + v = self._resolve() + if v is None: + raise KeyError(key) + value = v[key] + if isinstance(value, dict): + return Proxy._ValueProxy(Proxy(v), key) + return value + def __int__(self) -> int: return int(self._resolve()) def __float__(self) -> float: - v = self._resolve() - try: - return float(v) - except (TypeError, ValueError): - return None + return float(self._resolve()) def __index__(self) -> int: v = self._resolve() @@ -435,14 +448,19 @@ def get(self, ref=None, default=None, proxy: bool = True) -> Any: except for ``str`` values, which are returned raw (see below). """ if ref is not None: - if proxy: + value = self._obj.get(ref, default) if hasattr(self._obj, "get") else default + # Avoid wrapping a python list, dict, or any callable in a live proxy — + # only "simple" values become live key proxies. (The previous form put + # `not callable(...)` INSIDE the isinstance group, so callables — which + # are not lists/dicts — slipped through and got proxied.) + if proxy and not (isinstance(value, (list, dict)) or callable(value)): vp = Proxy._ValueProxy(self, ref, default) # str and torch.device are handed back as plain values (see - # _plain_get_value); other types (dict/list/int/float/...) stay + # _plain_get_value); other simple types (int/float/bool/...) stay # live proxies so studio edits keep tracking. plain = _plain_get_value(vp._resolve()) return vp if plain is _KEEP_AS_PROXY else plain - return self._obj.get(ref, default) + return value return self._obj if self._obj is not None else default def __getattr__(self, item): @@ -528,7 +546,7 @@ def __next__(self): return next(self._it) except StopIteration: # Let StopIteration propagate naturally - self._it.is_a_loop = False # Loop ends here + self._it.is_a_loop = False # Loop ends here raise except KeyError: # Quiet by default; only surface this diagnostic when the user @@ -711,6 +729,68 @@ def __exit__(self, exc_type, exc, tb): MutableMapping.register(Proxy) +def _register_yaml_representers() -> None: + """Teach PyYAML to dump ledger proxies as their underlying value. + + A hyperparameter handle returned by ``watch_or_edit(..., flag='hyperparameters')`` + is a live ``Proxy`` / ``Proxy._ValueProxy``. When such a value flows into a + library that serializes it (e.g. Ultralytics writes its run args to ``args.yaml`` + via ``yaml.safe_dump``), PyYAML has no representer for the proxy type and raises + ``RepresenterError`` — and because the proxy masquerades as its wrapped type via + ``__class__``, callers' ``isinstance(x, (int, str, ...))`` "stringify" guards skip + it too. Registering representers that emit the *resolved* value makes proxies + transparently serializable everywhere, so the live-proxy HP design stays + compatible with such libraries. Registered on both the default and Safe dumpers. + """ + def _represent_obj_proxy(dumper, data): + return dumper.represent_data(data.get()) # Proxy -> wrapped object + + def _represent_value_proxy(dumper, data): + return dumper.represent_data(data._resolve()) # _ValueProxy -> resolved value + + # Register on every dumper variant: the pure-Python Dumper/SafeDumper, the + # libyaml C dumpers (CDumper/CSafeDumper — Ultralytics dumps with CSafeDumper), + # and the base Representer/SafeRepresenter. Each keeps its own representer + # table, so registering on one does not cover the others. + for _name in ("Dumper", "SafeDumper", "CDumper", "CSafeDumper", "Representer", "SafeRepresenter"): + _dumper = getattr(yaml, _name, None) + if _dumper is not None: + _dumper.add_representer(Proxy, _represent_obj_proxy) + _dumper.add_representer(Proxy._ValueProxy, _represent_value_proxy) + + +def _register_json_default() -> None: + """Teach the stdlib ``json`` encoder to serialize ledger proxies as their + underlying value, mirroring :func:`_register_yaml_representers`. + + ``json`` has no global representer registry, so we wrap ``JSONEncoder.default`` + (the hook called for objects the encoder doesn't natively handle). The C + encoder dispatches by concrete type, so a proxy — whose real type is not int/ + str/dict/... despite its ``__class__`` masquerade — reaches ``default`` and is + replaced by its resolved value. Makes ``json.dumps``/``json.dump`` of HP + proxies work everywhere (e.g. audit/JSON config dumps) without per-call hooks. + """ + import json + + if getattr(json.JSONEncoder, "_wl_proxy_patched", False): + return + _orig_default = json.JSONEncoder.default + + def default(self, obj): + if isinstance(obj, Proxy._ValueProxy): + return obj._resolve() + if isinstance(obj, Proxy): + return obj.get() + return _orig_default(self, obj) + + json.JSONEncoder.default = default + json.JSONEncoder._wl_proxy_patched = True + + +_register_yaml_representers() +_register_json_default() + + class Ledger: """Thread-safe ledger storing named registries for different object types. @@ -1177,7 +1257,7 @@ def list_models() -> List[str]: def register_model(model: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_model(name) # Init empty proxy + get_model(name) # Init empty proxy GLOBAL_LEDGER.register_model(model, weak=weak, name=name) @@ -1195,7 +1275,7 @@ def get_dataloaders(names: Optional[List[str]] = None) -> Dict[str, Any]: def register_dataloaders(dataloaders: Dict[str, Any], weak: bool = False) -> None: """Register multiple dataloaders from a dict, e.g., {'train': train_loader, 'val': val_loader}.""" for k in dataloaders.keys(): - get_dataloader(k) # Init empty proxy - get_dataloaders(list(dataloaders.keys())) + get_dataloader(k) # Init empty proxy - get_dataloaders(list(dataloaders.keys())) GLOBAL_LEDGER.register_dataloaders_dict(dataloaders, weak=weak) def list_dataloaders() -> List[str]: @@ -1203,7 +1283,7 @@ def list_dataloaders() -> List[str]: def register_dataloader(dataloader: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_dataloader(name) # Init the empty proxy first + get_dataloader(name) # Init the empty proxy first GLOBAL_LEDGER.register_dataloader(dataloader, weak=weak, name=name) @@ -1219,7 +1299,7 @@ def list_optimizers() -> List[str]: def register_optimizer(optimizer: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_optimizer(name) # Init the empty proxy first + get_optimizer(name) # Init the empty proxy first GLOBAL_LEDGER.register_optimizer(optimizer, weak=weak, name=name) # Hyperparameters @@ -1242,7 +1322,7 @@ def resolve_hp_name() -> str | None: return 'experiment' # If we have any names at all, returning the first one is better than returning None # and causing a "Cannot resolve hyperparams name" error in the UI. - return names[-1] # first is empty proxy parameters generated at init + return names[-1] # first is empty proxy parameters generated at init def set_hyperparam(key_path: str, value: Any, name: str = DEFAULT_NAME) -> None: try: @@ -1258,14 +1338,14 @@ def unwatch_hyperparams_file(name: str = DEFAULT_NAME) -> None: def register_hyperparams(params: Dict[str, Any] = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_hyperparams(name) # Init empty proxy + get_hyperparams(name) # Init empty proxy GLOBAL_LEDGER.register_hyperparams(params, weak=weak, name=name) # Logger def register_logger(logger: Any = None, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_logger(name) # Init empty proxy + get_logger(name) # Init empty proxy GLOBAL_LEDGER.register_logger(logger, name=name) def get_logger(name: str = DEFAULT_NAME) -> Any: @@ -1281,7 +1361,7 @@ def unregister_logger(name: str = DEFAULT_NAME) -> None: # Signals def register_signal(signal: Any = None, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_signal(name) # Init empty proxy + get_signal(name) # Init empty proxy GLOBAL_LEDGER.register_signal(signal, name=name) def get_signal(name: str = DEFAULT_NAME) -> Any: @@ -1297,7 +1377,7 @@ def unregister_signal(name: str = DEFAULT_NAME) -> None: # Checkpoint managers def register_checkpoint_manager(manager: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> Any: name = DEFAULT_NAME if name is None else name - get_checkpoint_manager(name) # Init empty proxy + get_checkpoint_manager(name) # Init empty proxy return GLOBAL_LEDGER.register_checkpoint_manager(manager, weak=weak, name=name) def get_checkpoint_manager(name: str = DEFAULT_NAME) -> Any: @@ -1313,7 +1393,7 @@ def unregister_checkpoint_manager(name: str = DEFAULT_NAME) -> None: # DataFrames def register_dataframe(dataframe: Any = None, weak: bool = False, name: str = DEFAULT_NAME) -> None: name = DEFAULT_NAME if name is None else name - get_dataframe(name) # Init empty proxy + get_dataframe(name) # Init empty proxy return GLOBAL_LEDGER.register_dataframe(dataframe, weak=weak, name=name) def get_dataframe(name: str = DEFAULT_NAME) -> Any: diff --git a/weightslab/backend/logger.py b/weightslab/backend/logger.py index c4a780c9..dc077e6b 100644 --- a/weightslab/backend/logger.py +++ b/weightslab/backend/logger.py @@ -1,69 +1,88 @@ -import torch as th +"""DuckDB-backed signal history logger. + +``LoggerQueue`` is a thin interface that maps the logger's public methods onto +a DuckDB database holding three history tables: + +* ``signals`` — aggregated training-curve points (one row per averaged + step entry / evaluation marker). +* ``per_sample`` — per-sample signal values ``(sample_id, step, value)``. +* ``per_instance`` — per-instance values ``(sample_id, annotation_id, step, value)`` + for detection / segmentation. + +Design notes +------------ +* **Hot path is RAM, reads hit DuckDB.** ``add_scalars`` / + ``add_instance_scalars`` only append to in-memory staging lists (O(1), no SQL). + Rows are bulk-inserted into DuckDB lazily — right before any query, snapshot, + delete or update — via a single vectorized ``INSERT ... SELECT``. This keeps + per-step logging cheap while letting DuckDB do the heavy aggregation + (``GROUP BY step`` over millions of rows) in native code — exactly what + break-by-slices needs. +* **Transient runtime state stays in Python.** The live-streaming pending queue, + the per-step aggregation buffer and the evaluation accumulator are small and + short-lived, so they remain plain Python structures. +* **Persistence.** ``db_path`` defaults to ``":memory:"``. Pass a file path to + back the history with an on-disk DuckDB file. Either way ``save_snapshot`` / + ``load_snapshot`` round-trip the full history as a plain dict, so the + checkpoint manager's snapshotting is unchanged. +* **Thread-safety.** A single DuckDB connection is guarded by an ``RLock``; + staging appends and flushes take the same lock. +""" + +import json +import threading import time -from array import array as _array -from copy import deepcopy - -from weightslab.backend.ledgers import get_logger, register_logger, get_checkpoint_manager - - -def _make_per_sample_buf(): - """Compact storage for per-sample signals: three typed C arrays. - Uses array.array instead of a list of dicts to reduce memory by ~20-40x: - - list of dicts: ~400-600 bytes/entry (Python dict overhead + 6 string keys) - - compact arrays: 12 bytes/entry (int32 + int32 + float32) +import duckdb +import pandas as pd +import torch as th - Fields: - sample_ids: list of str - dataset sample index - steps: signed int32 - global training step - values: float32 - signal value at that step for that sample - """ - return { - "sample_ids": [], # str - "steps": _array('i'), # int32, 4 bytes each - "values": _array('f'), # float32, 4 bytes each - } +from weightslab.backend.ledgers import get_logger, register_logger, get_checkpoint_manager -def _make_per_instance_buf(): - """Compact storage for per-instance signals: four typed C arrays. +# Column order for each table's staging buffer / bulk insert. +_SIGNAL_COLS = [ + "metric_name", "experiment_hash", "step", "metric_value", "timestamp", + "audit_mode", "is_evaluation_marker", "split_name", "evaluation_tags", + "point_note", "seq", +] +_SAMPLE_COLS = ["metric_name", "experiment_hash", "sample_id", "step", "value", "seq"] +_INSTANCE_COLS = [ + "metric_name", "experiment_hash", "sample_id", "annotation_id", "step", "value", "seq", +] - Fields: - sample_ids: list of str - dataset sample index - annotation_ids: signed int32 - instance index within sample (1-based) - steps: signed int32 - global training step - values: float32 - signal value at that step for that instance - """ - return { - "sample_ids": [], # str - "annotation_ids": _array('i'), # int32, 4 bytes each - "steps": _array('i'), # int32, 4 bytes each - "values": _array('f'), # float32, 4 bytes each - } +# Auto-flush staged rows to DuckDB once the combined staging buffers exceed this +# many rows, to bound memory during long runs that never read history. +_STAGE_FLUSH_THRESHOLD = 50_000 class LoggerQueue: - def __init__(self, register: bool = True) -> None: + def __init__(self, register: bool = True, db_path: str = ":memory:") -> None: self.graph_names = set() self._current_step_buffer = {} self._last_step = None - self._signal_history = {} # Keep all signals in memory for persistence - self._signal_history_per_sample = {} # Keep all signals per sample in memory for persistence - self._signal_history_per_instance = {} # Keep all signals per instance in memory for persistence - # Reverse indices: O(1) lookup by sample_id or (sample_id, annotation_id) - # Structure: {graph_name: {exp_hash: {sample_id: [row_indices]}}} - self._sample_index = {} - # Structure: {graph_name: {exp_hash: {(sample_id, annotation_id): [row_indices]}}} - self._instance_index = {} - self._pending_queue = [] # Queue for new signals waiting to be sent to WeightsStudio + + # Live-streaming queue of new points waiting to be sent to WeightsStudio. + self._pending_queue = [] self._buffered_step = None - # Evaluation mode state + # Evaluation mode state (transient). self._eval_mode_active: bool = False self._eval_mode_hash: str = "" self._eval_mode_split: str = "" self._eval_mode_tags: list[str] = [] - self._eval_accum: dict = {} # {graph_name: [sum, count]} + self._eval_accum: dict = {} # {graph_name: [sum, count]} + + # DuckDB connection + write-staging buffers. + self._lock = threading.RLock() + self._db_path = db_path + self._conn = duckdb.connect(database=db_path) + self._stage_signals: list = [] + self._stage_sample: list = [] + self._stage_instance: list = [] + self._seq = 0 + self._ensure_tables() + self._restore_runtime_state_from_db() lg = None if register: @@ -76,27 +95,155 @@ def __init__(self, register: bool = True) -> None: # Init checkpoint manager for experiment hash retrieval (if available) self.chkpt_manager = get_checkpoint_manager() + # ------------------------------------------------------------------ + # DuckDB plumbing + # ------------------------------------------------------------------ + def _ensure_tables(self) -> None: + with self._lock: + self._conn.execute( + """ + CREATE TABLE IF NOT EXISTS signals ( + metric_name VARCHAR, + experiment_hash VARCHAR, + step INTEGER, + metric_value DOUBLE, + timestamp BIGINT, + audit_mode BOOLEAN, + is_evaluation_marker BOOLEAN, + split_name VARCHAR, + evaluation_tags VARCHAR, + point_note VARCHAR, + seq BIGINT + ); + CREATE TABLE IF NOT EXISTS per_sample ( + metric_name VARCHAR, + experiment_hash VARCHAR, + sample_id VARCHAR, + step INTEGER, + value REAL, + seq BIGINT + ); + CREATE TABLE IF NOT EXISTS per_instance ( + metric_name VARCHAR, + experiment_hash VARCHAR, + sample_id VARCHAR, + annotation_id INTEGER, + step INTEGER, + value REAL, + seq BIGINT + ); + """ + ) + + def _restore_runtime_state_from_db(self) -> None: + """Repopulate seq counter and graph names from an existing (file) DB.""" + with self._lock: + max_seq = self._conn.execute( + """ + SELECT max(m) FROM ( + SELECT max(seq) AS m FROM signals + UNION ALL SELECT max(seq) FROM per_sample + UNION ALL SELECT max(seq) FROM per_instance + ) + """ + ).fetchone()[0] + self._seq = (int(max_seq) + 1) if max_seq is not None else 0 + + for tbl in ("signals", "per_sample", "per_instance"): + for (name,) in self._conn.execute( + f"SELECT DISTINCT metric_name FROM {tbl}" + ).fetchall(): + if name is not None: + self.graph_names.add(name) + + def _next_seq(self) -> int: + s = self._seq + self._seq += 1 + return s + + def _maybe_autoflush(self) -> None: + if (len(self._stage_signals) + len(self._stage_sample) + + len(self._stage_instance)) >= _STAGE_FLUSH_THRESHOLD: + self._flush_stage() + + def _flush_stage(self) -> None: + """Bulk-insert all staged rows into DuckDB and clear the buffers.""" + with self._lock: + if self._stage_signals: + df = pd.DataFrame(self._stage_signals, columns=_SIGNAL_COLS) + self._conn.register("_stg_sig", df) + self._conn.execute("INSERT INTO signals SELECT * FROM _stg_sig") + self._conn.unregister("_stg_sig") + self._stage_signals = [] + if self._stage_sample: + df = pd.DataFrame(self._stage_sample, columns=_SAMPLE_COLS) + self._conn.register("_stg_ps", df) + self._conn.execute("INSERT INTO per_sample SELECT * FROM _stg_ps") + self._conn.unregister("_stg_ps") + self._stage_sample = [] + if self._stage_instance: + df = pd.DataFrame(self._stage_instance, columns=_INSTANCE_COLS) + self._conn.register("_stg_pi", df) + self._conn.execute("INSERT INTO per_instance SELECT * FROM _stg_pi") + self._conn.unregister("_stg_pi") + self._stage_instance = [] + + def _stage_signal_row(self, graph_name, exp_hash, step, metric_value, timestamp, + audit_mode, is_marker, split_name, eval_tags, point_note): + self._stage_signals.append(( + graph_name, exp_hash, int(step), float(metric_value), int(timestamp), + bool(audit_mode), bool(is_marker), split_name or "", + json.dumps(list(eval_tags or [])), point_note or "", self._next_seq(), + )) + self._maybe_autoflush() + + def _stage_sample_row(self, graph_name, exp_hash, sample_id, step, value): + self._stage_sample.append(( + graph_name, exp_hash, str(sample_id), int(step), float(value), self._next_seq(), + )) + self._maybe_autoflush() + + def _stage_instance_row(self, graph_name, exp_hash, sample_id, annotation_id, step, value): + self._stage_instance.append(( + graph_name, exp_hash, str(sample_id), int(annotation_id), int(step), + float(value), self._next_seq(), + )) + self._maybe_autoflush() + + @staticmethod + def _hash_filter(exp_hash, params, table_alias=""): + """Append an experiment-hash WHERE fragment. ``None`` means 'all hashes'.""" + if exp_hash is None: + return "" + params.append(exp_hash) + col = f"{table_alias}experiment_hash" if table_alias else "experiment_hash" + return f" AND {col} = ?" + def __len__(self): - """Return logger length.""" - len_history = 0 - for k in self._signal_history: - for exp_hash in self._signal_history[k]: - l = len(self._signal_history[k][exp_hash]) - len_history = max(len_history, l) - return len_history - - # Clear history method (can be called by WeightsLabCallback at the start of a new experiment to reset state, - # while preserving graph names which are derived from signals and may be needed for future signals after clearing history) + """Max number of distinct steps recorded for any (metric, hash) curve.""" + with self._lock: + self._flush_stage() + row = self._conn.execute( + """ + SELECT max(cnt) FROM ( + SELECT count(DISTINCT step) AS cnt + FROM signals GROUP BY metric_name, experiment_hash + ) + """ + ).fetchone() + return int(row[0]) if row and row[0] is not None else 0 + def clear_signal_histories(self): - """Clear signal histories.""" - # Note: We do not clear graph names here as they are derived from signals and may be needed for future signals after clearing history. - self._signal_history.clear() - self._signal_history_per_sample.clear() - self._signal_history_per_instance.clear() - self._sample_index.clear() - self._instance_index.clear() - self._current_step_buffer.clear() - self._buffered_step = None + """Clear all signal histories (keeps graph names and runtime buffers reset).""" + with self._lock: + self._stage_signals = [] + self._stage_sample = [] + self._stage_instance = [] + self._conn.execute("DELETE FROM signals") + self._conn.execute("DELETE FROM per_sample") + self._conn.execute("DELETE FROM per_instance") + self._current_step_buffer.clear() + self._buffered_step = None def _to_float(self, value): if isinstance(value, th.Tensor): @@ -111,7 +258,6 @@ def _get_audit_mode(self): 2. Check hyperparams auditor_mode (fallback for legacy/hyperparams-based control) """ try: - # First priority: check registered model interface from weightslab.backend.ledgers import get_model model = get_model() if model is not None and hasattr(model, 'audit_mode'): @@ -120,7 +266,6 @@ def _get_audit_mode(self): pass try: - # Fallback: check hyperparams auditor_mode from weightslab.backend.ledgers import get_hyperparams hp = get_hyperparams() if hp is not None: @@ -129,27 +274,33 @@ def _get_audit_mode(self): pass return False - def _append_history_entry(self, graph_name, exp_hash, global_step, metric_value, audit_mode=None): + def _append_history_entry(self, graph_name, exp_hash, global_step, metric_value, + audit_mode=None, is_marker=False, split_name="", + evaluation_tags=None): + """Stage a signals row and return the live-queue entry dict.""" if audit_mode is None: audit_mode = self._get_audit_mode() + timestamp = int(time.time()) signal_entry = { "model_age": global_step, "metric_name": graph_name, "metric_value": metric_value, "experiment_hash": exp_hash, - "timestamp": int(time.time()), + "timestamp": timestamp, "audit_mode": audit_mode, } - - if graph_name not in self._signal_history: - self._signal_history[graph_name] = {} - if exp_hash not in self._signal_history[graph_name]: - self._signal_history[graph_name][exp_hash] = {} - if global_step not in self._signal_history[graph_name][exp_hash]: - self._signal_history[graph_name][exp_hash][global_step] = [] - - self._signal_history[graph_name][exp_hash][global_step].append(signal_entry) + if is_marker: + signal_entry["is_evaluation_marker"] = True + signal_entry["split_name"] = split_name + signal_entry["evaluation_tags"] = list(evaluation_tags or []) + + with self._lock: + self._stage_signal_row( + graph_name, exp_hash, global_step, metric_value, timestamp, + bool(audit_mode), bool(is_marker), split_name, + list(evaluation_tags or []), "", + ) return signal_entry def _flush_current_step_buffer(self, add_to_queue: bool): @@ -179,36 +330,38 @@ def _flush_current_step_buffer(self, add_to_queue: bool): def get_next_evaluation_count(self, base_hash: str) -> int: """Return the next unused evaluation index for *base_hash*. - Scans the current signal history for keys of the form + Scans recorded experiment hashes for keys of the form ``_`` and returns max(found) + 1 (or 1 if none). """ prefix = base_hash + "_" max_count = 0 - for gname in self._signal_history: - for hash_key in self._signal_history[gname]: - if isinstance(hash_key, str) and hash_key.startswith(prefix): - suffix = hash_key[len(prefix):] - try: - count = int(suffix) - if count > max_count: - max_count = count - except ValueError: - pass + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT DISTINCT experiment_hash FROM signals " + "WHERE experiment_hash LIKE ?", + [prefix + "%"], + ).fetchall() + for (hash_key,) in rows: + if isinstance(hash_key, str) and hash_key.startswith(prefix): + suffix = hash_key[len(prefix):] + try: + count = int(suffix) + if count > max_count: + max_count = count + except ValueError: + pass return max_count + 1 def start_evaluation_mode(self, split_name: str, eval_hash: str, evaluation_tags=None) -> None: """Redirect subsequent add_scalars() calls into the evaluation buffer. While evaluation mode is active, signals are NOT added to the normal - curve history. Instead they accumulate in an internal buffer. + curve history. Instead they accumulate in an internal buffer. ``stop_evaluation_mode()`` finalises the buffer into a single marker. Per-sample history *is* still updated (for Break-By-Slice on eval results), using *eval_hash* as the experiment key. - - Args: - split_name: Human-readable split name (e.g. ``"train_loader"``). - eval_hash: Modified experiment hash (e.g. ``"abc123_1"``). """ self._flush_current_step_buffer(add_to_queue=True) self._eval_mode_active = True @@ -225,9 +378,6 @@ def stop_evaluation_mode(self, model_age: int) -> dict: history under *eval_hash* and into the pending queue, then resets evaluation-mode state. - Args: - model_age: Current model age (training step) at time of evaluation. - Returns: Dict mapping graph_name → averaged value for all signals seen. """ @@ -248,26 +398,16 @@ def stop_evaluation_mode(self, model_age: int) -> dict: results[graph_name] = avg self.graph_names.add(graph_name) - # Store in signal history under eval_hash - if graph_name not in self._signal_history: - self._signal_history[graph_name] = {} - if eval_hash not in self._signal_history[graph_name]: - self._signal_history[graph_name][eval_hash] = {} - if model_age not in self._signal_history[graph_name][eval_hash]: - self._signal_history[graph_name][eval_hash][model_age] = [] - - entry = { - "model_age": model_age, - "metric_name": graph_name, - "metric_value": avg, - "experiment_hash": eval_hash, - "timestamp": int(time.time()), - "is_evaluation_marker": True, - "split_name": split_name, - "evaluation_tags": evaluation_tags, - "audit_mode": audit_mode, - } - self._signal_history[graph_name][eval_hash][model_age].append(entry) + entry = self._append_history_entry( + graph_name=graph_name, + exp_hash=eval_hash, + global_step=model_age, + metric_value=avg, + audit_mode=audit_mode, + is_marker=True, + split_name=split_name, + evaluation_tags=evaluation_tags, + ) self._pending_queue.append(entry) self._eval_accum = {} @@ -277,12 +417,7 @@ def stop_evaluation_mode(self, model_age: int) -> dict: return results def abort_evaluation_mode(self) -> None: - """Abort evaluation mode and drop all in-progress evaluation data. - - This is used when an evaluation is canceled or timed out. - It clears the accumulation buffer and removes any per-sample history - that may have been written under the in-flight evaluation hash. - """ + """Abort evaluation mode and drop all in-progress evaluation data.""" if not self._eval_mode_active: return @@ -304,19 +439,10 @@ def remove_evaluation_hash(self, eval_hash: str) -> None: if not eval_hash: return - # Remove any marker/history entries tied to the evaluation hash. - for graph_name in list(self._signal_history.keys()): - try: - self._signal_history[graph_name].pop(eval_hash, None) - except Exception: - pass - - # Remove per-sample traces recorded under the same hash. - for graph_name in list(self._signal_history_per_sample.keys()): - try: - self._signal_history_per_sample[graph_name].pop(eval_hash, None) - except Exception: - pass + with self._lock: + self._flush_stage() + self._conn.execute("DELETE FROM signals WHERE experiment_hash = ?", [eval_hash]) + self._conn.execute("DELETE FROM per_sample WHERE experiment_hash = ?", [eval_hash]) # Drop queued points that reference this hash. self._pending_queue = [ @@ -335,243 +461,270 @@ def add_scalars(self, graph_name, signal, global_step, signal_per_sample, aggreg - Evaluation mode active: accumulate into internal buffer; per-sample history still gets written under the eval hash for Break-By-Slice support. """ - self.graph_names.add(graph_name) - self._last_step = global_step - - # ---------------------------------------------------------------- - # Evaluation-mode interception - # ---------------------------------------------------------------- - if self._eval_mode_active: - # Collect scalar values to accumulate - values: list = [] - if aggregate_by_step and signal_per_sample and isinstance(signal_per_sample, dict): - values = [self._to_float(v) for v in signal_per_sample.values()] - elif signal and isinstance(signal, dict): - values = [self._to_float(v) for _, v in signal.items()] - - if values: - if graph_name not in self._eval_accum: - self._eval_accum[graph_name] = [0.0, 0] - self._eval_accum[graph_name][0] += sum(values) - self._eval_accum[graph_name][1] += len(values) - - # Still store per-sample signals under eval_hash (for Break-By-Slice) - if signal_per_sample and isinstance(signal_per_sample, dict): - eval_hash = self._eval_mode_hash - if graph_name not in self._signal_history_per_sample: - self._signal_history_per_sample[graph_name] = {} - if eval_hash not in self._signal_history_per_sample[graph_name]: - self._signal_history_per_sample[graph_name][eval_hash] = _make_per_sample_buf() - buf = self._signal_history_per_sample[graph_name][eval_hash] + with self._lock: + self.graph_names.add(graph_name) + self._last_step = global_step + + # ------------------------------------------------------------ + # Evaluation-mode interception + # ------------------------------------------------------------ + if self._eval_mode_active: + values: list = [] + if aggregate_by_step and signal_per_sample and isinstance(signal_per_sample, dict): + values = [self._to_float(v) for v in signal_per_sample.values()] + elif signal and isinstance(signal, dict): + values = [self._to_float(v) for _, v in signal.items()] + + if values: + if graph_name not in self._eval_accum: + self._eval_accum[graph_name] = [0.0, 0] + self._eval_accum[graph_name][0] += sum(values) + self._eval_accum[graph_name][1] += len(values) + + # Still store per-sample signals under eval_hash (for Break-By-Slice) + if signal_per_sample and isinstance(signal_per_sample, dict): + eval_hash = self._eval_mode_hash + step_i = int(global_step) + for sid, value in signal_per_sample.items(): + self._stage_sample_row(graph_name, eval_hash, sid, step_i, self._to_float(value)) + + return # Do NOT add to normal history during evaluation mode + # ------------------------------------------------------------ + + exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager else None + + if self._buffered_step is not None and global_step != self._buffered_step: + self._flush_current_step_buffer(add_to_queue=True) + + if not aggregate_by_step and self._current_step_buffer: + self._flush_current_step_buffer(add_to_queue=True) + + # Update per-sample signal history + if isinstance(signal_per_sample, dict) and len(signal_per_sample): step_i = int(global_step) - idx_map = self._sample_index.setdefault(graph_name, {}).setdefault(eval_hash, {}) for sid, value in signal_per_sample.items(): - row = len(buf["sample_ids"]) - buf["sample_ids"].append(sid) - buf["steps"].append(step_i) - buf["values"].append(self._to_float(value)) - idx_map.setdefault(str(sid), []).append(row) - - return # Do NOT add to normal history during evaluation mode - # ---------------------------------------------------------------- + self._stage_sample_row(graph_name, exp_hash, sid, step_i, self._to_float(value)) - exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager else None - - if self._buffered_step is not None and global_step != self._buffered_step: - self._flush_current_step_buffer(add_to_queue=True) + metric_values = [] + if isinstance(signal_per_sample, dict) and aggregate_by_step and len(signal_per_sample): + for value in signal_per_sample.values(): + metric_values.append(self._to_float(value)) + else: + for _, line_value in signal.items(): + metric_values.append(self._to_float(line_value)) + + if aggregate_by_step: + if metric_values: + self._buffered_step = global_step + buffer_key = (global_step, graph_name, exp_hash) + if buffer_key not in self._current_step_buffer: + self._current_step_buffer[buffer_key] = {"sum": 0.0, "count": 0} + self._current_step_buffer[buffer_key]["sum"] += sum(metric_values) + self._current_step_buffer[buffer_key]["count"] += len(metric_values) + return + + # Update averaged signal history immediately. Only emit when we have at + # least one valid metric value (signals carrying only per-sample data are + # stored separately in per_sample). + signal_entry = None + if len(metric_values) > 0: + signal_entry = self._append_history_entry( + graph_name=graph_name, + exp_hash=exp_hash, + global_step=global_step, + metric_value=sum(metric_values) / len(metric_values) if len(metric_values) > 1 else metric_values[0], + ) + + if signal_entry is not None: + self._pending_queue.append(signal_entry) - if not aggregate_by_step and self._current_step_buffer: - self._flush_current_step_buffer(add_to_queue=True) + def ingest_per_sample(self, graph_name: str, exp_hash, triples) -> None: + """Insert per-sample ``(sample_id, step, value)`` triples, de-duplicating + on ``(sample_id, step)`` within ``(graph_name, exp_hash)``. - # Update per-sample signal history with compact array storage - if isinstance(signal_per_sample, dict) and len(signal_per_sample): - if graph_name not in self._signal_history_per_sample: - self._signal_history_per_sample[graph_name] = {} - if exp_hash not in self._signal_history_per_sample[graph_name]: - self._signal_history_per_sample[graph_name][exp_hash] = _make_per_sample_buf() + Unlike ``add_scalars`` (which always appends), this is idempotent on the + ``(sample_id, step)`` key: the first value wins and later duplicates are + ignored. Useful for back-filling / importing history without creating + repeated points. - buf = self._signal_history_per_sample[graph_name][exp_hash] - step_i = int(global_step) - idx_map = self._sample_index.setdefault(graph_name, {}).setdefault(exp_hash, {}) - for sid, value in signal_per_sample.items(): - row = len(buf["sample_ids"]) - buf["sample_ids"].append(sid) - buf["steps"].append(step_i) - buf["values"].append(self._to_float(value)) - idx_map.setdefault(str(sid), []).append(row) - - metric_values = [] - if isinstance(signal_per_sample, dict) and aggregate_by_step and len(signal_per_sample): - for value in signal_per_sample.values(): - metric_values.append(self._to_float(value)) - else: - for _, line_value in signal.items(): - metric_values.append(self._to_float(line_value)) - - if aggregate_by_step: - if metric_values: - self._buffered_step = global_step - buffer_key = (global_step, graph_name, exp_hash) - if buffer_key not in self._current_step_buffer: - self._current_step_buffer[buffer_key] = {"sum": 0.0, "count": 0} - self._current_step_buffer[buffer_key]["sum"] += sum(metric_values) - self._current_step_buffer[buffer_key]["count"] += len(metric_values) + Args: + graph_name: Signal name. + exp_hash: Experiment hash (``None`` allowed). + triples: Iterable of ``(sample_id, step, value)``. + """ + triples = list(triples) + if not triples: return - # Update averaged signal history immediately - signal_entry = None + with self._lock: + self.graph_names.add(graph_name) + self._flush_stage() - # Only add to history if we have at least one valid metric value (otherwise we may end up with empty/invalid entries from signals that only contain per-sample values, which are stored separately in _signal_history_per_sample) - if len(metric_values) > 0: - signal_entry = self._append_history_entry( - graph_name=graph_name, - exp_hash=exp_hash, - global_step=global_step, - metric_value=sum(metric_values) / len(metric_values) if len(metric_values) > 1 else metric_values[0], - ) + # Existing (sample_id, step) keys for this (graph, hash). + params = [graph_name] + sql = "SELECT sample_id, step FROM per_sample WHERE metric_name = ?" + sql += self._hash_filter(exp_hash, params) + seen = {(str(s), int(t)) for s, t in self._conn.execute(sql, params).fetchall()} - # Add signal to pending queue for live incremental update to WeightsStudio - if signal_entry is not None: - self._pending_queue.append(signal_entry) + for sid, step, value in triples: + key = (str(sid), int(step)) + if key in seen: + continue + seen.add(key) + self._stage_sample_row(graph_name, exp_hash, sid, step, self._to_float(value)) - # Print methods for debugging/inspection of logger state + # ------------------------------------------------------------------ + # Print helpers (debug) + # ------------------------------------------------------------------ def print_history(self): - """Print all items in history.""" - for metric_name, experiments in self._signal_history.items(): + history = self.get_signal_history() + for metric_name, experiments in history.items(): print(f"Metric: {metric_name}") for exp_hash, steps in experiments.items(): - print(f" Experiment Hash: {exp_hash}") + print(f" Experiment Hash: {exp_hash}") for step, signals in steps.items(): - print(f" Step: {step}") + print(f" Step: {step}") for signal in signals: - print(f" Signal: {signal}") - return self._signal_history + print(f" Signal: {signal}") + return history def print_history_per_sample(self): - """Print all items in per-sample history.""" - for metric_name, exps in self._signal_history_per_sample.items(): + history = self.get_signal_history_per_sample() + for metric_name, exps in history.items(): print(f"Metric: {metric_name}") - for exp_hash, buf in exps.items(): - print(f" Experiment Hash: {exp_hash}") - for sid, step, val in zip(buf["sample_ids"], buf["steps"], buf["values"]): - print(f" Sample ID: {sid}, Step: {step}, Value: {val}") - return self._signal_history_per_sample + for exp_hash, entries in exps.items(): + print(f" Experiment Hash: {exp_hash}") + for e in entries: + print(f" Sample ID: {e['sample_id']}, Step: {e['model_age']}, Value: {e['metric_value']}") + return history def print_buffer(self): - """Print current step buffer contents.""" print(f"Current step: {self._last_step}") print(f"Buffered metrics: {self._current_step_buffer}") return self._current_step_buffer - # Accessor methods for retrieving logger state (e.g. for checkpoint saving or programmatic access) + # ------------------------------------------------------------------ + # Accessors + # ------------------------------------------------------------------ def get_graph_names(self): - """ - Get list of all graph names encountered in signals. - Returns: - List of graph names. - """ + """Get list of all graph names encountered in signals.""" return list(self.graph_names) + def list_sample_signal_names(self) -> list: + """Distinct signal names that have per-sample history.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute("SELECT DISTINCT metric_name FROM per_sample").fetchall() + return [r[0] for r in rows] + + def list_instance_signal_names(self) -> list: + """Distinct signal names that have per-instance history.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute("SELECT DISTINCT metric_name FROM per_instance").fetchall() + return [r[0] for r in rows] + def get_signal_history(self): - """Retrieve all accumulated signals from memory.""" - # self._flush_current_step_buffer(add_to_queue=False) # History should already be up to date since we flush on step change and on add_scalars when not aggregating by step, but we can flush here as well to be safe before retrieving history for checkpoint saving - return deepcopy(self._signal_history) + """Reconstruct aggregated history as ``{metric: {hash: {step: [entry, ...]}}}``.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + """ + SELECT metric_name, experiment_hash, step, metric_value, timestamp, + audit_mode, is_evaluation_marker, split_name, evaluation_tags, point_note + FROM signals ORDER BY seq + """ + ).fetchall() + + result: dict = {} + for (metric, h, step, val, ts, audit, marker, split, tags, note) in rows: + entry = { + "model_age": step, + "metric_name": metric, + "metric_value": val, + "experiment_hash": h, + "timestamp": int(ts) if ts is not None else 0, + "audit_mode": bool(audit), + "is_evaluation_marker": bool(marker), + "split_name": split or "", + "evaluation_tags": json.loads(tags) if tags else [], + } + if note: + entry["point_note"] = note + result.setdefault(metric, {}).setdefault(h, {}).setdefault(step, []).append(entry) + return result def get_current_signaL_history(self, graph_name: str, meta: bool = False): - """Get current history for a specific signal.""" - if graph_name not in self._signal_history: + """Get current-hash aggregated history for a specific signal.""" + if graph_name not in self.graph_names: return {} - # Get Current Hash exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager else None - # Process history + with self._lock: + self._flush_stage() + params = [graph_name] + sql = "SELECT step, metric_value FROM signals WHERE metric_name = ?" + sql += self._hash_filter(exp_hash, params) + sql += " ORDER BY seq" + rows = self._conn.execute(sql, params).fetchall() + if meta: - return self._signal_history.get(graph_name, {}).get(exp_hash, {}) - else: - history = self._signal_history.get(graph_name, {}).get(exp_hash, {}) - result = [] - for _, entries in history.items(): - for entry in entries: - result.append({ - "model_age": entry.get("model_age"), - "metric_value": entry.get("metric_value"), - }) - return result + steps: dict = {} + for step, val in rows: + steps.setdefault(step, []).append({ + "model_age": step, "metric_value": val, + }) + return steps + + return [{"model_age": step, "metric_value": val} for step, val in rows] def get_signal_history_per_sample(self): - """Reconstruct per-sample history as list-of-dicts from compact array storage.""" - result = {} - for graph_name, exps in self._signal_history_per_sample.items(): - result[graph_name] = {} - for exp_hash, buf in exps.items(): - entries = [] - for sid, step, val in zip(buf["sample_ids"], buf["steps"], buf["values"]): - entries.append({ - "sample_id": sid, - "model_age": step, - "metric_name": graph_name, - "metric_value": float(val), - "experiment_hash": exp_hash, - }) - result[graph_name][exp_hash] = entries + """Per-sample history as ``{metric: {hash: [entry, ...]}}``.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT metric_name, experiment_hash, sample_id, step, value " + "FROM per_sample ORDER BY seq" + ).fetchall() + + result: dict = {} + for (metric, h, sid, step, val) in rows: + result.setdefault(metric, {}).setdefault(h, []).append({ + "sample_id": sid, + "model_age": step, + "metric_name": metric, + "metric_value": float(val), + "experiment_hash": h, + }) return result def get_current_signaL_history_per_sample(self, graph_name: str, sample_ids: list = None, exp_hash: str = None): - """Get current history for a specific signal.""" - if graph_name not in self._signal_history: + """Get current-hash per-sample history for a specific signal.""" + if graph_name not in self.graph_names: return {} - # Get Current Hash exp_hash = self.chkpt_manager.get_current_experiment_hash() if self.chkpt_manager and exp_hash is None else exp_hash - - # Return history for the specified graph name, filtered by sample IDs and experiment hash if provided. If meta=True, returns raw history dict; otherwise returns list of (sample_id, step, value) tuples. - result = self.query_per_sample( - graph_name, - sample_ids=sample_ids, - exp_hash=exp_hash - ) - return result + return self.query_per_sample(graph_name, sample_ids=sample_ids, exp_hash=exp_hash) def query_per_sample(self, graph_name: str, sample_ids=None, exp_hash=None): - """Efficiently query per-sample history for specific sample IDs. + """Query per-sample history. - Returns a dict mapping sample_id → list of {model_age, signal_value} dicts, - filtered by sample_ids and optionally by experiment hash. - Much faster than get_signal_history_per_sample() for targeted queries - (e.g., "show me only samples with label 8"). - - Args: - graph_name: Signal name (e.g., "loss", "accuracy"). - sample_ids: Collection of sample IDs to filter by. If None, returns all. - exp_hash: Specific experiment hash to query. If None, queries all hashes. - - Returns: - List of (sample_id, step, value, experiment_hash) tuples. + Returns a list of ``(sample_id, step, value, experiment_hash)`` tuples, + filtered by *sample_ids* and optionally *exp_hash* (``None`` = all hashes). """ - if graph_name not in self._signal_history_per_sample: - return [] - - exps = self._signal_history_per_sample[graph_name] - hashes = [exp_hash] if exp_hash is not None else list(exps.keys()) - # Stored ids are ints; callers pass str (df index is str-normalized) — compare as str. - sid_set = {str(s) for s in sample_ids} if sample_ids is not None else None - - results = [] - for h in hashes: - buf = exps.get(h) - if buf is None: - continue - if sid_set is None: - for sid, step, val in zip(buf["sample_ids"], buf["steps"], buf["values"]): - results.append((sid, step, float(val), h)) - else: - idx_map = self._sample_index.get(graph_name, {}).get(h, {}) - for sid in sid_set: - for row in idx_map.get(sid, []): - results.append((sid, buf["steps"][row], float(buf["values"][row]), h)) - - return results + with self._lock: + self._flush_stage() + params = [graph_name] + sql = "SELECT sample_id, step, value, experiment_hash FROM per_sample WHERE metric_name = ?" + sql += self._hash_filter(exp_hash, params) + if sample_ids is not None: + sql += " AND sample_id IN (SELECT UNNEST(?))" + params.append([str(s) for s in sample_ids]) + sql += " ORDER BY seq" + rows = self._conn.execute(sql, params).fetchall() + + return [(sid, int(step), float(val), h) for (sid, step, val, h) in rows] def query_per_instance( self, @@ -583,54 +736,26 @@ def query_per_instance( """Query per-instance signal history. Returns a list of ``(sample_id, annotation_id, step, value, exp_hash)`` - tuples. Any of *sample_id*, *annotation_id*, *exp_hash* may be ``None`` + tuples. Any of *sample_id*, *annotation_id*, *exp_hash* may be ``None`` to return all values along that dimension. - - Args: - graph_name: Signal name (e.g. ``"confidence"``). - sample_id: Filter to a single sample. ``None`` returns all samples. - annotation_id: Filter to a single instance (1-based). ``None`` = all. - exp_hash: Filter to one experiment hash. ``None`` = all. """ - if graph_name not in self._signal_history_per_instance: - return [] - - exps = self._signal_history_per_instance[graph_name] - hashes = [exp_hash] if exp_hash is not None else list(exps.keys()) - sid_filter = str(sample_id) if sample_id is not None else None - aid_filter = int(annotation_id) if annotation_id is not None else None - - results = [] - for h in hashes: - buf = exps.get(h) - if buf is None: - continue - if sid_filter is None and aid_filter is None: - # No filter: full scan - for sid, aid, step, val in zip( - buf["sample_ids"], buf["annotation_ids"], buf["steps"], buf["values"] - ): - results.append((str(sid), int(aid), int(step), float(val), h)) - elif sid_filter is not None and aid_filter is not None: - # Both filters: O(1) index lookup - idx_map = self._instance_index.get(graph_name, {}).get(h, {}) - for row in idx_map.get((sid_filter, aid_filter), []): - results.append((sid_filter, aid_filter, int(buf["steps"][row]), float(buf["values"][row]), h)) - elif sid_filter is not None: - # Sample filter only: collect all annotation_ids for this sample - idx_map = self._instance_index.get(graph_name, {}).get(h, {}) - for (sid_k, aid_k), rows in idx_map.items(): - if sid_k == sid_filter: - for row in rows: - results.append((sid_filter, aid_k, int(buf["steps"][row]), float(buf["values"][row]), h)) - else: - # annotation_id filter only: scan index keys - idx_map = self._instance_index.get(graph_name, {}).get(h, {}) - for (sid_k, aid_k), rows in idx_map.items(): - if aid_k == aid_filter: - for row in rows: - results.append((sid_k, aid_filter, int(buf["steps"][row]), float(buf["values"][row]), h)) - return results + with self._lock: + self._flush_stage() + params = [graph_name] + sql = ("SELECT sample_id, annotation_id, step, value, experiment_hash " + "FROM per_instance WHERE metric_name = ?") + sql += self._hash_filter(exp_hash, params) + if sample_id is not None: + sql += " AND sample_id = ?" + params.append(str(sample_id)) + if annotation_id is not None: + sql += " AND annotation_id = ?" + params.append(int(annotation_id)) + sql += " ORDER BY seq" + rows = self._conn.execute(sql, params).fetchall() + + return [(str(sid), int(aid), int(step), float(val), h) + for (sid, aid, step, val, h) in rows] def aggregate_per_sample_by_step( self, @@ -640,58 +765,29 @@ def aggregate_per_sample_by_step( ) -> dict: """Return mean signal value per step, aggregated over matching samples. - Uses numpy vectorized operations instead of a Python loop — ~100× faster - than iterating ``query_per_sample`` results for large sample counts. - - Args: - graph_name: Signal name. - sample_ids: Samples to include. ``None`` = all samples. - exp_hash: Filter to one experiment hash. ``None`` = all hashes. + DuckDB performs the ``GROUP BY step`` average natively, which scales to + millions of rows far better than a Python loop — this is the path used + by break-by-slices. Returns: - ``{exp_hash: [(step, mean_value), ...]}`` — one sorted series per hash. + ``{exp_hash: [(step, mean_value), ...]}`` — one step-sorted series + per hash. """ - import numpy as _np - - if graph_name not in self._signal_history_per_sample: - return {} - - exps = self._signal_history_per_sample[graph_name] - hashes = [exp_hash] if exp_hash is not None else list(exps.keys()) - sid_set = {str(s) for s in sample_ids} if sample_ids is not None else None - - result = {} - for h in hashes: - buf = exps.get(h) - if buf is None: - continue - - # Convert typed C arrays to numpy with zero-copy (frombuffer gives a read-only view) - steps_np = _np.frombuffer(buf["steps"], dtype=_np.int32).copy() - values_np = _np.frombuffer(buf["values"], dtype=_np.float32).copy() - - if sid_set is not None: - idx_map = self._sample_index.get(graph_name, {}).get(h, {}) - rows = [] - for sid in sid_set: - rows.extend(idx_map.get(sid, [])) - if not rows: - continue - row_idx = _np.array(rows, dtype=_np.intp) - steps_np = steps_np[row_idx] - values_np = values_np[row_idx] - - if len(steps_np) == 0: - continue - - # Vectorized group-by step → mean - unique_steps, inverse = _np.unique(steps_np, return_inverse=True) - sums = _np.bincount(inverse, weights=values_np.astype(_np.float64)) - counts = _np.bincount(inverse) - means = sums / counts - - result[h] = list(zip(unique_steps.tolist(), means.tolist())) - + with self._lock: + self._flush_stage() + params = [graph_name] + sql = ("SELECT experiment_hash, step, avg(value) AS mean_value " + "FROM per_sample WHERE metric_name = ?") + sql += self._hash_filter(exp_hash, params) + if sample_ids is not None: + sql += " AND sample_id IN (SELECT UNNEST(?))" + params.append([str(s) for s in sample_ids]) + sql += " GROUP BY experiment_hash, step ORDER BY experiment_hash, step" + rows = self._conn.execute(sql, params).fetchall() + + result: dict = {} + for (h, step, mean_val) in rows: + result.setdefault(h, []).append((int(step), float(mean_val))) return result def add_instance_scalars( @@ -703,11 +799,10 @@ def add_instance_scalars( global_step: int, exp_hash: str | None = None, ) -> None: - """Record per-instance scalar values in compact storage. + """Record per-instance scalar values. - Call this from ``save_instance_signals`` once per scalar signal per - batch. Each element of *sample_ids*, *annotation_ids*, *values* - corresponds to one detection / segmentation instance. + Each element of *sample_ids*, *annotation_ids*, *values* corresponds to + one detection / segmentation instance. Args: graph_name: Signal name (e.g. ``"confidence"``). @@ -724,76 +819,63 @@ def add_instance_scalars( else None ) - if graph_name not in self._signal_history_per_instance: - self._signal_history_per_instance[graph_name] = {} - if exp_hash not in self._signal_history_per_instance[graph_name]: - self._signal_history_per_instance[graph_name][exp_hash] = _make_per_instance_buf() - - buf = self._signal_history_per_instance[graph_name][exp_hash] - step_i = int(global_step) - idx_map = self._instance_index.setdefault(graph_name, {}).setdefault(exp_hash, {}) try: import numpy as _np vals = _np.asarray(values, dtype=_np.float32).ravel() except Exception: vals = [float(v) for v in values] - for sid, aid, val in zip(sample_ids, annotation_ids, vals): - row = len(buf["sample_ids"]) - sid_s, aid_i = str(sid), int(aid) - buf["sample_ids"].append(sid_s) - buf["annotation_ids"].append(aid_i) - buf["steps"].append(step_i) - buf["values"].append(float(val)) - idx_map.setdefault((sid_s, aid_i), []).append(row) + with self._lock: + step_i = int(global_step) + for sid, aid, val in zip(sample_ids, annotation_ids, vals): + self._stage_instance_row(graph_name, exp_hash, sid, aid, step_i, float(val)) def get_signal_history_per_instance(self) -> dict: - """Reconstruct per-instance history as list-of-dicts from compact array storage.""" - result = {} - for graph_name, exps in self._signal_history_per_instance.items(): - result[graph_name] = {} - for exp_hash, buf in exps.items(): - entries = [] - for sid, aid, step, val in zip( - buf["sample_ids"], buf["annotation_ids"], buf["steps"], buf["values"] - ): - entries.append({ - "sample_id": str(sid), - "annotation_id": int(aid), - "model_age": int(step), - "metric_name": graph_name, - "metric_value": float(val), - "experiment_hash": exp_hash, - }) - result[graph_name][exp_hash] = entries + """Per-instance history as ``{metric: {hash: [entry, ...]}}``.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT metric_name, experiment_hash, sample_id, annotation_id, step, value " + "FROM per_instance ORDER BY seq" + ).fetchall() + + result: dict = {} + for (metric, h, sid, aid, step, val) in rows: + result.setdefault(metric, {}).setdefault(h, []).append({ + "sample_id": str(sid), + "annotation_id": int(aid), + "model_age": int(step), + "metric_name": metric, + "metric_value": float(val), + "experiment_hash": h, + }) return result def save_snapshot(self) -> dict: - """Build a serializable snapshot of the logger state.""" + """Build a serializable snapshot of the logger state (compact format).""" self._flush_current_step_buffer(add_to_queue=False) - # Compact serialization: store parallel lists instead of list-of-dicts - per_sample_compact = {} - for graph_name, exps in self._signal_history_per_sample.items(): + per_sample_compact: dict = {} + for graph_name, exps in self.get_signal_history_per_sample().items(): per_sample_compact[graph_name] = {} - for exp_hash, buf in exps.items(): + for exp_hash, entries in exps.items(): per_sample_compact[graph_name][exp_hash] = { "_compact": True, - "sample_ids": list(buf["sample_ids"]), - "steps": list(buf["steps"]), - "values": list(buf["values"]), + "sample_ids": [e["sample_id"] for e in entries], + "steps": [e["model_age"] for e in entries], + "values": [e["metric_value"] for e in entries], } - per_instance_compact = {} - for graph_name, exps in self._signal_history_per_instance.items(): + per_instance_compact: dict = {} + for graph_name, exps in self.get_signal_history_per_instance().items(): per_instance_compact[graph_name] = {} - for exp_hash, buf in exps.items(): + for exp_hash, entries in exps.items(): per_instance_compact[graph_name][exp_hash] = { - "_compact": True, - "sample_ids": list(buf["sample_ids"]), - "annotation_ids": list(buf["annotation_ids"]), - "steps": list(buf["steps"]), - "values": list(buf["values"]), + "_compact": True, + "sample_ids": [e["sample_id"] for e in entries], + "annotation_ids": [e["annotation_id"] for e in entries], + "steps": [e["model_age"] for e in entries], + "values": [e["metric_value"] for e in entries], } return { @@ -803,32 +885,34 @@ def save_snapshot(self) -> dict: "signal_history_per_instance": per_instance_compact, } - # ------------------------------------------------------------------ - # Convenience: list all evaluation-marker hashes in history - # ------------------------------------------------------------------ def get_evaluation_marker_hashes(self) -> list: - """Return all experiment hashes that correspond to evaluation markers.""" + """Return all experiment hashes of the form ``_`` in history.""" + with self._lock: + self._flush_stage() + rows = self._conn.execute( + "SELECT DISTINCT experiment_hash FROM signals WHERE experiment_hash IS NOT NULL" + ).fetchall() + hashes = set() - for gname in self._signal_history: - for hash_key in self._signal_history[gname]: - if isinstance(hash_key, str) and "_" in hash_key: - # Check that the suffix is a pure integer - suffix = hash_key.rsplit("_", 1)[-1] - try: - int(suffix) - hashes.add(hash_key) - except ValueError: - pass + for (hash_key,) in rows: + if isinstance(hash_key, str) and "_" in hash_key: + suffix = hash_key.rsplit("_", 1)[-1] + try: + int(suffix) + hashes.add(hash_key) + except ValueError: + pass return sorted(hashes) def get_and_clear_queue(self): """Get pending queue and clear it (for incremental updates to WeightsStudio).""" - queue_copy = list(self._pending_queue) - self._pending_queue.clear() + with self._lock: + queue_copy = list(self._pending_queue) + self._pending_queue.clear() return queue_copy def set_point_note(self, metric_name: str, experiment_hash: str, model_age: int, note: str) -> bool: - """Attach or clear a note for a specific signal point identified by metric/hash/step.""" + """Attach or clear a note for a signal point identified by metric/hash/step.""" metric_name = str(metric_name or "") experiment_hash = str(experiment_hash or "") if not metric_name or not experiment_hash: @@ -836,55 +920,64 @@ def set_point_note(self, metric_name: str, experiment_hash: str, model_age: int, normalized_step = int(model_age) cleaned_note = str(note or "").strip() - updated = False - - entries = ( - self._signal_history.get(metric_name, {}) - .get(experiment_hash, {}) - .get(normalized_step, []) - ) - for entry in entries: - if not isinstance(entry, dict): - continue - if cleaned_note: - entry["point_note"] = cleaned_note - else: - entry.pop("point_note", None) - updated = True - for entry in self._pending_queue: - if not isinstance(entry, dict): - continue - if str(entry.get("metric_name", "")) != metric_name: - continue - if str(entry.get("experiment_hash", "")) != experiment_hash: - continue - try: - if int(entry.get("model_age", -1)) != normalized_step: + with self._lock: + self._flush_stage() + matched = self._conn.execute( + "SELECT count(*) FROM signals " + "WHERE metric_name = ? AND experiment_hash = ? AND step = ?", + [metric_name, experiment_hash, normalized_step], + ).fetchone()[0] + if matched: + self._conn.execute( + "UPDATE signals SET point_note = ? " + "WHERE metric_name = ? AND experiment_hash = ? AND step = ?", + [cleaned_note, metric_name, experiment_hash, normalized_step], + ) + + for entry in self._pending_queue: + if not isinstance(entry, dict): continue - except Exception: - continue - if cleaned_note: - entry["point_note"] = cleaned_note - else: - entry.pop("point_note", None) + if str(entry.get("metric_name", "")) != metric_name: + continue + if str(entry.get("experiment_hash", "")) != experiment_hash: + continue + try: + if int(entry.get("model_age", -1)) != normalized_step: + continue + except Exception: + continue + if cleaned_note: + entry["point_note"] = cleaned_note + else: + entry.pop("point_note", None) - return updated + return bool(matched) - # Logger saving/loading methods for checkpoint persistence (used in WeightsLabCallback) + # ------------------------------------------------------------------ + # Snapshot loading (checkpoint persistence) + # ------------------------------------------------------------------ def load_signal_history(self, signals): - """Load signal history into memory (supports legacy and nested formats).""" + """Load aggregated signal history (supports legacy list and nested dict).""" if not signals: return - def _append_signal_entry(metric_name, exp_hash, step, signal_entry): - if metric_name not in self._signal_history: - self._signal_history[metric_name] = {} - if exp_hash not in self._signal_history[metric_name]: - self._signal_history[metric_name][exp_hash] = {} - if step not in self._signal_history[metric_name][exp_hash]: - self._signal_history[metric_name][exp_hash][step] = [] - self._signal_history[metric_name][exp_hash][step].append(signal_entry) + def _stage_entry(metric_name, exp_hash, step, entry): + try: + step_i = int(step) + except (TypeError, ValueError): + return + with self._lock: + self._stage_signal_row( + metric_name, exp_hash, step_i, + float(entry.get("metric_value", 0.0)), + int(entry.get("timestamp", int(time.time()))), + bool(entry.get("audit_mode", False)), + bool(entry.get("is_evaluation_marker", False)), + entry.get("split_name", ""), + entry.get("evaluation_tags", []), + entry.get("point_note", "") or "", + ) if isinstance(signals, dict): for metric_name, experiments in signals.items(): @@ -895,23 +988,10 @@ def _append_signal_entry(metric_name, exp_hash, step, signal_entry): if not isinstance(steps, dict): continue for step_key, entries in steps.items(): - step = step_key - if isinstance(step_key, str): - try: - step = int(step_key) - except Exception: - step = step_key - entries_list = entries if isinstance(entries, list) else [entries] for entry in entries_list: - if not isinstance(entry, dict): - continue - signal_entry = dict(entry) - signal_entry.setdefault("metric_name", metric_name) - signal_entry.setdefault("model_age", step) - signal_entry.setdefault("experiment_hash", exp_hash) - signal_entry.setdefault("timestamp", int(time.time())) - _append_signal_entry(metric_name, exp_hash, step, signal_entry) + if isinstance(entry, dict): + _stage_entry(metric_name, exp_hash, step_key, entry) return if isinstance(signals, list): @@ -921,138 +1001,100 @@ def _append_signal_entry(metric_name, exp_hash, step, signal_entry): metric_name = signal.get("metric_name") if not metric_name: continue - exp_hash = signal.get("experiment_hash") - step = signal.get("model_age") - signal_entry = dict(signal) - signal_entry.setdefault("metric_name", metric_name) - signal_entry.setdefault("model_age", step) - signal_entry.setdefault("experiment_hash", exp_hash) - signal_entry.setdefault("timestamp", int(time.time())) self.graph_names.add(metric_name) - _append_signal_entry(metric_name, exp_hash, step, signal_entry) + _stage_entry( + metric_name, + signal.get("experiment_hash"), + signal.get("model_age", 0), + signal, + ) def load_signal_history_per_sample(self, signals_per_sample): - """Load per-sample history into compact array storage. + """Load per-sample history. Handles three formats: - - New compact: {graph_name: {exp_hash: {"_compact": True, "sample_ids": [...], "steps": [...], "values": [...]}}} - - Legacy list: {graph_name: {exp_hash: [{sample_id, model_age, metric_value, ...}, ...]}} - - Legacy dict: {graph_name: {sample_id_as_key: {model_age, metric_value, ...}}} → stored under None key + - Compact: {graph: {hash: {"_compact": True, "sample_ids": [...], "steps": [...], "values": [...]}}} + - Legacy list: {graph: {hash: [{sample_id, model_age, metric_value, ...}, ...]}} + - Legacy dict: {graph: {sample_id_as_key: {model_age, metric_value, ...}}} → stored under None hash """ if not signals_per_sample: return for metric_name, samples_by_exp in signals_per_sample.items(): self.graph_names.add(metric_name) - if metric_name not in self._signal_history_per_sample: - self._signal_history_per_sample[metric_name] = {} - if not isinstance(samples_by_exp, dict): continue for exp_hash, entries in samples_by_exp.items(): - # --- New compact format --- + # --- Compact format --- if isinstance(entries, dict) and entries.get("_compact"): - if exp_hash not in self._signal_history_per_sample[metric_name]: - self._signal_history_per_sample[metric_name][exp_hash] = _make_per_sample_buf() - buf = self._signal_history_per_sample[metric_name][exp_hash] - ids = entries.get("sample_ids", []) + ids = entries.get("sample_ids", []) steps = entries.get("steps", []) - vals = entries.get("values", []) - idx_map = self._sample_index.setdefault(metric_name, {}).setdefault(exp_hash, {}) - for s, t, v in zip(ids, steps, vals): - try: - row = len(buf["sample_ids"]) - sid_s = str(s) - buf["sample_ids"].append(sid_s) - buf["steps"].append(int(t)) - buf["values"].append(float(v)) - idx_map.setdefault(sid_s, []).append(row) - except (TypeError, ValueError): - pass + vals = entries.get("values", []) + with self._lock: + for s, t, v in zip(ids, steps, vals): + try: + self._stage_sample_row(metric_name, exp_hash, s, int(t), float(v)) + except (TypeError, ValueError): + pass - # --- Legacy list-of-dicts format --- + # --- Legacy list-of-dicts --- elif isinstance(entries, list): - if exp_hash not in self._signal_history_per_sample[metric_name]: - self._signal_history_per_sample[metric_name][exp_hash] = _make_per_sample_buf() - buf = self._signal_history_per_sample[metric_name][exp_hash] - idx_map = self._sample_index.setdefault(metric_name, {}).setdefault(exp_hash, {}) - for entry in entries: - if not isinstance(entry, dict): - continue + with self._lock: + for entry in entries: + if not isinstance(entry, dict): + continue + try: + self._stage_sample_row( + metric_name, exp_hash, + entry.get("sample_id", -1), + int(entry.get("model_age", 0)), + float(entry.get("metric_value", 0.0)), + ) + except (TypeError, ValueError): + pass + + # --- Legacy single-dict (exp_hash key was actually the sample_id) --- + elif isinstance(entries, dict): + sid = str(exp_hash) if isinstance(exp_hash, (int, float)) else str(-1) + with self._lock: try: - row = len(buf["sample_ids"]) - sid_s = str(entry.get("sample_id", -1)) - buf["sample_ids"].append(sid_s) - buf["steps"].append(int(entry.get("model_age", 0))) - buf["values"].append(float(entry.get("metric_value", 0.0))) - idx_map.setdefault(sid_s, []).append(row) + self._stage_sample_row( + metric_name, None, sid, + int(entries.get("model_age", 0)), + float(entries.get("metric_value", 0.0)), + ) except (TypeError, ValueError): pass - # --- Legacy single-dict format (exp_hash key was actually the sample_id) --- - elif isinstance(entries, dict): - null_key = None - if null_key not in self._signal_history_per_sample[metric_name]: - self._signal_history_per_sample[metric_name][null_key] = _make_per_sample_buf() - buf = self._signal_history_per_sample[metric_name][null_key] - idx_map = self._sample_index.setdefault(metric_name, {}).setdefault(null_key, {}) - try: - row = len(buf["sample_ids"]) - sid = str(exp_hash) if isinstance(exp_hash, (int, float)) else str(-1) - buf["sample_ids"].append(sid) - buf["steps"].append(int(entries.get("model_age", 0))) - buf["values"].append(float(entries.get("metric_value", 0.0))) - idx_map.setdefault(sid, []).append(row) - except (TypeError, ValueError): - pass - def load_signal_history_per_instance(self, signals_per_instance: dict) -> None: """Load per-instance history from a compact snapshot dict.""" if not signals_per_instance: return for metric_name, exps in signals_per_instance.items(): self.graph_names.add(metric_name) - if metric_name not in self._signal_history_per_instance: - self._signal_history_per_instance[metric_name] = {} if not isinstance(exps, dict): continue for exp_hash, entries in exps.items(): if not (isinstance(entries, dict) and entries.get("_compact")): continue - if exp_hash not in self._signal_history_per_instance[metric_name]: - self._signal_history_per_instance[metric_name][exp_hash] = _make_per_instance_buf() - buf = self._signal_history_per_instance[metric_name][exp_hash] - ids = entries.get("sample_ids", []) + ids = entries.get("sample_ids", []) aids = entries.get("annotation_ids", []) steps = entries.get("steps", []) - vals = entries.get("values", []) - idx_map = self._instance_index.setdefault(metric_name, {}).setdefault(exp_hash, {}) - for s, a, t, v in zip(ids, aids, steps, vals): - try: - row = len(buf["sample_ids"]) - sid_s, aid_i = str(s), int(a) - buf["sample_ids"].append(sid_s) - buf["annotation_ids"].append(aid_i) - buf["steps"].append(int(t)) - buf["values"].append(float(v)) - idx_map.setdefault((sid_s, aid_i), []).append(row) - except (TypeError, ValueError): - pass + vals = entries.get("values", []) + with self._lock: + for s, a, t, v in zip(ids, aids, steps, vals): + try: + self._stage_instance_row(metric_name, exp_hash, s, int(a), int(t), float(v)) + except (TypeError, ValueError): + pass def load_snapshot(self, snapshot: dict): """Restore logger state from a snapshot dict.""" if not snapshot: return - graph_names = snapshot.get("graph_names", []) - self.graph_names.update(graph_names) - - signals = snapshot.get("signal_history", []) - self.load_signal_history(signals) - - signals_per_sample = snapshot.get("signal_history_per_sample", {}) - self.load_signal_history_per_sample(signals_per_sample) - - signals_per_instance = snapshot.get("signal_history_per_instance", {}) - self.load_signal_history_per_instance(signals_per_instance) + self.graph_names.update(snapshot.get("graph_names", [])) + self.load_signal_history(snapshot.get("signal_history", [])) + self.load_signal_history_per_sample(snapshot.get("signal_history_per_sample", {})) + self.load_signal_history_per_instance(snapshot.get("signal_history_per_instance", {})) diff --git a/weightslab/backend/model_interface.py b/weightslab/backend/model_interface.py index 73803e7f..773d290a 100755 --- a/weightslab/backend/model_interface.py +++ b/weightslab/backend/model_interface.py @@ -134,7 +134,7 @@ def __init__( if dummy_input is None: raise ValueError("Model object must have 'input_shape' attribute for proper registration with WeightsLab.") else: - self.model.input_shape = tuple(dummy_input.shape[1:]) # Exclude batch dimension + self.model.input_shape = tuple(dummy_input.shape[1:]) # Exclude batch dimension # Move dummy input to the correct device, or create a default one if not provided if dummy_input is not None: @@ -443,7 +443,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): True if an exception occurred and it was successfully handled by this method, preventing it from being re-raised. """ - self.visited_nodes = set() # Reset NetworkWithOps nodes visited + self.visited_nodes = set() # Reset NetworkWithOps nodes visited if exc_type is not None: logger.error( f"[{self.__class__.__name__}]: An exception occurred: \ @@ -492,7 +492,7 @@ def load_state_dict(self, state_dict, strict: bool = True): Note: `assign=False` is explicitly passed so that parameter tensors are updated **in-place** (data copy) rather than replaced with new - objects. Replacing parameter objects (assign=True, the NetworkWithOps + objects. Replacing parameter objects (assign=True, the NetworkWithOps default) would silently invalidate any optimizer that was created before this load_state_dict call, because the optimizer holds references to the old Parameter objects. @@ -885,9 +885,9 @@ def __repr__(self): pass seq_lines = seq_module_repr.split('\n') # The first line is formatted with the name, the rest are indented - seq_string += f" ({seq_name}): {seq_lines[0]}\n" + seq_string += f" ({seq_name}): {seq_lines[0]}\n" for seq_line in seq_lines[1:]: - seq_string += f" {seq_line}\n" + seq_string += f" {seq_line}\n" module_repr = f"{seq_string}" else: module_repr = f"ID=None | {module_repr}" @@ -899,9 +899,9 @@ def __repr__(self): lines = module_repr.split('\n') # The first line is formatted with the name, the rest are indented - string += f" ({name}): {lines[0]}\n" + string += f" ({name}): {lines[0]}\n" for line in lines[1:]: - string += f" {line}\n" + string += f" {line}\n" string += ")" return string diff --git a/weightslab/baseline_models/pytorch/models.py b/weightslab/baseline_models/pytorch/models.py index 9a4dd041..0b82a0b7 100644 --- a/weightslab/baseline_models/pytorch/models.py +++ b/weightslab/baseline_models/pytorch/models.py @@ -38,7 +38,7 @@ def __init__(self): self.m1 = nn.MaxPool2d(2) # Block 2 - self.c2 = nn.Conv2d(4, 4, 3) # Default stride=1, no padding + self.c2 = nn.Conv2d(4, 4, 3) # Default stride=1, no padding self.b2 = nn.BatchNorm2d(4) self.r2 = nn.ReLU() self.m2 = nn.MaxPool2d(2) @@ -154,21 +154,21 @@ def __init__(self): self.input_shape = (1, 1, 28, 28) # Block 1 (Path A) - self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Id 0 + self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Id 0 # Block 2 (Residual/Skip Path) # Note: c2 takes b1's output. c3 takes c2's output. - self.c2 = nn.Conv2d(4, 8, 3, padding=1) # Id 2 - self.c3 = nn.Conv2d(8, 4, 3, padding=1) # Id 3 + self.c2 = nn.Conv2d(4, 8, 3, padding=1) # Id 2 + self.c3 = nn.Conv2d(8, 4, 3, padding=1) # Id 3 def forward(self, x): # Path A - x1 = self.c1(x) # [4, 28, 28] - x2 = self.c2(x1) # [8, 28, 28] - x3 = self.c3(x2) # [4, 28, 28] + x1 = self.c1(x) # [4, 28, 28] + x2 = self.c2(x1) # [8, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual Connection (Add operation) - x_out = x1 + x3 # The output of b1 and c3 both flow into the add op + x_out = x1 + x3 # The output of b1 and c3 both flow into the add op return x_out @@ -199,15 +199,15 @@ def forward(self, x): # Main Path (where the skip connection comes from) x2 = self.c2(x1) - x3 = self.c3(x2) # [4, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x4 = self.c4(x1) - x5 = self.c5(self.b1(x4)) # [4, 28, 28] + x5 = self.c5(self.b1(x4)) # [4, 28, 28] # Residual Connection (Add operation) # Now x3 and x5 have the same shape: B x 4 x 28 x 28 - x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 + x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 return x_out @@ -237,15 +237,15 @@ def forward(self, x): # Main Path (where the skip connection comes from) x2 = self.c2(x1) - x3 = self.c3(x2) # [4, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x4 = self.c4(x1) - x5 = self.b1(x4) # [4, 28, 28] + x5 = self.b1(x4) # [4, 28, 28] # Residual Connection (Add operation) # Now x3 and x5 have the same shape: B x 4 x 28 x 28 - x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 + x_out = x3 + x5 # Assuming you intended to add x3 and x5/x4 return x_out @@ -259,7 +259,7 @@ def __init__(self): self.input_shape = (1, 1, 28, 28) # Block 1 (Path A) - Stays the same - self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Input (1), Output (4) + self.c1 = nn.Conv2d(1, 4, 3, padding=1) # Input (1), Output (4) # Block 2 (Main Path) - Stays the same self.c2 = nn.Conv2d(4, 8, 3, padding=1) @@ -279,19 +279,19 @@ def forward(self, x): # Main Path (where the skip connection comes from) x2 = self.c2(x1) - x3 = self.c3(x2) # [4, 28, 28] + x3 = self.c3(x2) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x4 = self.c4(x1) - x5 = self.b1(x4) # [4, 28, 28] + x5 = self.b1(x4) # [4, 28, 28] # Residual connection path (Transform x1 to match x3) x6 = self.c5(x1) - x7 = self.b2(x6) # [4, 28, 28] + x7 = self.b2(x6) # [4, 28, 28] # Residual Connection (Add operation) # Now x3 and x5 have the same shape: B x 4 x 28 x 28 - x_out = x3 + x5 - x7 # Assuming you intended to add x3 and x5/x4 + x_out = x3 + x5 - x7 # Assuming you intended to add x3 and x5/x4 return x_out @@ -381,7 +381,7 @@ def forward(self, x): # Main path out = self.block_conv1(x) out = self.block_bn1(out) - out = self.block_bn3(out) # Second BN to match original code + out = self.block_bn3(out) # Second BN to match original code out = self.relu(out) out = self.block_conv2(out) @@ -478,7 +478,7 @@ def __init__(self, in_channels=1, out_classes=1): nn.BatchNorm2d(c[1]), nn.ReLU(inplace=True) ) - self.pool1 = nn.MaxPool2d(2) # Downsample 1 + self.pool1 = nn.MaxPool2d(2) # Downsample 1 # --- B. BOTTLENECK --- # 2. BOTTLENECK: Conv -> 16 canaux @@ -516,7 +516,7 @@ def __init__(self, in_channels=1, out_classes=1): def forward(self, x): # 1. ENCODER x1 = self.enc1(x) - p1 = self.pool1(x1) # Skip x1 + p1 = self.pool1(x1) # Skip x1 # 2. BOTTLENECK bottleneck = self.bottleneck(p1) @@ -613,7 +613,7 @@ def __init__(self, *args, **kwargs): ) def forward(self, input): - input = torch.cat([input,]*3, dim=1) # Add channels dim + input = torch.cat([input,]*3, dim=1) # Add channels dim return self.model(input) @@ -880,7 +880,7 @@ def _init_generator_sequential(self, z_dim, img_channels, features_g): # Final Conv: N x 64 x 32 x 32 -> N x 3 x 64 x 64 nn.ConvTranspose2d(features_g, img_channels, kernel_size=4, stride=2, padding=1), - nn.Tanh() # Output range [-1, 1] + nn.Tanh() # Output range [-1, 1] ) def _init_discriminator_sequential(self, img_channels, features_d): @@ -888,7 +888,7 @@ def _init_discriminator_sequential(self, img_channels, features_d): return nn.Sequential( # Input: N x C x 64 x 64 nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, - padding=1), # Output: N x 64 x 32 x 32 + padding=1), # Output: N x 64 x 32 x 32 nn.LeakyReLU(0.2, inplace=True), # Block 2: N x 64 x 32 x 32 -> N x 128 x 16 x 16 @@ -960,7 +960,7 @@ def __init__(self, image_size=784, h_dim=200, z_dim=20): nn.Linear(z_dim, h_dim), nn.ReLU(), nn.Linear(h_dim, image_size), - nn.Sigmoid() # Sigmoid to output pixel values in the range [0, 1] + nn.Sigmoid() # Sigmoid to output pixel values in the range [0, 1] ) def reparameterize(self, mu, log_var): @@ -1041,7 +1041,7 @@ def forward(self, x): x = self.conv2(x) # Flatten layer - x = x.view(x.shape[0], -1) # Flatten the tensor + x = x.view(x.shape[0], -1) # Flatten the tensor # Linear layers and ReLU activation function h_relu = self.linear1(x).clamp(min=0) @@ -1150,7 +1150,7 @@ def double_conv(in_c, out_c): # ------------------ DECODER (Upsampling Path) ------------------ # 4. Up 4 self.up4_up = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=1) - self.up4_conv = double_conv(192, 64) # Input channels: 64 + 64 = 128 + self.up4_conv = double_conv(192, 64) # Input channels: 64 + 64 = 128 # ------------------ OUTPUT Layer ------------------ # 1x1 convolution to map the final feature channels (64) to the number of classes @@ -1165,7 +1165,7 @@ def forward(self, x): # ------------------ DECODER (Concatenate and Convolve) ------------------ # Up 4 - x = self.up4_up(x2) # B3 + x = self.up4_up(x2) # B3 x = self._align_and_concat(x, x1) x = self.up4_conv(x) @@ -1192,7 +1192,7 @@ def _align_and_concat(self, upsampled, skip): upsampled, size=skip.shape[-2:], mode='bilinear', - align_corners=False # Set to False for compatibility and best practice + align_corners=False # Set to False for compatibility and best practice ) # Concatenate along the channel dimension (dim=1) @@ -1246,19 +1246,19 @@ def double_conv(in_c, out_c): # ------------------ DECODER (Upsampling Path) ------------------ # 1. Up 1 (Upsample + Conv + Skip Connection) self.up1_up = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) - self.up1_conv = double_conv(1024, 512) # Input channels: 512 (from up) + 512 (from skip) = 1024 + self.up1_conv = double_conv(1024, 512) # Input channels: 512 (from up) + 512 (from skip) = 1024 # 2. Up 2 self.up2_up = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) - self.up2_conv = double_conv(512, 256) # Input channels: 256 + 256 = 512 + self.up2_conv = double_conv(512, 256) # Input channels: 256 + 256 = 512 # 3. Up 3 self.up3_up = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) - self.up3_conv = double_conv(256, 128) # Input channels: 128 + 128 = 256 + self.up3_conv = double_conv(256, 128) # Input channels: 128 + 128 = 256 # 4. Up 4 self.up4_up = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) - self.up4_conv = double_conv(128, 64) # Input channels: 64 + 64 = 128 + self.up4_conv = double_conv(128, 64) # Input channels: 64 + 64 = 128 # ------------------ OUTPUT Layer ------------------ # 1x1 convolution to map the final feature channels (64) to the number of classes @@ -1278,11 +1278,11 @@ def forward(self, x): x4 = self.down3_conv(x4) x5 = self.down4_pool(x4) - x5 = self.down4_conv(x5) # This is the bottleneck feature map (lowest resolution) + x5 = self.down4_conv(x5) # This is the bottleneck feature map (lowest resolution) # ------------------ DECODER (Concatenate and Convolve) ------------------ # Up 1 - x = self.up1_up(x5) # Upsample x5 + x = self.up1_up(x5) # Upsample x5 x = self._align_and_concat(x, x4) x = self.up1_conv(x) @@ -1324,7 +1324,7 @@ def _align_and_concat(self, upsampled, skip): upsampled, size=skip.shape[-2:], mode='bilinear', - align_corners=False # Set to False for compatibility and best practice + align_corners=False # Set to False for compatibility and best practice ) # Concatenate along the channel dimension (dim=1) @@ -1346,8 +1346,8 @@ def __init__(self, n_channels=3, n_classes=1, filter_list=[64, 128, 256, 512, 10 self.input_shape = (1, n_channels, 256, 256) self.n_channels = n_channels self.n_classes = n_classes - self.filters = filter_list # [F1, F2, F3, F4, F5] - self.F_cat = self.filters[0] * 5 # Total channels in feature concatenation (e.g., 64 * 5 = 320) + self.filters = filter_list # [F1, F2, F3, F4, F5] + self.F_cat = self.filters[0] * 5 # Total channels in feature concatenation (e.g., 64 * 5 = 320) # ------------------- Internal Building Blocks ------------------- @@ -1537,7 +1537,7 @@ def forward(self, x): # Concatenate and convolve d1 = torch.cat((h1_d1, h2_d1, h3_d1, h4_d1, h5_d1), dim=1) - d1 = self.up_conv1(d1) # Output D1 (64 channels, Size H/1) + d1 = self.up_conv1(d1) # Output D1 (64 channels, Size H/1) # ------------------- OUTPUT ------------------- logits = self.outc(d1) @@ -1569,7 +1569,7 @@ def double_conv_3d(in_c, out_c): # --- ENCODER (Contracting Path) --- # Initial convolution and first block - self.inc = double_conv_3d(input_channels, base_channels) # C -> 32 + self.inc = double_conv_3d(input_channels, base_channels) # C -> 32 # Down 1 self.down1_pool = nn.MaxPool3d(kernel_size=2, stride=2) @@ -1606,39 +1606,39 @@ def forward(self, x): # x shape: (B, C, D, H, W) # --- ENCODER --- - x1 = self.inc(x) # B x 32 x D x H x W (Skip 1) + x1 = self.inc(x) # B x 32 x D x H x W (Skip 1) x2 = self.down1_pool(x1) - x2 = self.down1_conv(x2) # B x 64 x D/2 x H/2 x W/2 (Skip 2) + x2 = self.down1_conv(x2) # B x 64 x D/2 x H/2 x W/2 (Skip 2) x3 = self.down2_pool(x2) - x3 = self.down2_conv(x3) # B x 128 x D/4 x H/4 x W/4 (Skip 3) + x3 = self.down2_conv(x3) # B x 128 x D/4 x H/4 x W/4 (Skip 3) x4 = self.down3_pool(x3) - x4 = self.down3_conv(x4) # B x 256 x D/8 x H/8 x W/8 (Bottleneck) + x4 = self.down3_conv(x4) # B x 256 x D/8 x H/8 x W/8 (Bottleneck) # --- DECODER --- # Up 3 - up3 = self.up3_upsample(x4) # B x 128 x D/4 x H/4 x W/4 (Upsampled) + up3 = self.up3_upsample(x4) # B x 128 x D/4 x H/4 x W/4 (Upsampled) # Skip connection: Concatenate with x3 (128 channels) cat3 = torch.cat([x3, up3], dim=1) # B x 256 x D/4 x H/4 x W/4 - x = self.up3_conv(cat3) # B x 128 x D/4 x H/4 x W/4 + x = self.up3_conv(cat3) # B x 128 x D/4 x H/4 x W/4 # Up 2 - up2 = self.up2_upsample(x) # B x 64 x D/2 x H/2 x W/2 + up2 = self.up2_upsample(x) # B x 64 x D/2 x H/2 x W/2 # Skip connection: Concatenate with x2 (64 channels) cat2 = torch.cat([x2, up2], dim=1) # B x 128 x D/2 x H/2 x W/2 - x = self.up2_conv(cat2) # B x 64 x D/2 x H/2 x W/2 + x = self.up2_conv(cat2) # B x 64 x D/2 x H/2 x W/2 # Up 1 - up1 = self.up1_upsample(x) # B x 32 x D x H x W + up1 = self.up1_upsample(x) # B x 32 x D x H x W # Skip connection: Concatenate with x1 (32 channels) cat1 = torch.cat([x1, up1], dim=1) # B x 64 x D x H x W - x = self.up1_conv(cat1) # B x 32 x D x H x W + x = self.up1_conv(cat1) # B x 32 x D x H x W # Final Output - logits = self.out_conv(x) # B x C_out x D x H x W + logits = self.out_conv(x) # B x C_out x D x H x W return logits @@ -1694,7 +1694,7 @@ def __init__(self, S: int = 7, B: int = 2, C: int = 1, image_size: int = 224): self.preprocess = T.Compose([ T.Resize((image_size, image_size)), T.ToTensor(), - T.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet mean/std + T.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet mean/std std=[0.229, 0.224, 0.225]), ]) @@ -1734,12 +1734,12 @@ def decode_preds(self, pred: torch.Tensor, conf_thresh: float = 0.25) -> List[Li grid_y = grid_y.float() for i in range(N): - img_pred = pred[i] # S x S x (B*5 + C) + img_pred = pred[i] # S x S x (B*5 + C) img_boxes = [] # class score per cell (for single class we use sigmoid) if self.C == 1: - cls_score = torch.sigmoid(img_pred[..., -1]) # (S, S) + cls_score = torch.sigmoid(img_pred[..., -1]) # (S, S) else: cls_logits = img_pred[..., self.B * 5:] cls_score_all = F.softmax(cls_logits, dim=-1) @@ -1762,7 +1762,7 @@ def decode_preds(self, pred: torch.Tensor, conf_thresh: float = 0.25) -> List[Li w = tw / S h = th / S - final_conf = conf * cls_score # class-aware confidence + final_conf = conf * cls_score # class-aware confidence mask = final_conf > conf_thresh if mask.any(): @@ -1842,8 +1842,8 @@ class Yolov11(nn.Module): def __init__(self, variant: str = "yolo11n.pt", device=None, img_size: int = 640): super().__init__() try: - from ultralytics import YOLO # type: ignore - except Exception as exc: # pragma: no cover - optional dependency + from ultralytics import YOLO # type: ignore + except Exception as exc: # pragma: no cover - optional dependency raise ImportError( "Ultralytics is required for Yolov11 baseline. Install with: pip install ultralytics" ) from exc @@ -1871,13 +1871,13 @@ def predict(self, x, **kwargs): model = Yolov11() # Predict with the model - results = model.yolo("https://ultralytics.com/images/bus.jpg") # predict on an image + results = model.yolo("https://ultralytics.com/images/bus.jpg") # predict on an image # Access the results for result in results: - xywh = result.boxes.xywh # center-x, center-y, width, height - xywhn = result.boxes.xywhn # normalized - xyxy = result.boxes.xyxy # top-left-x, top-left-y, bottom-right-x, bottom-right-y - xyxyn = result.boxes.xyxyn # normalized - names = [result.names[cls.item()] for cls in result.boxes.cls.int()] # class name of each box - confs = result.boxes.conf # confidence score of each box \ No newline at end of file + xywh = result.boxes.xywh # center-x, center-y, width, height + xywhn = result.boxes.xywhn # normalized + xyxy = result.boxes.xyxy # top-left-x, top-left-y, bottom-right-x, bottom-right-y + xyxyn = result.boxes.xyxyn # normalized + names = [result.names[cls.item()] for cls in result.boxes.cls.int()] # class name of each box + confs = result.boxes.conf # confidence score of each box \ No newline at end of file diff --git a/weightslab/components/__init__.py b/weightslab/components/__init__.py index cbb7d872..a41418b3 100644 --- a/weightslab/components/__init__.py +++ b/weightslab/components/__init__.py @@ -11,18 +11,18 @@ # # Other components # from weightslab.components.tracking import Tracker, TrackingMode -# # from weightslab.components.global_monitoring import GlobalMonitoring # TODO: Fix missing GlobalMonitoring class +# # from weightslab.components.global_monitoring import GlobalMonitoring # TODO: Fix missing GlobalMonitoring class # __all__ = [ -# # Checkpoint management -# 'CheckpointManager', # Manual checkpoint system -# 'ExperimentHashGenerator', +# # Checkpoint management +# 'CheckpointManager', # Manual checkpoint system +# 'ExperimentHashGenerator', -# # Tracking -# 'Tracker', -# 'TrackingMode', +# # Tracking +# 'Tracker', +# 'TrackingMode', -# # Monitoring - commented out until GlobalMonitoring is implemented -# # 'GlobalMonitoring', +# # Monitoring - commented out until GlobalMonitoring is implemented +# # 'GlobalMonitoring', # ] diff --git a/weightslab/components/checkpoint_manager.py b/weightslab/components/checkpoint_manager.py index ca5a8e4a..6dfca3c8 100644 --- a/weightslab/components/checkpoint_manager.py +++ b/weightslab/components/checkpoint_manager.py @@ -6,12 +6,12 @@ Directory Structure: root_log_dir/ - data/ # Data-related files (global) - logs/ # Training logs (global) + data/ # Data-related files (global) + logs/ # Training logs (global) checkpoints/ - manifest.yaml # Tracks all hashes with timestamps + manifest.yaml # Tracks all hashes with timestamps models/ - {hash}/ # 24-byte hash: HP_MODEL_DATA + {hash}/ # 24-byte hash: HP_MODEL_DATA {hash}_step_000100.pt {hash}_architecture.pkl HP/ @@ -120,12 +120,12 @@ def __init__(self, root_log_dir: str = 'root_experiment', load_model: bool = Tru self.hash_generator = ExperimentHashGenerator() self.current_exp_hash: Optional[str] = None self.previous_exp_hash: Optional[str] = None - self.hash_by_module: list = [None, None, None] # HP, MODEL, DATA + self.hash_by_module: list = [None, None, None] # HP, MODEL, DATA # Step tracking self._step_counter = None self._model_init_step = 0 - self._last_time_loaded: Optional[float] = time.time() # Track last load time for model hash uniqueness + self._last_time_loaded: Optional[float] = time.time() # Track last load time for model hash uniqueness # First time only self.first_time = True @@ -151,9 +151,9 @@ def __init__(self, root_log_dir: str = 'root_experiment', load_model: bool = Tru def __repr__(self) -> str: return ( f"CheckpointManager(\n" - f" root_log_dir={self.root_log_dir}\n" - f" current_exp_hash={self.current_exp_hash}\n" - f" step_counter={self._step_counter}\n" + f" root_log_dir={self.root_log_dir}\n" + f" current_exp_hash={self.current_exp_hash}\n" + f" step_counter={self._step_counter}\n" f")" ) @@ -805,7 +805,7 @@ def _save_changes( # Get checkpoint manager hp manager_hp = config.get('checkpoint_manager', {}) if config else {} enable_checkpoints = manager_hp.get('enable_checkpoints', True) - dump_model_architecture = manager_hp.get('dump_model_architecture', False) # Set to false by default + dump_model_architecture = manager_hp.get('dump_model_architecture', False) # Set to false by default dump_model_state = manager_hp.get('dump_model_state', True) dump_optimizer_state = manager_hp.get('dump_optimizer_state', True) dump_data_state = manager_hp.get('dump_data_state', True) @@ -925,7 +925,7 @@ def save_model_checkpoint( 'model_state_dict': model.state_dict(), 'timestamp': datetime.now().isoformat(), 'exp_hash': self.current_exp_hash, - 'rng_state': capture_rng_state(), # Capture RNG state for reproducible training + 'rng_state': capture_rng_state(), # Capture RNG state for reproducible training } # Capture dataloader iteration state(s) for reproducible resume (support multiple loaders) @@ -976,7 +976,7 @@ def save_model_checkpoint( # If model architecture doesn't exist in this hash directory, save a reference to where it is if self.config.get('checkpoint_manager', {}).get('dump_model_architecture', False): - self._save_architecture_reference_if_needed() # TODO (GP): Disable for now because it adds complexity for big models, and we want to ensure architecture is always saved with weights for simplicity + self._save_architecture_reference_if_needed() # TODO (GP): Disable for now because it adds complexity for big models, and we want to ensure architecture is always saved with weights for simplicity # Persist logger queues alongside weight checkpoints try: @@ -1379,7 +1379,7 @@ def _load_architecture_with_retry(self, arch_file: Path, max_retries: int = 5, b break sleep_time = base_delay * (2 ** (attempt - 1)) logger.warning( - f" [WARN] Architecture load locked (attempt {attempt}/{max_retries}). " + f" [WARN] Architecture load locked (attempt {attempt}/{max_retries}). " f"Retrying in {sleep_time:.2f}s..." ) time.sleep(sleep_time) @@ -1388,7 +1388,7 @@ def _load_architecture_with_retry(self, arch_file: Path, max_retries: int = 5, b if isinstance(e, EOFError): sleep_time = base_delay * (2 ** (attempt - 1)) logger.warning( - f" [WARN] Architecture load incomplete (attempt {attempt}/{max_retries}). " + f" [WARN] Architecture load incomplete (attempt {attempt}/{max_retries}). " f"Retrying in {sleep_time:.2f}s..." ) time.sleep(sleep_time) @@ -1635,8 +1635,8 @@ def load_checkpoint(self, # Logger logger.info(f"Loading checkpoint {exp_hash[:16]}...") - logger.info(f" Target: HP={target_hp_hash} MODEL={target_model_hash} DATA={target_data_hash}") - logger.info(f" Current: HP={current_hp_hash} MODEL={current_model_hash} DATA={current_data_hash}") + logger.info(f" Target: HP={target_hp_hash} MODEL={target_model_hash} DATA={target_data_hash}") + logger.info(f" Current: HP={current_hp_hash} MODEL={current_model_hash} DATA={current_data_hash}") # Load model architecture if different, or load only RNG state for reproducibility if model hash is unchanged model_rng_loaded = False @@ -1651,7 +1651,7 @@ def load_checkpoint(self, with open(arch_ref_file, 'r') as f: ref_data = json.load(f) actual_arch_hash = ref_data.get('architecture_hash', exp_hash[8:-8]) - logger.debug(f" Architecture reference found: pointing to hash {actual_arch_hash}") + logger.debug(f" Architecture reference found: pointing to hash {actual_arch_hash}") except Exception as e: logger.warning(f"Failed to load architecture reference: {e}") @@ -1669,12 +1669,12 @@ def load_checkpoint(self, result['model'].guard_testing_context = guard_testing_context result['loaded_components'].add('model') - logger.info(f" [OK] Loaded model architecture from hash {actual_arch_hash[:16]}") - self._last_time_loaded = time.time() # Update last loaded time after successful load + logger.info(f" [OK] Loaded model architecture from hash {actual_arch_hash[:16]}") + self._last_time_loaded = time.time() # Update last loaded time after successful load except Exception as e: - logger.error(f" [ERROR] Failed to load model architecture: {e}") + logger.error(f" [ERROR] Failed to load model architecture: {e}") else: - logger.warning(f" [WARNING] Model architecture file not found: {actual_arch_file}") + logger.warning(f" [WARNING] Model architecture file not found: {actual_arch_file}") elif load_model and (target_model_hash == current_model_hash and not force): # Try to load only the RNG state from the latest model checkpoint for reproducibility @@ -1690,14 +1690,14 @@ def load_checkpoint(self, if rng_state: result['rng_state'] = rng_state model_rng_loaded = True - logger.info(f" [OK] Loaded model RNG state for reproducibility (model unchanged)") - self._last_time_loaded = time.time() # Update last loaded time after successful load + logger.info(f" [OK] Loaded model RNG state for reproducibility (model unchanged)") + self._last_time_loaded = time.time() # Update last loaded time after successful load except Exception as e: - logger.debug(f" [WARNING] Could not load model RNG state: {e}") + logger.debug(f" [WARNING] Could not load model RNG state: {e}") if not model_rng_loaded: - logger.info(f" [-] Model architecture unchanged, using current model") + logger.info(f" [-] Model architecture unchanged, using current model") else: - logger.info(f" [-] Model architecture unchanged, using current model") + logger.info(f" [-] Model architecture unchanged, using current model") # Load model weights (always if requested) if load_weights: @@ -1712,16 +1712,16 @@ def load_checkpoint(self, checkpoint_path = model_dir / manifest_weight_checkpoint if checkpoint_path.exists(): checkpoint_file_to_load = checkpoint_path - logger.debug(f" Using weight checkpoint from manifest: {manifest_weight_checkpoint}") + logger.debug(f" Using weight checkpoint from manifest: {manifest_weight_checkpoint}") # Fallback: scan for weight files (old behavior for backward compatibility) if checkpoint_file_to_load is None: checkpoint_file_to_load = self._select_weight_checkpoint_file(exp_hash, target_step=target_step) if checkpoint_file_to_load is not None: if target_step is None: - logger.debug(f" Using latest weight checkpoint from directory scan: {checkpoint_file_to_load.name}") + logger.debug(f" Using latest weight checkpoint from directory scan: {checkpoint_file_to_load.name}") else: - logger.debug(f" Using closest weight checkpoint for target step {target_step}: {checkpoint_file_to_load.name}") + logger.debug(f" Using closest weight checkpoint for target step {target_step}: {checkpoint_file_to_load.name}") if checkpoint_file_to_load: try: @@ -1733,9 +1733,9 @@ def load_checkpoint(self, checkpoint_rng_state = result['weights'].get('rng_state') if checkpoint_rng_state: result['rng_state'] = checkpoint_rng_state - logger.info(f" [OK] Loaded weights from step {step} with RNG state") + logger.info(f" [OK] Loaded weights from step {step} with RNG state") else: - logger.info(f" [OK] Loaded weights from step {step}") + logger.info(f" [OK] Loaded weights from step {step}") # Extract dataloader iteration state if available dataloader_iter_state = result['weights'].get('dataloader_iteration_state') @@ -1749,12 +1749,12 @@ def load_checkpoint(self, iter_state_map = {'default': dataloader_iter_state} result['dataloader_iteration_state'] = iter_state_map - logger.debug(f" [OK] Found dataloader iteration state(s): {iter_state_map}") + logger.debug(f" [OK] Found dataloader iteration state(s): {iter_state_map}") except Exception as e: - logger.error(f" [ERROR] Failed to load weights: {e}") + logger.error(f" [ERROR] Failed to load weights: {e}") self._last_time_loaded = time.time() else: - logger.warning(f" [WARNING] No weight files found for {exp_hash[8:-8]}") + logger.warning(f" [WARNING] No weight files found for {exp_hash[8:-8]}") # Load config if different if load_config and (target_hp_hash != current_hp_hash or force): @@ -1767,13 +1767,13 @@ def load_checkpoint(self, config_data = yaml.safe_load(f) result['config'] = config_data.get('hyperparameters', config_data) result['loaded_components'].add('config') - logger.info(f" [OK] Loaded config (hash changed)") + logger.info(f" [OK] Loaded config (hash changed)") except Exception as e: - logger.error(f" [ERROR] Failed to load config: {e}") + logger.error(f" [ERROR] Failed to load config: {e}") else: - logger.warning(f" [WARNING] Config file not found: {config_file}") + logger.warning(f" [WARNING] Config file not found: {config_file}") else: - logger.info(f" [-] Config unchanged, using current config") + logger.info(f" [-] Config unchanged, using current config") # Load data snapshot if different, or if only RNG state changed (for reproducibility) if load_data: @@ -1798,19 +1798,19 @@ def load_checkpoint(self, result['loaded_components'].add('data') if rng_state: result['rng_state'] = rng_state - logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows) with RNG state") + logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows) with RNG state") else: - logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows)") + logger.info(f" [OK] Loaded data snapshot ({len(snapshot_df)} rows)") elif load_rng_only and rng_state: # Only RNG state is needed for reproducibility result['rng_state'] = rng_state - logger.info(f" [OK] Loaded RNG state for reproducibility (data unchanged)") + logger.info(f" [OK] Loaded RNG state for reproducibility (data unchanged)") else: - logger.info(f" [-] Data state unchanged, using current data") + logger.info(f" [-] Data state unchanged, using current data") except Exception as e: - logger.error(f" [ERROR] Failed to load data snapshot: {e}") + logger.error(f" [ERROR] Failed to load data snapshot: {e}") else: - logger.warning(f" [WARNING] Data snapshot file not found: {json_file}") + logger.warning(f" [WARNING] Data snapshot file not found: {json_file}") logger.info(f"Loaded components: {result['loaded_components']}") return result @@ -1863,8 +1863,8 @@ def load_state( # Apply model (architecture + weights) if 'model' in checkpoint_data['loaded_components']: try: - model = checkpoint_data['model'] # Include model architecture, weights, and optimizer state at this level - # model.update_optimizer() # Update optimizer with new model parameters if needed + model = checkpoint_data['model'] # Include model architecture, weights, and optimizer state at this level + # model.update_optimizer() # Update optimizer with new model parameters if needed # Remove existing locks if hasattr(model, 'guard_testing_context'): @@ -1876,8 +1876,8 @@ def load_state( ledgers.register_model(model) # Set Model Training Guard - guard_training_context.model = model # Train - guard_testing_context.model = model # Eval + guard_training_context.model = model # Train + guard_testing_context.model = model # Eval loaded_step = None if checkpoint_data.get('weights') is not None: @@ -1923,8 +1923,8 @@ def load_state( logger.warning(f"Could not load optimizer state: {e}") # Set Model Training Guard - guard_training_context.model = model # Train - guard_testing_context.model = model # Eval + guard_training_context.model = model # Train + guard_testing_context.model = model # Eval except Exception: if 'model' not in checkpoint_data['loaded_components']: @@ -1956,14 +1956,14 @@ def load_state( setattr(model, 'current_step', step) except Exception: pass - # model.update_optimizer() # Update optimizer with new model parameters if needed + # model.update_optimizer() # Update optimizer with new model parameters if needed logger.info(f"[OK] Applied weights to reloaded model (step {step})") self._model_init_step = step logger.info("Successfully recovered by reloading full checkpoint with architecture and weights") # Set Model Training Guard - guard_training_context.model = model # Train - guard_testing_context.model = model # Eval + guard_training_context.model = model # Train + guard_testing_context.model = model # Eval self.error_loading_checkpoint.remove('weights') if 'weights' in self.error_loading_checkpoint else None except Exception as e: @@ -1979,7 +1979,7 @@ def load_state( self.error_loading_checkpoint.remove('config') if 'config' in self.error_loading_checkpoint else None except Exception as e: logger.error(f"[ERROR] Failed to apply config: {e}") - self.error_loading_checkpoint.append('config') if 'config' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if config application failed + self.error_loading_checkpoint.append('config') if 'config' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if config application failed # Apply data (merge snapshot columns into current dataframe) if 'data' in checkpoint_data['loaded_components']: @@ -2001,7 +2001,7 @@ def load_state( self.error_loading_checkpoint.remove('data') if 'data' in self.error_loading_checkpoint else None except Exception as e: logger.error(f"[ERROR] Failed to apply data: {e}") - self.error_loading_checkpoint.append('data') if 'data' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if data application failed + self.error_loading_checkpoint.append('data') if 'data' not in self.error_loading_checkpoint else None # Reset first_time to allow future auto-resume attempts if data application failed # Restore RNG state if provided and not already restored if checkpoint_data.get('rng_state'): @@ -2013,13 +2013,13 @@ def load_state( # # We should re-enable this in the future once we have a more robust solution for managing dataloader iteration state across different types of dataloaders and shuffling state or not. # # Reset dataloaders iterators to ensure reproducibility # for loader_name in ledgers.get_dataloaders(): - # loader = ledgers.get_dataloader(loader_name) + # loader = ledgers.get_dataloader(loader_name) - # if loader is not None: - # # Resume loader state - # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): - # loader.reset_iterator() - # logger.debug(f"Reset iterator for dataloader: {loader}") + # if loader is not None: + # # Resume loader state + # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): + # loader.reset_iterator() + # logger.debug(f"Reset iterator for dataloader: {loader}") # # Restore RNG state again after resetting dataloaders # restore_rng_state(checkpoint_data['rng_state']) @@ -2034,46 +2034,46 @@ def load_state( # # We should re-enable this in the future once we have a more robust solution for managing dataloader iteration state across different types of dataloaders and shuffling state or not. # # Restore dataloader iteration state if provided # if checkpoint_data.get('dataloader_iteration_state'): - # try: - # iter_state_raw = checkpoint_data['dataloader_iteration_state'] - - # # Normalize to mapping loader_name -> state for backward compatibility - # if isinstance(iter_state_raw, dict) and 'samples_yielded' in iter_state_raw: - # state_map = {'default': iter_state_raw} - # elif isinstance(iter_state_raw, dict): - # state_map = iter_state_raw - # else: - # state_map = {'default': iter_state_raw} - - # restored_any = False - # for loader_name in ledgers.get_dataloaders(): - # loader = ledgers.get_dataloader(loader_name) - # if loader is None or not hasattr(loader, 'restore_iteration_state'): - # continue - - # state_for_loader = state_map.get(loader_name) or state_map.get('default') - # if state_for_loader: - # try: - # loader.restore_iteration_state(state_for_loader) - # # Resume loader state - # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): - # loader.reset_iterator() - # logger.debug(f"Reset iterator for dataloader: {loader}") - # logger.info(f"[OK] Restored dataloader iteration state for {loader_name}: {state_for_loader}") - # restored_any = True - # except Exception as inner_e: - # logger.warning(f"[WARNING] Failed to restore iteration state for {loader_name}: {inner_e}") - - # if not restored_any: - # logger.warning("No dataloader iteration state could be applied to registered loaders") - # self.error_loading_checkpoint.remove('dataloader_iteration') if 'dataloader_iteration' in self.error_loading_checkpoint else None - # except Exception as e: - # logger.error(f"[ERROR] Failed to restore dataloader iteration state: {e}") - # self.error_loading_checkpoint.append('dataloader_iteration') if 'dataloader_iteration' not in self.error_loading_checkpoint else None + # try: + # iter_state_raw = checkpoint_data['dataloader_iteration_state'] + + # # Normalize to mapping loader_name -> state for backward compatibility + # if isinstance(iter_state_raw, dict) and 'samples_yielded' in iter_state_raw: + # state_map = {'default': iter_state_raw} + # elif isinstance(iter_state_raw, dict): + # state_map = iter_state_raw + # else: + # state_map = {'default': iter_state_raw} + + # restored_any = False + # for loader_name in ledgers.get_dataloaders(): + # loader = ledgers.get_dataloader(loader_name) + # if loader is None or not hasattr(loader, 'restore_iteration_state'): + # continue + + # state_for_loader = state_map.get(loader_name) or state_map.get('default') + # if state_for_loader: + # try: + # loader.restore_iteration_state(state_for_loader) + # # Resume loader state + # if hasattr(loader, 'reset_iterator') and callable(loader.reset_iterator): + # loader.reset_iterator() + # logger.debug(f"Reset iterator for dataloader: {loader}") + # logger.info(f"[OK] Restored dataloader iteration state for {loader_name}: {state_for_loader}") + # restored_any = True + # except Exception as inner_e: + # logger.warning(f"[WARNING] Failed to restore iteration state for {loader_name}: {inner_e}") + + # if not restored_any: + # logger.warning("No dataloader iteration state could be applied to registered loaders") + # self.error_loading_checkpoint.remove('dataloader_iteration') if 'dataloader_iteration' in self.error_loading_checkpoint else None + # except Exception as e: + # logger.error(f"[ERROR] Failed to restore dataloader iteration state: {e}") + # self.error_loading_checkpoint.append('dataloader_iteration') if 'dataloader_iteration' not in self.error_loading_checkpoint else None # Restore logger snapshot for this experiment if available logger_len = self.get_logger_length() - if load_logger and logger_len == 0: # Load logger if requested and logger is currently empty (e.g. on fresh start) + if load_logger and logger_len == 0: # Load logger if requested and logger is currently empty (e.g. on fresh start) try: self.load_logger_snapshot() self.error_loading_checkpoint.remove('logger') if 'logger' in self.error_loading_checkpoint else None diff --git a/weightslab/components/experiment_hash.py b/weightslab/components/experiment_hash.py index 5f7ac621..48d8ef0f 100644 --- a/weightslab/components/experiment_hash.py +++ b/weightslab/components/experiment_hash.py @@ -72,15 +72,15 @@ def generate_hash( if config != -1: hp_hash = self._hash_config(config) if config is not None else "00000000" else: - hp_hash = self._last_hp_hash or "00000000" # If config is -1, keep previous HP hash to avoid marking as changed + hp_hash = self._last_hp_hash or "00000000" # If config is -1, keep previous HP hash to avoid marking as changed if model != -1: model_hash = self._hash_model(model, model_init_step=model_init_step, _last_time_loaded=_last_time_loaded) if model is not None else "00000000" else: - model_hash = self._last_model_hash or "00000000" # If model is -1, keep previous model hash to avoid marking as changed + model_hash = self._last_model_hash or "00000000" # If model is -1, keep previous model hash to avoid marking as changed if data_state != -1: data_hash = self._hash_data_state(data_state) if data_state is not None else "00000000" else: - data_hash = self._last_data_hash or "00000000" # If data_state is -1, keep previous data hash to avoid marking as changed + data_hash = self._last_data_hash or "00000000" # If data_state is -1, keep previous data hash to avoid marking as changed # Combine into 24-byte hash: HP (8) + MODEL (8) + DATA (8) final_hash = f"{hp_hash}{model_hash}{data_hash}" @@ -92,9 +92,9 @@ def generate_hash( self._last_data_hash = data_hash logger.info(f"Generated experiment hash: {final_hash}- (HP: {hp_hash}, Model: {model_hash}, Data: {data_hash})") - logger.debug(f" HP hash: {hp_hash}") - logger.debug(f" Model hash: {model_hash}") - logger.debug(f" Data hash: {data_hash}") + logger.debug(f" HP hash: {hp_hash}") + logger.debug(f" Model hash: {model_hash}") + logger.debug(f" Data hash: {data_hash}") return final_hash @@ -187,7 +187,7 @@ def _hash_model(self, model: th.nn.Module, model_init_step: int = 0, _last_time_ arch_info = [] # Model class name - arch_info.append(f"previously_loaded:{_last_time_loaded}") # Add a unique timestamp to ensure different hash for each load, even if architecture is the same + arch_info.append(f"previously_loaded:{_last_time_loaded}") # Add a unique timestamp to ensure different hash for each load, even if architecture is the same arch_info.append(f"class:{model.__class__.__name__}") arch_info.append(f"init_step:{int(model_init_step)}") @@ -197,7 +197,7 @@ def _hash_model(self, model: th.nn.Module, model_init_step: int = 0, _last_time_ # Remove these trackers from hash if 'train_dataset_tracker' in name or 'eval_dataset_tracker' in name: continue - if name: # Skip root module + if name: # Skip root module module_info = f"{name}:{module.__class__.__name__}" # Add key parameters for common layer types @@ -236,7 +236,7 @@ def _hash_config(self, config: Dict[str, Any]) -> str: config_cp.pop('root_log_dir', None) config_cp.pop('is_training', None) config_cp.pop('pause_at_step', None) - # config_cp.pop('auditor_mode', None) # Audit should be another state + # config_cp.pop('auditor_mode', None) # Audit should be another state if 'auditor_mode' not in config_cp: config_cp['auditor_mode'] = False diff --git a/weightslab/components/global_monitoring.py b/weightslab/components/global_monitoring.py index 17e9e029..4dcfa184 100644 --- a/weightslab/components/global_monitoring.py +++ b/weightslab/components/global_monitoring.py @@ -146,7 +146,7 @@ def resume(self, force: bool = False) -> bool: self.checkpoint_manager = get_checkpoint_manager() if self.checkpoint_manager != None: self.checkpoint_manager.update_experiment_hash(first_time=True) - self.checkpoint_manager.save_pending_changes() # Write pending change to disk + self.checkpoint_manager.save_pending_changes() # Write pending change to disk hash_by_module = self.checkpoint_manager.hash_by_module else: logger.warning('Cannot access checkpoint manager on resume.') @@ -229,9 +229,16 @@ def __enter__(self, f: bool = False): context = Context.TRAINING if self.for_training else Context.TESTING self._context_token = set_current_context(context) - # Update model + # Update model — always resolve from the current ledger so stale references + # from previous calls never bleed through. get_model() returns a Proxy(None) + # placeholder when nothing is registered; treat that as "no model". _model = get_model() - self.model = _model if _model != None else self.model + try: + _target = object.__getattribute__(_model, '_obj') + self.model = _model if _target is not None else None + except AttributeError: + # _model is a plain (non-Proxy) object; use it directly + self.model = _model # The exact logic requested by the user: if self.model is not None: @@ -291,7 +298,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any, f: bool = Fals logger.debug(f"Suppressing exception: {exc_value} in GuardContext.__exit__:") traceback.print_exc() if os.getenv("WL_DEBUG", "0") == "1" else None self.architecture_guard.__exit__(exc_type, exc_value, traceback) - return True # suppress the exception + return True # suppress the exception self.architecture_guard.__exit__(exc_type, exc_value, traceback) @@ -357,11 +364,11 @@ def _pause_hp_sync_loop(poll_interval: float = 3): # # Drive controller from ledger when ledger explicitly sets the flag # controller_running = not controller_paused # if isinstance(hp_is_training, bool): - # if controller_paused and hp_is_training: - # resumed = pause_controller.resume() - # firstresume = False if resumed else True - # elif controller_running and not hp_is_training: - # pause_controller.pause() + # if controller_paused and hp_is_training: + # resumed = pause_controller.resume() + # firstresume = False if resumed else True + # elif controller_running and not hp_is_training: + # pause_controller.pause() # Re-evaluate controller state after potential changes controller_paused = pause_controller.is_paused() @@ -386,5 +393,5 @@ def start_hp_sync_thread_event(): # Start sync thread once at module import if _pause_sync_thread_started: - _pause_sync_thread_started = False # already activated + _pause_sync_thread_started = False # already activated start_hp_sync_thread_event() diff --git a/weightslab/components/parallel_primitives.py b/weightslab/components/parallel_primitives.py index 6d1feeae..63a98ca2 100644 --- a/weightslab/components/parallel_primitives.py +++ b/weightslab/components/parallel_primitives.py @@ -41,7 +41,7 @@ import logging import os -from weightslab.utils import ddp_info # single source of truth for (rank, world) +from weightslab.utils import ddp_info # single source of truth for (rank, world) logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ def ddp_log(msg): print(f"[ddp r{r}/{w}] {msg}", flush=True) -_collectives = 0 # collectives since the last reset (i.e. during this step) +_collectives = 0 # collectives since the last reset (i.e. during this step) def reset_collectives(): @@ -149,7 +149,7 @@ def _gather(obj, what): # --------------------------------------------------------------------------- # State registry — reconciled in ONE bundled broadcast at the anchor. # --------------------------------------------------------------------------- -_REGISTRY = [] # (name, snapshot, apply) +_REGISTRY = [] # (name, snapshot, apply) def register_consistent_state(name, snapshot, apply): @@ -182,7 +182,7 @@ def reconcile_all(): payload = [snapshot] else: payload = [None] - bundle = _broadcast(payload, what="reconcile_all") # collective ALWAYS reached + bundle = _broadcast(payload, what="reconcile_all") # collective ALWAYS reached if r != 0 and bundle: for name, _snap, apply in _REGISTRY: if name in bundle: @@ -222,7 +222,7 @@ def clear_registry(): # payload bounded by the per-step change set, not the dataset size. Merge MUST # be idempotent (a delta may re-flush on retry / respawn). # --------------------------------------------------------------------------- -_OUTBOXES = [] # (name, local_dump, merge) +_OUTBOXES = [] # (name, local_dump, merge) def register_outbox(name, local_dump, merge): @@ -252,7 +252,7 @@ def flush_outbox(): except Exception as exc: payload[name] = None logger.debug("[flush_outbox] dump '%s' failed: %s", name, exc) - bucket = _gather(payload, what="flush_outbox") # collective ALWAYS reached + bucket = _gather(payload, what="flush_outbox") # collective ALWAYS reached if r != 0 or not bucket: return for name, _dump, merge in _OUTBOXES: @@ -278,10 +278,10 @@ def _ensure_core_ddp_registered(): `guard_training_context.__enter__` on first entry per process — by that point the hparam store + dataloaders + pause_controller are all wired up. - - "hparams" — rank 0's hyperparams dict; children diff-apply each leaf. + - "hparams" — rank 0's hyperparams dict; children diff-apply each leaf. - "deny-list" — {origin: discarded sample-id set} across all known loaders; children mirror via the WL discard_samples API. - - "paused" — rank 0's pause_controller.is_paused(); rides in the same + - "paused" — rank 0's pause_controller.is_paused(); rides in the same bundle so sync_step's spin uses ONE broadcast per iter. """ global _CORE_REGISTERED @@ -290,10 +290,10 @@ def _ensure_core_ddp_registered(): # Imports are deferred — this module must stay import-light + cycle-free. from weightslab.components.global_monitoring import pause_controller from weightslab.components.parallel_state import ( - rank0_hparams, apply_hparams, # CONFIG plane ↓ - rank0_df_down_state, apply_df_down_state, # DATAFRAME plane ↓ - local_df_writes, merge_df_writes, # DATAFRAME plane ↑ - local_signal_triples, merge_signal_triples_into_logger, # LOGGER plane ↑ + rank0_hparams, apply_hparams, # CONFIG plane ↓ + rank0_df_down_state, apply_df_down_state, # DATAFRAME plane ↓ + local_df_writes, merge_df_writes, # DATAFRAME plane ↑ + local_signal_triples, merge_signal_triples_into_logger, # LOGGER plane ↑ ) # DOWN reconcile — CONFIG + CONTROL + DATAFRAME (DOWN_ONLY cols) — 1 broadcast register_consistent_state("hparams", rank0_hparams, apply_hparams) @@ -322,11 +322,11 @@ def sync_step(spin_wait=0.5): if not _active(): return rank, _ = ddp_info() - reset_collectives() # logs prior step's count (down+up), then resets + reset_collectives() # logs prior step's count (down+up), then resets while True: - bundle = reconcile_all() # DOWN: 1 broadcast, ALL consistent states + bundle = reconcile_all() # DOWN: 1 broadcast, ALL consistent states if not bundle or not bundle.get("paused", False): - return # → step body runs; UP flush happens in __exit__ + return # → step body runs; UP flush happens in __exit__ # Paused: no busy-spin. Rank 0 blocks on the resume Event (wakes on the gRPC # resume); rank-1+ block inside the next reconcile_all broadcast. Cheap only on # gloo (socket-wait); NCCL would spin (NCCL_BLOCKING_WAIT). The bounded timeout diff --git a/weightslab/components/tracking.py b/weightslab/components/tracking.py index fd56aa38..a3e2c98c 100644 --- a/weightslab/components/tracking.py +++ b/weightslab/components/tracking.py @@ -211,13 +211,13 @@ def update(self, tensor: th.Tensor): # Shape is expected to be in the form [batch_size x neuron_count] if len(tensor.shape) > 2: # raise ValueError( - # f"Neuron stats are updated on a per neuron level, hence only " - # f"two dims are expected [batch_size x neuron_count] but " - # f"activation map has shape: {str(tensor.shape)}") + # f"Neuron stats are updated on a per neuron level, hence only " + # f"two dims are expected [batch_size x neuron_count] but " + # f"activation map has shape: {str(tensor.shape)}") tensor = tensor.view(-1, self.number_of_neurons) try: if tensor.shape == th.Size([]): - tensor = tensor[None, None] # Add one dim + tensor = tensor[None, None] # Add one dim bs = tensor.shape[0] self.triggrs_by_neuron += th.sum( tensor, dim=(0, )).view(-1).long() @@ -286,9 +286,9 @@ def get_neuron_age(self, neuron_id: int): return self.updates_by_neuron[neuron_id].item() # def get_neuron_stats(self, neuron_id: int): - # """ Get how often did this neuron trigger on average. """ - # return self.get_neuron_triggers(neuron_id) / \ - # max(self.get_neuron_age(neuron_id), 1) + # """ Get how often did this neuron trigger on average. """ + # return self.get_neuron_triggers(neuron_id) / \ + # max(self.get_neuron_age(neuron_id), 1) def get_neuron_stats(self, neuron_id: int): """ Get how often did this neuron trigger on average. """ @@ -409,9 +409,9 @@ def update(self, tensor: th.Tensor): super().update(tensor) # Update trackers with class and sample ids. The shapes are expected # to be in the following form: - # * tensor: [batch_size x neuron_count] - # * tensor.in_id_batch: [batch_size] - # * tensor.label_batch: [batch_size] + # * tensor: [batch_size x neuron_count] + # * tensor.in_id_batch: [batch_size] + # * tensor.label_batch: [batch_size] if not hasattr(tensor, 'in_id_batch') or \ not hasattr(tensor, 'label_batch'): diff --git a/weightslab/data/array_proxy.py b/weightslab/data/array_proxy.py index cc8b76db..efec4c97 100644 --- a/weightslab/data/array_proxy.py +++ b/weightslab/data/array_proxy.py @@ -138,9 +138,9 @@ class ArrayAccessor: Pandas DataFrame accessor for automatic array loading. Usage: - df.arrays.load('prediction') # Load all prediction arrays - df.arrays.load_sample(sample_id, 'prediction') # Load specific array - df.arrays.set_store(array_store) # Configure array store + df.arrays.load('prediction') # Load all prediction arrays + df.arrays.load_sample(sample_id, 'prediction') # Load specific array + df.arrays.set_store(array_store) # Configure array store """ def __init__(self, pandas_obj): diff --git a/weightslab/data/data_samples_with_ops.py b/weightslab/data/data_samples_with_ops.py index 72f92795..8236fe5f 100644 --- a/weightslab/data/data_samples_with_ops.py +++ b/weightslab/data/data_samples_with_ops.py @@ -59,7 +59,7 @@ def _match_column_patterns(col: str, patterns: list) -> bool: if re.search(pattern, col): return True except re.error: - pass # Invalid regex, skip + pass # Invalid regex, skip return False @@ -109,18 +109,18 @@ class DataSampleTrackingWrapper(Dataset): Examples: Binary classification based on tags: >>> dataset = DataSampleTrackingWrapper( - ... mnist_train, - ... root_log_dir="./logs", - ... use_tags=True, - ... tags_mapping={'huge': 1} # Images tagged 'huge' → label 1, others → 0 + ... mnist_train, + ... root_log_dir="./logs", + ... use_tags=True, + ... tags_mapping={'huge': 1} # Images tagged 'huge' → label 1, others → 0 ... ) Multiclass classification based on tags: >>> dataset = DataSampleTrackingWrapper( - ... mnist_train, - ... root_log_dir="./logs", - ... use_tags=True, - ... tags_mapping={'small': 0, 'medium': 1, 'large': 2} + ... mnist_train, + ... root_log_dir="./logs", + ... use_tags=True, + ... tags_mapping={'small': 0, 'medium': 1, 'large': 2} ... ) """ def __init__( @@ -167,7 +167,7 @@ def __init__( # Setup H5 persistence path self._root_log_dir = Path(root_log_dir) if root_log_dir else self._resolve_root_log_dir() self._h5_path = None - self._h5_pending_uids = set() # Track UIDs with pending H5 saves + self._h5_pending_uids = set() # Track UIDs with pending H5 saves self._stats_store = stats_store self._enable_h5_persistence = enable_h5_persistence @@ -281,7 +281,7 @@ def __init__( self._map_updates_hook_fns = [] self._df_lock = threading.RLock() self.is_training = is_training - self._dataset_split = split # Store for H5 filename (can be train, test, val, validation, eval, etc.) + self._dataset_split = split # Store for H5 filename (can be train, test, val, validation, eval, etc.) # Initialize DataFrame as single source of truth # Start with defaults for all UIDs (single dict build per row to trim overhead) @@ -315,10 +315,10 @@ def __init__( # to get_items() which may load images and run heavy transforms. raw_item = wrapped_dataset.fast_get_label(p_idx) elif hasattr(wrapped_dataset, 'get_items'): - raw_item = wrapped_dataset.get_items(p_idx, include_metadata=preload_metadata, include_labels=preload_labels, include_images=False) # Try to get metadata if supported + raw_item = wrapped_dataset.get_items(p_idx, include_metadata=preload_metadata, include_labels=preload_labels, include_images=False) # Try to get metadata if supported else: # logger.warning(f"Wrapped dataset for split '{split}' does not have get_items method. Falling back to __getitem__, which may cause issues if the dataset is not designed for it. Consider implementing get_items for better performance and compatibility.") - raw_item = wrapped_dataset[p_idx] # By default load everything + raw_item = wrapped_dataset[p_idx] # By default load everything except Exception as e: logger.error(f"Failed to load physical index {p_idx} during initialization: {e}") continue @@ -363,7 +363,7 @@ def __init__( row = SampleStats.DEFAULTS.copy() row.update({ SampleStatsEx.SAMPLE_ID.value: sid, - # SampleStatsEx.INSTANCE_ID.value: str(0), # Added later in the preprocessing during df registration + # SampleStatsEx.INSTANCE_ID.value: str(0), # Added later in the preprocessing during df registration SampleStatsEx.ORIGIN.value: split, SampleStatsEx.GROUP_ID.value: str(group_id), SampleStatsEx.MEMBER_RANK.value: rank @@ -540,7 +540,7 @@ def _getitem_raw(self, index: int = None, id: int = None): target = self._tags_mapping[tag] break else: - target = 0 # Default to 0 if no tags match the mapping + target = 0 # Default to 0 if no tags match the mapping else: # No mapping provided but use_tags=True: keep original target logger.warning(f"use_tags=True but no tags_mapping provided for sample {id}") @@ -658,7 +658,7 @@ def _generate_unique_ids_parallel(self, dataset: Callable = None) -> np.ndarray: dataset = self.wrapped_dataset if dataset is None else dataset n_samples = len(dataset) - unique_ids = [i for i in range(n_samples)] # Initialize with indices as fallback IDs + unique_ids = [i for i in range(n_samples)] # Initialize with indices as fallback IDs unique_id_to_index = {} def compute_id(idx): @@ -681,11 +681,11 @@ def compute_id(idx): # Generate the ID uid = array_id_2bytes(data_array, return_hex=False, tronc_1byte=True) - uid = str(uid) # Convert to string for consistent handling + uid = str(uid) # Convert to string for consistent handling return idx, uid except Exception as e: logger.warning(f"Failed to generate ID for sample {idx}: {e}") - return idx, str(idx) # Fallback to index as ID + return idx, str(idx) # Fallback to index as ID # Use ThreadPoolExecutor; track progress on completed tasks. with ThreadPoolExecutor(thread_name_prefix="unique_id_generator") as executor: @@ -695,10 +695,10 @@ def compute_id(idx): # Collect results as they complete for future in tqdm(as_completed(futures), total=n_samples, desc="Generating unique IDs", unit="sample"): idx, uid = future.result() - uid = str(uid) # Ensure UID is a string for consistent handling + uid = str(uid) # Ensure UID is a string for consistent handling unique_ids[idx] = uid unique_id_to_index[uid] = idx if uid not in unique_id_to_index else unique_id_to_index[uid] - unique_ids = np.asanyarray(unique_ids, dtype=object) # Use object dtype for string UIDs + unique_ids = np.asanyarray(unique_ids, dtype=object) # Use object dtype for string UIDs return unique_ids, unique_id_to_index def _get_df_view(self, limit: int = -1, column: str = None, value: str = None) -> pd.DataFrame: diff --git a/weightslab/data/data_utils.py b/weightslab/data/data_utils.py index 9ce19864..9c0b256a 100644 --- a/weightslab/data/data_utils.py +++ b/weightslab/data/data_utils.py @@ -36,10 +36,10 @@ def _to_uint8_image(img_float: np.ndarray) -> np.ndarray: img = np.asarray(img_float) if img.ndim == 2: - img = img[..., None] # HxWx1 + img = img[..., None] # HxWx1 if img.shape[-1] == 1: - img = np.repeat(img, 3, axis=-1) # grayscale -> RGB + img = np.repeat(img, 3, axis=-1) # grayscale -> RGB if img.shape[-1] != 3: raise ValueError(f"Expected image with 1 or 3 channels, got shape {img.shape}") @@ -68,8 +68,8 @@ def overlay_gt_pred( pred_value=None, alpha_gt=0.45, alpha_pred=0.45, - color_gt=(0, 255, 0), # green - color_pred=(255, 0, 0), # red + color_gt=(0, 255, 0), # green + color_pred=(255, 0, 0), # red show_overlap_as_yellow=True ) -> np.ndarray: """ @@ -280,11 +280,11 @@ def get_mask(raw, dataset=None, dataset_index=None, raw_data=None): segmentation_map = np.zeros((height, width), dtype=np.int64) # Return segmentation map directly if it matches raw shape - if segmentation_map.shape == raw.shape[-2:]: # B, C, H, W + if segmentation_map.shape == raw.shape[-2:]: # B, C, H, W return raw # Generate segmentation map from bboxes - raw = raw[0] if raw.ndim == 3 else raw # Handle batch dimension if present + raw = raw[0] if raw.ndim == 3 else raw # Handle batch dimension if present for bbox_data in raw: x1, y1, x2, y2 = bbox_data[:4].astype(int) if bbox_data.max() > 1 else (bbox_data[:4] * [width, height, width, height]).astype(int) # Extract class id if available, otherwise use 1 @@ -344,10 +344,10 @@ def _extract_slice_from_4d(np_img: np.ndarray, slice_idx: int = None) -> np.ndar # Now we should have (Z, H, W) or (Z, H, W, C) z_dim = np_img.shape[0] if slice_idx is None: - slice_idx = z_dim // 2 # Middle slice + slice_idx = z_dim // 2 # Middle slice slice_idx = max(0, min(slice_idx, z_dim - 1)) - return np_img[slice_idx] # Returns (H, W) or (H, W, C) + return np_img[slice_idx] # Returns (H, W) or (H, W, C) def _get_image_array_and_metadata(wrapped, index, rank: int = 0) -> tuple: @@ -383,7 +383,7 @@ def _get_image_array_and_metadata(wrapped, index, rank: int = 0) -> tuple: if hasattr(np_img, 'numpy'): np_img = np_img.numpy() - is_volumetric = np_img.ndim >= 4 # 3 is for RGB; while 4 is 3D # TODO (GP): Should be fix because this will not work with grayscale image wo. color channel + is_volumetric = np_img.ndim >= 4 # 3 is for RGB; while 4 is 3D # TODO (GP): Should be fix because this will not work with grayscale image wo. color channel # For 4D volumetric data, detect and transpose channel-first formats: # 1. (C, Z, H, W) → (Z, H, W, C) - channels first in all dimensions @@ -421,7 +421,7 @@ def to_uint8(np_img: np.ndarray) -> np.ndarray: if np.issubdtype(np_img.dtype, np.floating): min_v = float(np.nanmin(np_img)) if np_img.size else 0.0 max_v = float(np.nanmax(np_img)) if np_img.size else 1.0 - if max_v <= 128.0: # Scale floats in [0, ~1] to [0, 255] + if max_v <= 128.0: # Scale floats in [0, ~1] to [0, 255] np_img = (np_img - min_v) / (max_v - min_v + 1e-8) * 255.0 # Clip to valid byte range then cast np_img = np.clip(np_img, 0, 255) @@ -453,30 +453,35 @@ def load_label(dataset, sample_id): # Get dataset wrapper if exists wrapped = getattr(dataset, "wrapped_dataset", dataset) + def _convert_label(lbl): + if isinstance(lbl, list) and len(lbl) and isinstance(lbl[0], (th.Tensor, np.ndarray)): + label = to_numpy_safe(lbl).max(0) # Aggr. instances + else: + label = to_numpy_safe(lbl) # Third element is typically the label + return label + # Try common dataset patterns first if hasattr(wrapped, '__getitem__'): data = wrapped.get_items(index, include_metadata=False, include_labels=True, include_images=False) if hasattr(wrapped, 'get_items') else wrapped[index] if isinstance(data, (list, tuple)): if len(data) == 1: - return None # Only data, no label - elif len(data) == 2: # Commonly (data, label) in standard PyTorch datasets - label = to_numpy_safe(data[1]) - elif len(data) == 3: # if len==3, data, uids, label, no extra info - label = to_numpy_safe(data[2]) # Third element is typically the label - elif len(data) > 3: # if len>3, data, uids, label, classes, extra info + return None # Only data, no label + elif len(data) <= 3: # if len==2|3, data, uids, label, no extra info + label = _convert_label(data[2]) + elif len(data) > 3: # if len>3, data, uids, label, classes, extra info if len(data) == 4: - label = to_numpy_safe(data[2]) # Third element is typically the label metadata = data[3] classes = to_numpy_safe(metadata['classes']) if isinstance(metadata, dict) and 'classes' in metadata else None if classes is not None: - label = to_numpy_safe(data[2]) # Second element is typically the label + label = _convert_label(data[2]) + # Concat label with classes if available (bbox detection, i.e., (4,) -> (5,) with class id) label = np.concatenate([label, classes[..., None]], axis=1) else: - label = to_numpy_safe(data[2]) # Second element is typically the label + label = _convert_label(data[2]) else: - label = to_numpy_safe(data[2]) # Third element is typically the label + label = _convert_label(data[2]) metadata = data[3:] if label is not None: return label[0] if label.ndim == 1 and label.shape[0] == 1 else label @@ -515,12 +520,12 @@ def load_metadata(dataset, sample_id): if isinstance(data, (list, tuple)): if len(data) == 1: - return None # Only data, no metadata - elif len(data) == 2: # if len==2, only data and uid, no extra info - return None # No metadata, only data and uid - elif len(data) == 3: # if len==3, data, uids, label, no extra info - return None # No metadata, only data, uid, and label - elif len(data) > 3: # if len>3, data, uids, label, classes, extra info + return None # Only data, no metadata + elif len(data) == 2: # if len==2, only data and uid, no extra info + return None # No metadata, only data and uid + elif len(data) == 3: # if len==3, data, uids, label, no extra info + return None # No metadata, only data, uid, and label + elif len(data) > 3: # if len>3, data, uids, label, classes, extra info metadata = {} for item in data[3:]: if isinstance(item, dict): @@ -649,7 +654,7 @@ def load_raw_image_array(dataset, index, rank: int = 0) -> tuple: elif channels == 4: middle_pil = Image.fromarray(middle_slice_uint8, mode="RGBA") else: - middle_pil = Image.fromarray(middle_slice_uint8[..., 0], mode="L") # Fallback + middle_pil = Image.fromarray(middle_slice_uint8[..., 0], mode="L") # Fallback return np_img, is_volumetric, original_shape, middle_pil @@ -705,7 +710,7 @@ def load_uid(dataset, sample_id): if isinstance(data, (list, tuple)): if len(data) == 1: - return None # Only data, no metadata - elif len(data) >= 2: # if len==2, only data and uid, no extra info - return data[1] # Second element is typically the uid + return None # Only data, no metadata + elif len(data) >= 2: # if len==2, only data and uid, no extra info + return data[1] # Second element is typically the uid return None diff --git a/weightslab/data/dataframe_manager.py b/weightslab/data/dataframe_manager.py index fc6c8fd3..3185d785 100644 --- a/weightslab/data/dataframe_manager.py +++ b/weightslab/data/dataframe_manager.py @@ -28,7 +28,7 @@ pd.set_option('future.no_silent_downcasting', True) -logger = logging.getLogger(__name__) # Set up logger +logger = logging.getLogger(__name__) # Set up logger def _safe_update(target: pd.DataFrame, source: pd.DataFrame) -> None: @@ -98,10 +98,10 @@ def __init__(self, flush_interval: float = 3.0, flush_max_rows: int = 100, enabl self._flush_max_rows = flush_max_rows self._flush_thread: threading.Thread | None = None self._flush_stop = threading.Event() - self._flush_event = threading.Event() # Event to wake thread for force flush + self._flush_event = threading.Event() # Event to wake thread for force flush self._flush_queue_count = 0 self._dense_store: Dict[str, Dict[int, np.ndarray]] = {} - self._buffer: Dict[int, Dict[str, Any]] = {} # {sample_id: {col: value}} + self._buffer: Dict[int, Dict[str, Any]] = {} # {sample_id: {col: value}} # Registry of categorical tags: tag_name (without "tag:" prefix) -> ordered # list of allowed category values. Distinguishes multi-value string tags # (e.g. weather -> [rainy, sunny]) from the legacy boolean tags. @@ -175,7 +175,7 @@ def _count_instances(target: Any) -> int: # Check if all items are scalar-like all_scalar = all(isinstance(item, (int, float, np.integer, np.floating)) for item in target) if all_scalar: - return 1 # Single instance with multiple values + return 1 # Single instance with multiple values except Exception: pass @@ -198,8 +198,8 @@ def _instance_targets_list(target: Any) -> list: A single array/tensor/scalar/label is the sample's OWN target and lives on the sample row (instance_id 0), so it yields no separate instance rows. - - list/tuple of array-likes → the list (one entry per instance, rows 1..N) - - everything else → ``[]`` (single-target / classification → only the sample row) + - list/tuple of array-likes → the list (one entry per instance, rows 1..N) + - everything else → ``[]`` (single-target / classification → only the sample row) """ if isinstance(target, (list, tuple)) and len(target) > 0 \ and isinstance(target[0], (np.ndarray, torch.Tensor, list)): @@ -249,10 +249,10 @@ def _expand_records_to_multi_index(self, records: List[Dict[str, Any]]) -> pd.Da sid = self._normalize_sample_id(rec.get(SID)) inst_targets = self._instance_targets_list(rec.get(TARGET)) n_inst = len(inst_targets) - total = n_inst + 1 # +1 for the sample row at instance_id 0 + total = n_inst + 1 # +1 for the sample row at instance_id 0 sample_ids.extend([sid] * total) - annotation_ids.extend(range(total)) # 0 (sample), 1..N (instances) + annotation_ids.extend(range(total)) # 0 (sample), 1..N (instances) for key in keys: val = rec.get(key) @@ -465,7 +465,7 @@ def _auto_register_categorical_tags(self, df: pd.DataFrame) -> None: if isinstance(s.dtype, pd.CategoricalDtype): cats = s.dtype.categories.tolist() if any(isinstance(c, bool) for c in cats): - continue # boolean-style categorical → not a categorical tag + continue # boolean-style categorical → not a categorical tag candidate = cats elif pd.api.types.is_bool_dtype(s.dtype): continue @@ -595,12 +595,12 @@ def _load_existing_data(self, origin: str = None, autoload_arrays: bool | list | _safe_update(self._df, loaded_df) # 2) Append rows that exist ONLY in the loaded df. This is the key - # fix: a freshly-registered loader has just the sample row - # (annotation_id == 0) per sample, while the persisted df from a - # previous run also has the instance rows (annotation_id >= 1). - # Those instance rows must be added back, not dropped. Use a - # boolean mask (not .loc[difference]) so duplicate keys can't be - # re-expanded by the label lookup. + # fix: a freshly-registered loader has just the sample row + # (annotation_id == 0) per sample, while the persisted df from a + # previous run also has the instance rows (annotation_id >= 1). + # Those instance rows must be added back, not dropped. Use a + # boolean mask (not .loc[difference]) so duplicate keys can't be + # re-expanded by the label lookup. new_rows = loaded_df[~loaded_df.index.isin(self._df.index)] if not new_rows.empty: self._df = pd.concat([self._df, new_rows]) @@ -754,7 +754,7 @@ def _is_array_column_to_norm(self, column_name: str, value: Any) -> bool: def _should_array_be_stored(self, array_name) -> bool: """Check if array storage is enabled.""" - return array_name in SAMPLES_STATS_TO_SAVE_TO_H5 # Regexed signals are not considered here + return array_name in SAMPLES_STATS_TO_SAVE_TO_H5 # Regexed signals are not considered here def _is_array_column_to_norm(self, column_name: str, value: Any) -> bool: """True if ``column_name`` is an array column whose ``value`` is an array @@ -921,7 +921,7 @@ def enqueue_batch( preds_raw: np.ndarray | dict | None, preds: np.ndarray | dict | None, losses: Dict[str, Any] | None, - targets: np.ndarray | dict | None = None, + targets: np.ndarray | dict | None = None, step: int | None = None ): """ @@ -970,9 +970,9 @@ def index_batch(obj, batch_index, rec=False): pred = index_batch(pred, batch_index) else: pred = index_batch(preds, batch_index) - pred = pred if is_meaningful(pred) else None # Replace nan by None + pred = pred if is_meaningful(pred) else None # Replace nan by None if pred is not None: - rec[SampleStats.Ex.PREDICTION.value] = self._normalize_preds_raw_uint16(pred) # Not normalized as already integer + rec[SampleStats.Ex.PREDICTION.value] = self._normalize_preds_raw_uint16(pred) # Not normalized as already integer ## Target if targets is not None: target = None @@ -986,9 +986,9 @@ def index_batch(obj, batch_index, rec=False): target = torch.cat((target, targets['cls'][mask]), -1) else: target = index_batch(targets, batch_index) - target = target if is_meaningful(target) else None # Replace nan by None + target = target if is_meaningful(target) else None # Replace nan by None if target is not None: - rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer + rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer ## Step if step is not None and is_meaningful(step): rec[SampleStats.Ex.LAST_SEEN.value] = int(step) @@ -1005,7 +1005,7 @@ def index_batch(obj, batch_index, rec=False): for sample_id, record in records_to_add.items(): self._buffer.setdefault(sample_id, {}).update(record) logger.debug(f"Enqueued {len(records_to_add)} records to buffer. Buffer size is now {len(self._buffer)}.") - should_flush = len(self._buffer) >= self._flush_max_rows or self.first_init # Check buffer size and trigger flush if needed + should_flush = len(self._buffer) >= self._flush_max_rows or self.first_init # Check buffer size and trigger flush if needed # Trigger flush outside lock if should_flush: @@ -1130,11 +1130,11 @@ def _index_target(obj, i): bid = int(targets['batch_idx'][mask].ravel()[0].item()) if sid == usid[bid]: if 'bboxes' in targets: - target = targets['bboxes'][mask][aid_i-1] # aid start to 1 for instance rows + target = targets['bboxes'][mask][aid_i-1] # aid start to 1 for instance rows if 'cls' in targets: - target = torch.cat((target, targets['cls'][mask][aid_i-1]), -1) # aid start to 1 for instance rows + target = torch.cat((target, targets['cls'][mask][aid_i-1]), -1) # aid start to 1 for instance rows if target is not None: - rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer + rec[SampleStats.Ex.TARGET.value] = self._normalize_preds_raw_uint16(target) # Not normalized as already integer else: # Nested-list targets are flattened sample-major (targets_rav), so the # i-th flat entry is this i-th instance's target — index by i, not the @@ -1279,7 +1279,7 @@ def update_by_groups_bulk(self, origin: str, group_ids: List[Any], updates_list: self._df.loc[indices, col] = val affected_ids.extend(indices) else: - if not affected_ids: # Only print once to avoid log spam + if not affected_ids: # Only print once to avoid log spam print(f"[DEBUG] Could not find gid {repr(gid)} in gid_to_indices keys. Sample key: {repr(list(gid_to_indices.keys())[0]) if gid_to_indices else 'None'}") if affected_ids: @@ -1720,7 +1720,7 @@ def _apply_buffer_records(self, records: List[Dict[str, Any]]): if not records: return - current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds + current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds logger.debug(f"[{current_time}] [LedgeredDataFrameManager] Applying {len(records)} buffered records to Global DataFrame.") sample_ids = [rec["sample_id"] for rec in records] @@ -1740,7 +1740,7 @@ def _apply_buffer_records(self, records: List[Dict[str, Any]]): # Mark all as pending for h5 flush (outside lock) self.mark_dirty_batch(sample_ids) - current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds + current_time = datetime.now().strftime("%H:%M:%S.%f")[:-3] # Milliseconds logger.debug(f"[{current_time}] [LedgeredDataFrameManager] Applied {len(records)} buffered records to Global DataFrame.") def _apply_buffer_records_nonblocking(self, records: List[Dict[str, Any]]): @@ -1995,8 +1995,8 @@ def _optimize_dataframe_memory(self, df: pd.DataFrame, categorical_tags: Dict[st # === 3) Categorical conversion (LAST step) === # Columns that are typically repetitive (good candidates for categorical) categorical_candidates = [ - SampleStats.Ex.ORIGIN.value, # Alias for origin (if different column name) - SampleStats.Ex.TASK_TYPE.value, # Task type (e.g. 'classification', 'segmentation') + SampleStats.Ex.ORIGIN.value, # Alias for origin (if different column name) + SampleStats.Ex.TASK_TYPE.value, # Task type (e.g. 'classification', 'segmentation') ] for col in categorical_candidates: @@ -2021,7 +2021,7 @@ def _optimize_dataframe_memory(self, df: pd.DataFrame, categorical_tags: Dict[st n_rows = len(df) compression_ratio = n_unique / n_rows if n_rows > 0 else 1.0 - if compression_ratio < 0.5 and n_unique > 1: # Worth compressing if < 50% unique + if compression_ratio < 0.5 and n_unique > 1: # Worth compressing if < 50% unique try: df[col] = df[col].astype('category') logger.debug( @@ -2086,7 +2086,7 @@ def _worker(): # Forced when buffer is full if force_requested: - self._flush_event.clear() # Clear before flush + self._flush_event.clear() # Clear before flush self.flush() # Wait for flush event (force) or timeout (periodic) @@ -2098,11 +2098,11 @@ def _worker(): if not self._flush_stop.is_set(): self.flush_if_needed_nonblocking(force=True) - self._flush_queue_count = 0 # Reset queue count after periodic flush + self._flush_queue_count = 0 # Reset queue count after periodic flush except Exception as e: traceback_str = traceback.format_exc() logger.error(f"[LedgeredDataFrameManager] Flush loop error: {e}\n{traceback_str}") - st = time.time() # Reset start time after each loop + st = time.time() # Reset start time after each loop self._flush_thread = threading.Thread(target=_worker, name="WL-Ledger_Dataframe_Flush", daemon=True) self._flush_thread.start() @@ -2153,7 +2153,7 @@ def get_collapse_annotations_to_samples_df(self, iid: str = None) -> pd.DataFram The shared dataframe manager now expands every sample into one row per instance/annotation using a ``(sample_id, annotation_id)`` MultiIndex. The studio UI and the agent, however, are sample-centric: they expect a - single row per sample. This helper folds the annotation rows back down: + single row per sample. This helper folds the annotation rows back down: - Sample-level columns (metadata, target, prediction, tags, ...) are duplicated identically on every annotation row, so we keep the first. @@ -2197,7 +2197,7 @@ def get_collapse_annotations_to_samples_df(self, iid: str = None) -> pd.DataFram # straight off the index (or columns) as numpy arrays — no reindex of ``df``. if has_annot_index: if SID not in (df.index.names or []): - return df # Cannot locate the sample level — leave untouched. + return df # Cannot locate the sample level — leave untouched. sid_arr = df.index.get_level_values(SID).to_numpy() annot_arr = np.asarray(df.index.get_level_values(ANNOT).to_numpy()) else: @@ -2270,7 +2270,7 @@ def get_collapse_annotations_to_samples_df(self, iid: str = None) -> pd.DataFram vals = None for c in per_instance_cols: v = col_lists[c][i] - if v is None or v != v: # None or NaN + if v is None or v != v: # None or NaN continue if vals is None: vals = {} @@ -2332,7 +2332,7 @@ def _coerce_df_for_h5(self, df: pd.DataFrame) -> pd.DataFrame: target_dtype = SAMPLES_STATS_DEFAULTS_TYPES[col] # Handle union types (e.g., int | list, str | list) - if hasattr(target_dtype, '__origin__'): # Python 3.10+ union types + if hasattr(target_dtype, '__origin__'): # Python 3.10+ union types if hasattr(target_dtype, '__args__'): target_dtype = target_dtype.__args__[0] @@ -2391,7 +2391,7 @@ def flush_async(self): """Signal flush thread. Returns once buffer has been drained (not after H5 write). Training is only blocked for the brief buffer-drain window (~1ms), not for the - full DF→H5 write. If the buffer refills before the flush thread loops back, the + full DF→H5 write. If the buffer refills before the flush thread loops back, the next call will wait again — bounding in-memory usage to 2×flush_max_rows records. """ with self._queue_lock: @@ -2474,7 +2474,7 @@ def create_ledger_manager(): enable_flushing_threads=enable_flush ) except Exception: - pass # Use defaults if hyperparams not available + pass # Use defaults if hyperparams not available return None @@ -2483,6 +2483,6 @@ def create_ledger_manager(): # from weightslab.backend import ledgers # LM = create_ledger_manager() # try: -# ledgers.register_dataframe(LM) +# ledgers.register_dataframe(LM) # except Exception as e: -# logger.debug(f"Failed to register LedgeredDataFrameManager in ledger: {e}") +# logger.debug(f"Failed to register LedgeredDataFrameManager in ledger: {e}") diff --git a/weightslab/data/h5_array_store.py b/weightslab/data/h5_array_store.py index 716b0bab..346e3b86 100644 --- a/weightslab/data/h5_array_store.py +++ b/weightslab/data/h5_array_store.py @@ -23,7 +23,7 @@ # Config global logger logger = logging.getLogger(__name__) -UINT_DEFAULT = 16 # Default to uint8 for array normalization +UINT_DEFAULT = 16 # Default to uint8 for array normalization class LRUArrayCache: @@ -34,7 +34,7 @@ class LRUArrayCache: Tracks memory usage and provides cache statistics. """ - def __init__(self, max_size_bytes: int = 2 * 1024**3): # 2GB default + def __init__(self, max_size_bytes: int = 2 * 1024**3): # 2GB default """ Initialize LRU cache. @@ -42,7 +42,7 @@ def __init__(self, max_size_bytes: int = 2 * 1024**3): # 2GB default max_size_bytes: Maximum total memory for cached arrays in bytes """ self._max_size = max_size_bytes - self._cache = OrderedDict() # Maintains insertion/access order + self._cache = OrderedDict() # Maintains insertion/access order self._current_size = 0 self._lock = threading.RLock() self._hits = 0 @@ -93,7 +93,7 @@ def put(self, key: str, array: np.ndarray) -> None: # Evict LRU entries until there's space while self._current_size + array_size > self._max_size and self._cache: - lru_key, lru_array = self._cache.popitem(last=False) # FIFO = LRU + lru_key, lru_array = self._cache.popitem(last=False) # FIFO = LRU self._current_size -= self._array_size(lru_array) logger.debug(f"[LRUArrayCache] Evicted {lru_key} to free memory (cache size: {self._current_size / 1024**2:.1f}MB)") @@ -283,7 +283,7 @@ def normalize_array_to_uint(arr: np.ndarray, preserve_original: bool = False, ui if arr_max == arr_min: if arr_max == 0: # All zeros, can store as uint with zero values - metadata['normalized'] = False # No need to normalize if all values are the same + metadata['normalized'] = False # No need to normalize if all values are the same return np.zeros(arr.shape, dtype=uint_dtype), metadata elif arr_max < 2**uint - 1: # Constant array @@ -320,7 +320,7 @@ def denormalize_array(arr: np.ndarray, metadata: Dict[str, Any], uint: int = 16) # Denormalize from uint range arr_min = metadata['min'] arr_max = metadata['max'] - uint = metadata.get('uint', uint) # Default to 16 if not specified + uint = metadata.get('uint', uint) # Default to 16 if not specified original_dtype = np.dtype(metadata['original_dtype']) # Scale back from 0-65535 to original range @@ -361,7 +361,7 @@ def __init__( """ self._path = Path(path) self._local_lock = threading.RLock() - self._rw_lock = _ReadWriteLock() # Read-write lock for concurrent reads + self._rw_lock = _ReadWriteLock() # Read-write lock for concurrent reads self._lock_path = self._path.with_suffix(self._path.suffix + ".lock") self._lock_timeout = lock_timeout self._poll_interval = poll_interval diff --git a/weightslab/data/h5_dataframe_store.py b/weightslab/data/h5_dataframe_store.py index 72c13a8e..6a797a96 100644 --- a/weightslab/data/h5_dataframe_store.py +++ b/weightslab/data/h5_dataframe_store.py @@ -15,7 +15,7 @@ from weightslab.data.sample_stats import SampleStats -logger = logging.getLogger(__name__) # Initialize logger +logger = logging.getLogger(__name__) # Initialize logger # WL signal columns use dotted names (e.g. "signals.defaults.brightness"), which # PyTables flags with NaturalNameWarning because they aren't valid Python @@ -42,7 +42,7 @@ def _align_col_dtype_for_assign(existing: pd.DataFrame, source: pd.DataFrame, co Best-effort: dtype alignment must never break a merge. """ try: - src_kind = source[col].dtype.kind # 'O' object, 'b' bool, 'i'/'u'/'f' numeric + src_kind = source[col].dtype.kind # 'O' object, 'b' bool, 'i'/'u'/'f' numeric tgt_dtype = existing[col].dtype if src_kind in ("O", "b") and tgt_dtype != object and tgt_dtype.kind != src_kind: existing[col] = existing[col].astype(object) @@ -101,7 +101,7 @@ def _unlock(): except OSError: pass - self._unlock = _unlock # type: ignore[attr-defined] + self._unlock = _unlock # type: ignore[attr-defined] while True: if _try_lock(): @@ -113,7 +113,7 @@ def _unlock(): def __exit__(self, exc_type, exc_val, exc_tb): try: if hasattr(self, "_unlock"): - self._unlock() # type: ignore[attr-defined] + self._unlock() # type: ignore[attr-defined] finally: if self._fh: try: @@ -184,9 +184,9 @@ def _extract_tag_columns(self, df: pd.DataFrame) -> dict: # Detect tag type from data non_null = df[col].dropna() if non_null.empty: - tag_cols[col] = None # Auto-detect if all null + tag_cols[col] = None # Auto-detect if all null elif all(isinstance(v, bool) for v in non_null): - tag_cols[col] = [True, False] # Boolean tag + tag_cols[col] = [True, False] # Boolean tag else: # String tag: use unique values as categories tag_cols[col] = non_null.unique().tolist() @@ -446,9 +446,9 @@ def deserialize_value(val): # Everything was persisted as plain strings (see upsert). Reconstruct the # in-memory representation for tag/discarded columns: - # 1. missing tokens ("nan"/"none"/"") → real NaN - # 2. boolean columns ("True"/"False") → real bool (bool('False') is truthy, - # so this MUST run before any bool checks) + # 1. missing tokens ("nan"/"none"/"") → real NaN + # 2. boolean columns ("True"/"False") → real bool (bool('False') is truthy, + # so this MUST run before any bool checks) # String categorical tags keep their string values here; their categorical # dtype + full category set is restored by _optimize_categorical_tags below. _BOOL_TOKENS = {"true": True, "false": False, "1": True, "0": False} @@ -505,7 +505,7 @@ def _verify_checksum(self, store: pd.HDFStore, key: str, expected_checksum: str) try: checksum_key = f"{key}/_checksum" if checksum_key not in store: - return True # No checksum to verify + return True # No checksum to verify checksum_df = store.get(checksum_key) stored_checksum = checksum_df["checksum"].iloc[0] return stored_checksum == expected_checksum @@ -562,7 +562,7 @@ def load(self, origin: str, columns: Optional[Iterable[str]] = None, start: Opti return pd.DataFrame() df = store.select(key, start=start, stop=stop, columns=list(columns) if columns else None) except (FileNotFoundError, OSError, KeyError) as exc: - if not non_blocking: # Only warn on blocking reads + if not non_blocking: # Only warn on blocking reads logger.warning(f"[H5DataFrameStore] Failed to load {key} from {self._path}: {exc}") return pd.DataFrame() except TimeoutError: @@ -811,7 +811,7 @@ def delete_column(self, column_name: str, origins: Optional[Iterable[str]] = Non True if successful, False otherwise """ if not self._path.exists(): - return True # Nothing to delete + return True # Nothing to delete # Create backup BEFORE any modifications backup_path = self._create_backup() diff --git a/weightslab/data/point_cloud_utils.py b/weightslab/data/point_cloud_utils.py index 521a1fec..dc090ae8 100644 --- a/weightslab/data/point_cloud_utils.py +++ b/weightslab/data/point_cloud_utils.py @@ -5,7 +5,7 @@ dimensionality) cannot be PIL-encoded directly, so the studio pipeline previews them as a server-rendered BEV (bird's-eye-view) image: - * thumbnails / preview cache / modal image -> ``point_cloud_to_bev_image`` + * thumbnails / preview cache / modal image -> ``point_cloud_to_bev_image`` * GT / prediction boxes overlaid on the BEV -> ``project_boxes_to_bev`` (3D boxes [cx, cy, cz, dx, dy, dz, yaw, cls?, conf?] or 2D metric boxes [cx, cy, dx, dy, cls?, conf?] -> normalized [x1, y1, x2, y2, cls, conf] @@ -78,9 +78,9 @@ def _default_feature_names(num_features: int) -> list: extra = num_features - len(base) if extra == 1: base = base + ["intensity"] - elif extra == 4: # intensity + normals + elif extra == 4: # intensity + normals base = base + ["intensity", "nx", "ny", "nz"] - elif extra == 3: # normals OR rgb — ambiguous, label generically + elif extra == 3: # normals OR rgb — ambiguous, label generically base = base + ["c0", "c1", "c2"] elif extra > 0: base = base + [f"c{i}" for i in range(extra)] @@ -140,11 +140,11 @@ def compute_point_normals(points: np.ndarray, k: int = 16) -> np.ndarray: k = int(max(3, min(k, n))) tree = cKDTree(xyz) _, idx = tree.query(xyz, k=k) - neigh = xyz[idx] # [M, k, 3] + neigh = xyz[idx] # [M, k, 3] centered = neigh - neigh.mean(axis=1, keepdims=True) cov = np.einsum("mki,mkj->mij", centered, centered) / k # Smallest-eigenvector of each 3x3 covariance is the surface normal. - eigvals, eigvecs = np.linalg.eigh(cov) # ascending eigenvalues + eigvals, eigvecs = np.linalg.eigh(cov) # ascending eigenvalues normals = eigvecs[:, :, 0] # Orient toward the sensor (origin) so shading is consistent. flip = np.einsum("mi,mi->m", normals, -xyz) < 0 @@ -175,7 +175,7 @@ def colorize_from_image(points_xyz, image, project_fn): Args: points_xyz: [M, 3] points in the LiDAR frame. - image: [H, W, 3] uint8 camera image (e.g. KITTI image_2). + image: [H, W, 3] uint8 camera image (e.g. KITTI image_2). project_fn: callable(points_xyz) -> ([M, 2] pixel uv, [M] bool valid) mapping LiDAR points to image pixels (dataset-specific, uses the calibration). Points that fall outside the image / behind @@ -213,7 +213,7 @@ def colorize_from_image(points_xyz, image, project_fn): dtype=np.float32, ) -_BEV_BACKGROUND = (13, 17, 23) # dark slate, matches the studio dark theme +_BEV_BACKGROUND = (13, 17, 23) # dark slate, matches the studio dark theme def default_bev_image_size() -> int: @@ -340,8 +340,8 @@ def point_cloud_to_bev_image( brightness grows with point density. +x is right, +y is up. Args: - points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop; derived + points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop; derived from the points when None. image_size: output resolution (default: WL_BEV_IMAGE_SIZE env or 640). """ @@ -408,12 +408,12 @@ def point_cloud_to_range_image( - Pixel value: distance (and optionally intensity) Args: - points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. - image_height: vertical resolution (elevation bins). - image_width: horizontal resolution (azimuth bins, default 512 like KITTI). - fov_up: max elevation angle in degrees (default 3.0°). - fov_down: min elevation angle in degrees (default -25.0°, typical LiDAR). - mode: "distance" (grayscale distance), "intensity" (intensity with hue), + points: [M, 2..4] (x, y, (z), (intensity)) metric coordinates. + image_height: vertical resolution (elevation bins). + image_width: horizontal resolution (azimuth bins, default 512 like KITTI). + fov_up: max elevation angle in degrees (default 3.0°). + fov_down: min elevation angle in degrees (default -25.0°, typical LiDAR). + mode: "distance" (grayscale distance), "intensity" (intensity with hue), or "distance+intensity" (default: distance as brightness, z/intensity as hue). Returns: @@ -433,8 +433,8 @@ def point_cloud_to_range_image( distance = np.sqrt(x**2 + y**2 + z**2) distance = np.maximum(distance, 1e-6) - azimuth = np.arctan2(y, x) # [-pi, pi] - elevation = np.arcsin(np.clip(z / distance, -1.0, 1.0)) # [-pi/2, pi/2] in radians + azimuth = np.arctan2(y, x) # [-pi, pi] + elevation = np.arcsin(np.clip(z / distance, -1.0, 1.0)) # [-pi/2, pi/2] in radians elevation_deg = np.degrees(elevation) # Map to image coordinates @@ -462,7 +462,7 @@ def point_cloud_to_range_image( intensity_norm = np.clip(intensity / (intensity.max() + 1e-6), 0.3, 1.0) colors = np.clip(colors * intensity_norm[:, None], 0, 255).astype(np.uint8) canvas[v, u] = colors - else: # "distance+intensity" (default) + else: # "distance+intensity" (default) # Distance as brightness (grayscale), height/intensity for hue dist_norm = distance / (distance.max() + 1e-6) z_norm = np.clip((z - np.percentile(z, 5)) / (np.percentile(z, 95) - np.percentile(z, 5) + 1e-6), 0.0, 1.0) @@ -494,10 +494,10 @@ def project_boxes_to_bev( """Project metric 3D/2D point-cloud boxes into the BEV image frame. Args: - boxes: [N, C] rows; C >= 7 -> 3D (cx, cy, cz, dx, dy, dz, yaw, + boxes: [N, C] rows; C >= 7 -> 3D (cx, cy, cz, dx, dy, dz, yaw, cls?, conf?), C <= 6 -> 2D metric (cx, cy, dx, dy, cls?, conf?). - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) of the + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) of the rendered BEV image. min_norm_size: minimum normalized box width/height (~2 px at 256) so distant pedestrians stay clickable in thumbnails. @@ -716,7 +716,7 @@ def pack_point_cloud(points: np.ndarray, max_points: int = 0, seed: int = 0): if max_points and pts.shape[0] > max_points: rng = np.random.default_rng(seed) keep = rng.choice(pts.shape[0], int(max_points), replace=False) - keep.sort() # preserve original ordering for cache-friendly decode + keep.sort() # preserve original ordering for cache-friendly decode pts = pts[keep] pts = np.ascontiguousarray(pts, dtype=" List[str]: """Return list of stats to save to H5, conditionally including predictions and targets.""" base_list = [ - "signals.*", # Prefix for dynamic signals - "SIGNALS.*", # Prefix for dynamic signals - "tag.*", # Prefix for dynamic TAG - "TAG.*", # Prefix for dynamic TAG + "signals.*", # Prefix for dynamic signals + "SIGNALS.*", # Prefix for dynamic signals + "tag.*", # Prefix for dynamic TAG + "TAG.*", # Prefix for dynamic TAG cls.Ex.DISCARDED.value, cls.Ex.TAG.value, diff --git a/weightslab/examples/Docker_training/docker_in_docker/Dockerfile b/weightslab/examples/Docker_training/docker_in_docker/Dockerfile index 722a0b32..cab3af85 100644 --- a/weightslab/examples/Docker_training/docker_in_docker/Dockerfile +++ b/weightslab/examples/Docker_training/docker_in_docker/Dockerfile @@ -17,10 +17,9 @@ FROM python:3.11-slim # --- System deps ------------------------------------------------------------- # - docker engine (dockerd + CLI + compose plugin + containerd): installed via # the official convenience script. We need the *daemon* here (DinD). -# - libgl1/libglib2.0-0: runtime libs for opencv-python (a weightslab dep). # - curl/ca-certificates/git: fetch the docker installer + optional dev install. RUN apt-get update && apt-get install -y --no-install-recommends \ - curl sudo ca-certificates git libgl1 libglib2.0-0 \ + curl sudo ca-certificates git \ && curl -fsSL https://get.docker.com | sh \ && rm -rf /var/lib/apt/lists/* diff --git a/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile b/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile index 15d43c6b..b66e2162 100644 --- a/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile +++ b/weightslab/examples/Docker_training/siblings_self_contained/Dockerfile @@ -16,9 +16,8 @@ FROM python:3.11-slim # --- System deps ------------------------------------------------------------- # docker CLI + compose plugin ONLY (no daemon — we use the host's daemon). -# libgl1/libglib2.0-0: runtime libs for opencv-python (a weightslab dep). RUN apt-get update && apt-get install -y --no-install-recommends \ - curl sudo ca-certificates gnupg libgl1 libglib2.0-0 \ + curl sudo ca-certificates gnupg \ && install -m 0755 -d /etc/apt/keyrings \ && curl -fsSL https://download.docker.com/linux/debian/gpg \ | gpg --dearmor -o /etc/apt/keyrings/docker.gpg \ diff --git a/weightslab/examples/Lightning/ws-classification/main.py b/weightslab/examples/Lightning/ws-classification/main.py index 9c81a46d..2bc9dc6e 100644 --- a/weightslab/examples/Lightning/ws-classification/main.py +++ b/weightslab/examples/Lightning/ws-classification/main.py @@ -44,7 +44,7 @@ def __init__(self, root, train=True, download=False, transform=None): root=root, train=train, download=download, - transform=None # We'll apply transform manually to track filepath + transform=None # We'll apply transform manually to track filepath ) self.transform = transform self.train = train @@ -111,7 +111,7 @@ def forward(self, x): def training_step(self, batch): with guard_training_context: x, ids, y = batch - logits = self(x) # forward pass + logits = self(x) # forward pass preds = torch.argmax(logits, dim=1) # WeightsLab tracked loss @@ -368,14 +368,14 @@ def main(): ) print("=" * 60) - print("🚀 STARTING TRAINING (PyTorch Lightning)") - print(f"📊 Max epochs: {max_epochs}") + print(" STARTING TRAINING (PyTorch Lightning)") + print(f" Max epochs: {max_epochs}") print( - f"⚙️ Trainer: accelerator={trainer_accelerator}, devices={trainer_devices}, " + f" Trainer: accelerator={trainer_accelerator}, devices={trainer_devices}, " f"strategy={trainer_strategy}" ) - print(f"📦 Dataset splits: train={len(_train_dataset)}, val={len(_val_dataset)}") - print(f"💾 Logs will be saved to: {log_dir}") + print(f" Dataset splits: train={len(_train_dataset)}, val={len(_val_dataset)}") + print(f" Logs will be saved to: {log_dir}") print("=" * 60 + "\n") # PyTorch Lightning Trainer @@ -383,7 +383,7 @@ def main(): # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. trainer = pl.Trainer( max_epochs=max_epochs, @@ -401,8 +401,8 @@ def main(): trainer.fit(L_model, train_loader, val_loader) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/PyTorch/ws-classification/config.yaml b/weightslab/examples/PyTorch/ws-classification/config.yaml index 655cff99..3a93f65e 100644 --- a/weightslab/examples/PyTorch/ws-classification/config.yaml +++ b/weightslab/examples/PyTorch/ws-classification/config.yaml @@ -8,8 +8,9 @@ training_steps_to_do: null # Set to null for infinite training until manually s compute_natural_sort: false # Experiment parameters -eval_full_to_train_steps_ratio: 100 -experiment_dump_to_train_steps_ratio: 25 +eval_full_to_train_steps_ratio: 500 # was 100 — full 10k eval was the dominant wall-clock cost +experiment_dump_to_train_steps_ratio: 250 # was 25 — frequent checkpoint dumps stalled training +write_export_ratio: 100 # Export signal history + data grid to JSON/CSV every N steps skip_checkpoint_load: false # If true restart the experiment from last state tqdm_display: true # Whether to use tqdm progress bars during training/evaluation is_training: false # Start training immediately or not @@ -35,5 +36,5 @@ data: batch_size: 16 test_loader: shuffle: false - batch_size: 16 + batch_size: 128 # was 16 — bigger eval batches => ~8x fewer eval steps per pass drop_last: false diff --git a/weightslab/examples/PyTorch/ws-classification/main.py b/weightslab/examples/PyTorch/ws-classification/main.py index 8f2ec0d0..8aea5115 100644 --- a/weightslab/examples/PyTorch/ws-classification/main.py +++ b/weightslab/examples/PyTorch/ws-classification/main.py @@ -63,7 +63,7 @@ def __init__(self, root, train=True, download=False, transform=None, max_samples root=root, train=train, download=download, - transform=None # We'll apply transform manually to track filepath + transform=None # We'll apply transform manually to track filepath ) except RuntimeError as e: logger.error(f"Error loading MNIST dataset: {e}") @@ -71,7 +71,7 @@ def __init__(self, root, train=True, download=False, transform=None, max_samples root=root, train=train, download=True, - transform=None # We'll apply transform manually to track filepath + transform=None # We'll apply transform manually to track filepath ) self.transform = transform self.train = train @@ -89,7 +89,7 @@ def _build_filepath_mapping(self): # For each index, construct a meaningful filepath # MNIST doesn't have original individual files, so we create virtual paths for idx in range(len(self.mnist)): - if self.max_samples is not None and idx >= self.max_samples: + if self.max_samples != None and idx >= self.max_samples: break label = self.mnist.targets[idx].item() if hasattr(self.mnist.targets[idx], 'item') else self.mnist.targets[idx] split = 'train' if self.train else 'test' @@ -105,7 +105,7 @@ def _build_filepath_mapping(self): self.filepaths[idx] = virtual_path def __len__(self): - if self.max_samples is not None: + if self.max_samples != None: return min(len(self.mnist), self.max_samples) return len(self.mnist) @@ -152,7 +152,7 @@ def train(loader, model, optimizer, criterion_mlt, device): batch_ids=ids, preds=preds ) - total_loss = loss_batch_mlt.mean() # Final scalar loss + total_loss = loss_batch_mlt.mean() # Final scalar loss # Model total_loss.backward() @@ -261,6 +261,7 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): log_dir = parameters["root_log_dir"] tqdm_display = parameters.get("tqdm_display", True) eval_full_to_train_steps_ratio = parameters.get("eval_full_to_train_steps_ratio", 50) + write_export_ratio = parameters.get("write_export_ratio", 100) enable_h5_persistence = parameters.get("enable_h5_persistence", True) training_steps_to_do = parameters.get("training_steps_to_do", 1000) @@ -362,16 +363,16 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): ) print("=" * 60) - print("🚀 STARTING TRAINING") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") + print(" STARTING TRAINING") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") print(f"� Dataset splits: train={len(_train_dataset)}, test={len(_test_dataset)}") - print(f"💾 Logs will be saved to: {log_dir}") + print(f" Logs will be saved to: {log_dir}") print("=" * 60 + "\n") # Setup clean progress bar with custom format if tqdm_display: train_range = tqdm.tqdm( - range(training_steps_to_do) if training_steps_to_do is not None else itertools.count(), + range(training_steps_to_do) if training_steps_to_do != None else itertools.count(), desc="Training", bar_format="{desc}: {n}/{total} [{elapsed}<{remaining}, {rate_fmt}] {bar} | {postfix}", ncols=140, @@ -379,17 +380,17 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): leave=True ) else: - train_range = range(training_steps_to_do) if training_steps_to_do is not None else itertools.count() + train_range = range(training_steps_to_do) if training_steps_to_do != None else itertools.count() # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. train_loss = None test_loss, test_metric = None, None - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm for train_step in train_range: - age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) + age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) # Train one step train_loss = train(train_loader, model, optimizer, train_criterion, device) @@ -406,6 +407,11 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): test_loader_len ) + # Periodic history + dataframe export (JSON/CSV snapshots to root_log_dir) + if age > 0 and age % write_export_ratio == 0: + wl.write_history() + wl.write_dataframe() + # Verbose if verbose and not tqdm_display: import sys @@ -428,9 +434,13 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len): train_range.set_postfix_str(" | ".join(postfix_parts)) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) + # Final export of signal history and data grid to root_log_dir + wl.write_history() + wl.write_dataframe() + # Keep the main thread alive to allow background serving threads to run wl.keep_serving() diff --git a/weightslab/examples/PyTorch/ws-clustering/face/data.py b/weightslab/examples/PyTorch/ws-clustering/face/data.py index 672913f4..a41eb97c 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/data.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/data.py @@ -3,17 +3,17 @@ Supported back-ends ------------------- -"olivetti" Olivetti Faces from sklearn (40 identities, 400 images, 64×64 grey). - Self-contained; requires only scikit-learn. Good default toy set. -"lfw" Labeled Faces in the Wild (LFW) via torchvision. Downloaded on - first use to *root*. Much larger but realistic. -"folder" Generic ImageFolder layout: root/{split}/{class_name}/*.jpg +"olivetti" Olivetti Faces from sklearn (40 identities, 400 images, 64×64 grey). + Self-contained; requires only scikit-learn. Good default toy set. +"lfw" Labeled Faces in the Wild (LFW) via torchvision. Downloaded on + first use to *root*. Much larger but realistic. +"folder" Generic ImageFolder layout: root/{split}/{class_name}/*.jpg Every sample is returned as: (image_tensor: Tensor[C,H,W], - uid: str, - label: int, - metadata: dict) + uid: str, + label: int, + metadata: dict) These map directly onto the (data, uid, target, metadata) convention used throughout the WeightsLAB kitchen examples so the same training loop works @@ -41,16 +41,16 @@ class FaceDataset(Dataset): """Unified face recognition dataset wrapper. Args: - root: Download / data root (only used for lfw / folder). - dataset_type: One of "olivetti", "lfw", "folder". - split: "train" or "test" (ignored for pre-split sources). - image_size: Spatial size; images are resized to (image_size, image_size). - train_ratio: Fraction of per-class samples used for training + root: Download / data root (only used for lfw / folder). + dataset_type: One of "olivetti", "lfw", "folder". + split: "train" or "test" (ignored for pre-split sources). + image_size: Spatial size; images are resized to (image_size, image_size). + train_ratio: Fraction of per-class samples used for training (Olivetti only). - min_images_per_class: Classes with fewer samples are discarded. - transform: Optional torchvision transform; defaults to + min_images_per_class: Classes with fewer samples are discarded. + transform: Optional torchvision transform; defaults to Resize → ToTensor → Normalize([0.5], [0.5]). - seed: RNG seed for reproducible train/test splits. + seed: RNG seed for reproducible train/test splits. """ def __init__( @@ -65,12 +65,12 @@ def __init__( seed: int = 42, ): self.dataset_type = dataset_type - self.split = split - self.image_size = image_size - self.transform = transform or self._default_transform(image_size) + self.split = split + self.image_size = image_size + self.transform = transform or self._default_transform(image_size) # These are populated by each loader - self.images: Optional[np.ndarray] = None # (N, H, W) float [0,1] — Olivetti only + self.images: Optional[np.ndarray] = None # (N, H, W) float [0,1] — Olivetti only self.img_paths: Optional[np.ndarray] = None self.labels: np.ndarray = np.array([], dtype=np.int64) self.num_classes: int = 0 @@ -108,67 +108,67 @@ def _load_olivetti(self, train_ratio: float, min_images: int, seed: int): """Load and split the Olivetti Faces dataset (sklearn).""" from sklearn.datasets import fetch_olivetti_faces - data = fetch_olivetti_faces(shuffle=True, random_state=seed) - images = data.images # (400, 64, 64) float [0,1] + data = fetch_olivetti_faces(shuffle=True, random_state=seed) + images = data.images # (400, 64, 64) float [0,1] labels = data.target.astype(np.int64) # Drop classes with insufficient samples unique, counts = np.unique(labels, return_counts=True) - valid_classes = unique[counts >= min_images] - mask = np.isin(labels, valid_classes) + valid_classes = unique[counts >= min_images] + mask = np.isin(labels, valid_classes) images, labels = images[mask], labels[mask] # Remap labels to a contiguous 0…N-1 range mapping = {int(c): i for i, c in enumerate(sorted(valid_classes.tolist()))} - labels = np.array([mapping[int(l)] for l in labels], dtype=np.int64) + labels = np.array([mapping[int(l)] for l in labels], dtype=np.int64) # Per-class stratified train/test split rng = np.random.RandomState(seed) train_idx, test_idx = [], [] for cls in np.unique(labels): - idx = np.where(labels == cls)[0] + idx = np.where(labels == cls)[0] n_train = max(1, int(len(idx) * train_ratio)) - perm = rng.permutation(len(idx)) + perm = rng.permutation(len(idx)) train_idx.extend(idx[perm[:n_train]].tolist()) test_idx.extend(idx[perm[n_train:]].tolist()) - indices = train_idx if self.split == "train" else test_idx - self.images = images[indices] - self.labels = labels[indices] + indices = train_idx if self.split == "train" else test_idx + self.images = images[indices] + self.labels = labels[indices] self.num_classes = len(mapping) def _load_lfw(self, root: str, min_images: int, split: str): """Load LFW People via torchvision (downloads on first call).""" from torchvision.datasets import LFWPeople - split_map = {"train": "train", "test": "test", "val": "10fold"} - lfw_split = split_map.get(split, "train") - ds = LFWPeople(root=root, split=lfw_split, download=True, transform=None) + split_map = {"train": "train", "test": "test", "val": "10fold"} + lfw_split = split_map.get(split, "train") + ds = LFWPeople(root=root, split=lfw_split, download=True, transform=None) paths, lbls = zip(*ds.imgs) - lbls = np.array(lbls, dtype=np.int64) + lbls = np.array(lbls, dtype=np.int64) # Filter low-shot identities unique, counts = np.unique(lbls, return_counts=True) - valid = set(unique[counts >= min_images].tolist()) - mask = np.array([int(l) in valid for l in lbls]) + valid = set(unique[counts >= min_images].tolist()) + mask = np.array([int(l) in valid for l in lbls]) self.img_paths = np.array(paths)[mask] - lbls = lbls[mask] + lbls = lbls[mask] - mapping = {int(c): i for i, c in enumerate(sorted(valid))} - self.labels = np.array([mapping[int(l)] for l in lbls], dtype=np.int64) + mapping = {int(c): i for i, c in enumerate(sorted(valid))} + self.labels = np.array([mapping[int(l)] for l in lbls], dtype=np.int64) self.num_classes = len(mapping) def _load_folder(self, root: str, split: str): """Load from a torchvision ImageFolder directory.""" from torchvision.datasets import ImageFolder - split_dir = os.path.join(root, split) - ds = ImageFolder(split_dir) - paths, lbls = zip(*ds.imgs) + split_dir = os.path.join(root, split) + ds = ImageFolder(split_dir) + paths, lbls = zip(*ds.imgs) self.img_paths = list(paths) - self.labels = np.array(lbls, dtype=np.int64) + self.labels = np.array(lbls, dtype=np.int64) self.num_classes = len(ds.classes) # ---------------------------------------------------------- @@ -188,7 +188,7 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, int, Dict]: if self.dataset_type == "olivetti": from PIL import Image as PILImage - img_np = self.images[idx] # (H, W) float [0,1] + img_np = self.images[idx] # (H, W) float [0,1] img_pil = PILImage.fromarray( (img_np * 255).astype(np.uint8), mode="L" ).convert("RGB") @@ -200,9 +200,9 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, int, Dict]: uid = f"{self.split}_cls{label:04d}_idx{idx:06d}" metadata = { - "split": self.split, - "label_id": label, - "idx": idx, + "split": self.split, + "label_id": label, + "idx": idx, "dataset_type": self.dataset_type, } return image_tensor, uid, label, metadata diff --git a/weightslab/examples/PyTorch/ws-clustering/face/model.py b/weightslab/examples/PyTorch/ws-clustering/face/model.py index 5a7e51d0..5c9974f4 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/model.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/model.py @@ -3,20 +3,20 @@ Architecture ------------ -Pretrained backbone (ResNet-18 / ResNet-50 / MobileNet-V3-Small) +Pretrained backbone (ResNet-18 / ResNet-50 / MobileNet-V3-Small) ? -EmbeddingHead Linear ? BN ? ReLU ? Linear +EmbeddingHead Linear ? BN ? ReLU ? Linear ? L2-normalised D-dimensional embedding The backbone is optionally frozen so that only the lightweight head is trained -(recommended toy-example setup). The combined graph is registered with +(recommended toy-example setup). The combined graph is registered with WeightsLAB for model tracking. Public interface ---------------- -FaceEmbeddingModel.get_embeddings(images) ? normalised embeddings (B, D) -FaceEmbeddingModel.train_step(images, labels, ...) ? scalar loss float +FaceEmbeddingModel.get_embeddings(images) ? normalised embeddings (B, D) +FaceEmbeddingModel.train_step(images, labels, ...) ? scalar loss float """ import logging @@ -64,8 +64,8 @@ def __init__(self, backbone: nn.Module, head: EmbeddingHead): self.head = head def forward(self, x: torch.Tensor) -> torch.Tensor: - features = self.backbone(x) # (B, feature_dim) - embeddings = self.head(features) # (B, embedding_dim), L2-normalised + features = self.backbone(x) # (B, feature_dim) + embeddings = self.head(features) # (B, embedding_dim), L2-normalised return embeddings @@ -126,16 +126,16 @@ class FaceEmbeddingModel: """Wrapper that manages the backbone + head, optimiser, and WeightsLAB tracking. Args: - backbone_name: "resnet18" | "resnet50" | "mobilenet_v3_small" - embedding_dim: Output embedding dimensionality (default 128). + backbone_name: "resnet18" | "resnet50" | "mobilenet_v3_small" + embedding_dim: Output embedding dimensionality (default 128). head_hidden_dim: Hidden size of the projection MLP (default 256). - lr: Learning rate for AdamW (default 1e-3). - weight_decay: AdamW weight decay (default 1e-4). + lr: Learning rate for AdamW (default 1e-3). + weight_decay: AdamW weight decay (default 1e-4). freeze_backbone: When True, only the head's parameters receive gradients ? recommended for quick toy runs. - device: "cpu", "cuda", or "cuda:N". - pretrained: Load ImageNet-pretrained weights for the backbone. - margin: Triplet margin (default 0.3). + device: "cpu", "cuda", or "cuda:N". + pretrained: Load ImageNet-pretrained weights for the backbone. + margin: Triplet margin (default 0.3). """ def __init__( @@ -203,11 +203,11 @@ def __init__( f"trainable_params={n_trainable:,}" ) print( - f" Backbone : {backbone_name} (pretrained={pretrained}, frozen={freeze_backbone})\n" - f" Emb dim : {embedding_dim}\n" - f" Head dim : {head_hidden_dim}\n" - f" Trainable : {n_trainable:,} params\n" - f" Device : {self.device}" + f" Backbone : {backbone_name} (pretrained={pretrained}, frozen={freeze_backbone})\n" + f" Emb dim : {embedding_dim}\n" + f" Head dim : {head_hidden_dim}\n" + f" Trainable : {n_trainable:,} params\n" + f" Device : {self.device}" ) def _build_backbone( @@ -278,8 +278,8 @@ def train_step( """One gradient update using online batch-hard triplet mining. Args: - images: (B, C, H, W) float tensor - labels: (B,) long tensor of identity ids + images: (B, C, H, W) float tensor + labels: (B,) long tensor of identity ids batch_ids: list of sample UIDs for WeightsLAB signal logging loss_name: "triplet" (contrastive support planned) @@ -292,7 +292,7 @@ def train_step( images = images.to(self.device) labels = labels.to(self.device) - embeddings = self.net(images) # (B, D) + embeddings = self.net(images) # (B, D) # Mine hardest triplets in the batch anc_idx, pos_idx, neg_idx = mine_batch_hard(embeddings, labels) diff --git a/weightslab/examples/PyTorch/ws-clustering/face/signals.py b/weightslab/examples/PyTorch/ws-clustering/face/signals.py index 9b14f270..2ef47045 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/signals.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/signals.py @@ -6,8 +6,8 @@ Classes ------- -TripletLosses differentiable loss functions (return torch.Tensor) -FaceMetrics evaluation metrics and clustering-oriented test signals +TripletLosses differentiable loss functions (return torch.Tensor) +FaceMetrics evaluation metrics and clustering-oriented test signals """ import numpy as np diff --git a/weightslab/examples/PyTorch/ws-clustering/face/utils.py b/weightslab/examples/PyTorch/ws-clustering/face/utils.py index acc77706..9e9fa306 100644 --- a/weightslab/examples/PyTorch/ws-clustering/face/utils.py +++ b/weightslab/examples/PyTorch/ws-clustering/face/utils.py @@ -21,13 +21,13 @@ def pairwise_distances(embeddings: torch.Tensor, squared: bool = False) -> torch Args: embeddings: (B, D) tensor - squared: return squared L2 distances when True + squared: return squared L2 distances when True Returns: (B, B) distance matrix """ dot = torch.matmul(embeddings, embeddings.t()) - sq_norms = torch.diagonal(dot) # (B,) + sq_norms = torch.diagonal(dot) # (B,) distances = sq_norms.unsqueeze(0) - 2.0 * dot + sq_norms.unsqueeze(1) distances = distances.clamp(min=0.0) @@ -58,8 +58,8 @@ def mine_batch_hard( Args: embeddings: (B, D) detached from graph during mining - labels: (B,) integer class ids - squared: use squared L2 distances for mining + labels: (B,) integer class ids + squared: use squared L2 distances for mining Returns: anc_idx, pos_idx, neg_idx 1-D LongTensors; only valid anchors @@ -69,7 +69,7 @@ def mine_batch_hard( B = labels.shape[0] device = labels.device - same = labels.unsqueeze(0) == labels.unsqueeze(1) # (B, B) + same = labels.unsqueeze(0) == labels.unsqueeze(1) # (B, B) diff = ~same eye = torch.eye(B, dtype=torch.bool, device=device) @@ -105,12 +105,12 @@ def compute_verification_metrics( n = len(embeddings) # Pairwise L2 distances - dot = embeddings @ embeddings.T # (N, N) + dot = embeddings @ embeddings.T # (N, N) sq = np.sum(embeddings ** 2, axis=1) dist_mat = (sq[:, None] - 2.0 * dot + sq[None, :]).clip(min=0.0) dist_mat = np.sqrt(dist_mat.clip(min=1e-16)) * (dist_mat != 0.0) - same_pair = labels[:, None] == labels[None, :] # (N, N) + same_pair = labels[:, None] == labels[None, :] # (N, N) # Upper triangle only (avoid double-counting) iu = np.triu_indices(n, k=1) diff --git a/weightslab/examples/PyTorch/ws-clustering/main.py b/weightslab/examples/PyTorch/ws-clustering/main.py index d7c03615..a1f62461 100644 --- a/weightslab/examples/PyTorch/ws-clustering/main.py +++ b/weightslab/examples/PyTorch/ws-clustering/main.py @@ -6,9 +6,9 @@ trained with online batch-hard triplet loss on the Olivetti Faces dataset. Dataset options (set in config.yaml -> data.dataset_type): - "olivetti" - sklearn Olivetti (40 ids, 400 imgs) - works offline, default - "lfw" - LFW People via torchvision (download required) - "folder" - any ImageFolder-style directory + "olivetti" - sklearn Olivetti (40 ids, 400 imgs) - works offline, default + "lfw" - LFW People via torchvision (download required) + "folder" - any ImageFolder-style directory Training flow ------------- @@ -62,7 +62,7 @@ def evaluate( all_uids: List[str] = [] for images, uids, labels, _metadata in loader: - emb = model.get_embeddings(images) # (B, D) + emb = model.get_embeddings(images) # (B, D) all_embeddings.append(emb.numpy()) if isinstance(labels, torch.Tensor): all_labels.append(labels.numpy()) @@ -80,15 +80,15 @@ def evaluate( name=name, ) - print(f" verification_accuracy : {metrics.get('verification_accuracy', float('nan')):.4f}") - print(f" rank1_accuracy : {metrics.get('rank1_accuracy', float('nan')):.4f}") - print(f" FAR : {metrics.get('far', float('nan')):.4f}") - print(f" FRR : {metrics.get('frr', float('nan')):.4f}") - print(f" best_threshold : {metrics.get('best_threshold', float('nan')):.4f}") + print(f" verification_accuracy : {metrics.get('verification_accuracy', float('nan')):.4f}") + print(f" rank1_accuracy : {metrics.get('rank1_accuracy', float('nan')):.4f}") + print(f" FAR : {metrics.get('far', float('nan')):.4f}") + print(f" FRR : {metrics.get('frr', float('nan')):.4f}") + print(f" best_threshold : {metrics.get('best_threshold', float('nan')):.4f}") if "num_clusters" in metrics: - print(f" num_clusters : {metrics['num_clusters']:.0f}") - print(f" noise_ratio : {metrics['noise_ratio']:.4f}") - print(f" mean_nn1_distance : {metrics['mean_nn1_distance']:.4f}") + print(f" num_clusters : {metrics['num_clusters']:.0f}") + print(f" noise_ratio : {metrics['noise_ratio']:.4f}") + print(f" mean_nn1_distance : {metrics['mean_nn1_distance']:.4f}") return metrics @@ -111,10 +111,10 @@ def train( performed every eval_full_to_train_steps_ratio steps when test_loader is provided. """ print("\n" + "=" * 60) - print("Face Recognition Training (open-ended while loop)") - print(f" Loss : {loss_name}") - print(f" Eval every : {eval_full_to_train_steps_ratio} steps") - print(" Max steps : infinite (stop with Ctrl+C)") + print("Face Recognition Training (open-ended while loop)") + print(f" Loss : {loss_name}") + print(f" Eval every : {eval_full_to_train_steps_ratio} steps") + print(" Max steps : infinite (stop with Ctrl+C)") print("=" * 60) data_iter = iter(train_loader) @@ -183,7 +183,7 @@ def train( print("\nTraining summary:") for k, v in summary.items(): - print(f" {k}: {v}") + print(f" {k}: {v}") return summary @@ -313,16 +313,16 @@ def train( print("\n" + "=" * 60) print("STARTING FACE RECOGNITION TRAINING") - print(f" Experiment : {parameters['experiment_name']}") - print(f" Device : {device}") - print(f" Steps : infinite | eval_full_to_train_steps_ratio={eval_full_to_train_steps_ratio}") - print(f" Loss : {model_cfg.get('loss', 'triplet')}") - print(f" Logs : {parameters['root_log_dir']}") + print(f" Experiment : {parameters['experiment_name']}") + print(f" Device : {device}") + print(f" Steps : infinite | eval_full_to_train_steps_ratio={eval_full_to_train_steps_ratio}") + print(f" Loss : {model_cfg.get('loss', 'triplet')}") + print(f" Logs : {parameters['root_log_dir']}") print("=" * 60) # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. train( model=model, diff --git a/weightslab/examples/PyTorch/ws-detection/main.py b/weightslab/examples/PyTorch/ws-detection/main.py index 63710bd4..9ff5357e 100644 --- a/weightslab/examples/PyTorch/ws-detection/main.py +++ b/weightslab/examples/PyTorch/ws-detection/main.py @@ -49,7 +49,7 @@ def train(loader, model, optimizer, sig, device, grid_size, conf_thresh): targets = [t.to(device) for t in targets] optimizer.zero_grad() - outputs = model(inputs) # [B, S, S, 5 + num_classes] + outputs = model(inputs) # [B, S, S, 5 + num_classes] # Decoded boxes for the UI overlay (detached — display only). preds = decode_predictions(outputs.detach(), grid_size, conf_thresh=conf_thresh) @@ -90,7 +90,7 @@ def test(loader, model, sig, device, grid_size, conf_thresh, test_loader_len): loss = float((losses / test_loader_len).detach().cpu().item()) iou = float((ious / test_loader_len).detach().cpu().item()) - return loss, iou * 100.0 # Return mean IoU as percentage + return loss, iou * 100.0 # Return mean IoU as percentage # ============================================================================= @@ -111,7 +111,7 @@ def test(loader, model, sig, device, grid_size, conf_thresh, test_loader_len): parameters.setdefault("training_steps_to_do", 500) parameters.setdefault("eval_full_to_train_steps_ratio", 50) parameters.setdefault("number_of_workers", 4) - parameters.setdefault("num_classes", 1) # Penn-Fudan: single class (person) + parameters.setdefault("num_classes", 1) # Penn-Fudan: single class (person) parameters.setdefault("image_size", 256) parameters.setdefault("grid_size", 8) parameters.setdefault("conf_thresh", 0.3) @@ -255,7 +255,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): class_counts = np.zeros(num_classes, dtype=np.float64) num_samples = min(len(dataset), max_samples) - for idx in tqdm.tqdm(range(num_samples), desc="📊 Analyzing Distribution"): + for idx in tqdm.tqdm(range(num_samples), desc=" Analyzing Distribution"): _, _, target, _ = dataset.get_items(idx, include_labels=True) if target is None or len(target) == 0: continue @@ -263,10 +263,10 @@ def compute_class_weights(dataset, num_classes, max_samples=200): if 0 <= c < num_classes: class_counts[c] += 1 - class_counts = np.maximum(class_counts, 1) # Avoid div by zero + class_counts = np.maximum(class_counts, 1) # Avoid div by zero total = class_counts.sum() class_weights = total / (num_classes * class_counts) - class_weights = class_weights / class_weights.mean() # Normalize + class_weights = class_weights / class_weights.mean() # Normalize print("\nClass distribution and weights:", flush=True) for c in range(num_classes): @@ -287,16 +287,16 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("=" * 60) - print("🚀 STARTING PENN-FUDAN PEDESTRIAN DETECTION TRAINING") - print(f"📈 Total steps: {max_steps}") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") - print(f"💾 Logs will be saved to: {log_dir}") - print(f"📂 Data root: {data_root}") + print(" STARTING PENN-FUDAN PEDESTRIAN DETECTION TRAINING") + print(f" Total steps: {max_steps}") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") + print(f" Logs will be saved to: {log_dir}") + print(f" Data root: {data_root}") print("=" * 60 + "\n") # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. # ================ train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() @@ -310,7 +310,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): # Test if age == 0 or age % eval_full_to_train_steps_ratio == 0: - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating") if tqdm_display else test_loader test_loss, test_metric = test(test_loader_it, model, test_sig, device, grid_size, conf_thresh, test_loader_len) @@ -332,8 +332,8 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/PyTorch/ws-detection/utils/criterions.py b/weightslab/examples/PyTorch/ws-detection/utils/criterions.py index 700b9099..6feeab6e 100644 --- a/weightslab/examples/PyTorch/ws-detection/utils/criterions.py +++ b/weightslab/examples/PyTorch/ws-detection/utils/criterions.py @@ -13,13 +13,13 @@ # is assigned to the grid cell containing its center; that cell is "responsible" # for predicting the box. # -# * PerSampleDetectionLoss -> one differentiable loss scalar per sample ([B]), -# wrapped with ``per_sample=True`` (the value WL backprops + dashboards). -# * PerSampleIoU -> mean IoU over a sample's boxes ([B]), a metric. -# * PerInstanceIoU -> flat tensor of one IoU per GT box (sample-major -# order), wrapped with ``per_instance=True`` so WL auto-saves it at -# (sample_id, annotation_id). The ordering matches the per-sample target -# iteration, so the wrapper's auto ``batch_idx`` maps each value correctly. +# * PerSampleDetectionLoss -> one differentiable loss scalar per sample ([B]), +# wrapped with ``per_sample=True`` (the value WL backprops + dashboards). +# * PerSampleIoU -> mean IoU over a sample's boxes ([B]), a metric. +# * PerInstanceIoU -> flat tensor of one IoU per GT box (sample-major +# order), wrapped with ``per_instance=True`` so WL auto-saves it at +# (sample_id, annotation_id). The ordering matches the per-sample target +# iteration, so the wrapper's auto ``batch_idx`` maps each value correctly. _EPS = 1e-6 _LAMBDA_COORD = 5.0 @@ -45,12 +45,12 @@ def _responsible_cells(boxes, S): Args: boxes: [N, 4] xyxy in [0, 1]. - S: grid size. + S: grid size. Returns: - rows, cols: [N] long, the responsible cell indices. - off_x, off_y: [N] center offset within the cell, in [0, 1). - w, h: [N] box size as a fraction of the image. + rows, cols: [N] long, the responsible cell indices. + off_x, off_y: [N] center offset within the cell, in [0, 1). + w, h: [N] box size as a fraction of the image. """ cx = (boxes[:, 0] + boxes[:, 2]) / 2 cy = (boxes[:, 1] + boxes[:, 3]) / 2 @@ -69,12 +69,12 @@ def _per_sample_loss(outputs, targets, num_classes, weights=None): B, S = outputs.shape[0], outputs.shape[1] device = outputs.device - obj_logit = outputs[..., 0] # [B, S, S] + obj_logit = outputs[..., 0] # [B, S, S] tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) w_pred = torch.sigmoid(outputs[..., 3]) h_pred = torch.sigmoid(outputs[..., 4]) - cls_logits = outputs[..., 5:] # [B, S, S, C] + cls_logits = outputs[..., 5:] # [B, S, S, C] if weights is not None: weights = torch.as_tensor(weights, device=device, dtype=outputs.dtype) @@ -138,7 +138,7 @@ def _per_box_iou(outputs, targets, grid_size): Returns a list[B] of 1-D tensors (one IoU per box for that sample, in annotation order). Detached — this is a metric, not a loss. """ - boxes_grid, _, _ = decode_grid(outputs, grid_size) # [B, S, S, 4] + boxes_grid, _, _ = decode_grid(outputs, grid_size) # [B, S, S, 4] B = outputs.shape[0] S = grid_size device = outputs.device @@ -154,8 +154,8 @@ def _per_box_iou(outputs, targets, grid_size): gt_boxes = tgt[:, :4] rows, cols, _, _, _, _ = _responsible_cells(gt_boxes, S) - pred_boxes = boxes_grid[s, rows, cols] # [N, 4] - ious = box_iou_xyxy(pred_boxes, gt_boxes) # [N] + pred_boxes = boxes_grid[s, rows, cols] # [N, 4] + ious = box_iou_xyxy(pred_boxes, gt_boxes) # [N] per_sample.append(ious.detach()) return per_sample @@ -222,8 +222,8 @@ def decode_predictions(outputs, grid_size, conf_thresh=0.3, max_det=10): boxes_grid, obj, cls_probs = decode_grid(outputs, grid_size) B, S = outputs.shape[0], grid_size - cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] - score = obj * cls_conf # combined confidence + cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] + score = obj * cls_conf # combined confidence flat_boxes = boxes_grid.view(B, S * S, 4) flat_score = score.view(B, S * S) diff --git a/weightslab/examples/PyTorch/ws-detection/utils/data.py b/weightslab/examples/PyTorch/ws-detection/utils/data.py index 93e60dd1..dbff4a02 100644 --- a/weightslab/examples/PyTorch/ws-detection/utils/data.py +++ b/weightslab/examples/PyTorch/ws-detection/utils/data.py @@ -21,9 +21,9 @@ # box per pedestrian from each mask. Downloaded + extracted on first use. # # On-disk layout after extraction: -# /PennFudanPed/ -# PNGImages/FudanPed00001.png ... -# PedMasks/FudanPed00001_mask.png ... # pixel value k = k-th pedestrian, 0 = bg +# /PennFudanPed/ +# PNGImages/FudanPed00001.png ... +# PedMasks/FudanPed00001_mask.png ... # pixel value k = k-th pedestrian, 0 = bg # # WL renders detection targets/predictions from a per-sample [N, 6] array # ``[x1, y1, x2, y2, class_id, confidence]`` normalized to [0, 1] (GT conf = 1.0) @@ -73,7 +73,7 @@ def _boxes_from_mask(mask_path): mask = np.array(Image.open(mask_path)) h, w = mask.shape[:2] obj_ids = np.unique(mask) - obj_ids = obj_ids[obj_ids != 0] # drop background + obj_ids = obj_ids[obj_ids != 0] # drop background boxes = [] for oid in obj_ids: @@ -91,9 +91,9 @@ class PennFudanDetectionDataset(Dataset): """Pedestrian bounding-box detection over the Penn-Fudan images. Args: - root: directory to download/extract the dataset into. - split: "train" or "val" (deterministic split of the 170 images). - image_size: square resize fed to the model. + root: directory to download/extract the dataset into. + split: "train" or "val" (deterministic split of the 170 images). + image_size: square resize fed to the model. val_fraction: fraction of images held out for validation. max_samples: optional cap on the split size (for quick runs). """ @@ -131,7 +131,7 @@ def __init__( val_set = set(all_imgs[::k]) selected = [f for f in all_imgs if f not in val_set] - selected = selected[:max_samples] if max_samples is not None else selected + selected = selected[:max_samples] if max_samples != None else selected self.images = [] self.masks = [] @@ -165,16 +165,16 @@ def _load_boxes(self, mask_path): norm[:, [0, 2]] /= float(w) norm[:, [1, 3]] /= float(h) n = norm.shape[0] - cls = np.zeros((n, 1), dtype=np.float32) # single class: person + cls = np.zeros((n, 1), dtype=np.float32) # single class: person conf = np.ones((n, 1), dtype=np.float32) return np.concatenate([norm, cls, conf], axis=1).astype(np.float32) def __getitem__(self, idx): """Returns (item, uid, target, metadata). - - item: normalized image tensor [C, H, W] - - uid: unique sample id (string) - - target: [N, 6] float32 = [x1, y1, x2, y2, class_id, confidence] + - item: normalized image tensor [C, H, W] + - uid: unique sample id (string) + - target: [N, 6] float32 = [x1, y1, x2, y2, class_id, confidence] - metadata: dict with source paths """ return self.get_items(idx, include_metadata=True, include_labels=True, include_images=True) @@ -210,10 +210,10 @@ def det_collate(batch): sample's boxes in annotation order). Returns: - images: FloatTensor [B, C, H, W] - ids: list[str] of length B + images: FloatTensor [B, C, H, W] + ids: list[str] of length B targets: list[B] of [N_i, 6] float tensors ([x1, y1, x2, y2, cls, conf]) - metas: list[B] of metadata dicts + metas: list[B] of metadata dicts """ images = torch.stack([b[0] for b in batch], dim=0) ids = [b[1] for b in batch] diff --git a/weightslab/examples/PyTorch/ws-detection/utils/model.py b/weightslab/examples/PyTorch/ws-detection/utils/model.py index effbc5f0..ed204a66 100644 --- a/weightslab/examples/PyTorch/ws-detection/utils/model.py +++ b/weightslab/examples/PyTorch/ws-detection/utils/model.py @@ -6,12 +6,12 @@ # (objectness, tx, ty, tw, th, class_logits...). # # Encoding (all coordinates normalized to the [0, 1] image frame): -# * objectness = sigmoid(t_obj) -> P(box present in this cell) -# * cx = (col + sigmoid(tx)) / S -> box center, x -# * cy = (row + sigmoid(ty)) / S -> box center, y -# * w = sigmoid(tw) -> box width (fraction of image) -# * h = sigmoid(th) -> box height (fraction of image) -# * class = softmax(class_logits) +# * objectness = sigmoid(t_obj) -> P(box present in this cell) +# * cx = (col + sigmoid(tx)) / S -> box center, x +# * cy = (row + sigmoid(ty)) / S -> box center, y +# * w = sigmoid(tw) -> box width (fraction of image) +# * h = sigmoid(th) -> box height (fraction of image) +# * class = softmax(class_logits) # # Raw forward output keeps logits (loss applies the activations); `decode` # turns logits into xyxy boxes for metrics and UI rendering. @@ -28,12 +28,12 @@ def decode_grid(outputs, grid_size): encoding lives in exactly one place. Args: - outputs: [B, S, S, 5 + num_classes] raw logits. + outputs: [B, S, S, 5 + num_classes] raw logits. grid_size: S. Returns: - boxes: [B, S, S, 4] xyxy in [0, 1] - obj: [B, S, S] objectness probability + boxes: [B, S, S, 4] xyxy in [0, 1] + obj: [B, S, S] objectness probability cls_probs: [B, S, S, num_classes] class probabilities """ B, S, _, _ = outputs.shape @@ -86,7 +86,7 @@ def __init__( # --- Pretrained backbone (ImageNet) --- weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1 if pretrained else None backbone = mobilenet_v3_small(weights=weights) - self.backbone = backbone.features # [B, 576, H/32, W/32] + self.backbone = backbone.features # [B, 576, H/32, W/32] backbone_out_ch = 576 self.freeze_backbone = freeze_backbone @@ -121,7 +121,7 @@ def forward(self, x): feat = self.backbone(x) feat = self.neck(feat) - out = self.head(feat) # [B, preds_per_cell, S', S'] + out = self.head(feat) # [B, preds_per_cell, S', S'] # Resize feature grid to the configured grid_size. if out.shape[-1] != self.grid_size or out.shape[-2] != self.grid_size: diff --git a/weightslab/examples/PyTorch/ws-generation/main.py b/weightslab/examples/PyTorch/ws-generation/main.py index b1890d5b..a2315af4 100644 --- a/weightslab/examples/PyTorch/ws-generation/main.py +++ b/weightslab/examples/PyTorch/ws-generation/main.py @@ -40,22 +40,22 @@ def __init__(self, in_ch: int = 3, base: int = 4, bottleneck: int = 512, image_s # ---- Encoder ---- self.enc1_conv = nn.Conv2d(in_ch, C1, kernel_size=3, padding=1) - self.enc1_bn = nn.BatchNorm2d(C1) + self.enc1_bn = nn.BatchNorm2d(C1) self.enc1_pool = nn.MaxPool2d(2) self.enc2_conv = nn.Conv2d(C1, C2, kernel_size=3, padding=1) - self.enc2_bn = nn.BatchNorm2d(C2) + self.enc2_bn = nn.BatchNorm2d(C2) self.enc2_pool = nn.MaxPool2d(2) self.enc3_conv = nn.Conv2d(C2, C3, kernel_size=3, padding=1) - self.enc3_bn = nn.BatchNorm2d(C3) + self.enc3_bn = nn.BatchNorm2d(C3) self.enc3_pool = nn.MaxPool2d(2) # ---- Mid / Bottleneck ---- - self.mid_conv3 = nn.Conv2d(C3, C3, kernel_size=3, padding=1) - self.mid_conv5 = nn.Conv2d(C3, C3, kernel_size=5, padding=2) - self.mid_conv7 = nn.Conv2d(C3, C3, kernel_size=7, padding=3) - self.mid_bn = nn.BatchNorm2d(C3 * 3) + self.mid_conv3 = nn.Conv2d(C3, C3, kernel_size=3, padding=1) + self.mid_conv5 = nn.Conv2d(C3, C3, kernel_size=5, padding=2) + self.mid_conv7 = nn.Conv2d(C3, C3, kernel_size=7, padding=3) + self.mid_bn = nn.BatchNorm2d(C3 * 3) # NEW: Spatial path for reconstruction (preserves 2D structure) self.spatial_bottleneck = nn.Conv2d(C3 * 3, C3, kernel_size=1) @@ -66,15 +66,15 @@ def __init__(self, in_ch: int = 3, base: int = 4, bottleneck: int = 512, image_s # ---- Decoder ---- self.up1_conv = nn.Conv2d(C3, C2, kernel_size=3, padding=1) - self.up1_bn = nn.BatchNorm2d(C2) + self.up1_bn = nn.BatchNorm2d(C2) self.up2_conv = nn.Conv2d(C2, C1, kernel_size=3, padding=1) - self.up2_bn = nn.BatchNorm2d(C1) + self.up2_bn = nn.BatchNorm2d(C1) # ---- Heads ---- - self.cls_head = nn.Linear(bottleneck, 1) # anomaly classification - self.recon_head = nn.Conv2d(C1, in_ch, kernel_size=1) # reconstruction - self.embed_head = nn.Linear(bottleneck, 64) # contrastive embedding + self.cls_head = nn.Linear(bottleneck, 1) # anomaly classification + self.recon_head = nn.Conv2d(C1, in_ch, kernel_size=1) # reconstruction + self.embed_head = nn.Linear(bottleneck, 64) # contrastive embedding def forward(self, x): # Encoder @@ -139,11 +139,11 @@ def __init__(self, root, split="train", transform=None): print(f"Warning: split directory {split_dir} not found.") return - for folder in sorted(os.listdir(split_dir)): # sorted for determinism + for folder in sorted(os.listdir(split_dir)): # sorted for determinism folder_path = os.path.join(split_dir, folder) if not os.path.isdir(folder_path): continue - for fname in sorted(os.listdir(folder_path)): # sorted for determinism + for fname in sorted(os.listdir(folder_path)): # sorted for determinism if fname.lower().endswith(('.png', '.jpg', '.jpeg')): full_path = os.path.join(folder_path, fname) if folder == 'good': @@ -176,9 +176,9 @@ def __getitem__(self, idx): # Fully Balanced Pairing Strategy (50/50 Pairs, 50/50 Labels) # Cycle through 4 types of pairs based on idx % 4: # 0: Good + Good (Positive Contrastive, 100% Good Labels) - # 1: Bad + Bad (Positive Contrastive, 100% Bad Labels) - # 2: Good + Bad (Negative Contrastive, 50/50 Labels) - # 3: Bad + Good (Negative Contrastive, 50/50 Labels) + # 1: Bad + Bad (Positive Contrastive, 100% Bad Labels) + # 2: Good + Bad (Negative Contrastive, 50/50 Labels) + # 3: Bad + Good (Negative Contrastive, 50/50 Labels) p_type = idx % 4 if p_type == 0: @@ -220,10 +220,10 @@ def __getitem__(self, idx): group_id = f"{self.split}_pair_{uid1}_{uid2}" return ( - [img1_t, img2_t], # The pair of inputs - [idx1, idx2], # Relative indices - [label1, label2], # Individual labels - { # Metadata — "uids" causes 2 ledger rows per pair + [img1_t, img2_t], # The pair of inputs + [idx1, idx2], # Relative indices + [label1, label2], # Individual labels + { # Metadata — "uids" causes 2 ledger rows per pair "group_id": group_id, "uids": [uid1, uid2], } @@ -443,7 +443,7 @@ def evaluate_all(loader, model, cls_criterion, contrastive_criterion, metric, de ]) _train_ds = VADDataset(data_root, split="train", transform=transform) - _test_ds = VADDataset(data_root, split="test", transform=transform) + _test_ds = VADDataset(data_root, split="test", transform=transform) train_loader = wl.watch_or_edit( _train_ds, flag="data", loader_name="train_loader", @@ -486,7 +486,7 @@ def evaluate_all(loader, model, cls_criterion, contrastive_criterion, metric, de # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. pbar = tqdm(range(training_steps), desc="Training") for step in pbar: diff --git a/weightslab/examples/PyTorch/ws-multitask/config.yaml b/weightslab/examples/PyTorch/ws-multitask/config.yaml new file mode 100644 index 00000000..8511a026 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/config.yaml @@ -0,0 +1,38 @@ +# Multi-task MNIST: digit classification + bounding-box localization +experiment_name: mnist_multitask +device: auto +training_steps_to_do: null # null = run until stopped +# root_log_dir: ./logs/mnist_multitask + +# Task loss weights — increase loc_loss_weight if localization is underfitting +cls_loss_weight: 1.0 +loc_loss_weight: 5.0 + +num_classes: 10 +eval_full_to_train_steps_ratio: 500 +write_export_ratio: 100 + +tqdm_display: true +compute_natural_sort: false +skip_checkpoint_load: false + +# H5 / dataframe persistence +ledger_enable_flushing_threads: true +ledger_enable_h5_persistence: true +ledger_flush_max_rows: 15000 +ledger_flush_interval: 30.0 + +# Services +serving_grpc: true +serving_cli: false + +optimizer: + lr: 0.001 + +data: + train_loader: + batch_size: 32 + shuffle: true + test_loader: + batch_size: 64 + shuffle: false diff --git a/weightslab/examples/PyTorch/ws-multitask/main.py b/weightslab/examples/PyTorch/ws-multitask/main.py new file mode 100644 index 00000000..95bfad40 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/main.py @@ -0,0 +1,338 @@ +""" +Multi-task learning with WeightsLab — MNIST digit classification + localization. + +This example demonstrates how to track a multi-head model with WeightsLab: + - A shared CNN backbone feeds two heads. + - Classification head: cross-entropy over 10 digit classes. + - Localization head: Smooth-L1 regression of the digit's tight bounding box. + +Both losses are tracked separately in WeightsLab so you can: + - Compare classification vs. localization learning curves in the plots board. + - Inspect per-sample loss breakdown (hardest-to-classify vs. hardest-to-locate). + - See predicted bounding boxes overlaid on each MNIST sample in the data grid. + +WeightsLab task_type="detection" enables bbox visualization in the UI grid. +""" + +import itertools +import os +import ssl +import time +import logging +import tempfile + +try: + ssl.create_default_context() +except ssl.SSLError: + ssl._create_default_https_context = ssl._create_unverified_context + +import yaml +import tqdm +import torch +import torch.optim as optim +from torchvision import transforms + +import weightslab as wl +from weightslab.components.global_monitoring import ( + guard_training_context, + guard_testing_context, +) + +from utils.data import MNISTMultiTaskDataset, multitask_collate +from utils.model import MNISTMultiTaskModel +from utils.criterions import PerSampleClassificationLoss, PerSampleLocalizationLoss + +logging.basicConfig(level=logging.ERROR) + + +# ============================================================================= +# Helpers +# ============================================================================= + +def _build_preds(cls_logits, bbox_pred): + """ + Build detection-format predictions for WeightsLab UI overlay. + + Returns a list of [1, 6] tensors — one per sample — with columns: + [x1, y1, x2, y2, predicted_class, confidence] + """ + classes = cls_logits.argmax(dim=1).float() + confs = cls_logits.softmax(dim=1).max(dim=1).values + return [ + torch.stack([ + bbox_pred[i, 0], bbox_pred[i, 1], + bbox_pred[i, 2], bbox_pred[i, 3], + classes[i], confs[i], + ]).unsqueeze(0) + for i in range(len(classes)) + ] + + +# ============================================================================= +# Train / Test loops +# ============================================================================= + +def train(loader, model, optimizer, sig, device, cls_weight, loc_weight): + """Single multi-task training step.""" + with guard_training_context: + inputs, ids, targets, _ = next(loader) + inputs = inputs.to(device) + targets = [t.to(device) for t in targets] + + optimizer.zero_grad() + cls_logits, bbox_pred = model(inputs) + + preds = _build_preds(cls_logits.detach(), bbox_pred.detach()) + + cls_loss_per_sample = sig["cls_loss"](cls_logits, targets, batch_ids=ids, preds=preds) + loc_loss_per_sample = sig["loc_loss"](bbox_pred, targets, batch_ids=ids, preds=preds) + + loss = (cls_weight * cls_loss_per_sample + loc_weight * loc_loss_per_sample).mean() + loss.backward() + optimizer.step() + + # Per-sample classification accuracy for inspection in the data grid. + labels = torch.stack([t[0, 4].long() for t in targets]).to(device) + preds_cls = cls_logits.argmax(dim=1) + wl.save_signals( + {"cls_correct_per_sample": (preds_cls == labels).float()}, + ids, + ) + + return float(loss.detach().cpu()) + + +def test(loader, model, sig, device, cls_weight, loc_weight, loader_len): + """Full evaluation pass.""" + total_cls_loss = 0.0 + total_loc_loss = 0.0 + correct = 0 + total = 0 + + with guard_testing_context, torch.no_grad(): + for inputs, ids, targets, _ in loader: + inputs = inputs.to(device) + targets = [t.to(device) for t in targets] + + cls_logits, bbox_pred = model(inputs) + preds = _build_preds(cls_logits, bbox_pred) + + cls_loss_per_sample = sig["cls_loss"](cls_logits, targets, batch_ids=ids, preds=preds) + loc_loss_per_sample = sig["loc_loss"](bbox_pred, targets, batch_ids=ids, preds=preds) + + total_cls_loss += cls_loss_per_sample.mean().item() + total_loc_loss += loc_loss_per_sample.mean().item() + + labels = torch.stack([t[0, 4].long() for t in targets]).to(device) + preds_cls = cls_logits.argmax(dim=1) + correct += (preds_cls == labels).sum().item() + total += len(labels) + + wl.save_signals( + {"cls_correct_per_sample": (preds_cls == labels).float()}, + ids, + ) + + cls_loss = total_cls_loss / loader_len + loc_loss = total_loc_loss / loader_len + accuracy = 100.0 * correct / total if total > 0 else 0.0 + return cls_loss, loc_loss, accuracy + + +# ============================================================================= +# Main +# ============================================================================= +if __name__ == "__main__": + start_time = time.time() + + config_path = os.path.join(os.path.dirname(__file__), "config.yaml") + if os.path.exists(config_path): + with open(config_path, "r") as fh: + parameters = yaml.safe_load(fh) or {} + else: + parameters = {} + + parameters.setdefault("experiment_name", "mnist_multitask") + parameters.setdefault("device", "auto") + parameters.setdefault("training_steps_to_do", None) + parameters.setdefault("eval_full_to_train_steps_ratio", 500) + parameters.setdefault("write_export_ratio", 100) + parameters.setdefault("num_classes", 10) + parameters.setdefault("cls_loss_weight", 1.0) + parameters.setdefault("loc_loss_weight", 5.0) + + wl.watch_or_edit( + parameters, + flag="hyperparameters", + name=parameters["experiment_name"], + defaults=parameters, + poll_interval=1.0, + ) + + if parameters.get("device", "auto") == "auto": + parameters["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = parameters["device"] + + if not parameters.get("root_log_dir"): + tmp_dir = tempfile.mkdtemp() + parameters["root_log_dir"] = tmp_dir + print(f"No root_log_dir specified, using temporary directory: {tmp_dir}") + os.makedirs(parameters["root_log_dir"], exist_ok=True) + + log_dir = parameters["root_log_dir"] + eval_ratio = parameters["eval_full_to_train_steps_ratio"] + write_export_ratio = parameters["write_export_ratio"] + training_steps_to_do = parameters.get("training_steps_to_do") + tqdm_display = parameters.get("tqdm_display", True) + verbose = parameters.get("verbose", True) + cls_weight = float(parameters["cls_loss_weight"]) + loc_weight = float(parameters["loc_loss_weight"]) + num_classes = int(parameters["num_classes"]) + enable_h5 = parameters.get("ledger_enable_h5_persistence", True) + + # -- Data ------------------------------------------------------------------- + if parameters.get("data_root"): + data_root = parameters["data_root"] + should_download = not os.path.exists(data_root) + else: + data_root = os.path.join(log_dir, "data") + should_download = True + os.makedirs(data_root, exist_ok=True) + + train_cfg = parameters.get("data", {}).get("train_loader", {}) + test_cfg = parameters.get("data", {}).get("test_loader", {}) + + tf = transforms.Compose([transforms.ToTensor()]) + + _train_dataset = MNISTMultiTaskDataset( + root=data_root, train=True, download=should_download, transform=tf, + max_samples=train_cfg.get("max_samples"), + ) + _test_dataset = MNISTMultiTaskDataset( + root=data_root, train=False, download=should_download, transform=tf, + max_samples=test_cfg.get("max_samples"), + ) + + # task_type="detection" tells the UI to render bbox overlays on each sample. + train_loader = wl.watch_or_edit( + _train_dataset, + flag="data", + loader_name="train_loader", + task_type="detection", + batch_size=train_cfg.get("batch_size", 32), + shuffle=train_cfg.get("shuffle", True), + is_training=True, + compute_hash=False, + preload_labels=False, + enable_h5_persistence=enable_h5, + collate_fn=multitask_collate, + ) + test_loader = wl.watch_or_edit( + _test_dataset, + flag="data", + loader_name="test_loader", + task_type="detection", + batch_size=test_cfg.get("batch_size", 64), + shuffle=False, + is_training=False, + compute_hash=False, + preload_labels=True, + enable_h5_persistence=enable_h5, + collate_fn=multitask_collate, + ) + + # -- Model ------------------------------------------------------------------ + _model = MNISTMultiTaskModel(num_classes=num_classes).to(device) + model = wl.watch_or_edit(_model, flag="model", device=device) + + lr = parameters.get("optimizer", {}).get("lr", 1e-3) + _optimizer = optim.Adam(model.parameters(), lr=lr) + optimizer = wl.watch_or_edit(_optimizer, flag="optimizer") + + # -- Losses (two separate tracked signals) ---------------------------------- + # Tracking each loss independently lets you inspect which task is harder, + # set per-task learning rate schedules, or diagnose multi-task trade-offs. + def _make_signals(split): + return { + "cls_loss": wl.watch_or_edit( + PerSampleClassificationLoss(), + flag="loss", + name=f"{split}_cls_loss", per_sample=True, log=True, + ), + "loc_loss": wl.watch_or_edit( + PerSampleLocalizationLoss(), + flag="loss", + name=f"{split}_loc_loss", per_sample=True, log=True, + ), + } + + train_sig = _make_signals("train") + test_sig = _make_signals("test") + + # -- Serving ---------------------------------------------------------------- + wl.serve( + serving_grpc=parameters.get("serving_grpc", True), + serving_cli=parameters.get("serving_cli", False), + ) + + print("=" * 60) + print(" STARTING MNIST MULTI-TASK TRAINING") + print(f" Tasks: classification (x{cls_weight}) + localization (x{loc_weight})") + print(f" Eval every {eval_ratio} steps | Export every {write_export_ratio} steps") + print(f" Train: {len(_train_dataset)} samples Test: {len(_test_dataset)} samples") + print(f" Logs: {log_dir}") + print("=" * 60 + "\n") + + wl.start_training(timeout=3) + + if tqdm_display: + train_range = tqdm.tqdm( + range(training_steps_to_do) if training_steps_to_do else itertools.count(), + desc="Training", ncols=140, + ) + else: + train_range = ( + range(training_steps_to_do) if training_steps_to_do else itertools.count() + ) + + test_cls_loss, test_loc_loss, test_acc = None, None, None + test_loader_len = len(test_loader) + + for train_step in train_range: + age = model.get_age() if hasattr(model, "get_age") else train_step + + train_loss = train(train_loader, model, optimizer, train_sig, device, cls_weight, loc_weight) + + if age == 0 or age % eval_ratio == 0: + test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating", leave=False) if tqdm_display else test_loader + test_cls_loss, test_loc_loss, test_acc = test( + test_loader_it, model, test_sig, device, cls_weight, loc_weight, test_loader_len + ) + + if age > 0 and age % write_export_ratio == 0: + wl.write_history() + wl.write_dataframe() + + if tqdm_display: + postfix = [f"train={train_loss:.4f}"] + if test_cls_loss is not None: + postfix.append(f"cls={test_cls_loss:.4f}") + if test_loc_loss is not None: + postfix.append(f"loc={test_loc_loss:.4f}") + if test_acc is not None: + postfix.append(f"acc={test_acc:.1f}%") + train_range.set_postfix_str(" | ".join(postfix)) + elif verbose: + msg = f"Step {train_step} (Age {age}): train={train_loss:.4f}" + if test_cls_loss is not None: + msg += f" | cls={test_cls_loss:.4f} loc={test_loc_loss:.4f} acc={test_acc:.1f}%" + print(f"\r{msg:<120}", end="", flush=True) + + print("\n" + "=" * 60) + print(f" Training completed in {time.time() - start_time:.2f}s") + print(f" Logs: {log_dir}") + print("=" * 60) + + wl.write_history() + wl.write_dataframe() + wl.keep_serving() diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/__init__.py b/weightslab/examples/PyTorch/ws-multitask/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/criterions.py b/weightslab/examples/PyTorch/ws-multitask/utils/criterions.py new file mode 100644 index 00000000..dbbff6c1 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/utils/criterions.py @@ -0,0 +1,34 @@ +""" +Multi-task loss functions for WeightsLab tracking. + +Both losses accept the standard WeightsLab call signature: + loss(preds_raw, targets, batch_ids=ids, preds=preds) + +where: + - preds_raw : raw model output for this head + - targets : list of [N, 6] detection tensors ([x1,y1,x2,y2,class_id,conf]) + - batch_ids : sample ids for per-sample logging + - preds : predicted boxes (list of [N,6] tensors) for UI overlay + +Both return a [B] per-sample loss tensor so WeightsLab records one value per sample. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PerSampleClassificationLoss(nn.Module): + """Cross-entropy loss per sample; class labels are extracted from targets.""" + + def forward(self, preds_raw, targets, batch_ids=None, preds=None): + labels = torch.stack([t[0, 4].long() for t in targets]).to(preds_raw.device) + return F.cross_entropy(preds_raw, labels, reduction="none") + + +class PerSampleLocalizationLoss(nn.Module): + """Smooth-L1 (Huber) bbox regression loss per sample; gt boxes from targets.""" + + def forward(self, preds_raw, targets, batch_ids=None, preds=None): + gt_boxes = torch.stack([t[0, :4] for t in targets]).to(preds_raw.device) + return F.smooth_l1_loss(preds_raw, gt_boxes, reduction="none").mean(dim=1) diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/data.py b/weightslab/examples/PyTorch/ws-multitask/utils/data.py new file mode 100644 index 00000000..b185269c --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/utils/data.py @@ -0,0 +1,63 @@ +""" +MNIST multi-task dataset: each sample has a digit class label (classification) +and a tight bounding box of the non-zero pixels (localization). + +Target format follows the WeightsLab detection convention: + tensor of shape [N, 6] with columns [x1, y1, x2, y2, class_id, confidence] + all coordinates normalized to [0, 1]. + +This lets the WeightsLab UI render ground-truth bboxes over each sample. +""" + +import torch +from torch.utils.data import Dataset +from torchvision import datasets, transforms + + +class MNISTMultiTaskDataset(Dataset): + """MNIST with per-sample tight bounding boxes synthesized from pixel intensity.""" + + def __init__(self, root, train=True, download=True, transform=None, max_samples=None): + try: + self._mnist = datasets.MNIST(root=root, train=train, download=download, transform=None) + except RuntimeError: + self._mnist = datasets.MNIST(root=root, train=train, download=True, transform=None) + + self.transform = transform or transforms.ToTensor() + self.max_samples = max_samples + self._length = min(len(self._mnist), max_samples) if max_samples else len(self._mnist) + + def __len__(self): + return self._length + + def _compute_bbox(self, img_tensor): + """Return (x1, y1, x2, y2) normalized to [0,1] for the digit's tight bbox.""" + mask = img_tensor.squeeze(0) > 0.1 + if not mask.any(): + return 0.0, 0.0, 1.0, 1.0 + + rows_with_signal = mask.any(dim=1).nonzero(as_tuple=True)[0] + cols_with_signal = mask.any(dim=0).nonzero(as_tuple=True)[0] + + H, W = img_tensor.shape[-2], img_tensor.shape[-1] + y1 = float(rows_with_signal.min()) / H + y2 = float(rows_with_signal.max()) / H + x1 = float(cols_with_signal.min()) / W + x2 = float(cols_with_signal.max()) / W + return x1, y1, x2, y2 + + def __getitem__(self, idx): + """Returns (image, idx, target) where target is a [1, 6] detection tensor.""" + image, label = self._mnist[idx] + image = self.transform(image) + x1, y1, x2, y2 = self._compute_bbox(image) + target = torch.tensor( + [[x1, y1, x2, y2, float(label), 1.0]], dtype=torch.float32 + ) + return image, idx, target + + +def multitask_collate(batch): + """Collate for detection-format targets: targets remains a list of [N,6] tensors.""" + images, ids, targets = zip(*batch) + return torch.stack(images), torch.tensor(ids, dtype=torch.long), list(targets), {} diff --git a/weightslab/examples/PyTorch/ws-multitask/utils/model.py b/weightslab/examples/PyTorch/ws-multitask/utils/model.py new file mode 100644 index 00000000..601270a3 --- /dev/null +++ b/weightslab/examples/PyTorch/ws-multitask/utils/model.py @@ -0,0 +1,38 @@ +""" +Multi-task CNN for MNIST: shared backbone + classification head + localization head. +""" + +import torch.nn as nn + + +class MNISTMultiTaskModel(nn.Module): + """ + Shared CNN backbone with two heads: + - cls_head: digit classification (10 classes) + - loc_head: tight bounding-box regression (normalized x1,y1,x2,y2) + """ + + def __init__(self, num_classes=10): + super().__init__() + self.backbone = nn.Sequential( + nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), + nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), + nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), + nn.AdaptiveAvgPool2d(4), + ) + feat_dim = 128 * 4 * 4 + + self.cls_head = nn.Sequential( + nn.Flatten(), + nn.Linear(feat_dim, 256), nn.ReLU(), nn.Dropout(0.3), + nn.Linear(256, num_classes), + ) + self.loc_head = nn.Sequential( + nn.Flatten(), + nn.Linear(feat_dim, 128), nn.ReLU(), + nn.Linear(128, 4), nn.Sigmoid(), + ) + + def forward(self, x): + features = self.backbone(x) + return self.cls_head(features), self.loc_head(features) diff --git a/weightslab/examples/PyTorch/ws-segmentation/config.yaml b/weightslab/examples/PyTorch/ws-segmentation/config.yaml index 5054332b..9aceb6c8 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/config.yaml +++ b/weightslab/examples/PyTorch/ws-segmentation/config.yaml @@ -29,7 +29,7 @@ ledger_flush_interval: 60.0 # Data num_classes: 6 image_size: 180 -data_root: .\BDD_subset # Bdd format +data_root: C:\Users\GuillaumePELLUET\Documents\Codes\weightslab\weightslab\examples\PyTorch\ws-segmentation\BDD_subset # Bdd format data: train_loader: batch_size: 2 diff --git a/weightslab/examples/PyTorch/ws-segmentation/main.py b/weightslab/examples/PyTorch/ws-segmentation/main.py index c3a4eb48..21eb0e0c 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/main.py +++ b/weightslab/examples/PyTorch/ws-segmentation/main.py @@ -45,13 +45,13 @@ def _instance_batch_idx(labels): def _run_instance_signals(sig, outputs, labels, ids, preds, return_metric=False): """Compute + log/save the per-sample AND per-instance Dice (metric) and BCE (loss).""" bce_sample = sig["bce_sample"](outputs, labels, batch_ids=ids, preds=preds) - dice_sample = sig["dice_sample"](outputs, labels, batch_ids=ids) # Register processed predictions one time only + dice_sample = sig["dice_sample"](outputs, labels, batch_ids=ids) # Register processed predictions one time only - sig["dice_instance"](outputs, labels, batch_ids=ids) # Register processed predictions one time only + sig["dice_instance"](outputs, labels, batch_ids=ids) # Register processed predictions one time only sig["bce_instance"](outputs, labels, batch_ids=ids) avg_loss = 0.5 * dice_sample + 0.5 * bce_sample - wl.save_signals({"combined_bce_dice_per_sample": avg_loss}, ids) # Save the per-sample aggregate loss for backward step + wl.save_signals({"combined_bce_dice_per_sample": avg_loss}, ids) # Save the per-sample aggregate loss for backward step if return_metric: return avg_loss, dice_sample return avg_loss @@ -91,11 +91,11 @@ def train(loader, model, optimizer, sig, device): with guard_training_context: (inputs, ids, labels, _) = next(loader) inputs = inputs.to(device) - labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances + labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances optimizer.zero_grad() - outputs = model(inputs) # [B,C,H,W] - preds = outputs.argmax(dim=1) # [B,H,W] + outputs = model(inputs) # [B,C,H,W] + preds = outputs.argmax(dim=1) # [B,H,W] # Per-instance + per-sample Dice/BCE (tracked & saved at annotation level). loss_per_sample = _run_instance_signals(sig, outputs, labels, ids, preds=preds) @@ -110,7 +110,7 @@ def train(loader, model, optimizer, sig, device): wl.save_signals( _user_custom_signals(preds, labels), ids - ) # Save the per-sample predictions for visualization + ) # Save the per-sample predictions for visualization return float(loss.detach().cpu().item()) @@ -122,23 +122,23 @@ def test(loader, model, sig, device, test_loader_len): with guard_testing_context, torch.no_grad(): for inputs, ids, labels, _ in loader: inputs = inputs.to(device) - labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances + labels = [[m.to(device) for m in insts] for insts in labels] # per-sample list of instances outputs = model(inputs) - preds = outputs.argmax(dim=1) # [B,H,W] + preds = outputs.argmax(dim=1) # [B,H,W] # Per-instance + per-sample Dice/BCE (tracked & saved at annotation level). loss_per_sample, dice_sample = _run_instance_signals(sig, outputs, labels, ids, preds=preds, return_metric=True) - losses += torch.mean(loss_per_sample) # Average over the batch and accumulate - dices += torch.mean(dice_sample) # Average over the batch and accumulate + losses += torch.mean(loss_per_sample) # Average over the batch and accumulate + dices += torch.mean(dice_sample) # Average over the batch and accumulate # I want to see in the UI the per-sample classes predicted by the model - wl.save_signals(_user_custom_signals(preds, labels), ids) # Save the per-sample predictions for visualization + wl.save_signals(_user_custom_signals(preds, labels), ids) # Save the per-sample predictions for visualization loss = float((losses / test_loader_len).detach().cpu().item()) dice = float((dices / test_loader_len).detach().cpu().item()) - return loss, dice * 100.0 # Return average Dice as percentage + return loss, dice * 100.0 # Return average Dice as percentage # ============================================================================= @@ -159,8 +159,8 @@ def test(loader, model, sig, device, test_loader_len): parameters.setdefault("training_steps_to_do", 500) parameters.setdefault("eval_full_to_train_steps_ratio", 50) parameters.setdefault("number_of_workers", 4) - parameters.setdefault("num_classes", 6) # adjust to your label set - parameters.setdefault("ignore_index", 255) # if you have void pixels + parameters.setdefault("num_classes", 6) # adjust to your label set + parameters.setdefault("ignore_index", 255) # if you have void pixels parameters.setdefault("image_size", 256) parameters.setdefault("compute_natural_sort", True) @@ -194,6 +194,7 @@ def test(loader, model, sig, device, test_loader_len): log_dir = parameters["root_log_dir"] max_steps = parameters["training_steps_to_do"] eval_full_to_train_steps_ratio = parameters["eval_full_to_train_steps_ratio"] + write_export_ratio = parameters.get("write_export_ratio", 100) verbose = parameters.get("verbose", True) tqdm_display = parameters.get("tqdm_display", True) @@ -212,7 +213,7 @@ def test(loader, model, sig, device, test_loader_len): num_classes=num_classes, ignore_index=ignore_index, image_size=image_size, - max_samples=train_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing + max_samples=train_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing ) _val_dataset = BDD100kSegDataset( root=data_root, @@ -220,7 +221,7 @@ def test(loader, model, sig, device, test_loader_len): num_classes=num_classes, ignore_index=ignore_index, image_size=image_size, - max_samples=test_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing + max_samples=test_cfg.get("max_samples", None) # Optionally limit number of samples for faster testing ) train_loader = wl.watch_or_edit( @@ -299,8 +300,8 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 class_counts = np.zeros(num_classes, dtype=np.float64) num_samples = min(len(dataset), max_samples) - for idx in tqdm.tqdm(range(num_samples), desc="📊 Analyzing Distribution"): - _, _, label, _ = dataset.get_items(idx, include_labels=True) # Get the label/mask for this sample + for idx in tqdm.tqdm(range(num_samples), desc=" Analyzing Distribution"): + _, _, label, _ = dataset.get_items(idx, include_labels=True) # Get the label/mask for this sample label_np = label.numpy() if hasattr(label, 'numpy') else np.array(label) for c in range(num_classes): class_counts[c] += (label_np == c).sum() @@ -329,33 +330,38 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 ) print("=" * 60) - print("🚀 STARTING BDD100k SEGMENTATION TRAINING") - print(f"📈 Total steps: {max_steps}") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") - print(f"💾 Logs will be saved to: {log_dir}") - print(f"📂 Data root: {data_root}") + print(" STARTING BDD100k SEGMENTATION TRAINING") + print(f" Total steps: {max_steps}") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") + print(f" Logs will be saved to: {log_dir}") + print(f" Data root: {data_root}") print("=" * 60 + "\n") - # ================ - # Training Loop - wl.start_training(timeout=3) # This will block and keep the main thread alive while background services run. You can optionally set a timeout (in seconds) to automatically stop after a certain duration. + # # ================ + # # Training Loop + # wl.start_training(timeout=3) # This will block and keep the main thread alive while background services run. You can optionally set a timeout (in seconds) to automatically stop after a certain duration. # ================ train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() test_loss, test_metric = None, None start_time = time.time() for train_step in train_range: - age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) + age = model.get_age() if hasattr(model, "get_age") else train_step # Get model age in steps (not necessarily equal to train_step if model was reloaded or has seen more data than training steps) # Train train_loss = train(train_loader, model, optimizer, train_sig, device) # Test if age == 0 or age % eval_full_to_train_steps_ratio == 0: - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating") if tqdm_display else test_loader test_loss, test_metric = test(test_loader_it, model, test_sig, device, test_loader_len) + # Periodic history + dataframe export (JSON/CSV snapshots to root_log_dir) + if age > 0 and age % write_export_ratio == 0: + wl.write_history() + wl.write_dataframe() + # Verbose if verbose and not tqdm_display: print( @@ -374,9 +380,13 @@ def compute_class_weights(dataset, num_classes, ignore_index=255, max_samples=10 ) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) + # Final export of signal history and data grid to root_log_dir + wl.write_history() + wl.write_dataframe() + # Keep the main thread alive to allow background serving threads to run wl.keep_serving() diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py b/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py index 43358196..e137dd32 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/criterions.py @@ -9,11 +9,11 @@ # The segmentation dataset yields, per sample, a LIST of instance masks # (each [H, W] with pixel value = class id). These criterions compute Dice and # BCE for every instance against the model's per-class probability map, then: -# * PerInstance* returns a flat tensor (one value per instance, ordered -# sample-major) — wrapped with `per_instance=True` so WL auto-saves it at -# (sample_id, annotation_id). -# * PerSample* aggregates instances to one value per sample (mean) — wrapped -# with `per_sample=True` for the per-sample dashboards. +# * PerInstance* returns a flat tensor (one value per instance, ordered +# sample-major) — wrapped with `per_instance=True` so WL auto-saves it at +# (sample_id, annotation_id). +# * PerSample* aggregates instances to one value per sample (mean) — wrapped +# with `per_sample=True` for the per-sample dashboards. # The instance ordering matches the `batch_idx` passed by the training loop # (built from the same per-sample instance lists), so WL maps each value to the # correct annotation. @@ -26,14 +26,14 @@ def _instance_dice_bce(outputs, labels, **kwargs): Args: outputs: logits [B, C, H, W]. - labels: list[B]; labels[s] is a list of instance masks ([H, W], value = class id). + labels: list[B]; labels[s] is a list of instance masks ([H, W], value = class id). Returns: (dice_per_sample, bce_per_sample) where each is a list[B] of 1-D tensors holding one value per instance for that sample (empty tensor if none). Values are kept on the outputs' device; BCE retains grad, Dice is a metric. """ - probs = torch.softmax(outputs, dim=1) # [B, C, H, W], differentiable + probs = torch.softmax(outputs, dim=1) # [B, C, H, W], differentiable B, C = probs.shape[0], probs.shape[1] device = outputs.device @@ -58,12 +58,12 @@ def _instance_dice_bce(outputs, labels, **kwargs): cls = int(m.max().item()) ch = cls if 0 <= cls < C else 0 gt = (m > 0).float() - p = probs[s, ch].clamp(_EPS, 1.0 - _EPS) # [H, W] + p = probs[s, ch].clamp(_EPS, 1.0 - _EPS) # [H, W] inter = (p * gt).sum() dice = (2.0 * inter + _EPS) / (p.sum() + gt.sum() + _EPS) bce = F.binary_cross_entropy(p, gt) if weights is not None: - bce = bce * weights[ch] # scalar class weight for this instance + bce = bce * weights[ch] # scalar class weight for this instance dices.append(dice) bces.append(bce) diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py index 864c6853..656c0180 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/data.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/data.py @@ -51,7 +51,7 @@ def __init__( for f in os.listdir(img_dir) if f.lower().endswith((".jpg", ".jpeg", ".png")) ] - image_files = sorted(set(image_files))[:max_samples] if max_samples is not None else sorted(set(image_files)) # Optionally limit number of samples for faster testing + image_files = sorted(set(image_files))[:max_samples] if max_samples != None else sorted(set(image_files)) # Optionally limit number of samples for faster testing self.images = [] self.masks = [] @@ -114,22 +114,31 @@ def get_items(self, idx, include_metadata=False, include_labels=False, include_i img_t = self.image_transform(img) # Process labels/masks - mask_t_instances = list() + # # Sample wise segmentation mask_t = None if include_labels: mask = Image.open(mask_path) mask_r = self.mask_resize(mask) mask_np = np.array(mask_r, dtype=np.int64) - mask_t = torch.from_numpy(mask_np) # [H, W] int64 - - # Format labels to register multiple instance_ids - lbl_max = mask_t.max().item() - for i in range(1, lbl_max + 1): - m = torch.zeros_like(mask_t) - m[mask_t == i] = i # Assign class ID as instance ID for simplicity; if set to 1, all instances of the same class would be merged... - mask_t_instances.append(m) - return img_t, uid, mask_t_instances, metadata - + mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 + return img_t, uid, mask_t, metadata + # # # Instance wise segmentaiton + # # Process labels/masks + # mask_t_instances = list() + # mask_t = None + # if include_labels: + # mask = Image.open(mask_path) + # mask_r = self.mask_resize(mask) + # mask_np = np.array(mask_r, dtype=np.int64) + # mask_t = torch.from_numpy(mask_np)[None] # [H, W] int64 + + # # Format labels to register multiple instance_ids + # lbl_max = mask_t.max().item() + # for i in range(1, lbl_max + 1): + # m = torch.zeros_like(mask_t) + # m[mask_t == i] = i # Assign class ID as instance ID for simplicity; if set to 1, all instances of the same class would be merged... + # mask_t_instances.append(m) + # return img_t, uid, mask_t_instances, metadata def seg_collate(batch): """Collate WL per-sample tuples for instance-segmentation. @@ -141,10 +150,10 @@ def seg_collate(batch): background) are filtered out so every kept instance is a real annotation. Returns: - images: FloatTensor [B, C, H, W] - ids: list[str] of length B - labels: list[B] where labels[s] is a list of instance mask tensors - metas: list[B] of metadata dicts + images: FloatTensor [B, C, H, W] + ids: list[str] of length B + labels: list[B] where labels[s] is a list of instance mask tensors + metas: list[B] of metadata dicts """ images = torch.stack([b[0] for b in batch], dim=0) ids = [b[1] for b in batch] diff --git a/weightslab/examples/PyTorch/ws-segmentation/utils/model.py b/weightslab/examples/PyTorch/ws-segmentation/utils/model.py index 26339e9d..d45311f5 100644 --- a/weightslab/examples/PyTorch/ws-segmentation/utils/model.py +++ b/weightslab/examples/PyTorch/ws-segmentation/utils/model.py @@ -58,7 +58,7 @@ def forward(self, x): # Decoder u2 = self.up2(b) - # ⚠️ Important: no `if` on shapes; always interpolate + # Important: no `if` on shapes; always interpolate u2 = F.interpolate(u2, size=e2.shape[-2:], mode="bilinear", align_corners=False) d2 = self.dec2(torch.cat([u2, e2], dim=1)) @@ -66,5 +66,5 @@ def forward(self, x): u1 = F.interpolate(u1, size=e1.shape[-2:], mode="bilinear", align_corners=False) d1 = self.dec1(torch.cat([u1, e1], dim=1)) - logits = self.head(d1) # [B, C, H, W] + logits = self.head(d1) # [B, C, H, W] return logits diff --git a/weightslab/examples/Ultralytics/ws-detection/config.yaml b/weightslab/examples/Ultralytics/ws-detection/config.yaml index 2ba8184e..9c07dfb9 100644 --- a/weightslab/examples/Ultralytics/ws-detection/config.yaml +++ b/weightslab/examples/Ultralytics/ws-detection/config.yaml @@ -38,7 +38,7 @@ ledger_flush_interval: 60.0 # Data num_classes: 2 image_size: 320 -data_root: .\data\data.yaml # Uncomment and set the path to your data.yaml file. YOLO format. +data_root: C:\Users\GuillaumePELLUET\Documents\Codes\weightslab_kitchen\guillaume_playground\ws-ultralytics_yolo\data\data.yaml # Uncomment and set the path to your data.yaml file. YOLO format. data: train_loader: batch_size: 4 diff --git a/weightslab/examples/Ultralytics/ws-detection/main.py b/weightslab/examples/Ultralytics/ws-detection/main.py index d2ba1a61..3174059f 100644 --- a/weightslab/examples/Ultralytics/ws-detection/main.py +++ b/weightslab/examples/Ultralytics/ws-detection/main.py @@ -41,8 +41,11 @@ def main(): os.makedirs(cfg["root_log_dir"], exist_ok=True) wl.watch_or_edit(cfg, flag="hyperparameters", defaults=cfg, poll_interval=1.0) - # Read raw config values BEFORE wrapping so YOLO.train kwargs are plain - # Python (avoids ProxyValue.__gt__ during max()/comparisons). + # After watch_or_edit, `cfg` is the live hyperparameter proxy, so these reads + # return ledger handles (e.g. image_size is a ValueProxy) that stay in sync + # with studio edits. They are passed straight to YOLO.train(...): ValueProxy + # supports int/compare ops for YOLO's imgsz handling, and the ledger registers + # a YAML representer (see ledgers.py) so Ultralytics can dump its run args. model_name = cfg["model"]["name"] data_root = str(cfg["data_root"]) image_size = cfg.get("image_size") @@ -52,21 +55,20 @@ def main(): serving_cli = cfg.get("serving_cli", False) project = cfg["root_log_dir"] name = cfg["experiment_name"] - signals_cfg = cfg.get('signals_cfg', {}) wl.serve(serving_grpc=serving_grpc, serving_cli=serving_cli) # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. YOLO(model_name).train( trainer=WLAwareTrainer, data=data_root, imgsz=image_size, - epochs=1000 if max_steps is None else max(1, int(max_steps)), + epochs=1000 if max_steps == None else max(1, int(max_steps)), device=device, - project=project, name=name, # → UL save_dir → WL logger log_dir/name + project=project, name=name, # → UL save_dir → WL logger log_dir/name resume=False, cache=False, optimizer="SGD", @@ -78,11 +80,13 @@ def main(): degrees=0.0, translate=0.0, scale=0.0, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.0, erasing=0.0, auto_augment=None, - # Signals cfg - **signals_cfg + # NOTE: signals_cfg (e.g. train_nms) is NOT passed here — it is read by + # WLAwareTrainer from the registered hyperparameters + # (ledgers.get_hyperparams()['signals_cfg']). Spreading it into .train() + # would make Ultralytics reject keys like `train_nms` as invalid YOLO args. ) - wl.keep_serving() # Keep main thread alive to analyze training results directly + wl.keep_serving() # Keep main thread alive to analyze training results directly if __name__ == "__main__": diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py index 7b48ec51..41c62365 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/main.py @@ -34,7 +34,7 @@ def train(loader, model, optimizer, sig, device, grid_size, pc_range, conf_thres points = points.to(device) targets = [t.to(device) for t in targets] optimizer.zero_grad() - outputs = model(points) # [B, S, S, 5 + num_classes] + outputs = model(points) # [B, S, S, 5 + num_classes] preds = decode_predictions(outputs.detach(), grid_size, pc_range, conf_thresh=conf_thresh) loss_per_sample = sig["loss"](outputs, targets, batch_ids=ids, preds=preds) sig["iou_sample"](outputs, targets, batch_ids=ids) @@ -162,15 +162,15 @@ def _make_det_signals(split): serving_cli=parameters.get("serving_cli", True)) print("=" * 60) - print("🚀 STARTING 2D LiDAR DETECTION TRAINING (Pillars2D-lite)") - print(f"📡 {len(_train_dataset)} train / {len(_val_dataset)} val scans") - print(f"💾 Logs: {log_dir}") + print(" STARTING 2D LiDAR DETECTION TRAINING (Pillars2D-lite)") + print(f" {len(_train_dataset)} train / {len(_val_dataset)} val scans") + print(f" Logs: {log_dir}") print("=" * 60 + "\n") # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() test_loss, test_metric = None, None @@ -193,5 +193,5 @@ def _make_det_signals(split): + (f" test_loss={test_loss:.4f}" if test_loss is not None else "") + (f" IoU={test_metric:.2f}%" if test_metric is not None else "")) - print(f"\n✅ Done in {time.time() - start_time:.1f}s; logs at {log_dir}") + print(f"\n Done in {time.time() - start_time:.1f}s; logs at {log_dir}") wl.keep_serving() diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py index 7c5252ca..5c893bba 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/criterions.py @@ -11,9 +11,9 @@ # Targets are [N, 6] rows [cx, cy, dx, dy, class_id, confidence] (metric). Each # GT box is assigned to the grid cell containing its (cx, cy) centre. # -# * PerSampleDetection2DLoss -> one differentiable scalar per sample ([B]). -# * PerSampleIoU2D -> mean axis-aligned IoU over a sample's boxes. -# * PerInstanceIoU2D -> one IoU per GT box (sample-major order). +# * PerSampleDetection2DLoss -> one differentiable scalar per sample ([B]). +# * PerSampleIoU2D -> mean axis-aligned IoU over a sample's boxes. +# * PerInstanceIoU2D -> one IoU per GT box (sample-major order). _EPS = 1e-6 _LAMBDA_COORD = 2.0 @@ -95,7 +95,7 @@ def _per_box_iou(outputs, targets, grid_size, pc_range): if tgt.ndim == 1: tgt = tgt.view(-1, 6) rows, cols, _, _ = _responsible_cells(tgt, S, pc_range) - pred = boxes_grid[s, rows, cols] # [N, 4] (cx,cy,w,h) + pred = boxes_grid[s, rows, cols] # [N, 4] (cx,cy,w,h) gt = torch.stack([tgt[:, 0], tgt[:, 1], tgt[:, 2], tgt[:, 3]], dim=1) per_sample.append(iou_2d_axis_aligned(pred, gt).detach()) return per_sample diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py index 10e19736..40c79b1c 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/data.py @@ -12,10 +12,10 @@ # task is to detect axis-aligned 2D boxes around object clusters. # # Per sample: -# * cloud: [M, 2] float32 (x, y) — genuinely 2D (the studio viewer renders -# it top-down; no z channel, so it is treated as a 2D cloud). -# * target: [N, 6] float32 = [cx, cy, dx, dy, class_id, confidence] -# (metric units; 2D box schema — exactly 6 columns). +# * cloud: [M, 2] float32 (x, y) — genuinely 2D (the studio viewer renders +# it top-down; no z channel, so it is treated as a 2D cloud). +# * target: [N, 6] float32 = [cx, cy, dx, dy, class_id, confidence] +# (metric units; 2D box schema — exactly 6 columns). # # task_type "detection_pointcloud" is shared with the 3D example; the box-row # column count (<= 6) is what marks this as 2D. @@ -29,7 +29,7 @@ PAD_VALUE = -1000.0 # Typical (length, width) per class for the generator. -_CLASS_DIMS = np.array([[3.6, 1.7], [0.7, 0.7]], dtype=np.float32) # Vehicle, Pedestrian +_CLASS_DIMS = np.array([[3.6, 1.7], [0.7, 0.7]], dtype=np.float32) # Vehicle, Pedestrian def _sample_rect_perimeter(rng, dims, n): @@ -74,7 +74,7 @@ def generate_synthetic_scene(seed, pc_range): n_pts = int(np.clip(400.0 / (1.0 + dist / 6.0), 20, 200)) local = _sample_rect_perimeter(rng, dims, n_pts) world = local + np.array([cx, cy], dtype=np.float32) - world += rng.normal(0.0, 0.03, world.shape).astype(np.float32) # sensor noise + world += rng.normal(0.0, 0.03, world.shape).astype(np.float32) # sensor noise clouds.append(world) boxes.append([cx, cy, dims[0], dims[1], float(cls), 1.0]) @@ -96,7 +96,7 @@ def __init__( max_samples=None, seed=0, thumbnail_projection="bev", - **_ignored, # tolerate shared kwargs (kitti_*, extra_features) for parity + **_ignored, # tolerate shared kwargs (kitti_*, extra_features) for parity ): super().__init__() self.split = split @@ -116,7 +116,7 @@ def __init__( else: val_set = set(frames[::k]) selected = [f for f in frames if f not in val_set] - self.frames = selected[:max_samples] if max_samples is not None else selected + self.frames = selected[:max_samples] if max_samples != None else selected def __len__(self): return len(self.frames) diff --git a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py index 2ea4e3b9..50df19b2 100644 --- a/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py +++ b/weightslab/examples/Usecases/ws-2d-lidar-detection/utils/model.py @@ -3,13 +3,13 @@ # ============================================================================= # The 2D analogue of the 3D PointPillars-lite, with z and yaw dropped: # -# 1. Point Feature Net: points are binned into grid cells on the (x, y) plane; -# each point gets 6 features (x, y, offsets to the cell's point mean, -# offsets to the cell center), runs a shared Linear+BN+ReLU, and is -# max-pooled per cell -> a [C, H, W] feature image. -# 2. A tiny 2D CNN backbone. -# 3. A YOLO-style grid head: each S x S cell predicts ONE 2D box -# (objectness, tx, ty, log w, log h, class_logits...). +# 1. Point Feature Net: points are binned into grid cells on the (x, y) plane; +# each point gets 6 features (x, y, offsets to the cell's point mean, +# offsets to the cell center), runs a shared Linear+BN+ReLU, and is +# max-pooled per cell -> a [C, H, W] feature image. +# 2. A tiny 2D CNN backbone. +# 3. A YOLO-style grid head: each S x S cell predicts ONE 2D box +# (objectness, tx, ty, log w, log h, class_logits...). # # decode_grid_2d turns logits into metric (cx, cy, w, h) boxes. import math @@ -24,8 +24,8 @@ def decode_grid_2d(outputs, grid_size, pc_range): """Decode raw grid logits -> per-cell 2D boxes, objectness, class probs. Returns: - boxes: [B, S, S, 4] (cx, cy, w, h) in meters - obj: [B, S, S] objectness probability + boxes: [B, S, S, 4] (cx, cy, w, h) in meters + obj: [B, S, S] objectness probability cls_probs: [B, S, S, num_classes] """ B, S = outputs.shape[0], grid_size @@ -35,7 +35,7 @@ def decode_grid_2d(outputs, grid_size, pc_range): obj = torch.sigmoid(outputs[..., 0]) tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) - dims = torch.exp(outputs[..., 3:5].clamp(-4.0, 4.0)) # (w_x, w_y), meters + dims = torch.exp(outputs[..., 3:5].clamp(-4.0, 4.0)) # (w_x, w_y), meters cls_probs = torch.softmax(outputs[..., 5:], dim=-1) cols = torch.arange(S, device=device).view(1, 1, S).expand(B, S, S) @@ -62,7 +62,7 @@ def __init__(self, num_classes=2, pc_range=DEFAULT_PC_RANGE, voxel_size=0.5, x_min, y_min, _, x_max, y_max, _ = self.pc_range self.nx = int(round((x_max - x_min) / voxel_size)) self.ny = int(round((y_max - y_min) / voxel_size)) - self.preds_per_cell = 5 + num_classes # obj + (tx,ty,log w,log h) + classes + self.preds_per_cell = 5 + num_classes # obj + (tx,ty,log w,log h) + classes self.pfn_channels = pfn_channels self.pfn = nn.Sequential( @@ -109,7 +109,7 @@ def _augment_points(self, points): cy = y_min + (iy.to(pts.dtype) + 0.5) * self.voxel_size f_center = torch.stack([pts[:, 0] - cx, pts[:, 1] - cy], dim=1) - feats = torch.cat([pts[:, :2], f_cluster, f_center], dim=1) # [M, 6] + feats = torch.cat([pts[:, :2], f_cluster, f_center], dim=1) # [M, 6] return feats, flat def _scatter_to_canvas(self, point_feats, flat): diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py index d0d41eb7..7dfc79ae 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/main.py @@ -59,32 +59,32 @@ def render_thumbnail_2d(self, points): # You can customize parameters here (resolution, FOV, rendering mode): return point_cloud_to_range_image( points, - image_height=80, # Custom height (default 64) - image_width=512, # Custom width (default 512, like KITTI) - fov_up=3.0, # Max elevation angle in degrees - fov_down=-25.0, # Min elevation angle (typical LiDAR) - mode="distance+intensity", # or "distance", "intensity" + image_height=80, # Custom height (default 64) + image_width=512, # Custom width (default 512, like KITTI) + fov_up=3.0, # Max elevation angle in degrees + fov_down=-25.0, # Min elevation angle (typical LiDAR) + mode="distance+intensity", # or "distance", "intensity" ) # Optional: override box projection for your custom 2D frame. # Uncomment if needed: # # def project_boxes_2d(self, boxes_3d): - # """Custom box projection to your 2D frame. + # """Custom box projection to your 2D frame. # - # Args: - # boxes_3d: [N, C] where C >= 7 is 3D ([cx,cy,cz,dx,dy,dz,yaw,...]) - # or C <= 6 is 2D ([cx,cy,dx,dy,...]) + # Args: + # boxes_3d: [N, C] where C >= 7 is 3D ([cx,cy,cz,dx,dy,dz,yaw,...]) + # or C <= 6 is 2D ([cx,cy,dx,dy,...]) # - # Returns: - # [N, 6] normalized xyxy boxes [x1, y1, x2, y2, class_id, confidence] - # in [0, 1] range (image coordinates, y down). - # """ - # from weightslab.data.point_cloud_utils import project_boxes_to_bev, get_pc_range - # # For now, just use the standard BEV projection as fallback. - # # Implement your custom projection here. - # pc_range = get_pc_range(self) - # return project_boxes_to_bev(boxes_3d, pc_range) + # Returns: + # [N, 6] normalized xyxy boxes [x1, y1, x2, y2, class_id, confidence] + # in [0, 1] range (image coordinates, y down). + # """ + # from weightslab.data.point_cloud_utils import project_boxes_to_bev, get_pc_range + # # For now, just use the standard BEV projection as fallback. + # # Implement your custom projection here. + # pc_range = get_pc_range(self) + # return project_boxes_to_bev(boxes_3d, pc_range) # ============================================================================= @@ -105,7 +105,7 @@ def train(loader, model, optimizer, sig, device, grid_size, pc_range, conf_thres targets = [t.to(device) for t in targets] optimizer.zero_grad() - outputs = model(points) # [B, S, S, 9 + num_classes] + outputs = model(points) # [B, S, S, 9 + num_classes] # Decoded 3D boxes (detached — stored alongside the loss for analysis). preds = decode_predictions( @@ -147,7 +147,7 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load loss = float((losses / test_loader_len).detach().cpu().item()) iou = float((ious / test_loader_len).detach().cpu().item()) - return loss, iou * 100.0 # Return mean BEV IoU as percentage + return loss, iou * 100.0 # Return mean BEV IoU as percentage # ============================================================================= @@ -168,7 +168,7 @@ def test(loader, model, sig, device, grid_size, pc_range, conf_thresh, test_load parameters.setdefault("training_steps_to_do", 500) parameters.setdefault("eval_full_to_train_steps_ratio", 50) parameters.setdefault("number_of_workers", 4) - parameters.setdefault("num_classes", 3) # Car, Pedestrian, Cyclist + parameters.setdefault("num_classes", 3) # Car, Pedestrian, Cyclist parameters.setdefault("point_cloud_range", list(DEFAULT_PC_RANGE)) parameters.setdefault("voxel_size", 0.5) parameters.setdefault("grid_size", 32) @@ -340,7 +340,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): class_counts = np.zeros(num_classes, dtype=np.float64) num_samples = min(len(dataset), max_samples) - for idx in tqdm.tqdm(range(num_samples), desc="📊 Analyzing Distribution"): + for idx in tqdm.tqdm(range(num_samples), desc=" Analyzing Distribution"): _, _, target, _ = dataset.get_items(idx, include_labels=True) if target is None or len(target) == 0: continue @@ -348,10 +348,10 @@ def compute_class_weights(dataset, num_classes, max_samples=200): if 0 <= c < num_classes: class_counts[c] += 1 - class_counts = np.maximum(class_counts, 1) # Avoid div by zero + class_counts = np.maximum(class_counts, 1) # Avoid div by zero total = class_counts.sum() class_weights = total / (num_classes * class_counts) - class_weights = class_weights / class_weights.mean() # Normalize + class_weights = class_weights / class_weights.mean() # Normalize print("\nClass distribution and weights:", flush=True) for c in range(num_classes): @@ -372,18 +372,18 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("=" * 60) - print("🚀 STARTING LIDAR 3D DETECTION TRAINING (PointPillars-lite)") - print(f"📡 Data source: {_train_dataset.source} " + print(" STARTING LIDAR 3D DETECTION TRAINING (PointPillars-lite)") + print(f" Data source: {_train_dataset.source} " f"({len(_train_dataset)} train / {len(_val_dataset)} val frames)") - print(f"📈 Total steps: {max_steps}") - print(f"🔄 Evaluation every {eval_full_to_train_steps_ratio} steps") - print(f"💾 Logs will be saved to: {log_dir}") - print(f"📂 Data root: {data_root}") + print(f" Total steps: {max_steps}") + print(f" Evaluation every {eval_full_to_train_steps_ratio} steps") + print(f" Logs will be saved to: {log_dir}") + print(f" Data root: {data_root}") print("=" * 60 + "\n") # ================ # Training Loop - wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. + wl.start_training(timeout=3) # Blocks and keeps the main thread alive while background services run. Optionally set a timeout (seconds) to auto-stop. # ================ train_range = tqdm.tqdm(itertools.count(), desc="Training") if tqdm_display else itertools.count() @@ -399,7 +399,7 @@ def compute_class_weights(dataset, num_classes, max_samples=200): # Test if age == 0 or age % eval_full_to_train_steps_ratio == 0: - test_loader_len = len(test_loader) # Store length before wrapping with tqdm + test_loader_len = len(test_loader) # Store length before wrapping with tqdm test_loader_it = tqdm.tqdm(test_loader, desc="Evaluating") if tqdm_display else test_loader test_loss, test_metric = test( test_loader_it, model, test_sig, device, @@ -423,8 +423,8 @@ def compute_class_weights(dataset, num_classes, max_samples=200): ) print("\n" + "=" * 60) - print(f"✅ Training completed in {time.time() - start_time:.2f} seconds") - print(f"💾 Logs saved to: {log_dir}") + print(f" Training completed in {time.time() - start_time:.2f} seconds") + print(f" Logs saved to: {log_dir}") print("=" * 60) # Keep the main thread alive to allow background serving threads to run diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py index 74809175..7a34da0f 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/criterions.py @@ -13,25 +13,25 @@ # coordinates. Each GT box is assigned to the BEV grid cell containing its # (cx, cy) center; that cell is "responsible" for predicting the box. # -# * PerSampleDetection3DLoss -> one differentiable loss scalar per sample -# ([B]), wrapped with ``per_sample=True`` (the value WL backprops + -# dashboards). -# * PerSampleBevIoU -> mean BEV IoU over a sample's boxes ([B]). -# * PerInstanceBevIoU -> flat tensor of one IoU per GT box -# (sample-major order), wrapped with ``per_instance=True`` so WL auto-saves -# it at (sample_id, annotation_id). The ordering matches the per-sample -# target iteration, so the wrapper's auto ``batch_idx`` maps each value -# correctly. +# * PerSampleDetection3DLoss -> one differentiable loss scalar per sample +# ([B]), wrapped with ``per_sample=True`` (the value WL backprops + +# dashboards). +# * PerSampleBevIoU -> mean BEV IoU over a sample's boxes ([B]). +# * PerInstanceBevIoU -> flat tensor of one IoU per GT box +# (sample-major order), wrapped with ``per_instance=True`` so WL auto-saves +# it at (sample_id, annotation_id). The ordering matches the per-sample +# target iteration, so the wrapper's auto ``batch_idx`` maps each value +# correctly. # # The IoU metric is axis-aligned in the BEV plane (yaw ignored) — a cheap, # dependency-free proxy for rotated-box IoU that is monotone enough to rank # samples / instances in the dashboards. _EPS = 1e-6 -_LAMBDA_COORD = 2.0 # x, y, z localization -_LAMBDA_SIZE = 1.0 # log-dims -_LAMBDA_YAW = 1.0 # sin / cos regression -_LAMBDA_NOOBJ = 0.5 # empty-cell objectness down-weighting +_LAMBDA_COORD = 2.0 # x, y, z localization +_LAMBDA_SIZE = 1.0 # log-dims +_LAMBDA_YAW = 1.0 # sin / cos regression +_LAMBDA_NOOBJ = 0.5 # empty-cell objectness down-weighting def bev_iou_axis_aligned(a, b): @@ -55,14 +55,14 @@ def _responsible_cells(boxes, grid_size, pc_range): """Map GT boxes -> their responsible BEV (row, col) cell and cell offsets. Args: - boxes: [N, 9] target rows (metric). + boxes: [N, 9] target rows (metric). grid_size: S. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). Returns: - rows, cols: [N] long, the responsible cell indices. - off_x, off_y: [N] center offset within the cell, in [0, 1). - z_t: [N] z center normalized to [0, 1] over the z range. + rows, cols: [N] long, the responsible cell indices. + off_x, off_y: [N] center offset within the cell, in [0, 1). + z_t: [N] z center normalized to [0, 1] over the z range. """ x_min, y_min, z_min, x_max, y_max, z_max = pc_range S = grid_size @@ -82,14 +82,14 @@ def _per_sample_loss(outputs, targets, num_classes, grid_size, pc_range, weights B, S = outputs.shape[0], grid_size device = outputs.device - obj_logit = outputs[..., 0] # [B, S, S] + obj_logit = outputs[..., 0] # [B, S, S] tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) tz = torch.sigmoid(outputs[..., 3]) - log_dims = outputs[..., 4:7] # [B, S, S, 3] + log_dims = outputs[..., 4:7] # [B, S, S, 3] t_sin = outputs[..., 7] t_cos = outputs[..., 8] - cls_logits = outputs[..., 9:] # [B, S, S, C] + cls_logits = outputs[..., 9:] # [B, S, S, C] if weights is not None: weights = torch.as_tensor(weights, device=device, dtype=outputs.dtype) @@ -162,7 +162,7 @@ def _per_box_bev_iou(outputs, targets, grid_size, pc_range): Returns a list[B] of 1-D tensors (one IoU per box for that sample, in annotation order). Detached — this is a metric, not a loss. """ - boxes_grid, _, _ = decode_grid_3d(outputs, grid_size, pc_range) # [B, S, S, 7] + boxes_grid, _, _ = decode_grid_3d(outputs, grid_size, pc_range) # [B, S, S, 7] B, S = outputs.shape[0], grid_size device = outputs.device @@ -176,7 +176,7 @@ def _per_box_bev_iou(outputs, targets, grid_size, pc_range): tgt = tgt.view(-1, 9) rows, cols, _, _, _ = _responsible_cells(tgt, S, pc_range) - pred = boxes_grid[s, rows, cols] # [N, 7] + pred = boxes_grid[s, rows, cols] # [N, 7] pred_bev = torch.stack( [pred[:, 0], pred[:, 1], pred[:, 3], pred[:, 4]], dim=1) gt_bev = torch.stack( @@ -254,8 +254,8 @@ def decode_predictions(outputs, grid_size, pc_range, conf_thresh=0.3, max_det=20 boxes_grid, obj, cls_probs = decode_grid_3d(outputs, grid_size, pc_range) B, S = outputs.shape[0], grid_size - cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] - score = obj * cls_conf # combined confidence + cls_conf, cls_id = cls_probs.max(dim=-1) # [B, S, S] + score = obj * cls_conf # combined confidence flat_boxes = boxes_grid.view(B, S * S, 7) flat_score = score.view(B, S * S) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py index 77df470b..7abb9f81 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/data.py @@ -11,23 +11,23 @@ # ============================================================================= # Self-driving 3D detection over LiDAR point clouds. Two sources: # -# * "kitti": the KITTI 3D Object Detection benchmark. Expected layout -# (download from https://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d): -# /kitti/training/velodyne/000000.bin ... (x, y, z, intensity float32) -# /kitti/training/label_2/000000.txt ... (camera-frame 3D boxes) -# /kitti/training/calib/000000.txt ... (velo->cam calibration) -# * "synthetic": procedurally generated road scenes (ground plane + car / -# pedestrian / cyclist point clusters). Lets the example run -# out-of-the-box with zero download; useful to validate the -# whole WL pipeline before pointing it at real data. +# * "kitti": the KITTI 3D Object Detection benchmark. Expected layout +# (download from https://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d): +# /kitti/training/velodyne/000000.bin ... (x, y, z, intensity float32) +# /kitti/training/label_2/000000.txt ... (camera-frame 3D boxes) +# /kitti/training/calib/000000.txt ... (velo->cam calibration) +# * "synthetic": procedurally generated road scenes (ground plane + car / +# pedestrian / cyclist point clusters). Lets the example run +# out-of-the-box with zero download; useful to validate the +# whole WL pipeline before pointing it at real data. # # Per-sample target is a [N, 9] float32 array, one row per ground-truth box, # all in the LiDAR (velodyne) frame, metric units: # -# [cx, cy, cz, dx, dy, dz, yaw, class_id, confidence] +# [cx, cy, cz, dx, dy, dz, yaw, class_id, confidence] # -# cx/cy/cz: box center (m); dx/dy/dz: size along the object's x/y/z axes -# (length, width, height); yaw: rotation around +z; GT confidence = 1.0. +# cx/cy/cz: box center (m); dx/dy/dz: size along the object's x/y/z axes +# (length, width, height); yaw: rotation around +z; GT confidence = 1.0. CLASS_NAMES = ["Car", "Pedestrian", "Cyclist"] @@ -42,14 +42,14 @@ # Typical (length, width, height) per class, used by the synthetic generator. _CLASS_DIMS = np.array( [ - [4.0, 1.7, 1.5], # Car - [0.8, 0.6, 1.75], # Pedestrian - [1.8, 0.6, 1.7], # Cyclist + [4.0, 1.7, 1.5], # Car + [0.8, 0.6, 1.75], # Pedestrian + [1.8, 0.6, 1.7], # Cyclist ], dtype=np.float32, ) -_GROUND_Z = -1.7 # LiDAR is mounted ~1.7 m above the road in KITTI. +_GROUND_Z = -1.7 # LiDAR is mounted ~1.7 m above the road in KITTI. # ============================================================================= @@ -78,7 +78,7 @@ def read_kitti_calib(path): m[:3, :4] = vals.reshape(3, 4) mats["Tr_velo_to_cam"] = m elif key.strip() == "P2": - mats["P2"] = vals.reshape(3, 4) # left colour camera projection + mats["P2"] = vals.reshape(3, 4) # left colour camera projection return mats @@ -89,10 +89,10 @@ def project_velo_to_image(points_xyz, calib): the camera (positive depth). Used to colourise the cloud from image_2. """ n = points_xyz.shape[0] - homo = np.concatenate([points_xyz, np.ones((n, 1))], axis=1) # [N, 4] - cam = (calib["R0_rect"] @ calib["Tr_velo_to_cam"] @ homo.T) # [4, N] + homo = np.concatenate([points_xyz, np.ones((n, 1))], axis=1) # [N, 4] + cam = (calib["R0_rect"] @ calib["Tr_velo_to_cam"] @ homo.T) # [4, N] depth = cam[2] - pix = calib["P2"] @ cam # [3, N] + pix = calib["P2"] @ cam # [3, N] valid = depth > 1e-3 uv = np.zeros((n, 2), dtype=np.float32) uv[valid] = (pix[:2, valid] / pix[2, valid]).T @@ -113,7 +113,7 @@ def _read_kitti_kv_file(path): try: out[key.strip()] = np.array([float(v) for v in vals.split()], dtype=np.float64) except ValueError: - pass # non-numeric header lines (calib_time, etc.) + pass # non-numeric header lines (calib_time, etc.) return out @@ -158,7 +158,7 @@ def parse_tracklets(xml_path): for tracklet in root.iter("item"): otype = tracklet.findtext("objectType") if otype is None or otype not in _TRACKLET_CLASS_MAP: - continue # not a tracklet item, or a class we don't keep + continue # not a tracklet item, or a class we don't keep h = tracklet.findtext("h"); w = tracklet.findtext("w"); l = tracklet.findtext("l") first = tracklet.findtext("first_frame") poses = tracklet.find("poses") @@ -207,8 +207,8 @@ def read_kitti_label(label_path, calib, pc_range): ry = float(parts[14]) center = _cam_to_velo(loc_cam, calib)[0] - center[2] += h / 2.0 # KITTI location is the bottom face center - yaw = -ry - np.pi / 2.0 # camera rotation_y -> velo-frame yaw + center[2] += h / 2.0 # KITTI location is the bottom face center + yaw = -ry - np.pi / 2.0 # camera rotation_y -> velo-frame yaw x_min, y_min, z_min, x_max, y_max, z_max = pc_range if not (x_min <= center[0] <= x_max and y_min <= center[1] <= y_max @@ -225,14 +225,14 @@ def read_kitti_label(label_path, calib, pc_range): def _sample_box_surface(rng, dims, n): """Uniformly sample n points on the surface of an axis-aligned box at origin.""" l, w, h = dims - areas = np.array([w * h, w * h, l * h, l * h, l * w, l * w]) # +-x, +-y, +-z faces + areas = np.array([w * h, w * h, l * h, l * h, l * w, l * w]) # +-x, +-y, +-z faces face = rng.choice(6, size=n, p=areas / areas.sum()) u = rng.uniform(-0.5, 0.5, size=n) v = rng.uniform(-0.5, 0.5, size=n) pts = np.zeros((n, 3), dtype=np.float32) sign = np.where(face % 2 == 0, 0.5, -0.5) - ax = face // 2 # 0: x faces, 1: y faces, 2: z faces + ax = face // 2 # 0: x faces, 1: y faces, 2: z faces pts[ax == 0] = np.stack( [sign[ax == 0] * l, u[ax == 0] * w, v[ax == 0] * h], axis=1) pts[ax == 1] = np.stack( @@ -307,15 +307,15 @@ class Lidar3DDetectionDataset(Dataset): """LiDAR 3D box detection over KITTI scans or synthetic scenes. Args: - root: data directory (expects /kitti/training/* for KITTI). - split: "train" or "val" (deterministic split). - source: "kitti", "synthetic", or "auto" (kitti if present on disk). - num_classes: how many of CLASS_NAMES to keep. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop, meters. - max_points: random subsample cap per cloud (speed / memory). + root: data directory (expects /kitti/training/* for KITTI). + split: "train" or "val" (deterministic split). + source: "kitti", "synthetic", or "auto" (kitti if present on disk). + num_classes: how many of CLASS_NAMES to keep. + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max) crop, meters. + max_points: random subsample cap per cloud (speed / memory). num_synthetic: number of generated scenes when source is synthetic. - val_fraction: fraction of frames held out for validation. - max_samples: optional cap on the split size (for quick runs). + val_fraction: fraction of frames held out for validation. + max_samples: optional cap on the split size (for quick runs). """ def __init__( @@ -353,9 +353,9 @@ def __init__( # Per-point channels. xyz + intensity are always present (the model # consumes the first 4 columns); ``extra_features`` appends extra # VISUALISATION-only channels the studio viewer can colour/shade by: - # "normals" -> nx, ny, nz (PCA over neighbours) - # "rgb" -> r, g, b (camera image projection; KITTI only, - # synthetic falls back to a height pseudo-colour) + # "normals" -> nx, ny, nz (PCA over neighbours) + # "rgb" -> r, g, b (camera image projection; KITTI only, + # synthetic falls back to a height pseudo-colour) self.extra_features = tuple(str(f).strip().lower() for f in (extra_features or ())) # Real KITTI drives ship camera images + calibration, so colourise by # default (set extra_features explicitly to override, e.g. [] or [normals]). @@ -407,7 +407,7 @@ def __init__( download_dir = kitti_download_dir or default_download_dir() drives = list(kitti_raw_drives) or ["drive_0001"] frames = [] - self._raw_tracklets = {} # drive -> {frame_index: [N, 9] GT boxes} + self._raw_tracklets = {} # drive -> {frame_index: [N, 9] GT boxes} for drive in drives: if download: self._raw_date_dir = ensure_sequence(kitti_raw_date, drive, dest_dir=download_dir) @@ -446,7 +446,7 @@ def __init__( else: val_set = set(frames[::k]) selected = [f for f in frames if f not in val_set] - self.frames = selected[:max_samples] if max_samples is not None else selected + self.frames = selected[:max_samples] if max_samples != None else selected if len(self.frames) == 0: raise RuntimeError(f"No LiDAR frames found (source={source}, root={root})") @@ -513,7 +513,7 @@ def _enrich_features(self, points, calib, image_path): """Append the configured visualisation channels (normals, rgb) to [M, 4].""" if points.shape[0] == 0 or not self.extra_features: return points.astype(np.float32) - channels = [points[:, :4]] # x, y, z, intensity (always) + channels = [points[:, :4]] # x, y, z, intensity (always) if "normals" in self.extra_features: from weightslab.data.point_cloud_utils import compute_point_normals @@ -536,7 +536,7 @@ def _point_rgb(self, points, calib, image_path): points[:, :3], image, lambda p: project_velo_to_image(p, calib)) except Exception: - pass # fall through to pseudo-colour + pass # fall through to pseudo-colour # Synthetic / no image: pseudo-colour from height so the channel is useful. z_min, z_max = self.pc_range[2], self.pc_range[5] @@ -546,9 +546,9 @@ def _point_rgb(self, points, calib, image_path): def __getitem__(self, idx): """Returns (item, uid, target, metadata). - - item: point cloud FloatTensor [M, 4] (x, y, z, intensity) - - uid: unique sample id (string) - - target: [N, 9] float32 = [cx, cy, cz, dx, dy, dz, yaw, cls, conf] + - item: point cloud FloatTensor [M, 4] (x, y, z, intensity) + - uid: unique sample id (string) + - target: [N, 9] float32 = [cx, cy, cz, dx, dy, dz, yaw, cls, conf] - metadata: dict with source paths / generation seed """ return self.get_items(idx, include_metadata=True, include_labels=True, include_images=True) @@ -571,10 +571,10 @@ def lidar_collate(batch): layout WL's per-instance helpers expect. Returns: - points: FloatTensor [B, M_max, 4] - ids: list[str] of length B + points: FloatTensor [B, M_max, 4] + ids: list[str] of length B targets: list[B] of [N_i, 9] float tensors - metas: list[B] of metadata dicts + metas: list[B] of metadata dicts """ clouds = [ b[0] if isinstance(b[0], torch.Tensor) else torch.as_tensor(b[0], dtype=torch.float32) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py index 372e8a1d..e70c63df 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/kitti_download.py @@ -13,8 +13,8 @@ calib_velo_to_cam.txt calib_imu_to_velo.txt __sync/ - velodyne_points/data/0000000000.bin ... (x, y, z, reflectance float32) - image_02/data/0000000000.png ... (left colour camera) + velodyne_points/data/0000000000.bin ... (x, y, z, reflectance float32) + image_02/data/0000000000.png ... (left colour camera) ... Downloads stream to disk with a tqdm progress bar and are idempotent (a @@ -149,8 +149,8 @@ def ensure_sequence(date, drive, dest_dir=None, keep_zip=False): """Download + extract one raw sequence (idempotent). Returns the date dir. Args: - date: e.g. "2011_09_26". - drive: e.g. "drive_0001". + date: e.g. "2011_09_26". + drive: e.g. "drive_0001". dest_dir: where to download/extract (default: a temp dir). keep_zip: keep the downloaded .zip after extraction (default: delete). @@ -162,7 +162,7 @@ def ensure_sequence(date, drive, dest_dir=None, keep_zip=False): seq_dir = os.path.join(dest_dir, date, f"{date}_{drive}_sync") if os.path.isdir(os.path.join(seq_dir, "velodyne_points", "data")): - return os.path.join(dest_dir, date) # already extracted + return os.path.join(dest_dir, date) # already extracted filename = f"{date}_{drive}_sync.zip" zip_path = os.path.join(dest_dir, filename) diff --git a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py index 6cf83231..0af99b0d 100644 --- a/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py +++ b/weightslab/examples/Usecases/ws-3d-lidar-detection/utils/model.py @@ -4,24 +4,24 @@ # Three stages, following the PointPillars recipe (Lang et al., CVPR 2019) but # heavily slimmed down: # -# 1. Pillar Feature Net: points are grouped into vertical columns ("pillars") -# on a BEV grid; each point gets 9 features (x, y, z, intensity, offsets -# to the pillar's point mean, offsets to the pillar center), runs through -# a shared Linear+BN+ReLU, and is max-pooled per pillar -> a sparse -# [C, H, W] BEV pseudo-image. -# 2. A tiny 2D CNN backbone over the BEV pseudo-image (2 stride-2 blocks). -# 3. A YOLO-v1-style grid head: each S x S BEV cell predicts ONE 3D box: -# (objectness, tx, ty, tz, log l, log w, log h, sin yaw, cos yaw, -# class_logits...). +# 1. Pillar Feature Net: points are grouped into vertical columns ("pillars") +# on a BEV grid; each point gets 9 features (x, y, z, intensity, offsets +# to the pillar's point mean, offsets to the pillar center), runs through +# a shared Linear+BN+ReLU, and is max-pooled per pillar -> a sparse +# [C, H, W] BEV pseudo-image. +# 2. A tiny 2D CNN backbone over the BEV pseudo-image (2 stride-2 blocks). +# 3. A YOLO-v1-style grid head: each S x S BEV cell predicts ONE 3D box: +# (objectness, tx, ty, tz, log l, log w, log h, sin yaw, cos yaw, +# class_logits...). # # Encoding (BEV cell-relative, mirrors the 2D ws-detection example): -# * objectness = sigmoid(t_obj) -> P(box centered in cell) -# * cx = x_min + (col + sigmoid(tx)) / S * range_x -# * cy = y_min + (row + sigmoid(ty)) / S * range_y -# * cz = z_min + sigmoid(tz) * range_z -# * (l, w, h) = exp(t_l, t_w, t_h) -> size in meters -# * yaw = atan2(t_sin, t_cos) -# * class = softmax(class_logits) +# * objectness = sigmoid(t_obj) -> P(box centered in cell) +# * cx = x_min + (col + sigmoid(tx)) / S * range_x +# * cy = y_min + (row + sigmoid(ty)) / S * range_y +# * cz = z_min + sigmoid(tz) * range_z +# * (l, w, h) = exp(t_l, t_w, t_h) -> size in meters +# * yaw = atan2(t_sin, t_cos) +# * class = softmax(class_logits) # # Raw forward output keeps logits (the loss applies activations); `decode_grid_3d` # turns logits into metric 3D boxes for metrics and prediction dumps. @@ -39,13 +39,13 @@ def decode_grid_3d(outputs, grid_size, pc_range): Shared by the model and the criterions so the encoding lives in one place. Args: - outputs: [B, S, S, 9 + num_classes] raw logits. + outputs: [B, S, S, 9 + num_classes] raw logits. grid_size: S. - pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). + pc_range: (x_min, y_min, z_min, x_max, y_max, z_max). Returns: - boxes: [B, S, S, 7] (cx, cy, cz, l, w, h, yaw) in meters - obj: [B, S, S] objectness probability + boxes: [B, S, S, 7] (cx, cy, cz, l, w, h, yaw) in meters + obj: [B, S, S] objectness probability cls_probs: [B, S, S, num_classes] class probabilities """ B, S = outputs.shape[0], grid_size @@ -56,7 +56,7 @@ def decode_grid_3d(outputs, grid_size, pc_range): tx = torch.sigmoid(outputs[..., 1]) ty = torch.sigmoid(outputs[..., 2]) tz = torch.sigmoid(outputs[..., 3]) - dims = torch.exp(outputs[..., 4:7].clamp(-4.0, 4.0)) # (l, w, h), meters + dims = torch.exp(outputs[..., 4:7].clamp(-4.0, 4.0)) # (l, w, h), meters yaw = torch.atan2(outputs[..., 7], outputs[..., 8]) cls_probs = torch.softmax(outputs[..., 9:], dim=-1) @@ -94,11 +94,11 @@ def __init__( self.pc_range = tuple(pc_range) self.voxel_size = float(voxel_size) self.pad_value = float(pad_value) - self.input_shape = (1, 4096, 4) # padded cloud [B, M, 4] for summaries + self.input_shape = (1, 4096, 4) # padded cloud [B, M, 4] for summaries x_min, y_min, _, x_max, y_max, _ = self.pc_range - self.nx = int(round((x_max - x_min) / voxel_size)) # BEV canvas cols - self.ny = int(round((y_max - y_min) / voxel_size)) # BEV canvas rows + self.nx = int(round((x_max - x_min) / voxel_size)) # BEV canvas cols + self.ny = int(round((y_max - y_min) / voxel_size)) # BEV canvas rows # Channels per head cell: obj(1) + box(8: tx ty tz, log lwh, sin cos) # + class logits(num_classes) @@ -176,7 +176,7 @@ def _augment_points(self, points): cy = y_min + (iy.to(pts.dtype) + 0.5) * self.voxel_size f_center = torch.stack([pts[:, 0] - cx, pts[:, 1] - cy], dim=1) - feats = torch.cat([pts, f_cluster, f_center], dim=1) # [M, 9] + feats = torch.cat([pts, f_cluster, f_center], dim=1) # [M, 9] return feats, flat def _scatter_to_canvas(self, point_feats, pillar_idx): @@ -217,9 +217,9 @@ def forward(self, points): else: canvases.append(self._scatter_to_canvas(chunks.pop(0), flat)) - x = torch.stack(canvases, dim=0) # [B, C, ny, nx] - x = self.backbone(x) # [B, 128, ny/4, nx/4] - out = self.head(x) # [B, preds_per_cell, S', S'] + x = torch.stack(canvases, dim=0) # [B, C, ny, nx] + x = self.backbone(x) # [B, 128, ny/4, nx/4] + out = self.head(x) # [B, preds_per_cell, S', S'] # Resize the feature grid to the configured head grid_size. if out.shape[-1] != self.grid_size or out.shape[-2] != self.grid_size: diff --git a/weightslab/integrations/ultralytics/__init__.py b/weightslab/integrations/ultralytics/__init__.py index 7bb3989e..15f7c57e 100644 --- a/weightslab/integrations/ultralytics/__init__.py +++ b/weightslab/integrations/ultralytics/__init__.py @@ -21,8 +21,8 @@ YOLO(cfg["model"]).train( trainer=WLAwareTrainer, data=cfg["data_root"], imgsz=640, epochs=1000, batch=4, - project="./logs", name="exp", # → WL log_dir/name - workers=0, # WL invariant (parent-process uid counter) + project="./logs", name="exp", # → WL log_dir/name + workers=0, # WL invariant (parent-process uid counter) ) wl.keep_serving() diff --git a/weightslab/integrations/ultralytics/dataset.py b/weightslab/integrations/ultralytics/dataset.py index 4d819b59..ff4be229 100644 --- a/weightslab/integrations/ultralytics/dataset.py +++ b/weightslab/integrations/ultralytics/dataset.py @@ -60,9 +60,9 @@ def fast_get_label(self, i): if shp is None: from PIL import Image as _PIL with _PIL.open(lab["im_file"]) as im: - w0, h0 = im.size # PIL: (w, h) + w0, h0 = im.size # PIL: (w, h) shp = (h0, w0) - lab["shape"] = shp # memoize + lab["shape"] = shp # memoize h0, w0 = shp new = self.imgsz r = min(new / h0, new / w0) diff --git a/weightslab/integrations/ultralytics/signals.py b/weightslab/integrations/ultralytics/signals.py index c1610a8f..e6a733f7 100644 --- a/weightslab/integrations/ultralytics/signals.py +++ b/weightslab/integrations/ultralytics/signals.py @@ -47,12 +47,12 @@ # fallbacks kick in only if `model.args.{conf,iou}` is `None` — which # happens for training because UL only auto-populates those for predict. # -# OVERLAY_CONF_FALLBACK — tiny so early-epoch overlays aren't empty. -# UL's predict default of 0.25 would hide the -# model entirely while it's still learning. -# OVERLAY_IOU_FALLBACK — matches UL's default inference IoU. -# OVERLAY_MAX_DETS — readability cap; UL's NMS otherwise produces -# up to 300 boxes per image, flooding the studio. +# OVERLAY_CONF_FALLBACK — tiny so early-epoch overlays aren't empty. +# UL's predict default of 0.25 would hide the +# model entirely while it's still learning. +# OVERLAY_IOU_FALLBACK — matches UL's default inference IoU. +# OVERLAY_MAX_DETS — readability cap; UL's NMS otherwise produces +# up to 300 boxes per image, flooding the studio. OVERLAY_CONF_FALLBACK = 1e-4 OVERLAY_IOU_FALLBACK = 0.45 OVERLAY_MAX_DETS = 50 @@ -157,7 +157,7 @@ class Signal: `preds=` kwarg. """ name: str - flag: str # "loss" | "metric" + flag: str # "loss" | "metric" reduce: Callable[[dict], Optional[th.Tensor]] preds: Optional[Callable[[dict], Optional[dict]]] = None @@ -211,8 +211,8 @@ def install_val_pipeline(validator, signals: list[Signal]): channels = _make_channels(signals) _orig = validator.update_metrics def _ship(preds, batch): - validator._wl_preds = preds # exposed to signal reducers/predsers - res = _orig(preds, batch) # runs first — fills _process_batch buf + validator._wl_preds = preds # exposed to signal reducers/predsers + res = _orig(preds, batch) # runs first — fills _process_batch buf _ship_round(signals, channels, batch) return res validator.update_metrics = _ship @@ -255,9 +255,9 @@ def default_train_signals(model, signals_cfg: dict = {}) -> list[Signal]: bl = crit.bbox_loss detect_head = next((m for m in model.modules() if isinstance(m, Detect)), None) - get_bce = fwd_hook(crit.bce) # bce is a plain nn.Module - get_iou = fn_tap(ul_loss, "bbox_iou") # bbox_iou is a plain function - get_dfl = method_call_tap(bl, "dfl_loss") # DFLoss overrides __call__ + get_bce = fwd_hook(crit.bce) # bce is a plain nn.Module + get_iou = fn_tap(ul_loss, "bbox_iou") # bbox_iou is a plain function + get_dfl = method_call_tap(bl, "dfl_loss") # DFLoss overrides __call__ get_bl_args = pre_hook(bl) def _fg_state(): diff --git a/weightslab/integrations/ultralytics/trainer.py b/weightslab/integrations/ultralytics/trainer.py index 95a54dfa..6efc7aaa 100644 --- a/weightslab/integrations/ultralytics/trainer.py +++ b/weightslab/integrations/ultralytics/trainer.py @@ -89,7 +89,7 @@ def _validate(loader): except Exception as e: raised_exc = e finally: - trainer.validator.dataloader = val_loader # Reset val loader + trainer.validator.dataloader = val_loader # Reset val loader # Finally raise exc. if raised_exc is not None: @@ -120,20 +120,20 @@ def _on_val_end(validator): return for ul_key, wl_key in ( ("metrics/precision(B)", "val/precision"), - ("metrics/recall(B)", "val/recall"), - ("metrics/mAP50(B)", "val/mAP50"), - ("metrics/mAP50-95(B)", "val/mAP50-95"), - ("fitness", "val/fitness"), + ("metrics/recall(B)", "val/recall"), + ("metrics/mAP50(B)", "val/mAP50"), + ("metrics/mAP50-95(B)", "val/mAP50-95"), + ("fitness", "val/fitness"), ): if ul_key in rd and wl_key in ch: ch[wl_key](torch.tensor([float(rd[ul_key])])) - self.add_callback("on_train_start", _on_train_start) + self.add_callback("on_train_start", _on_train_start) self.add_callback("on_train_batch_start", _on_train_batch_start) - self.add_callback("on_train_batch_end", _on_train_batch_end) - self.add_callback("on_val_batch_start", _on_val_batch_start) - self.add_callback("on_val_batch_end", _on_val_batch_end) - self.add_callback("on_val_end", _on_val_end) + self.add_callback("on_train_batch_end", _on_train_batch_end) + self.add_callback("on_val_batch_start", _on_val_batch_start) + self.add_callback("on_val_batch_end", _on_val_batch_end) + self.add_callback("on_val_end", _on_val_end) def validate(self): # UL's metrics.process does np.concatenate([]) → ValueError when val diff --git a/weightslab/models/model_with_ops.py b/weightslab/models/model_with_ops.py index c3a7b13e..322abcff 100755 --- a/weightslab/models/model_with_ops.py +++ b/weightslab/models/model_with_ops.py @@ -20,9 +20,9 @@ def __init__(self): # Initialize variables self.current_step = 0 - self.visited_nodes = set() # Memory trace of explored nodes - self.visited_incoming_nodes = set() # Memory trace of explored nodes - self.name = self._get_name() # Name of the model + self.visited_nodes = set() # Memory trace of explored nodes + self.visited_incoming_nodes = set() # Memory trace of explored nodes + self.name = self._get_name() # Name of the model self.linearized_layers = [] self._architecture_change_hook_fns = [] self.tracking_mode = TrackingMode.DISABLED @@ -450,7 +450,7 @@ def _operate( elif current_child_name is not None and current_child_name in module.src_to_dst_mapping_tnsrs: kwargs['current_child_name'] = current_child_name else: - kwargs['current_child_name'] = None # Its child is an Orphan node + kwargs['current_child_name'] = None # Its child is an Orphan node # # Operate module.operate( neuron_indices, diff --git a/weightslab/models/monkey_patcher.py b/weightslab/models/monkey_patcher.py index 8a19ce4c..13e7bd79 100644 --- a/weightslab/models/monkey_patcher.py +++ b/weightslab/models/monkey_patcher.py @@ -74,7 +74,7 @@ def wrapped_forward(self, input): data=input ) return output - module.forward = types.MethodType(wrapped_forward, module) # Monkey patch + module.forward = types.MethodType(wrapped_forward, module) # Monkey patch module.is_leaf = True return module diff --git a/weightslab/modules/modules_with_ops.py b/weightslab/modules/modules_with_ops.py index c4c54c7e..a4b96763 100644 --- a/weightslab/modules/modules_with_ops.py +++ b/weightslab/modules/modules_with_ops.py @@ -46,7 +46,7 @@ def __init__( self.module_name = module_name self.device = device self.tracking_mode = TrackingMode.DISABLED - self.operation_age = {op.name: 0 for op in ArchitectureNeuronsOpType} # keep track of all operations performed + self.operation_age = {op.name: 0 for op in ArchitectureNeuronsOpType} # keep track of all operations performed # IN/OUT neurons indexing & mapping dictionary self.src_to_dst_mapping_tnsrs = {} @@ -77,7 +77,7 @@ def __init__( } # Naming - self.assign_id() # assign ids + self.assign_id() # assign ids # Tracking self.register_trackers() @@ -183,7 +183,7 @@ def __hash__(self) -> int: # Trackers Functions # ================== def register_trackers(self): - is_disabled = bool(getattr(self, "wl_same_flag", False)) # Remove SAME layer like BN from neurons stats ..etc + is_disabled = bool(getattr(self, "wl_same_flag", False)) # Remove SAME layer like BN from neurons stats ..etc # Train if self.get_neurons('out_neurons') is not None: @@ -246,7 +246,7 @@ def get_operation( Callable: The operation function. """ if callable(op_type): - return op_type # if already got, just return the fct + return op_type # if already got, just return the fct elif op_type == ArchitectureNeuronsOpType.ADD or \ op_type == ArchitectureNeuronsOpType.ADD.value: return self._add_neurons @@ -435,7 +435,7 @@ def _process_neurons_indices( elif not isinstance(neuron_indices, set): # If it's a single int, wrap; if it's iterable, cast to set try: - neuron_indices = set(neuron_indices) # type: ignore[arg-type] + neuron_indices = set(neuron_indices) # type: ignore[arg-type] except TypeError: neuron_indices = {neuron_indices} @@ -470,7 +470,7 @@ def _process_neurons_indices( if mapped_indices_dict is not None: mapped_indexs = normalize_dicts( {"mapped": mapped_indices_dict} - )["mapped"] # TODO (GP): Improve this function + )["mapped"] # TODO (GP): Improve this function else: # No mapping tensors available: fall back to identity mapping n_neurons = self.get_neurons( @@ -526,7 +526,7 @@ def register( tracker = self.get_tracker() if tracker is None or activation_map is None or input is None: return - activation_map = (activation_map > 0).long() # bool to int + activation_map = (activation_map > 0).long() # bool to int processed_activation_map = th.sum(activation_map, dim=(-2, -1)) if len(activation_map.shape) > 2 else activation_map copy_forward_tracked_attrs(processed_activation_map, activation_map) tracker.update(processed_activation_map) @@ -995,7 +995,7 @@ def _add_neurons( self.related_dst_to_src_mapping_tnsrs[ current_name ].keys() - )[neuron_indice + -1 + length] # get new index + )[neuron_indice + -1 + length] # get new index # Update the mapping tensor with 1 or range(x) neurons self.related_dst_to_src_mapping_tnsrs[ @@ -1013,7 +1013,7 @@ def _add_neurons( channel_size, mapped_neuron_indice * channel_size ) - ) # in range of x neurons + ) # in range of x neurons ] } ) @@ -1042,13 +1042,13 @@ def _add_neurons( self.super_in_name, self.get_neurons(self.super_in_name) + nb_neurons ) - ) # Update neurons count + ) # Update neurons count elif dependency == DepType.SAME: if self.get_neurons(self.super_out_name) is not None: self.set_neurons( attr_name='in_neurons', new_value=self.get_neurons(self.super_out_name) - ) # Update neurons count + ) # Update neurons count # By default get deps name from current relation deps_names = list(self.dst_to_src_mapping_tnsrs.keys()) @@ -1071,7 +1071,7 @@ def _add_neurons( ) if index >= len(mapped_neuron_indice): logger.warning( - f"Index {index} out of range for " + + f"Index {index} out of range for " + f"mapped_neuron_indice with length " + f"{len(mapped_neuron_indice)}" ) @@ -1086,8 +1086,8 @@ def _add_neurons( self.dst_to_src_mapping_tnsrs[ deps_name ][mapped_neuron_indice] - ) for i in range(0, nb_neurons) # neurons - ] for j in range(0, nb_neurons) # neurons | chan. + ) for i in range(0, nb_neurons) # neurons + ] for j in range(0, nb_neurons) # neurons | chan. } ) @@ -1195,7 +1195,7 @@ def _prune_neurons( f"overlap: {neuron_indices} & {neurons} => " f"{neuron_indices & neurons}" ) - return # Do not change + return # Do not change # # Enough neurons to operate if len(neurons) <= 1: @@ -1221,7 +1221,7 @@ def _prune_neurons( ) self.weight = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if self.weight.grad is not None: with th.no_grad(): @@ -1232,7 +1232,7 @@ def _prune_neurons( ) self.weight.grad = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if hasattr(self, 'bias') and self.bias is not None and \ not is_incoming: @@ -1244,7 +1244,7 @@ def _prune_neurons( ) self.bias.data = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if self.bias.grad is not None: with th.no_grad(): @@ -1255,7 +1255,7 @@ def _prune_neurons( ) self.bias.grad = nn.Parameter( tmp_tsnr.clone().detach() - ).to(self.device) # Safe approach + ).to(self.device) # Safe approach if hasattr(self, 'running_mean'): tmp_tsnr = th.index_select( @@ -1263,7 +1263,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_mean = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_mean = tmp_tsnr.clone().detach().to(self.device) # Safe approach if self.running_mean.grad is not None: tmp_tsnr = th.index_select( @@ -1271,7 +1271,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_mean.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_mean.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach if hasattr(self, 'running_var'): tmp_tsnr = th.index_select( @@ -1279,7 +1279,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_var = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_var = tmp_tsnr.clone().detach().to(self.device) # Safe approach if self.running_var.grad is not None: tmp_tsnr = th.index_select( @@ -1287,7 +1287,7 @@ def _prune_neurons( dim=0, index=idx_tnsr ) - self.running_var.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach + self.running_var.grad = tmp_tsnr.clone().detach().to(self.device) # Safe approach # Sort indices to prune from last to first to maintain # the original order @@ -1385,12 +1385,12 @@ def _prune_neurons( self.super_in_name, len(idx_tokeep) ) - ) # Update neurons count + ) # Update neurons count elif dependency == DepType.SAME: self.set_neurons( attr_name='in_neurons', new_value=self.get_neurons(self.super_out_name) - ) # Update neurons count + ) # Update neurons count # By default get deps name from current relation deps_names = self.dst_to_src_mapping_tnsrs.keys() @@ -1494,7 +1494,7 @@ def _freeze_neurons( # Work on the output tensors_name = self.learnable_tensors_name if not is_incoming \ - else ['weight'] # Weight is the only learnable tensor input + else ['weight'] # Weight is the only learnable tensor input for tensor_name in tensors_name: neurons_lr = { neuron_indices[n]: diff --git a/weightslab/proto/experiment_service.proto b/weightslab/proto/experiment_service.proto index 73de39d7..770965fd 100644 --- a/weightslab/proto/experiment_service.proto +++ b/weightslab/proto/experiment_service.proto @@ -17,6 +17,13 @@ service ExperimentService { // Data Service (for weights_studio UI) rpc ApplyDataQuery (DataQueryRequest) returns (DataQueryResponse); rpc GetDataSamples (DataSamplesRequest) returns (DataSamplesResponse); + // Server-side histogram binning of one metadata/signal column. + rpc GetHistogram (HistogramRequest) returns (HistogramResponse); + // Metadata-only retrieval (dataframe columns). Returns every metadata column + // name for the WHOLE dataset, the current grid slice's per-sample metadata, and + // the open modal sample's metadata. Separated from GetDataSamples, which now + // returns only image / label / prediction data. + rpc GetMetaData (GetMetaDataRequest) returns (GetMetaDataResponse); // Raw point cloud of one sample (task_type "detection_pointcloud"), server-streamed // in binary chunks for the interactive 3D viewer. rpc GetPointCloud (PointCloudRequest) returns (stream PointCloudChunk); @@ -162,6 +169,13 @@ message PlotNoteOperation { string note = 4; } +// Manual "save now" trigger: force a checkpoint of the current model weights +// (and, when requested, the architecture) regardless of pending-change tracking. +message SaveCheckpointOperation { + bool save_architecture = 1; // force re-dump architecture even if a file already exists + bool save_optimizer = 2; // also persist optimizer state +} + message TrainerCommand { bool get_hyper_parameters = 4; bool get_interactive_layers = 5; @@ -174,6 +188,7 @@ message TrainerCommand { optional DenySamplesOperation remove_from_denylist_operation = 11; optional DenySamplesOperation remove_eval_from_denylist_operation = 12; optional PlotNoteOperation plot_note_operation = 13; + optional SaveCheckpointOperation save_checkpoint_operation = 14; } message HyperParameterDesc { @@ -382,6 +397,52 @@ message DataSamplesResponse { repeated DataRecord data_records = 3; } +// --- Server-side histogram binning --- +// One stacked sub-segment of a bar: count of samples in this bin for a given +// (origin, discarded) combination (used to colour train/eval/discarded splits). +message HistogramSubBar { + string origin = 1; + bool discarded = 2; + int64 count = 3; +} + +// One histogram bar: aggregate stats over the samples whose row index falls in +// this bin's range, plus the per-(origin,discarded) breakdown. +message HistogramBin { + double min = 1; + double max = 2; + double avg = 3; + int64 count = 4; + repeated HistogramSubBar sub_bars = 5; +} + +message HistogramRequest { + string column = 1; // dataframe/signal column to histogram + int32 max_bins = 2; // 0 => server default (512) +} + +message HistogramResponse { + bool success = 1; + string message = 2; + int64 total_rows = 3; // rows in the view that were binned + repeated HistogramBin bins = 4; +} + +// --- Metadata retrieval (separated from GetDataSamples) --- +message GetMetaDataRequest { + int32 start_index = 1; // grid slice start (current view order) + int32 records_cnt = 2; // grid slice size + string modal_sample_id = 3; // optional: sample_id of the open modal ("" = none) +} + +message GetMetaDataResponse { + bool success = 1; + string message = 2; + repeated string all_metadata_names = 3; // every metadata column for the WHOLE dataset + repeated DataRecord grid_records = 4; // per-sample metadata for the requested slice + DataRecord modal_record = 5; // metadata for the open modal sample (if found) +} + // --- Point cloud transfer (task_type "detection_pointcloud") --- message PointCloudRequest { string sample_id = 1; diff --git a/weightslab/proto/experiment_service_pb2.py b/weightslab/proto/experiment_service_pb2.py index 9ea162d1..14d171a0 100644 --- a/weightslab/proto/experiment_service_pb2.py +++ b/weightslab/proto/experiment_service_pb2.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! +# Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: weightslab/proto/experiment_service.proto # Protobuf Python Version: 6.31.1 @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"\xdc\x06\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\x84\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)weightslab/proto/experiment_service.proto\"\x89\x01\n\x1aGetLatestLoggerDataRequest\x12\x1c\n\x14request_full_history\x18\x01 \x01(\x08\x12\x12\n\nmax_points\x18\x02 \x01(\x05\x12\x17\n\x0f\x62reak_by_slices\x18\x03 \x01(\x08\x12\x0c\n\x04tags\x18\x04 \x03(\t\x12\x12\n\ngraph_name\x18\x05 \x01(\t\"\x81\x02\n\x0fLoggerDataPoint\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x11\n\tmodel_age\x18\x02 \x01(\x05\x12\x14\n\x0cmetric_value\x18\x03 \x01(\x02\x12\x17\n\x0f\x65xperiment_hash\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\x03\x12\x11\n\tsample_id\x18\x06 \x01(\t\x12\x1c\n\x14is_evaluation_marker\x18\x07 \x01(\x08\x12\x12\n\nsplit_name\x18\x08 \x01(\t\x12\x17\n\x0f\x65valuation_tags\x18\t \x03(\t\x12\x12\n\npoint_note\x18\n \x01(\t\x12\x12\n\naudit_mode\x18\x0b \x01(\x08\"?\n\x1bGetLatestLoggerDataResponse\x12 \n\x06points\x18\x01 \x03(\x0b\x32\x10.LoggerDataPoint\"\x07\n\x05\x45mpty\"/\n\x08NeuronId\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tneuron_id\x18\x02 \x01(\x05\"\x91\x02\n\x0fWeightOperation\x12*\n\x07op_type\x18\x01 \x01(\x0e\x32\x14.WeightOperationTypeH\x00\x88\x01\x01\x12\x15\n\x08layer_id\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1d\n\nneuron_ids\x18\x03 \x03(\x0b\x32\t.NeuronId\x12\x16\n\x0eneurons_to_add\x18\t \x01(\x05\x12 \n\x18zerofy_from_incoming_ids\x18\x0b \x03(\x05\x12\x1c\n\x14zerofy_to_neuron_ids\x18\x0c \x03(\x05\x12+\n\x11zerofy_predicates\x18\r \x03(\x0e\x32\x10.ZerofyPredicateB\n\n\x08_op_typeB\x0b\n\t_layer_id\"_\n\x17WeightsOperationRequest\x12/\n\x10weight_operation\x18\x01 \x01(\x0b\x32\x10.WeightOperationH\x00\x88\x01\x01\x42\x13\n\x11_weight_operation\"<\n\x18WeightsOperationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xc1\x05\n\x0fHyperParameters\x12\x1c\n\x0f\x65xperiment_name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12!\n\x14training_steps_to_do\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x17\n\nbatch_size\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12 \n\x13\x66ull_eval_frequency\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12 \n\x13\x63heckpont_frequency\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x18\n\x0bis_training\x18\x07 \x01(\x08H\x06\x88\x01\x01\x12\x15\n\x08nb_steps\x18\x08 \x01(\x05H\x07\x88\x01\x01\x12\x19\n\x0c\x61uditor_mode\x18\t \x01(\x08H\x08\x88\x01\x01\x12\x1d\n\x10train_batch_size\x18\n \x01(\x05H\t\x88\x01\x01\x12\x1b\n\x0eval_batch_size\x18\x0b \x01(\x05H\n\x88\x01\x01\x12\x1c\n\x0ftest_batch_size\x18\x0c \x01(\x05H\x0b\x88\x01\x01\x12\x1c\n\x0f\x65valuation_mode\x18\r \x01(\x08H\x0c\x88\x01\x01\x12\x1e\n\x11\x65valuation_config\x18\x0e \x01(\tH\r\x88\x01\x01\x42\x12\n\x10_experiment_nameB\x17\n\x15_training_steps_to_doB\x10\n\x0e_learning_rateB\r\n\x0b_batch_sizeB\x16\n\x14_full_eval_frequencyB\x16\n\x14_checkpont_frequencyB\x0e\n\x0c_is_trainingB\x0b\n\t_nb_stepsB\x0f\n\r_auditor_modeB\x13\n\x11_train_batch_sizeB\x11\n\x0f_val_batch_sizeB\x12\n\x10_test_batch_sizeB\x12\n\x10_evaluation_modeB\x14\n\x12_evaluation_config\",\n\rMetricsStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"~\n\rAnnotatStatus\x12\x0c\n\x04name\x18\x01 \x01(\t\x12.\n\x08metadata\x18\x02 \x03(\x0b\x32\x1c.AnnotatStatus.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x90\x02\n\x10TrainingStatusEx\x12\x16\n\ttimestamp\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0f\x65xperiment_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x16\n\tmodel_age\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12+\n\x0emetrics_status\x18\x04 \x01(\x0b\x32\x0e.MetricsStatusH\x03\x88\x01\x01\x12+\n\x0e\x61nnotat_status\x18\x05 \x01(\x0b\x32\x0e.AnnotatStatusH\x04\x88\x01\x01\x42\x0c\n\n_timestampB\x12\n\x10_experiment_nameB\x0c\n\n_model_ageB\x11\n\x0f_metrics_statusB\x11\n\x0f_annotat_status\"]\n\x15HyperParameterCommand\x12/\n\x10hyper_parameters\x18\x01 \x01(\x0b\x32\x10.HyperParametersH\x00\x88\x01\x01\x42\x13\n\x11_hyper_parameters\">\n\x14\x44\x65nySamplesOperation\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\"0\n\x17LoadCheckpointOperation\x12\x15\n\rcheckpoint_id\x18\x01 \x01(\x05\"b\n\x11PlotNoteOperation\x12\x13\n\x0bmetric_name\x18\x01 \x01(\t\x12\x17\n\x0f\x65xperiment_hash\x18\x02 \x01(\t\x12\x11\n\tmodel_age\x18\x03 \x01(\x05\x12\x0c\n\x04note\x18\x04 \x01(\t\"L\n\x17SaveCheckpointOperation\x12\x19\n\x11save_architecture\x18\x01 \x01(\x08\x12\x16\n\x0esave_optimizer\x18\x02 \x01(\x08\"\xbc\x07\n\x0eTrainerCommand\x12\x1c\n\x14get_hyper_parameters\x18\x04 \x01(\x08\x12\x1e\n\x16get_interactive_layers\x18\x05 \x01(\x08\x12\x1d\n\x10get_data_records\x18\x06 \x01(\tH\x00\x88\x01\x01\x12%\n\x18get_single_layer_info_id\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12;\n\x16hyper_parameter_change\x18\x01 \x01(\x0b\x32\x16.HyperParameterCommandH\x02\x88\x01\x01\x12:\n\x16\x64\x65ny_samples_operation\x18\x07 \x01(\x0b\x32\x15.DenySamplesOperationH\x03\x88\x01\x01\x12?\n\x1b\x64\x65ny_eval_samples_operation\x18\n \x01(\x0b\x32\x15.DenySamplesOperationH\x04\x88\x01\x01\x12@\n\x19load_checkpoint_operation\x18\t \x01(\x0b\x32\x18.LoadCheckpointOperationH\x05\x88\x01\x01\x12\x42\n\x1eremove_from_denylist_operation\x18\x0b \x01(\x0b\x32\x15.DenySamplesOperationH\x06\x88\x01\x01\x12G\n#remove_eval_from_denylist_operation\x18\x0c \x01(\x0b\x32\x15.DenySamplesOperationH\x07\x88\x01\x01\x12\x34\n\x13plot_note_operation\x18\r \x01(\x0b\x32\x12.PlotNoteOperationH\x08\x88\x01\x01\x12@\n\x19save_checkpoint_operation\x18\x0e \x01(\x0b\x32\x18.SaveCheckpointOperationH\t\x88\x01\x01\x42\x13\n\x11_get_data_recordsB\x1b\n\x19_get_single_layer_info_idB\x19\n\x17_hyper_parameter_changeB\x19\n\x17_deny_samples_operationB\x1e\n\x1c_deny_eval_samples_operationB\x1c\n\x1a_load_checkpoint_operationB!\n\x1f_remove_from_denylist_operationB&\n$_remove_eval_from_denylist_operationB\x16\n\x14_plot_note_operationB\x1c\n\x1a_save_checkpoint_operation\"\x9d\x01\n\x12HyperParameterDesc\x12\r\n\x05label\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x1c\n\x0fnumerical_value\x18\x04 \x01(\x02H\x00\x88\x01\x01\x12\x19\n\x0cstring_value\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_numerical_valueB\x0f\n\r_string_value\"\xf2\x02\n\x10NeuronStatistics\x12!\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronIdH\x00\x88\x01\x01\x12\x17\n\nneuron_age\x18\x02 \x01(\x05H\x01\x88\x01\x01\x12\x1f\n\x12train_trigger_rate\x18\x03 \x01(\x02H\x02\x88\x01\x01\x12\x1e\n\x11\x65val_trigger_rate\x18\x04 \x01(\x02H\x03\x88\x01\x01\x12\x1a\n\rlearning_rate\x18\x07 \x01(\x02H\x04\x88\x01\x01\x12\x36\n\x0bincoming_lr\x18\x08 \x03(\x0b\x32!.NeuronStatistics.IncomingLrEntry\x1a\x31\n\x0fIncomingLrEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\n_neuron_idB\r\n\x0b_neuron_ageB\x15\n\x13_train_trigger_rateB\x14\n\x12_eval_trigger_rateB\x10\n\x0e_learning_rate\"\xf0\x02\n\x13LayerRepresentation\x12\x15\n\x08layer_id\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x1a\n\rneurons_count\x18\x04 \x01(\x05H\x03\x88\x01\x01\x12#\n\x16incoming_neurons_count\x18\x05 \x01(\x05H\x04\x88\x01\x01\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x05\x88\x01\x01\x12\x13\n\x06stride\x18\x07 \x01(\x05H\x06\x88\x01\x01\x12-\n\x12neurons_statistics\x18\n \x03(\x0b\x32\x11.NeuronStatisticsB\x0b\n\t_layer_idB\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x10\n\x0e_neurons_countB\x19\n\x17_incoming_neurons_countB\x0e\n\x0c_kernel_sizeB\t\n\x07_stride\"H\n\x11\x41\x63tivationRequest\x12\x10\n\x08layer_id\x18\x01 \x01(\x05\x12\x11\n\tsample_id\x18\x02 \x01(\t\x12\x0e\n\x06origin\x18\x03 \x01(\t\"H\n\rActivationMap\x12\x11\n\tneuron_id\x18\x01 \x01(\x05\x12\x0e\n\x06values\x18\x02 \x03(\x02\x12\t\n\x01H\x18\x03 \x01(\x05\x12\t\n\x01W\x18\x04 \x01(\x05\"d\n\x12\x41\x63tivationResponse\x12\x12\n\nlayer_type\x18\x01 \x01(\t\x12\x15\n\rneurons_count\x18\x02 \x01(\x05\x12#\n\x0b\x61\x63tivations\x18\x03 \x03(\x0b\x32\x0e.ActivationMap\"\x93\x01\n\tTaskField\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x03 \x01(\x05H\x00\x12\x16\n\x0cstring_value\x18\x04 \x01(\tH\x00\x12\x15\n\x0b\x62ytes_value\x18\x05 \x01(\x0cH\x00\x12\x14\n\nbool_value\x18\x06 \x01(\x08H\x00\x42\x07\n\x05value\"\xcc\x02\n\x0eRecordMetadata\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x14\n\x0csample_label\x18\x02 \x03(\x05\x12\x19\n\x11sample_prediction\x18\x03 \x03(\x05\x12=\n\x10sample_last_loss\x18\x04 \x03(\x0b\x32#.RecordMetadata.SampleLastLossEntry\x12\x19\n\x11sample_encounters\x18\x05 \x01(\x05\x12\x18\n\x10sample_discarded\x18\x06 \x01(\x08\x12 \n\x0c\x65xtra_fields\x18\x07 \x03(\x0b\x32\n.TaskField\x12\x16\n\x0eprediction_raw\x18\t \x01(\x0c\x12\x11\n\ttask_type\x18\n \x01(\t\x1a\x35\n\x13SampleLastLossEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x93\x01\n\x10SampleStatistics\x12\x13\n\x06origin\x18\x06 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0csample_count\x18\x07 \x01(\x05H\x01\x88\x01\x01\x12\x11\n\ttask_type\x18\t \x01(\t\x12 \n\x07records\x18\x08 \x03(\x0b\x32\x0f.RecordMetadataB\t\n\x07_originB\x0f\n\r_sample_count\"\xe6\x01\n\x0f\x43ommandResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x33\n\x16hyper_parameters_descs\x18\x03 \x03(\x0b\x32\x13.HyperParameterDesc\x12\x33\n\x15layer_representations\x18\x04 \x03(\x0b\x32\x14.LayerRepresentation\x12\x31\n\x11sample_statistics\x18\x05 \x01(\x0b\x32\x11.SampleStatisticsH\x00\x88\x01\x01\x42\x14\n\x12_sample_statistics\"U\n\rSampleRequest\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_origin\"\xad\x02\n\x15SampleRequestResponse\x12\x16\n\tsample_id\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06origin\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x12\n\x05label\x18\x03 \x01(\x05H\x02\x88\x01\x01\x12\x11\n\x04\x64\x61ta\x18\x04 \x01(\x0cH\x03\x88\x01\x01\x12\x1a\n\rerror_message\x18\x05 \x01(\tH\x04\x88\x01\x01\x12\x15\n\x08raw_data\x18\x06 \x01(\x0cH\x05\x88\x01\x01\x12\x11\n\x04mask\x18\x07 \x01(\x0cH\x06\x88\x01\x01\x12\x17\n\nprediction\x18\x08 \x01(\x0cH\x07\x88\x01\x01\x42\x0c\n\n_sample_idB\t\n\x07_originB\x08\n\x06_labelB\x07\n\x05_dataB\x10\n\x0e_error_messageB\x0b\n\t_raw_dataB\x07\n\x05_maskB\r\n\x0b_prediction\"\x92\x01\n\x12\x42\x61tchSampleRequest\x12\x12\n\nsample_ids\x18\x01 \x03(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x19\n\x0cresize_width\x18\x03 \x01(\x05H\x00\x88\x01\x01\x12\x1a\n\rresize_height\x18\x04 \x01(\x05H\x01\x88\x01\x01\x42\x0f\n\r_resize_widthB\x10\n\x0e_resize_height\">\n\x13\x42\x61tchSampleResponse\x12\'\n\x07samples\x18\x01 \x03(\x0b\x32\x16.SampleRequestResponse\".\n\x0eWeightsRequest\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\"\x9d\x02\n\x0fWeightsResponse\x12\x1c\n\tneuron_id\x18\x01 \x01(\x0b\x32\t.NeuronId\x12\x17\n\nlayer_name\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nlayer_type\x18\x03 \x01(\tH\x01\x88\x01\x01\x12\x10\n\x08incoming\x18\x04 \x01(\x05\x12\x10\n\x08outgoing\x18\x05 \x01(\x05\x12\x18\n\x0bkernel_size\x18\x06 \x01(\x05H\x02\x88\x01\x01\x12\x0f\n\x07weights\x18\x07 \x03(\x02\x12\x0f\n\x07success\x18\x0b \x01(\x08\x12\x1a\n\rerror_message\x18\x0c \x01(\tH\x03\x88\x01\x01\x42\r\n\x0b_layer_nameB\r\n\x0b_layer_typeB\x0e\n\x0c_kernel_sizeB\x10\n\x0e_error_message\"R\n\x10\x44\x61taQueryRequest\x12\r\n\x05query\x18\x01 \x01(\t\x12\x12\n\naccumulate\x18\x02 \x01(\x08\x12\x1b\n\x13is_natural_language\x18\x03 \x01(\x08\"5\n\x11\x43\x61tegoricalTagDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ncategories\x18\x02 \x03(\t\"\xa9\x02\n\x11\x44\x61taQueryResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1d\n\x15number_of_all_samples\x18\x03 \x01(\x05\x12%\n\x1dnumber_of_samples_in_the_loop\x18\x04 \x01(\x05\x12#\n\x1bnumber_of_discarded_samples\x18\x05 \x01(\x05\x12\x13\n\x0bunique_tags\x18\x06 \x03(\t\x12+\n\x11\x61gent_intent_type\x18\x07 \x01(\x0e\x32\x10.AgentIntentType\x12\x17\n\x0f\x61nalysis_result\x18\x08 \x01(\t\x12,\n\x10\x63\x61tegorical_tags\x18\t \x03(\x0b\x32\x12.CategoricalTagDef\"\xc2\x01\n\x12\x44\x61taSamplesRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12 \n\x18include_transformed_data\x18\x03 \x01(\x08\x12\x18\n\x10include_raw_data\x18\x04 \x01(\x08\x12\x19\n\x11stats_to_retrieve\x18\x05 \x03(\t\x12\x14\n\x0cresize_width\x18\x06 \x01(\x05\x12\x15\n\rresize_height\x18\x07 \x01(\x05\"m\n\x08\x44\x61taStat\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x05\x12\r\n\x05value\x18\x04 \x03(\x02\x12\x14\n\x0cvalue_string\x18\x05 \x01(\t\x12\x11\n\tthumbnail\x18\x06 \x01(\x0c\">\n\nDataRecord\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x1d\n\ndata_stats\x18\x02 \x03(\x0b\x32\t.DataStat\"Z\n\x13\x44\x61taSamplesResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12!\n\x0c\x64\x61ta_records\x18\x03 \x03(\x0b\x32\x0b.DataRecord\"C\n\x0fHistogramSubBar\x12\x0e\n\x06origin\x18\x01 \x01(\t\x12\x11\n\tdiscarded\x18\x02 \x01(\x08\x12\r\n\x05\x63ount\x18\x03 \x01(\x03\"h\n\x0cHistogramBin\x12\x0b\n\x03min\x18\x01 \x01(\x01\x12\x0b\n\x03max\x18\x02 \x01(\x01\x12\x0b\n\x03\x61vg\x18\x03 \x01(\x01\x12\r\n\x05\x63ount\x18\x04 \x01(\x03\x12\"\n\x08sub_bars\x18\x05 \x03(\x0b\x32\x10.HistogramSubBar\"4\n\x10HistogramRequest\x12\x0e\n\x06\x63olumn\x18\x01 \x01(\t\x12\x10\n\x08max_bins\x18\x02 \x01(\x05\"f\n\x11HistogramResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\ntotal_rows\x18\x03 \x01(\x03\x12\x1b\n\x04\x62ins\x18\x04 \x03(\x0b\x32\r.HistogramBin\"W\n\x12GetMetaDataRequest\x12\x13\n\x0bstart_index\x18\x01 \x01(\x05\x12\x13\n\x0brecords_cnt\x18\x02 \x01(\x05\x12\x17\n\x0fmodal_sample_id\x18\x03 \x01(\t\"\x99\x01\n\x13GetMetaDataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x1a\n\x12\x61ll_metadata_names\x18\x03 \x03(\t\x12!\n\x0cgrid_records\x18\x04 \x03(\x0b\x32\x0b.DataRecord\x12!\n\x0cmodal_record\x18\x05 \x01(\x0b\x32\x0b.DataRecord\"J\n\x11PointCloudRequest\x12\x11\n\tsample_id\x18\x01 \x01(\t\x12\x0e\n\x06origin\x18\x02 \x01(\t\x12\x12\n\nmax_points\x18\x03 \x01(\x05\"\xbf\x01\n\x0fPointCloudChunk\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x12\n\nnum_points\x18\x03 \x01(\x05\x12\x14\n\x0cnum_features\x18\x04 \x01(\x05\x12\x10\n\x08pc_range\x18\x05 \x03(\x02\x12\x0c\n\x04\x64\x61ta\x18\x06 \x01(\x0c\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\x05\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\x05\x12\x15\n\rfeature_names\x18\t \x03(\t\"\xdc\x01\n\x10\x44\x61taEditsRequest\x12\x11\n\tstat_name\x18\x01 \x01(\t\x12\x13\n\x0b\x66loat_value\x18\x02 \x01(\x02\x12\x14\n\x0cstring_value\x18\x03 \x01(\t\x12\x12\n\nbool_value\x18\x04 \x01(\x08\x12\x1d\n\x04type\x18\x05 \x01(\x0e\x32\x0f.SampleEditType\x12\x13\n\x0bsamples_ids\x18\x06 \x03(\t\x12\x16\n\x0esample_origins\x18\x07 \x03(\t\x12\x16\n\x0eis_categorical\x18\x08 \x01(\x08\x12\x12\n\ncategories\x18\t \x03(\t\"5\n\x11\x44\x61taEditsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x12\x44\x61taSplitsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x13\n\x0bsplit_names\x18\x02 \x03(\t\"9\n\x13\x41gentHealthResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"^\n\x16InitializeAgentRequest\x12\x0f\n\x07\x61pi_key\x18\x01 \x01(\t\x12$\n\x08provider\x18\x02 \x01(\x0e\x32\x12.AgentProviderType\x12\r\n\x05model\x18\x03 \x01(\t\";\n\x17InitializeAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x17\x43hangeAgentModelRequest\x12\r\n\x05model\x18\x01 \x01(\t\"<\n\x18\x43hangeAgentModelResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x17\n\x15GetAgentModelsRequest\"J\n\x16GetAgentModelsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0e\n\x06models\x18\x02 \x03(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"6\n\x12ResetAgentResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"3\n\x18RestoreCheckpointRequest\x12\x17\n\x0f\x65xperiment_hash\x18\x01 \x01(\t\"=\n\x19RestoreCheckpointResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"R\n\x18TriggerEvaluationRequest\x12\x12\n\nsplit_name\x18\x01 \x01(\t\x12\x0c\n\x04tags\x18\x02 \x03(\t\x12\x14\n\x0cuse_full_set\x18\x03 \x01(\x08\"=\n\x19TriggerEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x1c\n\x1aGetEvaluationStatusRequest\"\x81\x01\n\x1bGetEvaluationStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x0f\n\x07\x63urrent\x18\x02 \x01(\x05\x12\r\n\x05total\x18\x03 \x01(\x05\x12\x0f\n\x07message\x18\x04 \x01(\t\x12\r\n\x05\x65rror\x18\x05 \x01(\t\x12\x12\n\nsplit_name\x18\x06 \x01(\t\")\n\x17\x43\x61ncelEvaluationRequest\x12\x0e\n\x06reason\x18\x01 \x01(\t\"<\n\x18\x43\x61ncelEvaluationResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*d\n\x13WeightOperationType\x12\n\n\x06ZEROFY\x10\x00\x12\x10\n\x0cREINITIALIZE\x10\x01\x12\n\n\x06\x46REEZE\x10\x02\x12\x12\n\x0eREMOVE_NEURONS\x10\t\x12\x0f\n\x0b\x41\x44\x44_NEURONS\x10\n*o\n\x0fZerofyPredicate\x12\x19\n\x15ZEROFY_PREDICATE_NONE\x10\x00\x12 \n\x1cZEROFY_PREDICATE_WITH_FROZEN\x10\x01\x12\x1f\n\x1bZEROFY_PREDICATE_WITH_OLDER\x10\x02*M\n\x0f\x41gentIntentType\x12\x12\n\x0eINTENT_UNKNOWN\x10\x00\x12\x11\n\rINTENT_FILTER\x10\x01\x12\x13\n\x0fINTENT_ANALYSIS\x10\x02*I\n\x0eSampleEditType\x12\x11\n\rEDIT_OVERRIDE\x10\x00\x12\x13\n\x0f\x45\x44IT_ACCUMULATE\x10\x01\x12\x0f\n\x0b\x45\x44IT_REMOVE\x10\x02*,\n\x11\x41gentProviderType\x12\x17\n\x13PROVIDER_OPENROUTER\x10\x00\x32\xf5\n\n\x11\x45xperimentService\x12P\n\x13GetLatestLoggerData\x12\x1b.GetLatestLoggerDataRequest\x1a\x1c.GetLatestLoggerDataResponse\x12\x36\n\x11\x45xperimentCommand\x12\x0f.TrainerCommand\x1a\x10.CommandResponse\x12H\n\x11ManipulateWeights\x12\x18.WeightsOperationRequest\x1a\x19.WeightsOperationResponse\x12/\n\nGetWeights\x12\x0f.WeightsRequest\x1a\x10.WeightsResponse\x12\x39\n\x0eGetActivations\x12\x12.ActivationRequest\x1a\x13.ActivationResponse\x12\x37\n\nGetSamples\x12\x13.BatchSampleRequest\x1a\x14.BatchSampleResponse\x12\x37\n\x0e\x41pplyDataQuery\x12\x11.DataQueryRequest\x1a\x12.DataQueryResponse\x12;\n\x0eGetDataSamples\x12\x13.DataSamplesRequest\x1a\x14.DataSamplesResponse\x12\x35\n\x0cGetHistogram\x12\x11.HistogramRequest\x1a\x12.HistogramResponse\x12\x38\n\x0bGetMetaData\x12\x13.GetMetaDataRequest\x1a\x14.GetMetaDataResponse\x12\x37\n\rGetPointCloud\x12\x12.PointCloudRequest\x1a\x10.PointCloudChunk0\x01\x12\x37\n\x0e\x45\x64itDataSample\x12\x11.DataEditsRequest\x1a\x12.DataEditsResponse\x12,\n\rGetDataSplits\x12\x06.Empty\x1a\x13.DataSplitsResponse\x12\x30\n\x10\x43heckAgentHealth\x12\x06.Empty\x1a\x14.AgentHealthResponse\x12\x44\n\x0fInitializeAgent\x12\x17.InitializeAgentRequest\x1a\x18.InitializeAgentResponse\x12G\n\x10\x43hangeAgentModel\x12\x18.ChangeAgentModelRequest\x1a\x19.ChangeAgentModelResponse\x12\x41\n\x0eGetAgentModels\x12\x16.GetAgentModelsRequest\x1a\x17.GetAgentModelsResponse\x12)\n\nResetAgent\x12\x06.Empty\x1a\x13.ResetAgentResponse\x12J\n\x11RestoreCheckpoint\x12\x19.RestoreCheckpointRequest\x1a\x1a.RestoreCheckpointResponse\x12J\n\x11TriggerEvaluation\x12\x19.TriggerEvaluationRequest\x1a\x1a.TriggerEvaluationResponse\x12P\n\x13GetEvaluationStatus\x12\x1b.GetEvaluationStatusRequest\x1a\x1c.GetEvaluationStatusResponse\x12G\n\x10\x43\x61ncelEvaluation\x12\x18.CancelEvaluationRequest\x1a\x19.CancelEvaluationResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -37,16 +37,16 @@ _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_options = b'8\001' _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._loaded_options = None _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_options = b'8\001' - _globals['_WEIGHTOPERATIONTYPE']._serialized_start=8812 - _globals['_WEIGHTOPERATIONTYPE']._serialized_end=8912 - _globals['_ZEROFYPREDICATE']._serialized_start=8914 - _globals['_ZEROFYPREDICATE']._serialized_end=9025 - _globals['_AGENTINTENTTYPE']._serialized_start=9027 - _globals['_AGENTINTENTTYPE']._serialized_end=9104 - _globals['_SAMPLEEDITTYPE']._serialized_start=9106 - _globals['_SAMPLEEDITTYPE']._serialized_end=9179 - _globals['_AGENTPROVIDERTYPE']._serialized_start=9181 - _globals['_AGENTPROVIDERTYPE']._serialized_end=9225 + _globals['_WEIGHTOPERATIONTYPE']._serialized_start=9564 + _globals['_WEIGHTOPERATIONTYPE']._serialized_end=9664 + _globals['_ZEROFYPREDICATE']._serialized_start=9666 + _globals['_ZEROFYPREDICATE']._serialized_end=9777 + _globals['_AGENTINTENTTYPE']._serialized_start=9779 + _globals['_AGENTINTENTTYPE']._serialized_end=9856 + _globals['_SAMPLEEDITTYPE']._serialized_start=9858 + _globals['_SAMPLEEDITTYPE']._serialized_end=9931 + _globals['_AGENTPROVIDERTYPE']._serialized_start=9933 + _globals['_AGENTPROVIDERTYPE']._serialized_end=9977 _globals['_GETLATESTLOGGERDATAREQUEST']._serialized_start=46 _globals['_GETLATESTLOGGERDATAREQUEST']._serialized_end=183 _globals['_LOGGERDATAPOINT']._serialized_start=186 @@ -81,100 +81,114 @@ _globals['_LOADCHECKPOINTOPERATION']._serialized_end=2367 _globals['_PLOTNOTEOPERATION']._serialized_start=2369 _globals['_PLOTNOTEOPERATION']._serialized_end=2467 - _globals['_TRAINERCOMMAND']._serialized_start=2470 - _globals['_TRAINERCOMMAND']._serialized_end=3330 - _globals['_HYPERPARAMETERDESC']._serialized_start=3333 - _globals['_HYPERPARAMETERDESC']._serialized_end=3490 - _globals['_NEURONSTATISTICS']._serialized_start=3493 - _globals['_NEURONSTATISTICS']._serialized_end=3863 - _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_start=3722 - _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_end=3771 - _globals['_LAYERREPRESENTATION']._serialized_start=3866 - _globals['_LAYERREPRESENTATION']._serialized_end=4234 - _globals['_ACTIVATIONREQUEST']._serialized_start=4236 - _globals['_ACTIVATIONREQUEST']._serialized_end=4308 - _globals['_ACTIVATIONMAP']._serialized_start=4310 - _globals['_ACTIVATIONMAP']._serialized_end=4382 - _globals['_ACTIVATIONRESPONSE']._serialized_start=4384 - _globals['_ACTIVATIONRESPONSE']._serialized_end=4484 - _globals['_TASKFIELD']._serialized_start=4487 - _globals['_TASKFIELD']._serialized_end=4634 - _globals['_RECORDMETADATA']._serialized_start=4637 - _globals['_RECORDMETADATA']._serialized_end=4969 - _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_start=4916 - _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_end=4969 - _globals['_SAMPLESTATISTICS']._serialized_start=4972 - _globals['_SAMPLESTATISTICS']._serialized_end=5119 - _globals['_COMMANDRESPONSE']._serialized_start=5122 - _globals['_COMMANDRESPONSE']._serialized_end=5352 - _globals['_SAMPLEREQUEST']._serialized_start=5354 - _globals['_SAMPLEREQUEST']._serialized_end=5439 - _globals['_SAMPLEREQUESTRESPONSE']._serialized_start=5442 - _globals['_SAMPLEREQUESTRESPONSE']._serialized_end=5743 - _globals['_BATCHSAMPLEREQUEST']._serialized_start=5746 - _globals['_BATCHSAMPLEREQUEST']._serialized_end=5892 - _globals['_BATCHSAMPLERESPONSE']._serialized_start=5894 - _globals['_BATCHSAMPLERESPONSE']._serialized_end=5956 - _globals['_WEIGHTSREQUEST']._serialized_start=5958 - _globals['_WEIGHTSREQUEST']._serialized_end=6004 - _globals['_WEIGHTSRESPONSE']._serialized_start=6007 - _globals['_WEIGHTSRESPONSE']._serialized_end=6292 - _globals['_DATAQUERYREQUEST']._serialized_start=6294 - _globals['_DATAQUERYREQUEST']._serialized_end=6376 - _globals['_CATEGORICALTAGDEF']._serialized_start=6378 - _globals['_CATEGORICALTAGDEF']._serialized_end=6431 - _globals['_DATAQUERYRESPONSE']._serialized_start=6434 - _globals['_DATAQUERYRESPONSE']._serialized_end=6731 - _globals['_DATASAMPLESREQUEST']._serialized_start=6734 - _globals['_DATASAMPLESREQUEST']._serialized_end=6928 - _globals['_DATASTAT']._serialized_start=6930 - _globals['_DATASTAT']._serialized_end=7039 - _globals['_DATARECORD']._serialized_start=7041 - _globals['_DATARECORD']._serialized_end=7103 - _globals['_DATASAMPLESRESPONSE']._serialized_start=7105 - _globals['_DATASAMPLESRESPONSE']._serialized_end=7195 - _globals['_POINTCLOUDREQUEST']._serialized_start=7197 - _globals['_POINTCLOUDREQUEST']._serialized_end=7271 - _globals['_POINTCLOUDCHUNK']._serialized_start=7274 - _globals['_POINTCLOUDCHUNK']._serialized_end=7465 - _globals['_DATAEDITSREQUEST']._serialized_start=7468 - _globals['_DATAEDITSREQUEST']._serialized_end=7688 - _globals['_DATAEDITSRESPONSE']._serialized_start=7690 - _globals['_DATAEDITSRESPONSE']._serialized_end=7743 - _globals['_DATASPLITSRESPONSE']._serialized_start=7745 - _globals['_DATASPLITSRESPONSE']._serialized_end=7803 - _globals['_AGENTHEALTHRESPONSE']._serialized_start=7805 - _globals['_AGENTHEALTHRESPONSE']._serialized_end=7862 - _globals['_INITIALIZEAGENTREQUEST']._serialized_start=7864 - _globals['_INITIALIZEAGENTREQUEST']._serialized_end=7958 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=7960 - _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8019 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8021 - _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8061 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8063 - _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8123 - _globals['_GETAGENTMODELSREQUEST']._serialized_start=8125 - _globals['_GETAGENTMODELSREQUEST']._serialized_end=8148 - _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8150 - _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8224 - _globals['_RESETAGENTRESPONSE']._serialized_start=8226 - _globals['_RESETAGENTRESPONSE']._serialized_end=8280 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=8282 - _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=8333 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=8335 - _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=8396 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=8398 - _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=8480 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=8482 - _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=8543 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=8545 - _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=8573 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=8576 - _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=8705 - _globals['_CANCELEVALUATIONREQUEST']._serialized_start=8707 - _globals['_CANCELEVALUATIONREQUEST']._serialized_end=8748 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=8750 - _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=8810 - _globals['_EXPERIMENTSERVICE']._serialized_start=9228 - _globals['_EXPERIMENTSERVICE']._serialized_end=10512 + _globals['_SAVECHECKPOINTOPERATION']._serialized_start=2469 + _globals['_SAVECHECKPOINTOPERATION']._serialized_end=2545 + _globals['_TRAINERCOMMAND']._serialized_start=2548 + _globals['_TRAINERCOMMAND']._serialized_end=3504 + _globals['_HYPERPARAMETERDESC']._serialized_start=3507 + _globals['_HYPERPARAMETERDESC']._serialized_end=3664 + _globals['_NEURONSTATISTICS']._serialized_start=3667 + _globals['_NEURONSTATISTICS']._serialized_end=4037 + _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_start=3896 + _globals['_NEURONSTATISTICS_INCOMINGLRENTRY']._serialized_end=3945 + _globals['_LAYERREPRESENTATION']._serialized_start=4040 + _globals['_LAYERREPRESENTATION']._serialized_end=4408 + _globals['_ACTIVATIONREQUEST']._serialized_start=4410 + _globals['_ACTIVATIONREQUEST']._serialized_end=4482 + _globals['_ACTIVATIONMAP']._serialized_start=4484 + _globals['_ACTIVATIONMAP']._serialized_end=4556 + _globals['_ACTIVATIONRESPONSE']._serialized_start=4558 + _globals['_ACTIVATIONRESPONSE']._serialized_end=4658 + _globals['_TASKFIELD']._serialized_start=4661 + _globals['_TASKFIELD']._serialized_end=4808 + _globals['_RECORDMETADATA']._serialized_start=4811 + _globals['_RECORDMETADATA']._serialized_end=5143 + _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_start=5090 + _globals['_RECORDMETADATA_SAMPLELASTLOSSENTRY']._serialized_end=5143 + _globals['_SAMPLESTATISTICS']._serialized_start=5146 + _globals['_SAMPLESTATISTICS']._serialized_end=5293 + _globals['_COMMANDRESPONSE']._serialized_start=5296 + _globals['_COMMANDRESPONSE']._serialized_end=5526 + _globals['_SAMPLEREQUEST']._serialized_start=5528 + _globals['_SAMPLEREQUEST']._serialized_end=5613 + _globals['_SAMPLEREQUESTRESPONSE']._serialized_start=5616 + _globals['_SAMPLEREQUESTRESPONSE']._serialized_end=5917 + _globals['_BATCHSAMPLEREQUEST']._serialized_start=5920 + _globals['_BATCHSAMPLEREQUEST']._serialized_end=6066 + _globals['_BATCHSAMPLERESPONSE']._serialized_start=6068 + _globals['_BATCHSAMPLERESPONSE']._serialized_end=6130 + _globals['_WEIGHTSREQUEST']._serialized_start=6132 + _globals['_WEIGHTSREQUEST']._serialized_end=6178 + _globals['_WEIGHTSRESPONSE']._serialized_start=6181 + _globals['_WEIGHTSRESPONSE']._serialized_end=6466 + _globals['_DATAQUERYREQUEST']._serialized_start=6468 + _globals['_DATAQUERYREQUEST']._serialized_end=6550 + _globals['_CATEGORICALTAGDEF']._serialized_start=6552 + _globals['_CATEGORICALTAGDEF']._serialized_end=6605 + _globals['_DATAQUERYRESPONSE']._serialized_start=6608 + _globals['_DATAQUERYRESPONSE']._serialized_end=6905 + _globals['_DATASAMPLESREQUEST']._serialized_start=6908 + _globals['_DATASAMPLESREQUEST']._serialized_end=7102 + _globals['_DATASTAT']._serialized_start=7104 + _globals['_DATASTAT']._serialized_end=7213 + _globals['_DATARECORD']._serialized_start=7215 + _globals['_DATARECORD']._serialized_end=7277 + _globals['_DATASAMPLESRESPONSE']._serialized_start=7279 + _globals['_DATASAMPLESRESPONSE']._serialized_end=7369 + _globals['_HISTOGRAMSUBBAR']._serialized_start=7371 + _globals['_HISTOGRAMSUBBAR']._serialized_end=7438 + _globals['_HISTOGRAMBIN']._serialized_start=7440 + _globals['_HISTOGRAMBIN']._serialized_end=7544 + _globals['_HISTOGRAMREQUEST']._serialized_start=7546 + _globals['_HISTOGRAMREQUEST']._serialized_end=7598 + _globals['_HISTOGRAMRESPONSE']._serialized_start=7600 + _globals['_HISTOGRAMRESPONSE']._serialized_end=7702 + _globals['_GETMETADATAREQUEST']._serialized_start=7704 + _globals['_GETMETADATAREQUEST']._serialized_end=7791 + _globals['_GETMETADATARESPONSE']._serialized_start=7794 + _globals['_GETMETADATARESPONSE']._serialized_end=7947 + _globals['_POINTCLOUDREQUEST']._serialized_start=7949 + _globals['_POINTCLOUDREQUEST']._serialized_end=8023 + _globals['_POINTCLOUDCHUNK']._serialized_start=8026 + _globals['_POINTCLOUDCHUNK']._serialized_end=8217 + _globals['_DATAEDITSREQUEST']._serialized_start=8220 + _globals['_DATAEDITSREQUEST']._serialized_end=8440 + _globals['_DATAEDITSRESPONSE']._serialized_start=8442 + _globals['_DATAEDITSRESPONSE']._serialized_end=8495 + _globals['_DATASPLITSRESPONSE']._serialized_start=8497 + _globals['_DATASPLITSRESPONSE']._serialized_end=8555 + _globals['_AGENTHEALTHRESPONSE']._serialized_start=8557 + _globals['_AGENTHEALTHRESPONSE']._serialized_end=8614 + _globals['_INITIALIZEAGENTREQUEST']._serialized_start=8616 + _globals['_INITIALIZEAGENTREQUEST']._serialized_end=8710 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_start=8712 + _globals['_INITIALIZEAGENTRESPONSE']._serialized_end=8771 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_start=8773 + _globals['_CHANGEAGENTMODELREQUEST']._serialized_end=8813 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_start=8815 + _globals['_CHANGEAGENTMODELRESPONSE']._serialized_end=8875 + _globals['_GETAGENTMODELSREQUEST']._serialized_start=8877 + _globals['_GETAGENTMODELSREQUEST']._serialized_end=8900 + _globals['_GETAGENTMODELSRESPONSE']._serialized_start=8902 + _globals['_GETAGENTMODELSRESPONSE']._serialized_end=8976 + _globals['_RESETAGENTRESPONSE']._serialized_start=8978 + _globals['_RESETAGENTRESPONSE']._serialized_end=9032 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_start=9034 + _globals['_RESTORECHECKPOINTREQUEST']._serialized_end=9085 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_start=9087 + _globals['_RESTORECHECKPOINTRESPONSE']._serialized_end=9148 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_start=9150 + _globals['_TRIGGEREVALUATIONREQUEST']._serialized_end=9232 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_start=9234 + _globals['_TRIGGEREVALUATIONRESPONSE']._serialized_end=9295 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_start=9297 + _globals['_GETEVALUATIONSTATUSREQUEST']._serialized_end=9325 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_start=9328 + _globals['_GETEVALUATIONSTATUSRESPONSE']._serialized_end=9457 + _globals['_CANCELEVALUATIONREQUEST']._serialized_start=9459 + _globals['_CANCELEVALUATIONREQUEST']._serialized_end=9500 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_start=9502 + _globals['_CANCELEVALUATIONRESPONSE']._serialized_end=9562 + _globals['_EXPERIMENTSERVICE']._serialized_start=9980 + _globals['_EXPERIMENTSERVICE']._serialized_end=11377 # @@protoc_insertion_point(module_scope) diff --git a/weightslab/proto/experiment_service_pb2_grpc.py b/weightslab/proto/experiment_service_pb2_grpc.py index 004c4d93..e75d391c 100644 --- a/weightslab/proto/experiment_service_pb2_grpc.py +++ b/weightslab/proto/experiment_service_pb2_grpc.py @@ -74,6 +74,16 @@ def __init__(self, channel): request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesRequest.SerializeToString, response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesResponse.FromString, _registered_method=True) + self.GetHistogram = channel.unary_unary( + '/ExperimentService/GetHistogram', + request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString, + response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString, + _registered_method=True) + self.GetMetaData = channel.unary_unary( + '/ExperimentService/GetMetaData', + request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, + response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.FromString, + _registered_method=True) self.GetPointCloud = channel.unary_stream( '/ExperimentService/GetPointCloud', request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.PointCloudRequest.SerializeToString, @@ -188,6 +198,23 @@ def GetDataSamples(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetHistogram(self, request, context): + """Server-side histogram binning of one metadata/signal column. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetMetaData(self, request, context): + """Metadata-only retrieval (dataframe columns). Returns every metadata column + name for the WHOLE dataset, the current grid slice's per-sample metadata, and + the open modal sample's metadata. Separated from GetDataSamples, which now + returns only image / label / prediction data. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def GetPointCloud(self, request, context): """Raw point cloud of one sample (task_type "detection_pointcloud"), server-streamed in binary chunks for the interactive 3D viewer. @@ -307,6 +334,16 @@ def add_ExperimentServiceServicer_to_server(servicer, server): request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesRequest.FromString, response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.DataSamplesResponse.SerializeToString, ), + 'GetHistogram': grpc.unary_unary_rpc_method_handler( + servicer.GetHistogram, + request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.FromString, + response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.SerializeToString, + ), + 'GetMetaData': grpc.unary_unary_rpc_method_handler( + servicer.GetMetaData, + request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.FromString, + response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.SerializeToString, + ), 'GetPointCloud': grpc.unary_stream_rpc_method_handler( servicer.GetPointCloud, request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.PointCloudRequest.FromString, @@ -594,6 +631,60 @@ def GetDataSamples(request, metadata, _registered_method=True) + @staticmethod + def GetHistogram(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/ExperimentService/GetHistogram', + weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString, + weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetMetaData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/ExperimentService/GetMetaData', + weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString, + weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def GetPointCloud(request, target, diff --git a/weightslab/security/cert_auth_manager.py b/weightslab/security/cert_auth_manager.py index f4b5331e..358dd86b 100644 --- a/weightslab/security/cert_auth_manager.py +++ b/weightslab/security/cert_auth_manager.py @@ -193,7 +193,7 @@ def get_or_create_auth_token(self) -> str: try: os.chmod(self.token_file, 0o600) except Exception: - pass # Windows doesn't support chmod + pass # Windows doesn't support chmod logger.info(f"Wrote gRPC auth token to {self.token_file}") except Exception as e: logger.error(f"Could not save token: {e}") diff --git a/weightslab/src.py b/weightslab/src.py index d3600fc6..35c1a420 100644 --- a/weightslab/src.py +++ b/weightslab/src.py @@ -43,11 +43,11 @@ def _rebind_caller_local(original_obj: Any, new_obj: Any) -> None: This lets ``wl.watch_or_edit(parameters, ...)`` (without capturing the return value) transparently replace ``parameters`` with the returned Proxy in the - calling scope. Silently does nothing on non-CPython runtimes. + calling scope. Silently does nothing on non-CPython runtimes. """ try: # frame 0 = _rebind_caller_local - # frame 1 = watch_or_edit (or whatever internal caller) + # frame 1 = watch_or_edit (or whatever internal caller) # frame 2 = user code frame = sys._getframe(2) changed = False @@ -217,26 +217,26 @@ def _get_step(step: int | None = None) -> int: if m is not None: # Safe attribute access (handle Proxy returning None for missing attr) if hasattr(m, 'get_age'): - val = m.get_age() -1 # At this point, model already saw one batch, except if we started by evaluation + val = m.get_age() -1 # At this point, model already saw one batch, except if we started by evaluation if val is not None: - step = max([int(val), 0]) # Use age-1 as step to reflect completed step; ensure non-negative + step = max([int(val), 0]) # Use age-1 as step to reflect completed step; ensure non-negative elif hasattr(m, 'current_step'): val = m.current_step if val is not None: - step = max([int(val), 0]) # Use current_step-1 as step to reflect completed step; ensure non-negative + step = max([int(val), 0]) # Use current_step-1 as step to reflect completed step; ensure non-negative elif step is not None: # step = step # fallback to provided step - m.current_step = step # add current_step attribute to model for future tracking + m.current_step = step # add current_step attribute to model for future tracking - m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType + m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType elif step is not None: # If model doesn't have current_step, force it to 0 or try to infer from checkpoint manager - m.current_step = step # add current_step attribute to model for future tracking - m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType + m.current_step = step # add current_step attribute to model for future tracking + m.get_age = types.MethodType(_get_age, m) # To make a proper bound method so `self` is passed correctly, we use types.MethodType return step @@ -334,7 +334,7 @@ def _log_signal(scalar: float, signal_per_sample: dict, reg_name: str, step: int {reg_name: scalar}, global_step=step, signal_per_sample=signal_per_sample, - aggregate_by_step=kwargs.get('per_sample', True) # Aggregate per-sample signals by step for logging if per_sample is True, + aggregate_by_step=kwargs.get('per_sample', True) # Aggregate per-sample signals by step for logging if per_sample is True, ) except Exception: traceback.print_exc() @@ -516,9 +516,9 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): if instance_batch_idx is None and 'batch_idx' in kw: instance_batch_idx = kw['batch_idx'] elif instance_batch_idx is None and targets is not None and isinstance(targets, list): - instance_batch_idx = [i for i, tars in enumerate(targets) for _ in tars] # Auto determine batch_idx from targets if not explicitly provided (assumes targets is list of lists of annotations) + instance_batch_idx = [i for i, tars in enumerate(targets) for _ in tars] # Auto determine batch_idx from targets if not explicitly provided (assumes targets is list of lists of annotations) else: - instance_batch_idx = ledgers.get_dataframe()._df.loc[batch_ids].index.get_level_values(1).tolist() # Query directly instance_ids related and ordered to the samples_ids in the batch + instance_batch_idx = ledgers.get_dataframe()._df.loc[batch_ids].index.get_level_values(1).tolist() # Query directly instance_ids related and ordered to the samples_ids in the batch batch_ids = ledgers.get_dataframe()._df.loc[batch_ids].index.get_level_values(0).tolist() # If output is a dict (from PerInstanceDetectionLoss), pick 'sample' @@ -533,7 +533,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): if kwargs.get('per_sample', False) and not isinstance(out, dict): if hasattr(out, 'ndim') and out.ndim > 1: - out = out.mean(dim=tuple(range(1, out.ndim))) # Reduce to [B,]0 + out = out.mean(dim=tuple(range(1, out.ndim))) # Reduce to [B,]0 # Extract scalar from tensor scalar, batch_scalar = _extract_scalar_from_tensor(batch_scalar, out, batch_ids) @@ -550,7 +550,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): batch_idx=instance_batch_idx, targets=targets, step=step, - log=False, # already logged sample-level above + log=False, # already logged sample-level above ) except Exception as e: traceback.print_exc() if os.environ.get('WEIGHTSLAB_LOG_LEVEL') == 'DEBUG' else None @@ -640,7 +640,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): origin=kwargs.get('origin', 'train') ) try: - res = func(ctx) # Compute per sample result with unified context + res = func(ctx) # Compute per sample result with unified context except TypeError: # Fallback for legacy subscriber functions res = func(sample_id=int(uid), value=val, dataframe=df_proxy) @@ -650,10 +650,10 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): dynamic_updates[name] = signal_value if dynamic_updates and meta.get('log', True): logger.debug(f"Dynamic updates computed for signal '{reg_name}': {list(dynamic_updates.keys())}") - _log_signal(sum(signal_value)/len(signal_value), signal_value, name, step=step, **kwargs) # Log custom subscribed signals + _log_signal(sum(signal_value)/len(signal_value), signal_value, name, step=step, **kwargs) # Log custom subscribed signals except Exception as e: logger.debug(f"Dynamic signal {name} failed: {e}") - pass # User function error, skip + pass # User function error, skip # Save statistics if requested and applicable. # Skip the per-sample save path when per_instance=True — instance values @@ -676,7 +676,7 @@ def wrappered_fwd(original_forward, kwargs, reg_name, *a, **kw): preds_raw=preds_raw, preds=preds, targets=targets, - log=False # Already logged above, no need to log again in save_signals; set to False to avoid duplicate logging if save_signals is called separately without logging + log=False # Already logged above, no need to log again in save_signals; set to False to avoid duplicate logging if save_signals is called separately without logging ) # Return the original output (dict for per-instance losses so caller can @@ -748,7 +748,7 @@ def watch_or_edit(obj: Callable, obj_name: str = None, flag: str = None, **kwarg forced_model_wrapping = kwargs.pop('forced_model_wrapping', False) # Now construct the wrapper and let it register into the ledger. - wrapper = ModelInterface(obj, **kwargs) if forced_model_wrapping or _model == None else _model + wrapper = ModelInterface(obj, **kwargs) if forced_model_wrapping or _model == None else _model # No rebind here since the model wrapper is designed to be a drop-in replacement for the original model @@ -791,11 +791,11 @@ def watch_or_edit(obj: Callable, obj_name: str = None, flag: str = None, **kwarg if 'loader_name' not in kwargs and 'name' in kwargs: kwargs['loader_name'] = kwargs['name'] except Exception: - pass # If we can't get hyperparameters, continue without root_log_dir + pass # If we can't get hyperparameters, continue without root_log_dir # Now construct the wrapper and let it register into the ledger. wrapper = DataLoaderInterface(obj, **kwargs) - _dataloader.__pl_saved_kwargs = kwargs # Force pytorch lightning compatibility + _dataloader.__pl_saved_kwargs = kwargs # Force pytorch lightning compatibility # There is not rebind here because obj can be a dataloader or a dataset @@ -981,7 +981,7 @@ def new_forward(*a, **kw): logger.info(f"Loaded hyperparameters from checkpoint {latest_hash[:16]}") checkpoint_hp_loaded = True except Exception: - pass # If checkpoint loading fails, proceed with normal registration + pass # If checkpoint loading fails, proceed with normal registration defaults = kwargs.get('defaults', None) if not checkpoint_hp_loaded: @@ -1071,7 +1071,7 @@ def start_training(timeout: int = None) -> None: if timeout is not None and isinstance(timeout, int) and timeout > 0: logger.info(f"Starting WeightsLab training mode with a timeout of {timeout} seconds.") time.sleep(timeout) - pause_ctrl.resume() # Ensure we're not paused if start_training is called after serve + pause_ctrl.resume() # Ensure we're not paused if start_training is called after serve def serve(serving_cli: bool = False, serving_grpc: bool = False, **kwargs) -> None: """Start WeightsLab services. @@ -1113,6 +1113,116 @@ def keep_serving(timeout: int = None, release_gpu: bool = True) -> None: logger.info("Shutting down WeightsLab services.") +def _rehydrate_dataframe_from_disk(root_log_dir) -> list: + """Best-effort: rebuild the data grid from the persisted H5 store so samples + are browsable in explore mode without the original ``Dataset`` object. + + Returns the list of data origins (splits) that were rehydrated. Any failure + is non-fatal — logs/plots still work without the data grid. + """ + from pathlib import Path + import pandas as _pd + from weightslab.backend import ledgers as _ledgers + from weightslab.data.h5_dataframe_store import H5DataFrameStore + from weightslab.data.dataframe_manager import LedgeredDataFrameManager + + data_h5 = Path(str(root_log_dir)) / "checkpoints" / "data" / "data.h5" + if not data_h5.exists(): + logger.info("Explore: no persisted data store at %s; data grid unavailable.", data_h5) + return [] + try: + store = H5DataFrameStore(str(data_h5)) + # Enumerate origins (HDF groups under the store's key prefix). + prefix = f"/{getattr(store, '_key_prefix', 'stats')}_" + origins: list = [] + try: + with _pd.HDFStore(str(data_h5), mode="r") as h5: + origins = sorted({k[len(prefix):] for k in h5.keys() if k.startswith(prefix)}) + except Exception: + logger.debug("Explore: could not enumerate data origins", exc_info=True) + if not origins: + return [] + + # No flush threads in a read-only explorer; data writes (tags/discard) are + # still applied in-memory and persisted on demand. + dfm = LedgeredDataFrameManager(enable_flushing_threads=False, enable_h5_persistence=True) + dfm.set_store(store) + loaded = [] + for origin in origins: + try: + dfm.register_split(origin, _pd.DataFrame(), store, autoload_arrays=False) + loaded.append(origin) + except Exception: + logger.warning("Explore: failed to rehydrate data split '%s'", origin, exc_info=True) + if loaded: + _ledgers.register_dataframe(dfm) + return loaded + except Exception: + logger.warning("Explore: data rehydration from disk failed; logs/plots still available.", exc_info=True) + return [] + + +def load_experiment_for_explore(root_log_dir, exp_hash: str = None) -> dict: + """Load a finished experiment from ``root_log_dir`` into a fresh, read-only ledger. + + Reconstructs hyperparameters, logger history, the checkpoint manager (and, + best-effort, the model and the data grid) purely from disk, then flips the + process into read-only **explore mode** (see + :mod:`weightslab.backend.explore_mode`). No training script, dataset, GPU, or + network is required — intended for browsing a run that is finished or still + training elsewhere (e.g. on a cluster). + + After this returns, start the gRPC server with :func:`serve` (``serving_grpc=True``) + and the UI can read everything while training/HP/weight mutations are refused. + + Args: + root_log_dir: An experiment ``root_log_dir`` produced by a previous run. + exp_hash: Optional specific experiment hash to open (defaults to the latest). + + Returns: + A dict summary: ``{root_log_dir, experiment_hash, has_logger, origins}``. + """ + from pathlib import Path + from weightslab.backend import ledgers as _ledgers + from weightslab.backend.explore_mode import set_explore_mode + from weightslab.components.checkpoint_manager import CheckpointManager + + root = Path(str(root_log_dir)).absolute() + if not root.exists(): + raise FileNotFoundError(f"root_log_dir does not exist: {root}") + + # A read-only explorer must not inherit any live training objects. + _ledgers.clear_all() + + # CheckpointManager.__init__ loads the manifest + logger snapshots (registering + # a logger with the saved history) and bootstraps the latest experiment state + # (config/HP, model best-effort, data snapshot) from disk. + manager = CheckpointManager(str(root)) + _ledgers.register_checkpoint_manager(manager) + + if exp_hash: + try: + manager.load_state(exp_hash) + except Exception: + logger.warning( + "Explore: could not load requested hash %s; using bootstrapped state.", + exp_hash, exc_info=True, + ) + + origins = _rehydrate_dataframe_from_disk(root) + + set_explore_mode(True) + + summary = { + "root_log_dir": str(root), + "experiment_hash": manager.get_current_experiment_hash(), + "has_logger": _ledgers.get_logger() is not None, + "origins": origins, + } + logger.info("Explore mode ready: %s", summary) + return summary + + def signal(name: str, subscribe_to: str = None, compute_every_n_steps: int = 1, **kwargs): """ Decorator that registers a custom signal function. @@ -1632,14 +1742,14 @@ def save_signals( Examples: Classification — one loss scalar per image:: - for inputs, targets, ids in train_loader: # ids: sample IDs, len B - logits = model(inputs) # (B, num_classes) - loss = loss_fn(logits, targets) # (B,) per-sample loss + for inputs, targets, ids in train_loader: # ids: sample IDs, len B + logits = model(inputs) # (B, num_classes) + loss = loss_fn(logits, targets) # (B,) per-sample loss wl.save_signals( - signals={"train_loss": loss}, # (B,) -> signals//train_loss + signals={"train_loss": loss}, # (B,) -> signals//train_loss batch_ids=ids, - preds_raw=logits, # (B, num_classes) - targets=targets, # (B,) + preds_raw=logits, # (B, num_classes) + targets=targets, # (B,) step=current_step, log=True, ) @@ -1647,7 +1757,7 @@ def save_signals( Several named per-sample metrics at once:: wl.save_signals( - signals={"iou": iou_per_image, "dice": dice_per_image}, # each (B,) + signals={"iou": iou_per_image, "dice": dice_per_image}, # each (B,) batch_ids=ids, ) """ @@ -1704,9 +1814,9 @@ def expand_dim(x): return x[:, np.newaxis] return x - preds_np = normalize(preds) + preds_np = normalize(preds) preds_raw_np = normalize(preds_raw) - target_np = normalize(targets) + target_np = normalize(targets) # Processing signals if isinstance(signals, dict): @@ -1726,8 +1836,8 @@ def expand_dim(x): losses_data = None # Expand dims for 1D arrays (skipped for lists) - target_np = expand_dim(target_np) - preds_np = expand_dim(preds_np) + target_np = expand_dim(target_np) + preds_np = expand_dim(preds_np) preds_raw_np = expand_dim(preds_raw_np) # Enqueue to dataframe manager buffer for efficiency @@ -1780,29 +1890,29 @@ def save_instance_signals( Worked example — ``batch_ids = ["img7", "img3"]`` (B = 2), 5 boxes total:: - # box: 0 1 2 3 4 - batch_idx = [ 0, 0, 1, 1, 1 ] # boxes 0-1 -> img7, 2-4 -> img3 - ious = [0.91, 0.62, 0.50, 0.74, 0.30] # one IoU per box + # box: 0 1 2 3 4 + batch_idx = [ 0, 0, 1, 1, 1 ] # boxes 0-1 -> img7, 2-4 -> img3 + ious = [0.91, 0.62, 0.50, 0.74, 0.30] # one IoU per box wl.save_instance_signals( - signals={"iou_instance": ious}, # -> signals//iou_instance + signals={"iou_instance": ious}, # -> signals//iou_instance batch_ids=["img7", "img3"], batch_idx=batch_idx, origin="train", ) # writes: - # ("img7", 1)=0.91 ("img7", 2)=0.62 - # ("img3", 1)=0.50 ("img3", 2)=0.74 ("img3", 3)=0.30 + # ("img7", 1)=0.91 ("img7", 2)=0.62 + # ("img3", 1)=0.50 ("img3", 2)=0.74 ("img3", 3)=0.30 Typical detection loop using the Ultralytics batch dict directly:: image, batch_ids, batch = inputs[0], inputs[1], inputs[3]["batch"] raw_preds = model(image) - iou_per_box = compute_iou(raw_preds, batch) # flat [total_instances] + iou_per_box = compute_iou(raw_preds, batch) # flat [total_instances] wl.save_instance_signals( signals={"iou_instance": iou_per_box}, batch_ids=batch_ids, - batch_idx=batch["batch_idx"], # Ultralytics flat index + batch_idx=batch["batch_idx"], # Ultralytics flat index step=current_step, ) @@ -1811,9 +1921,9 @@ def save_instance_signals( in the same per-sample order ``batch_idx`` implies. It is flattened sample-major to align with the instances:: - targets = [ # batch_ids = ["img7", "img3"] - [box7_0, box7_1], # img7's two boxes -> annotation_id 1, 2 - [box3_0, box3_1, box3_2], # img3's three boxes -> annotation_id 1, 2, 3 + targets = [ # batch_ids = ["img7", "img3"] + [box7_0, box7_1], # img7's two boxes -> annotation_id 1, 2 + [box3_0, box3_1, box3_2], # img3's three boxes -> annotation_id 1, 2, 3 ] wl.save_instance_signals(signals={"iou_instance": ious}, batch_ids=["img7", "img3"], @@ -1986,7 +2096,7 @@ def get_active_group_mask( Example:: # Cosine embedding loss — one value per pair in the batch - loss_embed = loss_cosine(e1, e2, y) # shape: (B/2,) + loss_embed = loss_cosine(e1, e2, y) # shape: (B/2,) group_mask = wl.get_active_group_mask(group_ids, origin="train_loader") # Zero out tainted pairs so they don't update weights n_active = group_mask.sum().clamp(min=1) @@ -2009,7 +2119,7 @@ def get_active_group_mask( if gid in tainted: mask[i] = 0.0 except Exception: - pass # Fail-safe: if check fails, treat all groups as active + pass # Fail-safe: if check fails, treat all groups as active return mask @@ -2127,14 +2237,14 @@ def save_group_signals( try: tainted_group_ids = DATAFRAME_M.get_tainted_group_ids(group_ids, origin) except Exception: - pass # Never block training on best-effort discard check + pass # Never block training on best-effort discard check # Broadcast to all members in ledger (skip tainted groups) all_updates = [] active_group_ids = [] for i, gid in enumerate(group_ids): if gid in tainted_group_ids: - continue # Skip: at least one member was discarded; group loss is undefined + continue # Skip: at least one member was discarded; group loss is undefined # We also record the last seen step for all members updates = scalar_signals.copy() @@ -2148,7 +2258,7 @@ def save_group_signals( active_group_ids.append(gid) if not active_group_ids: - return # All groups were tainted; nothing to write + return # All groups were tainted; nothing to write # Bulk update for performance (avoids repeated dataframe scans) DATAFRAME_M.update_by_groups_bulk(origin=origin, group_ids=active_group_ids, updates_list=all_updates) @@ -2221,7 +2331,7 @@ def _unpack_batch(batch, device=None): def _make_default_eval_fn(model): """Return a default evaluation callable that uses all registered ledger signals. - This is used when no ``@wl.eval_fn`` decorator was applied. For every + This is used when no ``@wl.eval_fn`` decorator was applied. For every batch it: 1. Unpacks ``(inputs, targets, ids)`` using a heuristic (tuple/list/dict). @@ -2232,7 +2342,7 @@ def _make_default_eval_fn(model): evaluation-mode buffer. Loss-style signals (wrapped ``forward``) and metric-style signals - (wrapped ``compute``) are both handled. Per-signal errors are silently + (wrapped ``compute``) are both handled. Per-signal errors are silently skipped so a missing target or shape mismatch does not abort the whole evaluation. """ @@ -2283,7 +2393,7 @@ def _default_eval(loader): except Exception: pass - preds = model(inputs) # infer predictions + preds = model(inputs) # infer predictions # Call each registered signal so its wrapped forward/compute # fires and feeds into the evaluation-mode logger buffer. @@ -2347,8 +2457,8 @@ def eval_fn(func): The decorated function receives a single *loader* argument — a ``_EvalManagedLoader`` wrapping the requested split's - ``DataLoaderInterface``. It should iterate that loader and compute - the watched criteria / metrics exactly as in a normal test pass. All + ``DataLoaderInterface``. It should iterate that loader and compute + the watched criteria / metrics exactly as in a normal test pass. All ``add_scalars`` calls are intercepted by the logger's evaluation-mode buffer. @@ -2377,8 +2487,8 @@ def pointcloud_thumbnail(func): own function, e.g. a range/spherical projection: @wl.pointcloud_thumbnail - def to_range_image(points): # points: [M, 2..F] float - return my_range_projection(points) # -> (H, W, 3) uint8 or PIL.Image + def to_range_image(points): # points: [M, 2..F] float + return my_range_projection(points) # -> (H, W, 3) uint8 or PIL.Image (Note: ``@wl.3d_pc_thumb`` isn't valid Python — identifiers can't start with a digit — so the verb is spelled out.) A ``render_thumbnail_2d`` @@ -2398,7 +2508,7 @@ def pointcloud_boxes(func): @wl.pointcloud_boxes def boxes_to_range(boxes): - return my_boxes_in_range_frame(boxes) # -> [N, 6] normalized + return my_boxes_in_range_frame(boxes) # -> [N, 6] normalized A ``project_boxes_2d`` method on the dataset takes precedence. """ @@ -2477,19 +2587,19 @@ def run_pending_evaluation( Can still be called from the training loop with explicit arguments for backwards-compatibility:: - if wl.run_pending_evaluation(): # ledger mode — no args needed + if wl.run_pending_evaluation(): # ledger mode — no args needed continue Args: - loaders: Optional mapping of *loader_name* → ``DataLoaderInterface``. + loaders: Optional mapping of *loader_name* → ``DataLoaderInterface``. When ``None``, the loader is looked up by split name from the ledger. - model: Optional tracked model instance (used to read ``get_age()``). + model: Optional tracked model instance (used to read ``get_age()``). When ``None``, resolved from the ledger. - eval_fn: Optional callable with signature ``eval_fn(loader) -> None``. + eval_fn: Optional callable with signature ``eval_fn(loader) -> None``. When ``None``, the function registered via ``@wl.eval_fn`` is used. - device: Unused; kept for API symmetry. + device: Unused; kept for API symmetry. Returns: ``True`` if an evaluation was executed (caller should ``continue`` @@ -2730,7 +2840,7 @@ def run_pending_evaluation( model_age = 0 try: - model_age = _model.get_age() - 1 if _model is not None and hasattr(_model, "get_age") else 0 # Model anticipates a step after eval, so subtract 1 to report the age corresponding to the just-evaluated checkpoint. + model_age = _model.get_age() - 1 if _model is not None and hasattr(_model, "get_age") else 0 # Model anticipates a step after eval, so subtract 1 to report the age corresponding to the just-evaluated checkpoint. except Exception: pass @@ -2799,42 +2909,42 @@ def run_pending_evaluation( logger.info(f"\n{'='*70}") logger.info(f"[WeightsLab] Evaluation Results") logger.info(f"{'='*70}") - logger.info(f" Split: {split_name}") - logger.info(f" Model Step: {model_age}") - logger.info(f" Tags: {tags}") - logger.info(f" Total Samples: {filtered_count if filtered_count is not None else 'unknown'}") - logger.info(f" Total Batches: {total_batches}") - logger.info(f" Eval Hash: {eval_hash}") + logger.info(f" Split: {split_name}") + logger.info(f" Model Step: {model_age}") + logger.info(f" Tags: {tags}") + logger.info(f" Total Samples: {filtered_count if filtered_count is not None else 'unknown'}") + logger.info(f" Total Batches: {total_batches}") + logger.info(f" Eval Hash: {eval_hash}") if result: - logger.info(f" Metrics:\n") + logger.info(f" Metrics:\n") for k, v in result.items(): if isinstance(v, float): - logger.info(f" {k:30s} = {v:.6f}") + logger.info(f" {k:30s} = {v:.6f}") else: - logger.info(f" {k:30s} = {v}") + logger.info(f" {k:30s} = {v}") else: - logger.info(f" Status: No metrics recorded") + logger.info(f" Status: No metrics recorded") error_msg = ( f"Evaluation did not produce any metrics.\n" - f" Possible causes:\n" - f" • Evaluation function is not compatible with the experiment setup\n" - f" • No signals were computed during evaluation\n" - f" • Model or data loader not registered in the ledger\n\n" - f" Solution: Create a custom evaluation function decorated with @wl.eval_fn.\n" - f" This function should:\n" - f" 1. Accept only one parameter: loader\n" - f" 2. Be fully based on the WeightsLab ledger\n" - f" 3. Retrieve model, device, and metrics from wl.ledger.*\n" - f" 4. Register loss/metric functions with wl.watch_or_edit(..., flag='loss/metric')\n\n" - f" Example from detection use case:\n" - f" @wl.eval_fn\n" - f" def validate(loader):\n" - f" model = wl.ledger.get_model()\n" - f" device = wl.ledger.get_device()\n" - f" for batch in loader:\n" - f" ...\n\n" - f" See documentation: https://grayboxtech.github.io/weightslab/latest/index.html" + f" Possible causes:\n" + f" • Evaluation function is not compatible with the experiment setup\n" + f" • No signals were computed during evaluation\n" + f" • Model or data loader not registered in the ledger\n\n" + f" Solution: Create a custom evaluation function decorated with @wl.eval_fn.\n" + f" This function should:\n" + f" 1. Accept only one parameter: loader\n" + f" 2. Be fully based on the WeightsLab ledger\n" + f" 3. Retrieve model, device, and metrics from wl.ledger.*\n" + f" 4. Register loss/metric functions with wl.watch_or_edit(..., flag='loss/metric')\n\n" + f" Example from detection use case:\n" + f" @wl.eval_fn\n" + f" def validate(loader):\n" + f" model = wl.ledger.get_model()\n" + f" device = wl.ledger.get_device()\n" + f" for batch in loader:\n" + f" ...\n\n" + f" See documentation: https://grayboxtech.github.io/weightslab/latest/index.html" ) logger.warning(error_msg) @@ -2875,7 +2985,7 @@ def _build_eval_allow_list(loader_if, tags: list, split_name: str) -> set: for tag in tags: col = f"{SampleStatsEx.TAG.value}:{tag}" if col in df.columns: - col_mask = df[col] == True # noqa: E712 + col_mask = df[col] == True # noqa: E712 mask = col_mask if mask is None else (mask & col_mask) if mask is None: @@ -3062,7 +3172,7 @@ def __iter__(self): def get_current_experiment_hash() -> str | None: """Return the hash of the currently active experiment run. - Reads the hash from the registered checkpoint manager. Returns ``None`` + Reads the hash from the registered checkpoint manager. Returns ``None`` when no experiment is active or no checkpoint manager has been registered. Example:: @@ -3123,7 +3233,7 @@ def query_sample_history( names = ( [signal_name] if signal_name - else list(_lg._signal_history_per_sample.keys()) + else _lg.list_sample_signal_names() ) results = [] for name in names: @@ -3156,7 +3266,7 @@ def query_instance_history( names = ( [signal_name] if signal_name - else list(_lg._signal_history_per_instance.keys()) + else _lg.list_instance_signal_names() ) results = [] for name in names: @@ -3181,7 +3291,7 @@ def write_history( Parameters ---------- path : str, optional - Output file path **or** directory. When omitted (``None``), the + Output file path **or** directory. When omitted (``None``), the ``root_log_dir`` from the active checkpoint manager is used as the output directory. @@ -3191,14 +3301,14 @@ def write_history( is auto-generated as ``_history.`` inside that directory, where ```` is an 8-character hex MD5 of the normalized call parameters (*type_of_history*, *graph_name*, - *experiment_hash*, *sample_id*, *instance_id*). The same filter + *experiment_hash*, *sample_id*, *instance_id*). The same filter combination always produces the same filename; different filters produce different filenames. - The directory is created automatically if it does not exist. format : {"json", "csv"} Output format (default ``"json"``). type_of_history : {None, "all", "global", "sample", "instance", "instances"} - Which history to include. ``None`` or ``"all"`` writes every type. + Which history to include. ``None`` or ``"all"`` writes every type. ``"global"`` writes the aggregated training-curve history. ``"sample"`` writes per-sample history. ``"instance"`` / ``"instances"`` writes per-instance history. @@ -3206,7 +3316,7 @@ def write_history( Restrict to one or more signal / metric names. experiment_hash : str, optional ``None`` (default) — use the current experiment hash from the - checkpoint manager. ``"all"`` — include every hash. + checkpoint manager. ``"all"`` — include every hash. Any other string — restrict to that specific experiment run. sample_id : str or list of str, optional Restrict per-sample and per-instance rows to one or more sample IDs. @@ -3279,9 +3389,9 @@ def write_history( # --- Normalize all parameters first (needed for the auto-filename hash) --- # Resolve experiment_hash: - # None → use the current hash from the checkpoint manager (default) - # "all" → no filter, include every hash - # any str → filter to that specific hash + # None → use the current hash from the checkpoint manager (default) + # "all" → no filter, include every hash + # any str → filter to that specific hash if experiment_hash is None or experiment_hash == 'last': try: _current = ( @@ -3293,7 +3403,7 @@ def write_history( except Exception: experiment_hash = None elif experiment_hash == "all": - experiment_hash = None # sentinel: skip hash filtering below + experiment_hash = None # sentinel: skip hash filtering below # Normalize graph_name → set or None _gn_filter = None @@ -3352,7 +3462,7 @@ def write_history( instance_rows: list = [] if write_global: - for gn, hashes in _lg._signal_history.items(): + for gn, hashes in _lg.get_signal_history().items(): if _gn_filter is not None and gn not in _gn_filter: continue for h, steps in hashes.items(): @@ -3378,7 +3488,7 @@ def write_history( graphs_s = ( list(_gn_filter) if _gn_filter is not None - else list(_lg._signal_history_per_sample.keys()) + else _lg.list_sample_signal_names() ) for gn in graphs_s: for sid, step, val, h in _lg.query_per_sample( @@ -3400,7 +3510,7 @@ def write_history( graphs_i = ( list(_gn_filter) if _gn_filter is not None - else list(_lg._signal_history_per_instance.keys()) + else _lg.list_instance_signal_names() ) # query_per_instance filters by a single (sample_id, annotation_id); iterate when multiple given _sid_iter = _sid_filter if _sid_filter is not None else [None] @@ -3488,19 +3598,19 @@ def write_dataframe( Parameters ---------- path : str, optional - Output file path **or** directory. When omitted (``None``), the + Output file path **or** directory. When omitted (``None``), the ``root_log_dir`` from the active checkpoint manager is used. - If *path* has a file extension the file is written directly. - If *path* has no extension or is an existing directory, a filename is auto-generated as ``_dataframe.`` inside that directory. ```` is an 8-character MD5 hex digest of the normalized call - parameters (*columns*, *sample_id*, *instance_id*). Same filters → + parameters (*columns*, *sample_id*, *instance_id*). Same filters → same filename (idempotent overwrite); different filters → different file. - The directory is created automatically if it does not exist. format : {"json", "csv"} - Output format. Default ``"json"``. + Output format. Default ``"json"``. columns : str or list of str, optional Which columns to include (index levels ``sample_id`` / ``annotation_id`` are always written). @@ -3513,10 +3623,10 @@ def write_dataframe( - ``"discarded"`` — only the boolean ``discarded`` column. - A list of any mix of the above group names and/or exact column names. sample_id : str or list of str, optional - Restrict to one or more sample IDs (index level 0). ``None`` keeps all. + Restrict to one or more sample IDs (index level 0). ``None`` keeps all. instance_id : int or list of int, optional Restrict to one or more annotation IDs (index level 1, 0 = sample row, - ≥ 1 = per-instance rows). ``None`` keeps all. + ≥ 1 = per-instance rows). ``None`` keeps all. Returns ------- @@ -3526,7 +3636,7 @@ def write_dataframe( Notes ----- The function calls ``flush()`` on the dataframe manager before reading so - that any in-flight writes are included in the output. Pass + that any in-flight writes are included in the output. Pass ``instance_id=0`` to keep only sample-level rows; pass ``instance_id=[1,2]`` to keep specific annotation rows. @@ -3655,7 +3765,7 @@ def write_dataframe( mask = df_out.index.get_level_values(1).astype(int).isin(_iid_set) df_out = df_out.loc[mask] except Exception: - pass # non-integer annotation_ids — skip this filter + pass # non-integer annotation_ids — skip this filter logger.debug("write_dataframe: after instance_id filter → %d row(s).", len(df_out)) # Filter columns by group or exact name @@ -3679,7 +3789,7 @@ def write_dataframe( else: if _item in df_out.columns: _selected.append(_item) - _selected = list(dict.fromkeys(_selected)) # deduplicate, preserve order + _selected = list(dict.fromkeys(_selected)) # deduplicate, preserve order df_out = df_out[_selected] if _selected else df_out[[]] logger.debug("write_dataframe: column filter → %d column(s): %s", len(_selected), _selected) diff --git a/weightslab/tests/backend/test_compare_dataloaders.py b/weightslab/tests/backend/test_compare_dataloaders.py index 0b4d137b..b17469c3 100644 --- a/weightslab/tests/backend/test_compare_dataloaders.py +++ b/weightslab/tests/backend/test_compare_dataloaders.py @@ -8,7 +8,7 @@ import unittest # On Windows, DataLoader workers use spawn: each worker re-imports the heavy -# weightslab package (torch + cv2 + onnx + langchain + cert/banner setup), so a +# weightslab package (torch + onnx + langchain + cert/banner setup), so a # multi-worker loader takes far longer than any sane test timeout. These tests # are meaningful on Linux/CI (cheap fork workers); skip the num_workers>0 cases # on Windows. Single-worker correctness still runs everywhere. @@ -62,7 +62,7 @@ def setUp(self): # worker parallelism measurable for the multi-worker throughput test. self.dataset_size = 256 self.batch_size = 32 - self.delay_per_sample = 0.01 # 10ms per sample to justify worker overhead + self.delay_per_sample = 0.01 # 10ms per sample to justify worker overhead pause_controller.resume() def _create_torch_dataloader(self, num_workers=0): @@ -123,7 +123,7 @@ def test_single_worker_correctness(self): self.assertTrue(torch.equal(torch_target, wl_target), f"Batch {i} target mismatch") - print(f"✓ Single worker: {len(torch_batches)} batches match perfectly") + print(f" Single worker: {len(torch_batches)} batches match perfectly") @_SKIP_MULTIWORKER_ON_WIN def test_multi_worker_correctness(self): @@ -141,7 +141,7 @@ def test_multi_worker_correctness(self): for batch in torch_loader: torch_batches.append(batch) - wl_loader.reset_iterator() # Reset for fresh iteration + wl_loader.reset_iterator() # Reset for fresh iteration for batch in wl_loader: wl_batches.append(batch) @@ -158,67 +158,67 @@ def test_multi_worker_correctness(self): self.assertTrue(torch.allclose(torch_sorted, wl_sorted), "All data samples must be present") - print(f"✓ Multi-worker: {len(torch_batches)} batches, all samples present") + print(f" Multi-worker: {len(torch_batches)} batches, all samples present") # def test_throughput_comparison(self): - # """Compare throughput: single worker vs multi-worker.""" - # print("\n" + "="*70) - # print("TEST: Throughput Comparison") - # print("="*70) - - # results = {} - - # # Torch DataLoader: Single Worker - # torch_loader = self._create_torch_dataloader(num_workers=0) - # start = time.time() - # for _ in torch_loader: - # pass - # torch_single_time = time.time() - start - # results['PyTorch (1 worker)'] = torch_single_time - - # # Torch DataLoader: Multiple Workers - # torch_loader = self._create_torch_dataloader(num_workers=4) - # start = time.time() - # for _ in torch_loader: - # pass - # torch_multi_time = time.time() - start - # results['PyTorch (4 workers)'] = torch_multi_time - - # # WeightsLab DataLoaderInterface: Single Worker - # wl_loader = self._create_weightslab_dataloader(num_workers=0) - # start = time.time() - # for _ in wl_loader: - # pass - # wl_single_time = time.time() - start - # results['WeightsLab (1 worker)'] = wl_single_time - - # # WeightsLab DataLoaderInterface: Multiple Workers - # wl_loader = self._create_weightslab_dataloader(num_workers=4) - # wl_loader.reset_iterator() # Ensure fresh start - # start = time.time() - # for _ in wl_loader: - # pass - # wl_multi_time = time.time() - start - # results['WeightsLab (4 workers)'] = wl_multi_time - - # # Print comparison - # print("\nThroughput Results (loading {} batches):".format(self.dataset_size // self.batch_size)) - # print("-" * 70) - # for name, elapsed in results.items(): - # throughput = (self.dataset_size / self.batch_size) / elapsed if elapsed > 0 else 0 - # print(f"{name:35} {elapsed:8.3f}s ({throughput:6.2f} batches/sec)") - - # print("-" * 70) - # speedup_wl = results['WeightsLab (1 worker)'] / results['WeightsLab (4 workers)'] - # speedup_torch = results['PyTorch (1 worker)'] / results['PyTorch (4 workers)'] - - # print(f"Multi-worker speedup:") - # print(f" PyTorch: {speedup_torch:.2f}x faster") - # print(f" WeightsLab: {speedup_wl:.2f}x faster") - - # # Verify multi-worker is faster than single-worker - # self.assertGreater(wl_single_time, wl_multi_time * 0.8, - # "Multi-worker should be faster or comparable to single-worker") + # """Compare throughput: single worker vs multi-worker.""" + # print("\n" + "="*70) + # print("TEST: Throughput Comparison") + # print("="*70) + + # results = {} + + # # Torch DataLoader: Single Worker + # torch_loader = self._create_torch_dataloader(num_workers=0) + # start = time.time() + # for _ in torch_loader: + # pass + # torch_single_time = time.time() - start + # results['PyTorch (1 worker)'] = torch_single_time + + # # Torch DataLoader: Multiple Workers + # torch_loader = self._create_torch_dataloader(num_workers=4) + # start = time.time() + # for _ in torch_loader: + # pass + # torch_multi_time = time.time() - start + # results['PyTorch (4 workers)'] = torch_multi_time + + # # WeightsLab DataLoaderInterface: Single Worker + # wl_loader = self._create_weightslab_dataloader(num_workers=0) + # start = time.time() + # for _ in wl_loader: + # pass + # wl_single_time = time.time() - start + # results['WeightsLab (1 worker)'] = wl_single_time + + # # WeightsLab DataLoaderInterface: Multiple Workers + # wl_loader = self._create_weightslab_dataloader(num_workers=4) + # wl_loader.reset_iterator() # Ensure fresh start + # start = time.time() + # for _ in wl_loader: + # pass + # wl_multi_time = time.time() - start + # results['WeightsLab (4 workers)'] = wl_multi_time + + # # Print comparison + # print("\nThroughput Results (loading {} batches):".format(self.dataset_size // self.batch_size)) + # print("-" * 70) + # for name, elapsed in results.items(): + # throughput = (self.dataset_size / self.batch_size) / elapsed if elapsed > 0 else 0 + # print(f"{name:35} {elapsed:8.3f}s ({throughput:6.2f} batches/sec)") + + # print("-" * 70) + # speedup_wl = results['WeightsLab (1 worker)'] / results['WeightsLab (4 workers)'] + # speedup_torch = results['PyTorch (1 worker)'] / results['PyTorch (4 workers)'] + + # print(f"Multi-worker speedup:") + # print(f" PyTorch: {speedup_torch:.2f}x faster") + # print(f" WeightsLab: {speedup_wl:.2f}x faster") + + # # Verify multi-worker is faster than single-worker + # self.assertGreater(wl_single_time, wl_multi_time * 0.8, + # "Multi-worker should be faster or comparable to single-worker") @_SKIP_MULTIWORKER_ON_WIN def test_correctness_with_reset(self): @@ -233,7 +233,7 @@ def test_correctness_with_reset(self): first_iteration = [] for i, batch in enumerate(wl_loader): first_iteration.append(batch[0].clone()) - if i >= 5: # Just collect a few batches + if i >= 5: # Just collect a few batches break # Reset and iterate again @@ -250,7 +250,7 @@ def test_correctness_with_reset(self): self.assertTrue(torch.allclose(first, second), f"Batch {i} differs after reset") - print(f"✓ Reset iterator works: {len(first_iteration)} batches verified") + print(f" Reset iterator works: {len(first_iteration)} batches verified") if __name__ == '__main__': diff --git a/weightslab/tests/backend/test_data_loader_interface.py b/weightslab/tests/backend/test_data_loader_interface.py index c77cee72..4da94914 100644 --- a/weightslab/tests/backend/test_data_loader_interface.py +++ b/weightslab/tests/backend/test_data_loader_interface.py @@ -265,17 +265,17 @@ def test_mixed_manual_and_for_loop_iteration(self): Pattern: step = 0 while step < max_steps: - data = next(loader) # Manual iteration with auto-reset after epoch + data = next(loader) # Manual iteration with auto-reset after epoch if step % 5 == 0: - for batches in loader: # For-loop continues from current position - process(batches) # Gets remaining batches, ends with StopIteration + for batches in loader: # For-loop continues from current position + process(batches) # Gets remaining batches, ends with StopIteration step += 1 """ iface = DataLoaderInterface(self.train_ds, batch_size=self.batch_size, is_training=False, compute_hash=True) loader = ledgers.get_dataloader() batches_per_epoch = len(iface.dataloader) step = 0 - max_steps = 30 # Run for multiple epochs + max_steps = 30 # Run for multiple epochs manual_batches_collected = 0 for_loop_batches_collected = 0 @@ -368,10 +368,10 @@ def setUpClass(cls): # Auto register hp = ledgers.get_hyperparams() - hp['ledger_flush_interval'] = 10 # Disable flushing threads for tests - hp['ledger_flush_max_rows'] = 15 # Disable flushing threads for tests - hp['ledger_enable_h5_persistence'] = False # Disable flushing threads for tests - hp['ledger_enable_flushing_threads'] = False # Disable flushing threads for tests + hp['ledger_flush_interval'] = 10 # Disable flushing threads for tests + hp['ledger_flush_max_rows'] = 15 # Disable flushing threads for tests + hp['ledger_enable_h5_persistence'] = False # Disable flushing threads for tests + hp['ledger_enable_flushing_threads'] = False # Disable flushing threads for tests # Set controller to resumed state pause_controller._resume() @@ -429,7 +429,7 @@ def test_rng_reproducibility_with_shuffle(self): # 2. Capture RNG state print("\n2. Capturing RNG state...") rng_state = capture_rng_state() - dataloader.reset_iterator() # Reset to use captured RNG + dataloader.reset_iterator() # Reset to use captured RNG print(f"[OK] RNG state captured and iterator reset") # 3. Generate batches with current RNG @@ -455,65 +455,65 @@ def test_rng_reproducibility_with_shuffle(self): b2_check = np.array_equal(bids_2, bids_2_repeat) print(f"\n{'='*60}") print("Verification:") - print(f" Batch 1 match: {b1_check}") - print(f" Batch 2 match: {b2_check}") + print(f" Batch 1 match: {b1_check}") + print(f" Batch 2 match: {b2_check}") self.assertTrue(b1_check, "First batches should be identical") self.assertTrue(b2_check, "Second batches should be identical") print(f"[OK] RNG reproducibility verified!\n") # TODO (GP): Re-enable once OffsetSampler is implemented and tested # def test_iteration_state_reproducibility_without_shuffle(self): - # """Test dataloader reproducibility without shuffle: capture iteration state → resume identically. - - # With shuffle disabled, RNG is irrelevant. We capture the iteration position - # (number of batches yielded) and restore that position efficiently using - # OffsetSampler to skip samples at the index level without data reprocessing. - # """ - # print(f"\n{'='*60}") - # print("Iteration State Reproducibility - No Shuffle") - # print(f"{'='*60}\n") - - # print("1. Creating dataloader (shuffle=False)...") - # dataloader = DataLoaderInterface( - # self.dataset, - - # batch_size=2, - # shuffle=False, - # num_workers=0 - # ) - # print(f"[OK] DataLoader created (batch_size=2, shuffle=False)") - - # # 2. Consume two batches, then capture state - # print("\n2. Consuming first 2 batches...") - # _, bids_1, _ = next(dataloader) - # _, bids_2, _ = next(dataloader) - # print(f"Batches 1-2: {bids_1}, {bids_2}") - - # iter_state = dataloader.capture_iteration_state() - # print(f"[OK] Iteration state captured: {iter_state}") - - # # 3. Consume next two batches - # print("\n3. Consuming batches 3-4...") - # _, bids_3, _ = next(dataloader) - # _, bids_4, _ = next(dataloader) - # print(f"Batches 3-4: {bids_3}, {bids_4}") - - # # 4. Restore iteration state - # print(f"\n4. Restoring to position after batch 2...") - # dataloader.restore_iteration_state(iter_state) - # print(f"[OK] Iteration state restored (skipped first 2 batches efficiently)") - - # # 5. Generate batches again - should match 3 and 4 - # print("\n5. Generating next batches (should match 3-4)...") - # _, bids_3_repeat, _ = next(dataloader) - # _, bids_4_repeat, _ = next(dataloader) - # print(f"Repeated batches: {bids_3_repeat}, {bids_4_repeat}") - - # # Verify - # print(f"\n{'='*60}") - # print("Verification:") - # print(f" Batch 3 match: {torch.equal(bids_3, bids_3_repeat)}") - # print(f" Batch 4 match: {torch.equal(bids_4, bids_4_repeat)}") - # self.assertTrue(torch.equal(bids_3, bids_3_repeat), "Batch 3 should be identical") - # self.assertTrue(torch.equal(bids_4, bids_4_repeat), "Batch 4 should be identical") - # print(f"[OK] Iteration state reproducibility verified!\n") + # """Test dataloader reproducibility without shuffle: capture iteration state → resume identically. + + # With shuffle disabled, RNG is irrelevant. We capture the iteration position + # (number of batches yielded) and restore that position efficiently using + # OffsetSampler to skip samples at the index level without data reprocessing. + # """ + # print(f"\n{'='*60}") + # print("Iteration State Reproducibility - No Shuffle") + # print(f"{'='*60}\n") + + # print("1. Creating dataloader (shuffle=False)...") + # dataloader = DataLoaderInterface( + # self.dataset, + + # batch_size=2, + # shuffle=False, + # num_workers=0 + # ) + # print(f"[OK] DataLoader created (batch_size=2, shuffle=False)") + + # # 2. Consume two batches, then capture state + # print("\n2. Consuming first 2 batches...") + # _, bids_1, _ = next(dataloader) + # _, bids_2, _ = next(dataloader) + # print(f"Batches 1-2: {bids_1}, {bids_2}") + + # iter_state = dataloader.capture_iteration_state() + # print(f"[OK] Iteration state captured: {iter_state}") + + # # 3. Consume next two batches + # print("\n3. Consuming batches 3-4...") + # _, bids_3, _ = next(dataloader) + # _, bids_4, _ = next(dataloader) + # print(f"Batches 3-4: {bids_3}, {bids_4}") + + # # 4. Restore iteration state + # print(f"\n4. Restoring to position after batch 2...") + # dataloader.restore_iteration_state(iter_state) + # print(f"[OK] Iteration state restored (skipped first 2 batches efficiently)") + + # # 5. Generate batches again - should match 3 and 4 + # print("\n5. Generating next batches (should match 3-4)...") + # _, bids_3_repeat, _ = next(dataloader) + # _, bids_4_repeat, _ = next(dataloader) + # print(f"Repeated batches: {bids_3_repeat}, {bids_4_repeat}") + + # # Verify + # print(f"\n{'='*60}") + # print("Verification:") + # print(f" Batch 3 match: {torch.equal(bids_3, bids_3_repeat)}") + # print(f" Batch 4 match: {torch.equal(bids_4, bids_4_repeat)}") + # self.assertTrue(torch.equal(bids_3, bids_3_repeat), "Batch 3 should be identical") + # self.assertTrue(torch.equal(bids_4, bids_4_repeat), "Batch 4 should be identical") + # print(f"[OK] Iteration state reproducibility verified!\n") diff --git a/weightslab/tests/backend/test_instance_signal_logger.py b/weightslab/tests/backend/test_instance_signal_logger.py index 60e6613c..9d4b8dcc 100644 --- a/weightslab/tests/backend/test_instance_signal_logger.py +++ b/weightslab/tests/backend/test_instance_signal_logger.py @@ -341,8 +341,9 @@ def test_sample_index_built_on_add(self): lg = _fresh_logger() lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"img0": 0.5, "img1": 0.3}, aggregate_by_step=False) - idx = lg._sample_index.get("loss", {}) - self.assertTrue(any("img0" in h_idx for h_idx in idx.values())) + rows = lg.query_per_sample("loss", sample_ids=["img0"]) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], "img0") def test_sample_index_points_to_correct_rows(self): lg = _fresh_logger() @@ -350,28 +351,24 @@ def test_sample_index_points_to_correct_rows(self): signal_per_sample={"img0": 0.5}, aggregate_by_step=False) lg.add_scalars("loss", {"loss": 0.3}, 2, signal_per_sample={"img0": 0.3}, aggregate_by_step=False) - # img0 appears twice — both rows should be indexed - h = list(lg._sample_index["loss"].keys())[0] - rows = lg._sample_index["loss"][h]["img0"] + # img0 appears twice — both rows should be returned, at steps 1 and 2 + rows = lg.query_per_sample("loss", sample_ids=["img0"]) self.assertEqual(len(rows), 2) - buf = list(lg._signal_history_per_sample["loss"].values())[0] - self.assertEqual(buf["steps"][rows[0]], 1) - self.assertEqual(buf["steps"][rows[1]], 2) + self.assertEqual({r[1] for r in rows}, {1, 2}) def test_instance_index_built_on_add(self): lg = _fresh_logger() lg.add_instance_scalars("iou", ["s0", "s0", "s1"], [1, 2, 1], [0.9, 0.8, 0.7], 5, "h1") - idx = lg._instance_index["iou"]["h1"] - self.assertIn(("s0", 1), idx) - self.assertIn(("s0", 2), idx) - self.assertIn(("s1", 1), idx) + keys = {(r[0], r[1]) for r in lg.query_per_instance("iou")} + self.assertIn(("s0", 1), keys) + self.assertIn(("s0", 2), keys) + self.assertIn(("s1", 1), keys) def test_instance_index_points_to_correct_values(self): lg = _fresh_logger() lg.add_instance_scalars("iou", ["s0", "s0"], [1, 1], [0.9, 0.8], 5, "h1") - # Same (s0, 1) at two different steps → two rows - idx = lg._instance_index["iou"]["h1"] - rows = idx[("s0", 1)] + # Same (s0, 1) recorded twice → two rows returned + rows = lg.query_per_instance("iou", sample_id="s0", annotation_id=1) self.assertEqual(len(rows), 2) def test_sample_index_rebuilt_after_snapshot_load(self): @@ -381,8 +378,8 @@ def test_sample_index_rebuilt_after_snapshot_load(self): snap = lg.save_snapshot() lg2 = _fresh_logger() lg2.load_snapshot(snap) - self.assertIn("img0", list(lg2._sample_index.get("loss", {}).values())[0]) - self.assertIn("img1", list(lg2._sample_index.get("loss", {}).values())[0]) + self.assertEqual(len(lg2.query_per_sample("loss", sample_ids=["img0"])), 1) + self.assertEqual(len(lg2.query_per_sample("loss", sample_ids=["img1"])), 1) def test_instance_index_rebuilt_after_snapshot_load(self): lg = _fresh_logger() @@ -390,9 +387,9 @@ def test_instance_index_rebuilt_after_snapshot_load(self): snap = lg.save_snapshot() lg2 = _fresh_logger() lg2.load_snapshot(snap) - idx = lg2._instance_index.get("iou", {}).get("h1", {}) - self.assertIn(("s0", 1), idx) - self.assertIn(("s1", 2), idx) + keys = {(r[0], r[1]) for r in lg2.query_per_instance("iou", exp_hash="h1")} + self.assertIn(("s0", 1), keys) + self.assertIn(("s1", 2), keys) def test_clear_signal_histories_also_clears_indices(self): lg = _fresh_logger() @@ -400,8 +397,8 @@ def test_clear_signal_histories_also_clears_indices(self): signal_per_sample={"img0": 0.5}, aggregate_by_step=False) lg.add_instance_scalars("iou", ["s0"], [1], [0.8], 1, "h1") lg.clear_signal_histories() - self.assertEqual(lg._sample_index, {}) - self.assertEqual(lg._instance_index, {}) + self.assertEqual(lg.query_per_sample("loss"), []) + self.assertEqual(lg.query_per_instance("iou"), []) def test_query_uses_index_not_full_scan(self): """query_per_sample with filter returns correct results via index path.""" @@ -422,10 +419,10 @@ def test_eval_mode_also_updates_sample_index(self): lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"imgA": 0.5, "imgB": 0.3}, aggregate_by_step=False) lg._eval_mode_active = False - idx = lg._sample_index.get("loss", {}).get("eval_h1", {}) - self.assertIn("imgA", idx) - self.assertIn("imgB", idx) - # Query must find them + # Per-sample data was written under the eval hash and is queryable + under_eval = {r[0] for r in lg.query_per_sample("loss", exp_hash="eval_h1")} + self.assertIn("imgA", under_eval) + self.assertIn("imgB", under_eval) rows = lg.query_per_sample("loss", sample_ids=["imgA"]) self.assertEqual(len(rows), 1) @@ -445,9 +442,9 @@ def test_legacy_list_of_dicts_snapshot_rebuilds_index(self): }, } lg.load_snapshot(legacy_snap) - idx = lg._sample_index.get("loss", {}).get("h1", {}) - self.assertIn("img0", idx) - self.assertIn("img1", idx) + under_h1 = {r[0] for r in lg.query_per_sample("loss", exp_hash="h1")} + self.assertIn("img0", under_h1) + self.assertIn("img1", under_h1) rows = lg.query_per_sample("loss", sample_ids=["img0"]) self.assertEqual(len(rows), 1) @@ -456,14 +453,8 @@ def test_multi_exp_hash_filter(self): lg = _fresh_logger() lg.add_scalars("loss", {"loss": 0.1}, 1, signal_per_sample={"s0": 0.1}, aggregate_by_step=False) - # Manually inject a second hash entry to simulate two runs - from array import array as _array - lg._signal_history_per_sample["loss"]["h2"] = { - "sample_ids": ["s0"], - "steps": _array('i', [1]), - "values": _array('f', [0.9]), - } - lg._sample_index.setdefault("loss", {}).setdefault("h2", {})["s0"] = [0] + # Add a second run's data under hash "h2" + lg.ingest_per_sample("loss", "h2", [("s0", 1, 0.9)]) rows_h2 = lg.query_per_sample("loss", sample_ids=["s0"], exp_hash="h2") self.assertEqual(len(rows_h2), 1) self.assertAlmostEqual(rows_h2[0][2], 0.9, places=4) diff --git a/weightslab/tests/backend/test_ledgers.py b/weightslab/tests/backend/test_ledgers.py index 07f679d7..8db41501 100644 --- a/weightslab/tests/backend/test_ledgers.py +++ b/weightslab/tests/backend/test_ledgers.py @@ -32,13 +32,13 @@ def test_default_name_usage(self): # Register without providing name - should use DEFAULT_NAME GLOBAL_LEDGER.register_model(model=d) self.assertIn(DEFAULT_NAME, GLOBAL_LEDGER.list_models()) - got = GLOBAL_LEDGER.get_model() # Should get 'main' by default + got = GLOBAL_LEDGER.get_model() # Should get 'main' by default self.assertIs(got, d) def test_proxy_initialization_pattern(self): """Test that get before register returns Proxy(None), then updates on register.""" # Get before register - should return Proxy(None) - hp = GLOBAL_LEDGER.get_hyperparams() # Uses DEFAULT_NAME + hp = GLOBAL_LEDGER.get_hyperparams() # Uses DEFAULT_NAME # Proxy should exist but not have underlying object yet self.assertEqual(hp.get(), {}) @@ -107,7 +107,7 @@ def test_weak_registration(self): self.assertNotIn("w", names) def test_optimizer_live_update_through_proxy(self): - GLOBAL_LEDGER.get_optimizer('opt_live') # Init opt with a proxy entry + GLOBAL_LEDGER.get_optimizer('opt_live') # Init opt with a proxy entry # define a simple optimizer-like object class DummyOpt: @@ -197,6 +197,29 @@ def test_proxy_get_key_default_mode_returns_live_proxy(self): hp_handle["lr"] = 0.02 self.assertEqual(lr.get(), 0.02) + def test_dict_value_returned_raw_not_proxied(self): + """Dict (and list / callable) values come back RAW, not wrapped in a live + proxy (see Proxy.get's list/dict/callable exclusion), so subscripting just + reads the plain mapping.""" + hp_handle = GLOBAL_LEDGER.get_hyperparams() + GLOBAL_LEDGER.register_hyperparams( + params={"dataset": {"batch_size": 32, "splits": {"train": 0.8}}} + ) + + dataset = hp_handle.get("dataset") + # A dict value is handed back raw, not as a live ValueProxy. + self.assertIsInstance(dataset, dict) + self.assertFalse(hasattr(dataset, "set")) + + # Plain mapping access (and nesting) works. + self.assertEqual(dataset["batch_size"], dataset.get("batch_size")) + self.assertEqual(dataset["batch_size"], 32) + self.assertEqual(dataset["splits"]["train"], 0.8) + + # Missing keys raise KeyError, matching standard dict subscript semantics. + with self.assertRaises(KeyError): + dataset["missing"] + def test_proxy_pickles_and_restores(self): proxy = Proxy({"flag": True, "count": 3}) @@ -231,6 +254,37 @@ def test_proxy_get_key_explicit_plain_value_mode(self): hp_handle["data_root"] = "C:/data/v2" self.assertEqual(data_root, "C:/data/v1") + def test_proxy_yaml_and_json_serialization(self): + """Ledger proxies serialize to their underlying value for both YAML and + JSON, so libraries that dump their config (e.g. Ultralytics' args.yaml, + or JSON audit/config dumps) don't choke on a live hyperparameter proxy.""" + import json + import yaml + + hp = GLOBAL_LEDGER.get_hyperparams() + GLOBAL_LEDGER.register_hyperparams(params={"image_size": 320, "lr": 0.01}) + + img = hp.get("image_size") # a live ValueProxy, not a plain int + self.assertEqual(type(img).__name__, "_ValueProxy") + + # YAML: cover every dumper variant — Ultralytics dumps with CSafeDumper + # (the libyaml C dumper), which keeps its own representer table. + for dumper_name in ("Dumper", "SafeDumper", "CDumper", "CSafeDumper"): + dumper = getattr(yaml, dumper_name, None) + if dumper is None: + continue + self.assertEqual( + yaml.dump({"imgsz": img}, Dumper=dumper).strip(), "imgsz: 320", + f"{dumper_name} did not serialize the proxy", + ) + self.assertEqual( + yaml.safe_load(yaml.safe_dump(hp)), {"image_size": 320, "lr": 0.01} + ) + + # JSON: json.dumps of a scalar proxy and of the whole HP proxy. + self.assertEqual(json.loads(json.dumps({"imgsz": img})), {"imgsz": 320}) + self.assertEqual(json.loads(json.dumps(hp)), {"image_size": 320, "lr": 0.01}) + def test_value_proxy_numeric_comparisons(self): """ValueProxy supports all standard numeric and string comparison operators.""" hp_handle = GLOBAL_LEDGER.get_hyperparams() diff --git a/weightslab/tests/backend/test_logger_core.py b/weightslab/tests/backend/test_logger_core.py index 867f6d37..f1b0cc80 100644 --- a/weightslab/tests/backend/test_logger_core.py +++ b/weightslab/tests/backend/test_logger_core.py @@ -39,6 +39,17 @@ def _add(lg, sig, sid, step, val, aggregate_by_step=False): aggregate_by_step=aggregate_by_step) +def _seed_eval_hash(lg, exp_hash, sig="loss", val=0.5, step=1): + """Write an aggregated marker under *exp_hash* via the evaluation lifecycle. + + Replaces white-box seeding of the (now DuckDB-backed) history dict. + """ + lg.start_evaluation_mode("val", exp_hash) + lg.add_scalars(sig, {sig: val}, step, + signal_per_sample=None, aggregate_by_step=False) + lg.stop_evaluation_mode(model_age=step) + + # --------------------------------------------------------------------------- # 1. __len__ # --------------------------------------------------------------------------- @@ -82,22 +93,24 @@ def test_no_existing_evals_returns_1(self): def test_existing_h1_1_returns_2(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {}} + _seed_eval_hash(lg, "h1_1") self.assertEqual(lg.get_next_evaluation_count("h1"), 2) def test_existing_h1_1_and_h1_3_returns_4(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {}, "h1_3": {}} + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1_3") self.assertEqual(lg.get_next_evaluation_count("h1"), 4) def test_non_int_suffix_is_ignored(self): lg = _lg() - lg._signal_history["loss"] = {"h1_abc": {}, "h1_1": {}} + _seed_eval_hash(lg, "h1_abc") + _seed_eval_hash(lg, "h1_1") self.assertEqual(lg.get_next_evaluation_count("h1"), 2) def test_different_base_hash_not_counted(self): lg = _lg() - lg._signal_history["loss"] = {"h2_5": {}} + _seed_eval_hash(lg, "h2_5") self.assertEqual(lg.get_next_evaluation_count("h1"), 1) @@ -126,7 +139,8 @@ def test_add_scalars_during_eval_goes_to_accum_not_history(self): lg.add_scalars("loss", {"loss": 0.4}, 10, signal_per_sample=None, aggregate_by_step=False) self.assertIn("loss", lg._eval_accum) - self.assertNotIn("loss", lg._signal_history) + # Nothing written to the aggregated history during evaluation mode + self.assertEqual(lg.get_signal_history(), {}) def test_add_scalars_during_eval_accumulates_values(self): lg = _lg() @@ -147,15 +161,15 @@ def test_stop_computes_mean_and_writes_history(self): results = lg.stop_evaluation_mode(model_age=10) self.assertIn("loss", results) self.assertAlmostEqual(results["loss"], 0.5, places=5) - # Written into _signal_history under eval_hash - self.assertIn("h1_1", lg._signal_history.get("loss", {})) + # Written into history under eval_hash + self.assertIn("h1_1", lg.get_signal_history().get("loss", {})) def test_stop_emits_is_evaluation_marker(self): lg = _lg() lg.start_evaluation_mode("val", "h1_1") lg.add_scalars("loss", {"loss": 0.5}, 10, signal_per_sample=None, aggregate_by_step=False) lg.stop_evaluation_mode(model_age=10) - entries = lg._signal_history["loss"]["h1_1"][10] + entries = lg.get_signal_history()["loss"]["h1_1"][10] self.assertTrue(entries[0].get("is_evaluation_marker")) def test_stop_adds_to_pending_queue(self): @@ -181,7 +195,7 @@ def test_stop_when_not_active_returns_empty(self): def test_stop_skips_zero_count_signals(self): lg = _lg() lg.start_evaluation_mode("val", "h1_1") - lg._eval_accum["loss"] = [0.0, 0] # injected directly with count=0 + lg._eval_accum["loss"] = [0.0, 0] # injected directly with count=0 results = lg.stop_evaluation_mode(model_age=1) self.assertNotIn("loss", results) @@ -190,7 +204,7 @@ def test_stop_stores_split_name_and_tags(self): lg.start_evaluation_mode("val", "h1_1", evaluation_tags=["hard", "easy"]) lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample=None, aggregate_by_step=False) lg.stop_evaluation_mode(model_age=1) - entry = lg._signal_history["loss"]["h1_1"][1][0] + entry = lg.get_signal_history()["loss"]["h1_1"][1][0] self.assertEqual(entry["split_name"], "val") self.assertEqual(entry["evaluation_tags"], ["hard", "easy"]) @@ -212,7 +226,7 @@ class TestAbortEvaluationMode(unittest.TestCase): def test_abort_when_not_active_is_noop(self): lg = _lg() - lg.abort_evaluation_mode() # should not raise + lg.abort_evaluation_mode() # should not raise self.assertFalse(lg._eval_mode_active) def test_abort_clears_active_flag_and_accum(self): @@ -230,7 +244,7 @@ def test_abort_removes_per_sample_written_during_eval(self): signal_per_sample={"img0": 0.5}, aggregate_by_step=True) lg.abort_evaluation_mode() # Per-sample history under "h1_1" should be gone - self.assertNotIn("h1_1", lg._signal_history_per_sample.get("loss", {})) + self.assertEqual(lg.query_per_sample("loss", exp_hash="h1_1"), []) def test_abort_removes_queue_entries_for_eval_hash(self): lg = _lg() @@ -239,7 +253,7 @@ def test_abort_removes_queue_entries_for_eval_hash(self): lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample=None, aggregate_by_step=False) # Manually inject a queue entry for the eval hash lg._pending_queue.append({"experiment_hash": "h1_1", "metric_name": "loss"}) - lg._eval_mode_active = True # re-arm + lg._eval_mode_active = True # re-arm lg.abort_evaluation_mode() hashes_in_queue = {e.get("experiment_hash") for e in lg._pending_queue} self.assertNotIn("h1_1", hashes_in_queue) @@ -253,27 +267,30 @@ class TestRemoveEvaluationHash(unittest.TestCase): def test_removes_from_signal_history(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {1: []}, "h1": {1: []}} + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1") lg.remove_evaluation_hash("h1_1") - self.assertNotIn("h1_1", lg._signal_history["loss"]) - self.assertIn("h1", lg._signal_history["loss"]) + hist = lg.get_signal_history()["loss"] + self.assertNotIn("h1_1", hist) + self.assertIn("h1", hist) def test_removes_from_per_sample_history(self): lg = _lg() _add(lg, "loss", "s0", 1, 0.5) - # manually inject an eval hash entry - from array import array as _array - lg._signal_history_per_sample["loss"]["h1_1"] = { - "sample_ids": ["s0"], "steps": _array('i', [1]), "values": _array('f', [0.5]) - } + # write per-sample data under an eval hash via evaluation mode + lg.start_evaluation_mode("val", "h1_1") + lg.add_scalars("loss", {"loss": 0.5}, 1, + signal_per_sample={"s0": 0.5}, aggregate_by_step=True) + lg._eval_mode_active = False + self.assertEqual(len(lg.query_per_sample("loss", exp_hash="h1_1")), 1) lg.remove_evaluation_hash("h1_1") - self.assertNotIn("h1_1", lg._signal_history_per_sample["loss"]) + self.assertEqual(lg.query_per_sample("loss", exp_hash="h1_1"), []) def test_removes_matching_entries_from_queue(self): lg = _lg() lg._pending_queue = [ {"experiment_hash": "h1_1", "metric_name": "loss"}, - {"experiment_hash": "h1", "metric_name": "loss"}, + {"experiment_hash": "h1", "metric_name": "loss"}, ] lg.remove_evaluation_hash("h1_1") self.assertEqual(len(lg._pending_queue), 1) @@ -281,13 +298,13 @@ def test_removes_matching_entries_from_queue(self): def test_empty_hash_is_noop(self): lg = _lg() - lg._signal_history["loss"] = {"h1": {}} + _seed_eval_hash(lg, "h1") lg.remove_evaluation_hash("") - self.assertIn("h1", lg._signal_history["loss"]) + self.assertIn("h1", lg.get_signal_history()["loss"]) def test_missing_hash_does_not_raise(self): lg = _lg() - lg.remove_evaluation_hash("nonexistent_hash_1") # must not raise + lg.remove_evaluation_hash("nonexistent_hash_1") # must not raise # --------------------------------------------------------------------------- @@ -300,8 +317,9 @@ def test_immediate_mode_writes_to_history(self): lg = _lg() lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"s0": 0.5}, aggregate_by_step=False) - self.assertIn(None, lg._signal_history.get("loss", {})) - self.assertEqual(lg._signal_history["loss"][None][1][0]["metric_value"], 0.5) + hist = lg.get_signal_history() + self.assertIn(None, hist.get("loss", {})) + self.assertEqual(hist["loss"][None][1][0]["metric_value"], 0.5) def test_immediate_mode_adds_to_queue(self): lg = _lg() @@ -314,7 +332,7 @@ def test_aggregate_mode_buffers_not_writes(self): lg.add_scalars("loss", {"loss": 0.5}, 1, signal_per_sample={"s0": 0.5}, aggregate_by_step=True) # Not in history yet — buffered - self.assertNotIn("loss", lg._signal_history) + self.assertEqual(lg.get_signal_history(), {}) self.assertIn((1, "loss", None), lg._current_step_buffer) def test_aggregate_mode_step_change_flushes_to_history(self): @@ -327,7 +345,7 @@ def test_aggregate_mode_step_change_flushes_to_history(self): lg.add_scalars("loss", {"loss": 0.9}, 2, signal_per_sample={"s0": 0.9}, aggregate_by_step=True) # Step 1 should now be averaged in history - entries = lg._signal_history["loss"][None][1] + entries = lg.get_signal_history()["loss"][None][1] self.assertAlmostEqual(entries[0]["metric_value"], 0.5, places=5) def test_aggregate_mode_averages_multiple_calls_same_step(self): @@ -338,7 +356,7 @@ def test_aggregate_mode_averages_multiple_calls_same_step(self): # Force flush lg.add_scalars("acc", {"acc": 0.9}, 2, signal_per_sample=None, aggregate_by_step=False) - entries = lg._signal_history["loss"][None][1] + entries = lg.get_signal_history()["loss"][None][1] self.assertAlmostEqual(entries[0]["metric_value"], 0.4, places=5) def test_per_sample_written_even_in_aggregate_mode(self): @@ -375,7 +393,7 @@ def test_adds_new_triples(self): def test_dedup_same_sample_and_step(self): lg = _lg() lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.4)]) - lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # same (sid, step) → ignored + lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # same (sid, step) → ignored rows = lg.query_per_sample("loss") self.assertEqual(len(rows), 1) self.assertAlmostEqual(rows[0][2], 0.4, places=4) @@ -383,7 +401,7 @@ def test_dedup_same_sample_and_step(self): def test_different_step_is_not_dedup(self): lg = _lg() lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.4)]) - lg.ingest_per_sample("loss", "h1", [("s0", 2, 0.9)]) # different step → accepted + lg.ingest_per_sample("loss", "h1", [("s0", 2, 0.9)]) # different step → accepted rows = lg.query_per_sample("loss") self.assertEqual(len(rows), 2) @@ -401,10 +419,10 @@ def test_updates_sample_index(self): def test_dedup_does_not_corrupt_index(self): lg = _lg() lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.4)]) - lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # duplicate ignored - # index should still point to exactly 1 row - idx = lg._sample_index["loss"]["h1"]["s0"] - self.assertEqual(len(idx), 1) + lg.ingest_per_sample("loss", "h1", [("s0", 1, 0.9)]) # duplicate ignored + # exactly one row should remain queryable for (s0, h1) + rows = lg.query_per_sample("loss", sample_ids=["s0"], exp_hash="h1") + self.assertEqual(len(rows), 1) # --------------------------------------------------------------------------- @@ -479,9 +497,10 @@ def test_returns_deepcopy(self): lg = _lg() _add(lg, "loss", "s0", 1, 0.5) hist = lg.get_signal_history() - # Mutate the copy — internal state must not change + # Mutate the returned copy — a fresh read must not reflect the mutation hist["loss"][None][1][0]["metric_value"] = 999.0 - self.assertNotEqual(lg._signal_history["loss"][None][1][0]["metric_value"], 999.0) + fresh = lg.get_signal_history() + self.assertNotEqual(fresh["loss"][None][1][0]["metric_value"], 999.0) def test_empty_when_nothing_added(self): lg = _lg() @@ -531,7 +550,9 @@ def test_empty_when_no_eval(self): def test_returns_eval_hashes(self): lg = _lg() - lg._signal_history["loss"] = {"h1_1": {}, "h1_2": {}, "h1": {}} + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1_2") + _seed_eval_hash(lg, "h1") hashes = lg.get_evaluation_marker_hashes() self.assertIn("h1_1", hashes) self.assertIn("h1_2", hashes) @@ -539,12 +560,15 @@ def test_returns_eval_hashes(self): def test_returns_sorted(self): lg = _lg() - lg._signal_history["loss"] = {"h1_3": {}, "h1_1": {}, "h1_2": {}} + _seed_eval_hash(lg, "h1_3") + _seed_eval_hash(lg, "h1_1") + _seed_eval_hash(lg, "h1_2") self.assertEqual(lg.get_evaluation_marker_hashes(), ["h1_1", "h1_2", "h1_3"]) def test_non_int_suffix_excluded(self): lg = _lg() - lg._signal_history["loss"] = {"h1_abc": {}, "h1_1": {}} + _seed_eval_hash(lg, "h1_abc") + _seed_eval_hash(lg, "h1_1") hashes = lg.get_evaluation_marker_hashes() self.assertNotIn("h1_abc", hashes) self.assertIn("h1_1", hashes) @@ -622,7 +646,7 @@ def test_set_note_on_history_entry(self): _add(lg, "loss", "s0", 5, 0.4) result = lg.set_point_note("loss", "run1", 5, "my note") self.assertTrue(result) - entry = lg._signal_history["loss"]["run1"][5][0] + entry = lg.get_signal_history()["loss"]["run1"][5][0] self.assertEqual(entry["point_note"], "my note") def test_clear_note_with_empty_string(self): @@ -630,7 +654,7 @@ def test_clear_note_with_empty_string(self): _add(lg, "loss", "s0", 5, 0.4) lg.set_point_note("loss", "run1", 5, "my note") lg.set_point_note("loss", "run1", 5, "") - entry = lg._signal_history["loss"]["run1"][5][0] + entry = lg.get_signal_history()["loss"]["run1"][5][0] self.assertNotIn("point_note", entry) def test_updates_pending_queue_entry(self): @@ -649,7 +673,7 @@ def test_nonexistent_step_returns_false(self): def test_does_not_modify_non_matching_queue_entries(self): lg = self._lg_with_hash("run1") _add(lg, "loss", "s0", 5, 0.4) - _add(lg, "acc", "s0", 5, 0.9) + _add(lg, "acc", "s0", 5, 0.9) lg.set_point_note("loss", "run1", 5, "only loss") acc_entry = next(e for e in lg._pending_queue if e["metric_name"] == "acc") self.assertNotIn("point_note", acc_entry) @@ -671,9 +695,10 @@ def test_dict_format_loads_correctly(self): } } }) - self.assertIn("loss", lg._signal_history) - self.assertIn("h1", lg._signal_history["loss"]) - self.assertEqual(lg._signal_history["loss"]["h1"][1][0]["metric_value"], 0.5) + hist = lg.get_signal_history() + self.assertIn("loss", hist) + self.assertIn("h1", hist["loss"]) + self.assertEqual(hist["loss"]["h1"][1][0]["metric_value"], 0.5) def test_dict_format_string_step_key_converted_to_int(self): lg = _lg() @@ -685,7 +710,7 @@ def test_dict_format_string_step_key_converted_to_int(self): } } }) - self.assertIn(42, lg._signal_history["loss"]["h1"]) + self.assertIn(42, lg.get_signal_history()["loss"]["h1"]) def test_list_format_loads_correctly(self): lg = _lg() @@ -693,21 +718,22 @@ def test_list_format_loads_correctly(self): {"metric_name": "acc", "experiment_hash": "h1", "model_age": 3, "metric_value": 0.9, "timestamp": 0}, ]) - self.assertIn("acc", lg._signal_history) - self.assertEqual(lg._signal_history["acc"]["h1"][3][0]["metric_value"], 0.9) + hist = lg.get_signal_history() + self.assertIn("acc", hist) + self.assertEqual(hist["acc"]["h1"][3][0]["metric_value"], 0.9) def test_list_format_skips_entries_without_metric_name(self): lg = _lg() lg.load_signal_history([ {"experiment_hash": "h1", "model_age": 1, "metric_value": 0.5}, ]) - self.assertEqual(lg._signal_history, {}) + self.assertEqual(lg.get_signal_history(), {}) def test_empty_input_is_noop(self): lg = _lg() lg.load_signal_history({}) lg.load_signal_history([]) - self.assertEqual(lg._signal_history, {}) + self.assertEqual(lg.get_signal_history(), {}) def test_adds_to_graph_names(self): lg = _lg() @@ -723,7 +749,7 @@ def test_missing_fields_get_defaults(self): {"metric_name": "loss", "model_age": 5, "metric_value": 0.1}, ]) # experiment_hash defaults to None - self.assertIn(None, lg._signal_history["loss"]) + self.assertIn(None, lg.get_signal_history()["loss"]) if __name__ == "__main__": diff --git a/weightslab/tests/backend/test_ui_docker_bridge.py b/weightslab/tests/backend/test_ui_docker_bridge.py index 01102b3d..c43f2bd5 100644 --- a/weightslab/tests/backend/test_ui_docker_bridge.py +++ b/weightslab/tests/backend/test_ui_docker_bridge.py @@ -150,7 +150,7 @@ def test_make_executable_is_noop_on_windows(self): def test_make_executable_swallows_oserror(self): # A non-chmod-able path (e.g. root-owned system install) must not raise. with patch("weightslab.ui_docker_bridge.os.stat", side_effect=OSError("denied")): - _make_executable("/root/owned.sh") # should not raise + _make_executable("/root/owned.sh") # should not raise @unittest.skipIf(sys.platform == "win32", "execute bit is POSIX-only") def test_ensure_scripts_executable_marks_bundled_scripts(self): @@ -185,15 +185,15 @@ def test_launch_default_no_cert_gen_cleans_and_launches_unsecured( _mock_shell, _gb, mock_mgr, ): mgr = MagicMock() - mgr.has_valid_certs.return_value = False # no certs on disk -> unsecured + mgr.has_valid_certs.return_value = False # no certs on disk -> unsecured mock_mgr.from_env_or_default.return_value = mgr with patch.dict(os.environ, {}, clear=False): os.environ.pop("VITE_PORT", None) with self.assertLogs("weightslab.ui_docker_bridge", level="INFO") as log_context: ui_launch(argparse.Namespace()) mock_check.assert_called_once() - mock_ensure.assert_not_called() # certs NOT generated by default - mock_clean.assert_called_once() # stale cleanup ran + mock_ensure.assert_not_called() # certs NOT generated by default + mock_clean.assert_called_once() # stale cleanup ran mock_compose.assert_called_once_with( "/fake/docker-compose.yml", "/fake/envoy.yaml", @@ -237,13 +237,13 @@ def test_launch_certs_flag_generates_and_runs_secured( _mock_shell, _gb, mock_mgr, ): mgr = MagicMock(certs_dir="/fake/certs") - mgr.has_valid_certs.return_value = True # certs present after generation + mgr.has_valid_certs.return_value = True # certs present after generation mock_mgr.from_env_or_default.return_value = mgr with patch.dict(os.environ, {}, clear=False): os.environ.pop("VITE_PORT", None) with self.assertLogs("weightslab.ui_docker_bridge", level="INFO") as log_context: ui_launch(argparse.Namespace(certs=True)) - mock_ensure.assert_called_once() # --certs generates certs + mock_ensure.assert_called_once() # --certs generates certs self.assertTrue(any("https://localhost:5173" in msg for msg in log_context.output)) @@ -310,7 +310,7 @@ def test_removes_when_present(self, mock_run): def test_noop_when_absent(self, mock_run): mock_run.return_value = MagicMock(stdout="") _remove_docker_image(_FRONTEND_IMAGE) - mock_run.assert_called_once() # only the 'docker images -q' query, no rmi + mock_run.assert_called_once() # only the 'docker images -q' query, no rmi class TestCleanStaleDockerResources(unittest.TestCase): @@ -354,7 +354,7 @@ class TestUiSecureEnvironment(unittest.TestCase): @patch("weightslab.ui_docker_bridge._generate_certs_with_fallback", return_value=0) def test_ui_secure_environment_success(self, mock_gen_certs, mock_cert_manager): """`weightslab se`: generate certs + token, export WEIGHTSLAB_CERTS_DIR.""" - mock_manager_instance = MagicMock() # certs_dir is a MagicMock (supports .mkdir) + mock_manager_instance = MagicMock() # certs_dir is a MagicMock (supports .mkdir) mock_manager_instance.get_or_create_auth_token.return_value = "fake_token" mock_cert_manager.return_value = mock_manager_instance @@ -411,19 +411,19 @@ def test_main_dispatches_start_example(self, mock_example): def test_main_ui_without_action_does_not_crash(self): with patch("sys.argv", ["weightslab", "ui"]): - main() # should print ui help, not raise + main() # should print ui help, not raise def test_main_start_without_target_does_not_crash(self): with patch("sys.argv", ["weightslab", "start"]): - main() # should print start help, not raise + main() # should print start help, not raise def test_main_help_does_not_crash(self): with patch("sys.argv", ["weightslab", "help"]): - main() # should not raise + main() # should not raise def test_main_no_args_does_not_crash(self): with patch("sys.argv", ["weightslab"]): - main() # should not raise + main() # should not raise class TestUserOnboardingFlow(unittest.TestCase): @@ -637,7 +637,7 @@ def test_installs_requirements_non_interactively_when_present(self, mock_run): mock_run.assert_called_once() cmd = mock_run.call_args.args[0] self.assertEqual(cmd[:5], [sys.executable, "-m", "pip", "install", "-r"]) - self.assertIn("--no-input", cmd) # never prompts + self.assertIn("--no-input", cmd) # never prompts self.assertTrue(mock_run.call_args.kwargs.get("check")) @patch("weightslab.ui_docker_bridge.subprocess.run") @@ -659,12 +659,12 @@ def _capture_main(argv): try: main() except SystemExit: - pass # argparse -h/--help exits 0 + pass # argparse -h/--help exits 0 return buf.getvalue() def test_dash_h_shows_banner_and_command_reference(self): out = self._capture_main(["weightslab", "-h"]) - self.assertIn("WeightsLab", out) # tagline from description + self.assertIn("WeightsLab", out) # tagline from description self.assertIn("ui launch", out) self.assertIn("--certs", out) self.assertIn("start example", out) diff --git a/weightslab/tests/backend/test_write_dataframe.py b/weightslab/tests/backend/test_write_dataframe.py index 882b0a21..163eec03 100644 --- a/weightslab/tests/backend/test_write_dataframe.py +++ b/weightslab/tests/backend/test_write_dataframe.py @@ -197,7 +197,7 @@ def test_returns_path_when_no_manager(self, tmp_json): patch("weightslab.src.get_logger", return_value=None): result = write_dataframe(tmp_json) assert result == tmp_json - assert not os.path.isfile(tmp_json) # nothing written + assert not os.path.isfile(tmp_json) # nothing written # --------------------------------------------------------------------------- @@ -264,7 +264,7 @@ def test_sample_id_single(self, mgr, tmp_json): _call(tmp_json, mgr, sample_id="s1") data = json.loads(open(tmp_json).read()) assert all(r["sample_id"] == "s1" for r in data) - assert len(data) == 2 # s1 has annotation_ids 0 and 1 + assert len(data) == 2 # s1 has annotation_ids 0 and 1 def test_sample_id_list(self, mgr, tmp_json): _call(tmp_json, mgr, sample_id=["s1", "s2"]) @@ -277,7 +277,7 @@ def test_instance_id_zero_keeps_sample_rows(self, mgr, tmp_json): _call(tmp_json, mgr, instance_id=0) data = json.loads(open(tmp_json).read()) assert all(r["annotation_id"] == 0 for r in data) - assert len(data) == 2 # s1 and s2 both have annotation_id=0 + assert len(data) == 2 # s1 and s2 both have annotation_id=0 def test_instance_id_list(self, mgr, tmp_json): _call(tmp_json, mgr, instance_id=[1, 2]) diff --git a/weightslab/tests/backend/test_write_history.py b/weightslab/tests/backend/test_write_history.py index 8bc15616..3608fc70 100644 --- a/weightslab/tests/backend/test_write_history.py +++ b/weightslab/tests/backend/test_write_history.py @@ -24,7 +24,7 @@ def _make_logger(): """Return a fresh LoggerQueue with data under two hashes. h1: loss (steps 1+2), acc (step 1), iou instances (annotation_ids 1,2) - h2: loss (step 1), acc (step 1), iou instance (annotation_id 3) ← current hash + h2: loss (step 1), acc (step 1), iou instance (annotation_id 3) ← current hash """ lg = LoggerQueue(register=False) ckpt = _mock_chkpt("h1") @@ -33,7 +33,7 @@ def _make_logger(): # h1 data lg.add_scalars("loss", {"loss": 1.0}, 1, signal_per_sample={"s1": 1.0}, aggregate_by_step=False) lg.add_scalars("loss", {"loss": 2.0}, 2, signal_per_sample={"s2": 2.0}, aggregate_by_step=False) - lg.add_scalars("acc", {"acc": 0.9}, 1, signal_per_sample={"s1": 0.9}, aggregate_by_step=False) + lg.add_scalars("acc", {"acc": 0.9}, 1, signal_per_sample={"s1": 0.9}, aggregate_by_step=False) lg.add_instance_scalars("iou", sample_ids=["s1"], annotation_ids=[1], values=[0.8], global_step=1, exp_hash="h1") lg.add_instance_scalars("iou", sample_ids=["s2"], annotation_ids=[2], @@ -42,7 +42,7 @@ def _make_logger(): # h2 data — left as the "current" hash after setup ckpt.get_current_experiment_hash.return_value = "h2" lg.add_scalars("loss", {"loss": 3.0}, 1, signal_per_sample={"s1": 3.0}, aggregate_by_step=False) - lg.add_scalars("acc", {"acc": 0.7}, 1, signal_per_sample={"s2": 0.7}, aggregate_by_step=False) + lg.add_scalars("acc", {"acc": 0.7}, 1, signal_per_sample={"s2": 0.7}, aggregate_by_step=False) lg.add_instance_scalars("iou", sample_ids=["s1"], annotation_ids=[3], values=[0.7], global_step=1, exp_hash="h2") @@ -117,7 +117,7 @@ def test_instance_row_keys(self, lg, tmp_json): def test_file_is_valid_json(self, lg, tmp_json): _call(tmp_json, lg) - json.loads(open(tmp_json).read()) # must not raise + json.loads(open(tmp_json).read()) # must not raise def test_output_file_created(self, lg, tmp_json): _call(tmp_json, lg) @@ -395,7 +395,7 @@ def test_empty_logger_writes_header_only_csv(self, tmp_csv): write_history(tmp_csv, format="csv", experiment_hash="all") with open(tmp_csv, newline="", encoding="utf-8") as fh: rows = list(csv.reader(fh)) - assert len(rows) == 1 # header only + assert len(rows) == 1 # header only def test_case_insensitive_type(self, lg, tmp_json): _call(tmp_json, lg, type_of_history="SAMPLE") diff --git a/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py b/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py index 2ae9118a..412b379f 100644 --- a/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py +++ b/weightslab/tests/chaos_monkeys_utests/test_grpc_chaos_monkey_robustness.py @@ -209,7 +209,7 @@ def run_one_call_with_watchdog(): def _invoke(): try: result_holder["value"] = wrapped.unary_unary(request={}, context=SimpleNamespace()) - except Exception as exc: # expected on first attempt + except Exception as exc: # expected on first attempt error_holder["error"] = exc worker = threading.Thread(target=_invoke, name="WL-Test-gRPC-Worker", daemon=True) diff --git a/weightslab/tests/components/test_checkpoint_workflow.py b/weightslab/tests/components/test_checkpoint_workflow.py index 1f824018..c7f003c7 100644 --- a/weightslab/tests/components/test_checkpoint_workflow.py +++ b/weightslab/tests/components/test_checkpoint_workflow.py @@ -102,7 +102,7 @@ class SimpleCNN(nn.Module): def __init__(self, conv1_out=8, conv2_out=16): super(SimpleCNN, self).__init__() - self.input_shape = (1, 1, 28, 28) # MNIST input shape + self.input_shape = (1, 1, 28, 28) # MNIST input shape self.conv1 = nn.Conv2d(1, conv1_out, kernel_size=3, padding=1) self.pool1 = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(conv1_out, conv2_out, kernel_size=3, padding=1) @@ -272,27 +272,27 @@ def train_epochs(self, model, loader, optimizer, criterion, num_epochs, criterio def check_reproducibility(self, original_loss, reloaded_loss, original_uids=None, reloaded_uids=None, loss_tol=0.1, uids_msg=None): """Common reproducibility check for losses and UIDs""" return - # # Check reproducibility of losses and UIDs + # # Check reproducibility of losses and UIDs # if isinstance(original_loss, (list, tuple)): - # original_loss_sum = sum(original_loss)/len(original_loss) + # original_loss_sum = sum(original_loss)/len(original_loss) # else: - # original_loss_sum = original_loss + # original_loss_sum = original_loss # if isinstance(reloaded_loss, (list, tuple)): - # reloaded_loss_sum = sum(reloaded_loss)/len(reloaded_loss) + # reloaded_loss_sum = sum(reloaded_loss)/len(reloaded_loss) # else: - # reloaded_loss_sum = reloaded_loss + # reloaded_loss_sum = reloaded_loss # loss_diff = abs(original_loss_sum - reloaded_loss_sum) # loss_relative_diff = loss_diff / original_loss_sum if original_loss_sum != 0 else 0 # print(f"[OK] Loss comparison:") - # print(f" Original: {original_loss_sum:.6f}") - # print(f" Reloaded: {reloaded_loss_sum:.6f}") - # print(f" Relative difference: {loss_relative_diff*100:.3f}%") + # print(f" Original: {original_loss_sum:.6f}") + # print(f" Reloaded: {reloaded_loss_sum:.6f}") + # print(f" Relative difference: {loss_relative_diff*100:.3f}%") # self.assertLess(loss_relative_diff, loss_tol, msg=f"Training should be reproducible within {loss_tol*100:.1f}%") # if original_uids is not None and reloaded_uids is not None: - # print(f"[OK] UIDs comparison:") - # print(f" Original: {original_uids}") - # print(f" Reloaded: {reloaded_uids}") - # self.assertListEqual(reloaded_uids, original_uids, msg=uids_msg or "Sample UIDs should match for reproducibility") + # print(f"[OK] UIDs comparison:") + # print(f" Original: {original_uids}") + # print(f" Reloaded: {reloaded_uids}") + # self.assertListEqual(reloaded_uids, original_uids, msg=uids_msg or "Sample UIDs should match for reproducibility") @classmethod def setUpClass(cls): @@ -365,8 +365,8 @@ def setUpClass(cls): download=True, transform=transform ) - mnist_subset = Subset(full_dataset, list(range(10))) # Create subset with 10 samples - cls.dataset = TaggableDataset(mnist_subset) # Wrap in taggable dataset + mnist_subset = Subset(full_dataset, list(range(10))) # Create subset with 10 samples + cls.dataset = TaggableDataset(mnist_subset) # Wrap in taggable dataset # ================= # Initialize Logger @@ -383,7 +383,7 @@ def setUpClass(cls): # Initialize Model # ================ model = SimpleCNN(conv1_out=8, conv2_out=16) - model = register_in_ledger(model, flag="model", device=DEVICE, skip_previous_auto_load=True, compute_dependencies=False) # Compute dependencies is disabled + model = register_in_ledger(model, flag="model", device=DEVICE, skip_previous_auto_load=True, compute_dependencies=False) # Compute dependencies is disabled # ===================== # Initialize DataLoader @@ -518,7 +518,7 @@ def test_01_train_A(self): self.state['uids_a'] = uids_A # Final verbose - print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") + print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") print(f"\n[OK] TEST A PASSED - Initial training completed") # ============================= @@ -546,15 +546,15 @@ def test_02_train_B_model_change(self): # Modify model architecture # TODO (GP): Still have the pb btw model archi. and torch cache # 11/05/2026-11:20:11.207 DEBUG:weightslab.components.global_monitoring:__exit__: Suppressing exception: Function ConvolutionBackward0 returned an invalid gradient at index 2 - got [15] but expected shape compatible with [16] in GuardContext.__exit__ - # model.operate(0, {-1, -2, -3, -4}, 1) # Increase conv1 out channels by 2 - # model.operate(2, {-1}, 2) # Freeze fc1 layer - # model.operate(-2, {}, 3) # Freeze fc1 layer - # # model.operate(-1, {1}, 4) # Reset fc2 layer + # model.operate(0, {-1, -2, -3, -4}, 1) # Increase conv1 out channels by 2 + # model.operate(2, {-1}, 2) # Freeze fc1 layer + # model.operate(-2, {}, 3) # Freeze fc1 layer + # # model.operate(-1, {1}, 4) # Reset fc2 layer - # print(f" Conv1: 8 -> 12 channels") - # print(f" Conv2: 16 -> 15 channels") - # print(f" FC1: Frozen") - # print(f" FC2: Reset") + # print(f" Conv1: 8 -> 12 channels") + # print(f" Conv2: 16 -> 15 channels") + # print(f" FC1: Frozen") + # print(f" FC2: Reset") # Update hash here to get hash exp_hash_b, _, changed = self.chkpt_manager.update_experiment_hash(force=True) @@ -587,7 +587,7 @@ def test_02_train_B_model_change(self): # Final verbose print(f"\n[OK] TEST B PASSED - Model architecture updated") - print(f" Final model_age: {model.get_age()}") + print(f" Final model_age: {model.get_age()}") # ======================================================================== # Test: 03_train_C_hyperparams_change @@ -614,7 +614,7 @@ def test_03_train_C_hyperparams_change(self): # Change batch size new_bs = 3 self.config['data']['train_loader']['batch_size'] = new_bs - print(f" Batch size: 2 -> 4") + print(f" Batch size: 2 -> 4") # Update hash exp_hash_c, _, _ = self.chkpt_manager.update_experiment_hash() @@ -651,7 +651,7 @@ def test_03_train_C_hyperparams_change(self): # Final verbose print(f"\n[OK] TEST C PASSED - Hyperparameters updated") - print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") + print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") # ======================================================================== # Test: 04_train_D_data_change @@ -666,8 +666,8 @@ def test_04_train_D_data_change(self): model = ledgers.get_model() # Data - dataloader = ledgers.get_dataloader() # Get dataloader - dfm = ledgers.get_dataframe() # Get dataframe manager + dataloader = ledgers.get_dataloader() # Get dataloader + dfm = ledgers.get_dataframe() # Get dataframe manager # Optimizer and criterion optimizer = ledgers.get_optimizer() @@ -686,8 +686,8 @@ def test_04_train_D_data_change(self): rows.append( { SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation - f"{SampleStatsEx.TAG.value}:ugly": True, # Random tag with 'ugly' + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + f"{SampleStatsEx.TAG.value}:ugly": True, # Random tag with 'ugly' SampleStatsEx.DISCARDED.value: bool(1 - dfm.get_df_view()[SampleStatsEx.DISCARDED.value].iloc[idx]) } ) @@ -698,8 +698,8 @@ def test_04_train_D_data_change(self): dfm.upsert_df(df_update, origin='train_loader', force_flush=True) # Changes will be pending - print(f" Added 'ugly' tag to 20 samples") - print(f" Discarded 20 samples") + print(f" Added 'ugly' tag to 20 samples") + print(f" Discarded 20 samples") # Update hash exp_hash_d, _, changed = self.chkpt_manager.update_experiment_hash() @@ -710,7 +710,7 @@ def test_04_train_D_data_change(self): self.assertNotEqual(self.state['exp_hash_c'], exp_hash_d, "Hash should be different") print("\nResuming training for 11 epochs...") - pause_controller.resume() # Pending changes to dump: data state + pause_controller.resume() # Pending changes to dump: data state loss_D, uids_D = self.train_epochs( model, dataloader, optimizer, criterion, num_epochs=self.config['training']['num_epochs'], @@ -739,7 +739,7 @@ def test_04_train_D_data_change(self): # Final verbose print(f"\n[OK] TEST D PASSED - Data state updated") - print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") + print(f" Final model_age (i.e., how many epochs lived by the model): {model.get_age()}") # ======================================================================== # Test: 05_train_E_reload_and_branch @@ -759,7 +759,7 @@ def test_05_train_E_reload_and_branch(self): all_hashes = self.chkpt_manager.get_all_hashes(sort_by='created') print(f"\n[OK] Found {len(all_hashes)} experiment states:") for i, entry in enumerate(all_hashes): - print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") + print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") # Reload state B (second state created) hash_a_from_manifest = self.state['exp_hash_a'] @@ -784,19 +784,19 @@ def test_05_train_E_reload_and_branch(self): if 'data' in hp_reloaded and 'train_loader' in hp_reloaded['data']: hp_reloaded['data']['train_loader']['batch_size'] = 1 old_batch_size = hp_original.get('data', {}).get('train_loader', {}).get('batch_size', 2) - print(f" Batch size: {old_batch_size} -> 1") + print(f" Batch size: {old_batch_size} -> 1") # Discard more data # Add 20 random tags with 'ugly' tagged_samples = random.sample(range(10), 1) rows = [] - dfm = ledgers.get_dataframe() # Get dataframe manager + dfm = ledgers.get_dataframe() # Get dataframe manager for idx in tagged_samples: uid, in_uid = dfm.get_df_view().index[idx] rows.append( { SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation f"{SampleStatsEx.TAG.value}:ugly": True, SampleStatsEx.DISCARDED.value: bool(1 - dfm.get_df_view(SampleStatsEx.DISCARDED.value).iloc[idx]) } @@ -846,7 +846,7 @@ def test_05_train_E_reload_and_branch(self): self.state['uids_e'] = uids_E print(f"\n[OK] TEST E PASSED - Reloaded and generate a new train branch successfully") - print(f" Final model_age: {model.get_age()}") + print(f" Final model_age: {model.get_age()}") # ======================================================================== # Test: 06_reload_before_model_change @@ -857,9 +857,9 @@ def test_06_reload_before_model_change(self): print("TEST 06: Reload Before Model Change - Fix Conv Size with RNG State") print(f"{'='*80}\n") - hash_A_original = self.state['exp_hash_a'] # Before model change - loss_A_original = self.state['losses_a'] # Before model change - uids_A_original = self.state['uids_a'] # Before model change + hash_A_original = self.state['exp_hash_a'] # Before model change + loss_A_original = self.state['losses_a'] # Before model change + uids_A_original = self.state['uids_a'] # Before model change print(f"Reloading state A (before model change) for verification: {hash_A_original[:16]}...") success = self.chkpt_manager.load_state(exp_hash=hash_A_original) @@ -901,7 +901,7 @@ def test_06_reload_before_model_change(self): # Fix model conv size - create new model with different architecture print("\nFixing model architecture...") model = ledgers.get_model() - # model.operate(0, {-1}, 1) # Commented; see test 2 - still have the pb btw model archi. and torch cache + # model.operate(0, {-1}, 1) # Commented; see test 2 - still have the pb btw model archi. and torch cache # model.operate(2, {-1}, 2) # model.operate(-2, {}, 3) # model.operate(-1, {-1 }, 4) @@ -929,9 +929,9 @@ def test_06_reload_before_model_change(self): # Compare: First batch should be same, but losses differ due to different model print(f"\n[OK] Reproducibility verified:") - print(f" Original model first batch loss: {loss_A_reloaded}") - print(f" Fixed model first batch loss: {loss_H}") - print(f" (Same RNG = same batches, different losses due to model change)") + print(f" Original model first batch loss: {loss_A_reloaded}") + print(f" Fixed model first batch loss: {loss_H}") + print(f" (Same RNG = same batches, different losses due to model change)") # Store state self.state['losses_h'] = loss_H @@ -949,7 +949,7 @@ def test_07_change_data_from_test06(self): print("TEST 07: Change Data from Test 06 - Discard More Data") print(f"{'='*80}\n") - hash_H = self.state['exp_hash_h'] # From test 06 + hash_H = self.state['exp_hash_h'] # From test 06 print(f"Starting from state H: {hash_H[:16]}...") @@ -962,7 +962,7 @@ def test_07_change_data_from_test06(self): uid, in_uid = dfm.get_df_view().index[idx] rows.append({ SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation f"{SampleStatsEx.TAG.value}:discard_25pct": True, SampleStatsEx.DISCARDED.value: True }) @@ -1004,7 +1004,7 @@ def test_08_reload_before_data_change_verify_and_modify(self): print("TEST 08: Reload Before Data Change - Verify and Modify Model") print(f"{'='*80}\n") - hash_c = self.state['exp_hash_c'] # Before data change (after HP change) + hash_c = self.state['exp_hash_c'] # Before data change (after HP change) print(f"Part A: Reloading state C and verifying training reproducibility...") print(f"Reloading state C: {hash_c[:16]}...") @@ -1068,7 +1068,7 @@ def test_09_reload_before_hp_change_verify_and_modify(self): print("TEST 09: Reload Before HP Change - Verify and Fix Everything") print(f"{'='*80}\n") - hash_b = self.state['exp_hash_b'] # Before HP change (after model change) + hash_b = self.state['exp_hash_b'] # Before HP change (after model change) loss_b = self.state['losses_b'] print(f"Part A: Reloading state B and verifying training reproducibility...") @@ -1100,11 +1100,11 @@ def test_09_reload_before_hp_change_verify_and_modify(self): # Fix HP hp = ledgers.get_hyperparams() - hp['data']['train_loader']['batch_size'] = 7 # Change batch size + hp['data']['train_loader']['batch_size'] = 7 # Change batch size # Fix model model = ledgers.get_model() - # # model.operate(0, {-3}, 1) # Further modify conv1 + # # model.operate(0, {-3}, 1) # Further modify conv1 # model.operate(-1, {-1 }, 4) # Fix data - discard 5 samples @@ -1115,7 +1115,7 @@ def test_09_reload_before_hp_change_verify_and_modify(self): uid, in_uid = dfm.get_df_view().index[idx] rows.append({ SampleStatsEx.SAMPLE_ID.value: uid, - SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation + SampleStatsEx.INSTANCE_ID.value: in_uid, # For simplicity, use same UID for annotation f"{SampleStatsEx.TAG.value}:discard_fix": True, SampleStatsEx.DISCARDED.value: True }) @@ -1156,7 +1156,7 @@ def test_10_reload_branch_j_verify_reproducibility(self): print("TEST 10: Reload Branch J - Verify Training Reproducibility") print(f"{'='*80}\n") - hash_j = self.state['exp_hash_j'] # From test 08.b + hash_j = self.state['exp_hash_j'] # From test 08.b print(f"Reloading branch J: {hash_j[:16]}...") @@ -1191,7 +1191,7 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): print(f"{'='*80}\n") # Reference variables - target_hash = self.state['exp_hash_d'] # Target is branch_d + target_hash = self.state['exp_hash_d'] # Target is branch_d print(f"Simulating fresh restart: loading everything from config...") print(f"Target state: {target_hash[:16]} (branch_d)") @@ -1215,7 +1215,7 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): print("[OK] Hyperparameters re-registered") # Create fresh model - model_restarted = SimpleCNN(conv1_out=8, conv2_out=16) # Match branch_d architecture + model_restarted = SimpleCNN(conv1_out=8, conv2_out=16) # Match branch_d architecture # # Model arch. and weights are updated at the init of model interface model_restarted = register_in_ledger(model_restarted, flag="model", device=DEVICE) @@ -1257,7 +1257,7 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): all_hashes = self.chkpt_manager.get_all_hashes(sort_by='created') print(f"\n[OK] Found {len(all_hashes)} experiment states:") for i, entry in enumerate(all_hashes): - print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") + print(f" {i+1}. {entry['hash'][:16]}... (created: {entry['created'][:19]})") # Reload state B (second state created) hash_a_from_manifest = self.state['exp_hash_a'] @@ -1271,18 +1271,18 @@ def test_11_restart_from_scratch_to_hash_d_and_verify_reproducibility(self): print(f"[OK] Checkpoint loaded to reach target state {target_hash[:16]}") print("\nTraining for 11 epochs to verify reproducibility...") pause_controller.resume() - model_restarted = ledgers.get_model() # Get model after loading state + model_restarted = ledgers.get_model() # Get model after loading state _, _ = self.train_epochs(model_restarted, dataloader, optimizer_restarted, criterion, num_epochs=self.config['training']['num_epochs'], criterion_bin=criterion_bin) pause_controller.pause() # # Check reproducibility with original loss and UIDs # self.assertEqual(model_restarted.layers[-1].operation_age['FREEZE'], 1, - # "Model architecture should match state in D") + # "Model architecture should match state in D") # self.assertEqual(model_restarted.layers[-1].operation_age['RESET'], 1, - # "Model architecture should match state in D") + # "Model architecture should match state in D") # self.assertEqual(model_restarted.layers[0].out_neurons, 12, - # "Model architecture should match state in D") + # "Model architecture should match state in D") # Not possible as data are generated randomly without reproducibility now # self.check_reproducibility(loss_d_original, loss_d_verify, originals_uids, None, loss_tol=1e-1) diff --git a/weightslab/tests/components/test_global_monitoring_unit.py b/weightslab/tests/components/test_global_monitoring_unit.py index 07650885..ee88c790 100644 --- a/weightslab/tests/components/test_global_monitoring_unit.py +++ b/weightslab/tests/components/test_global_monitoring_unit.py @@ -38,31 +38,31 @@ def test_contextvar_set_and_restore(self): self.assertIn(get_current_context(), {Context.UNKNOWN, Context.TESTING, Context.TRAINING}) # def test_guard_context_training_non_audit(self): - # model = _DummyModel() - # gc = GuardContext(for_training=True) - # gc.model = model + # model = _DummyModel() + # gc = GuardContext(for_training=True) + # gc.model = model - # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ - # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value=None), \ - # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={}): - # gc.__enter__() - # self.assertEqual(get_current_context(), Context.TRAINING) - # self.assertIn(True, model.train_calls) - # result = gc.__exit__(None, None, None) + # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ + # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value=None), \ + # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={}): + # gc.__enter__() + # self.assertEqual(get_current_context(), Context.TRAINING) + # self.assertIn(True, model.train_calls) + # result = gc.__exit__(None, None, None) - # self.assertFalse(result) + # self.assertFalse(result) # def test_guard_context_training_audit_uses_eval(self): - # model = _DummyModel() - # gc = GuardContext(for_training=True) - # gc.model = model - - # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ - # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value="hp"), \ - # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={"auditorMode": True}): - # gc.__enter__() - # self.assertEqual(model.eval_calls, 1) - # gc.__exit__(None, None, None) + # model = _DummyModel() + # gc = GuardContext(for_training=True) + # gc.model = model + + # with patch("weightslab.components.global_monitoring.pause_controller.wait_if_paused"), \ + # patch("weightslab.components.global_monitoring.resolve_hp_name", return_value="hp"), \ + # patch("weightslab.components.global_monitoring.get_hyperparams", return_value={"auditorMode": True}): + # gc.__enter__() + # self.assertEqual(model.eval_calls, 1) + # gc.__exit__(None, None, None) def test_guard_context_suppresses_runtime_error(self): gc = GuardContext(for_training=False) diff --git a/weightslab/tests/data/test_data_samples_with_ops.py b/weightslab/tests/data/test_data_samples_with_ops.py index 93700647..702cd1f8 100644 --- a/weightslab/tests/data/test_data_samples_with_ops.py +++ b/weightslab/tests/data/test_data_samples_with_ops.py @@ -46,7 +46,7 @@ def __getitem__(self, idx): # Return random data with shape (3, 32, 32) to simulate images data = np.random.randn(3, 32, 32).astype(np.float32) uid = str(idx) # Consistent with string UID preference - label = idx % 10 # Simulate 10 classes + label = idx % 10 # Simulate 10 classes return data, uid, label @@ -214,7 +214,7 @@ def test_getitem_returns_data_and_id(self): # Should return tuple with (data, id, target, ...) self.assertIsInstance(result, tuple) - self.assertGreaterEqual(len(result), 3) # data, id, target at minimum + self.assertGreaterEqual(len(result), 3) # data, id, target at minimum # First element should be numpy array or tensor self.assertTrue(isinstance(result[0], (np.ndarray, torch.Tensor))) @@ -315,7 +315,7 @@ def test_binary_tag_labeling_single_tag(self): """ self.temp_dir = tempfile.mkdtemp() - tags_mapping = {"target_tag": 1} # Binary: only 1 tag in mapping + tags_mapping = {"target_tag": 1} # Binary: only 1 tag in mapping wrapper = DataSampleTrackingWrapper( wrapped_dataset=self.dataset, root_log_dir=self.temp_dir, diff --git a/weightslab/tests/data/test_dataframe_manager_unit.py b/weightslab/tests/data/test_dataframe_manager_unit.py index 7329d7fb..4fd4b74c 100644 --- a/weightslab/tests/data/test_dataframe_manager_unit.py +++ b/weightslab/tests/data/test_dataframe_manager_unit.py @@ -99,7 +99,7 @@ def test_enqueue_instance_batch_buffers_records(self): self.assertEqual(rec["signal:bbox_loss"], 0.2) self.assertEqual(rec[SampleStats.Ex.LAST_SEEN.value], 3) self.assertIn(SampleStats.Ex.TARGET.value, rec) - self.assertTrue(mgr._df.empty) # df untouched until flush + self.assertTrue(mgr._df.empty) # df untouched until flush def test_flush_applies_instance_records(self): """Flushing instance records writes per-(sample_id, annotation_id) values.""" @@ -119,7 +119,7 @@ def test_flush_applies_instance_records(self): mgr.flush() result = mgr.get_df_view() - self.assertEqual(len(result), 4) # sample row (0) + 3 instance rows + self.assertEqual(len(result), 4) # sample row (0) + 3 instance rows self.assertAlmostEqual(float(result.loc[("1", 1), "signal:il"]), 0.5) self.assertAlmostEqual(float(result.loc[("1", 2), "signal:il"]), 0.6) self.assertAlmostEqual(float(result.loc[("1", 3), "signal:il"]), 0.7) @@ -178,9 +178,9 @@ def test_multi_instance_expansion(self): # Single sample with 3 instances (detections/annotations) # Use list of arrays to indicate multiple instances target = [ - np.array([10, 20, 30, 40]), # instance 0 - np.array([50, 60, 70, 80]), # instance 1 - np.array([90, 100, 110, 120]) # instance 2 + np.array([10, 20, 30, 40]), # instance 0 + np.array([50, 60, 70, 80]), # instance 1 + np.array([90, 100, 110, 120]) # instance 2 ] df = pd.DataFrame([{ "sample_id": 1, @@ -296,7 +296,7 @@ def test_categorical_memory_optimization(self): self.assertTrue((result_df["metadata"] == "urban").sum() > 0) # Memory usage comparison - # original_bytes = 100 * (len("train") + len("urban")) # Rough estimate + # original_bytes = 100 * (len("train") + len("urban")) # Rough estimate # With categorical: ~100 bytes for codes + ~40 bytes for categories = ~140 bytes # Real compression achieved by pandas @@ -357,7 +357,7 @@ def test_per_sample_buffer_into_multi_index_does_not_corrupt(self): losses={"signals//train/clsf_sample": np.array([0.99])}, step=11, ) - mgr.flush() # Would raise if bug regressed + mgr.flush() # Would raise if bug regressed result = mgr.get_df_view() self.assertAlmostEqual(result.loc[("1", 0), col], 0.99) @@ -385,7 +385,7 @@ def get_index_from_sample_id(self, sid): "origin": "train", SampleStats.Ex.TARGET.value: np.zeros((30, 30), dtype=np.float32), }) - row.name = ("12", 0) # MultiIndex-style row.name + row.name = ("12", 0) # MultiIndex-style row.name # Should not raise and should pass just the sample_id, not the tuple mgr._normalize_arrays_for_storage(row) diff --git a/weightslab/tests/data/test_flush_pipeline.py b/weightslab/tests/data/test_flush_pipeline.py index 0398876b..a6001c5b 100644 --- a/weightslab/tests/data/test_flush_pipeline.py +++ b/weightslab/tests/data/test_flush_pipeline.py @@ -25,7 +25,7 @@ def _make_mgr(flush_max_rows=4, enable_flushing_threads=False) -> LedgeredDataFrameManager: mgr = LedgeredDataFrameManager( - flush_interval=60.0, # disable periodic timer during tests + flush_interval=60.0, # disable periodic timer during tests flush_max_rows=flush_max_rows, enable_flushing_threads=enable_flushing_threads, enable_h5_persistence=False, @@ -125,7 +125,7 @@ class TestFlushAsyncReturnsAfterBufferDrain(unittest.TestCase): """ def test_flush_async_does_not_wait_for_h5(self): - H5_WRITE_DELAY = 1.0 # seconds — intentionally slow + H5_WRITE_DELAY = 1.0 # seconds — intentionally slow mgr = _make_mgr(flush_max_rows=4, enable_flushing_threads=True) @@ -166,7 +166,7 @@ class TestBufferRefillDuringH5Write(unittest.TestCase): def test_training_resumes_after_second_drain(self): FLUSH_MAX = 4 - H5_WRITE_DELAY = 0.3 # seconds + H5_WRITE_DELAY = 0.3 # seconds mgr = _make_mgr(flush_max_rows=FLUSH_MAX, enable_flushing_threads=True) @@ -181,9 +181,9 @@ def counting_slow_h5(*args, **kwargs): second_enqueue_returned = threading.Event() def training_sim(): - _enqueue(mgr, [str(i) for i in range(FLUSH_MAX)]) # fills buffer, triggers flush - time.sleep(0.05) # let flush thread start H5 write - _enqueue(mgr, [str(i) for i in range(FLUSH_MAX, FLUSH_MAX * 2)]) # refill + _enqueue(mgr, [str(i) for i in range(FLUSH_MAX)]) # fills buffer, triggers flush + time.sleep(0.05) # let flush thread start H5 write + _enqueue(mgr, [str(i) for i in range(FLUSH_MAX, FLUSH_MAX * 2)]) # refill second_enqueue_returned.set() with patch.object(mgr, "_flush_to_h5_if_needed", side_effect=counting_slow_h5): diff --git a/weightslab/tests/data/test_h5_array_store.py b/weightslab/tests/data/test_h5_array_store.py index 0307b267..5e05bc47 100644 --- a/weightslab/tests/data/test_h5_array_store.py +++ b/weightslab/tests/data/test_h5_array_store.py @@ -210,7 +210,7 @@ def test_clean_write_leaves_no_temp_or_backup(self): def test_recover_safe_on_empty_directory(self): """recover() must not raise when arrays.h5 does not exist yet.""" store = self._make_store() - store.recover() # Should complete without error + store.recover() # Should complete without error if __name__ == "__main__": diff --git a/weightslab/tests/data/test_h5_dataframe_store.py b/weightslab/tests/data/test_h5_dataframe_store.py index 94f38e52..3c6457a9 100644 --- a/weightslab/tests/data/test_h5_dataframe_store.py +++ b/weightslab/tests/data/test_h5_dataframe_store.py @@ -108,8 +108,8 @@ def test_categorical_tags_preservation(self): df = pd.DataFrame({ 'sample_id': [1, 2, 3], 'brightness': [0.75, 0.82, 0.65], - 'tag:quality': ['high', 'low', 'high'], # String tag - 'tag:outdoor': [True, False, True], # Boolean tag + 'tag:quality': ['high', 'low', 'high'], # String tag + 'tag:outdoor': [True, False, True], # Boolean tag }).set_index('sample_id') # Write (should optimize to categorical) @@ -194,7 +194,7 @@ def test_upsert_merge_multi_index(self): # Update with new data for same sample but different annotation df2 = pd.DataFrame({ - 'brightness': [0.80], # Update brightness for annotation 1 + 'brightness': [0.80], # Update brightness for annotation 1 'iou': [0.60], }) df2.index = pd.MultiIndex.from_arrays( diff --git a/weightslab/tests/data/test_point_cloud_utils.py b/weightslab/tests/data/test_point_cloud_utils.py index 8dac4a6b..62243a9d 100644 --- a/weightslab/tests/data/test_point_cloud_utils.py +++ b/weightslab/tests/data/test_point_cloud_utils.py @@ -65,7 +65,7 @@ def test_is_point_cloud_task(): def test_is_point_cloud_detection_task(): assert is_point_cloud_detection_task("detection_pointcloud") assert is_point_cloud_detection_task("Detection_PointCloud") - assert is_point_cloud_detection_task("detection_3d") # legacy alias + assert is_point_cloud_detection_task("detection_3d") # legacy alias assert not is_point_cloud_detection_task("detection") assert not is_point_cloud_detection_task("segmentation") assert not is_point_cloud_detection_task(None) @@ -77,10 +77,10 @@ def test_looks_like_point_cloud(): assert looks_like_point_cloud(_cloud()[:, :2]) # Multi-channel clouds (xyz + intensity + normals + rgb = 10 cols) qualify. assert looks_like_point_cloud(np.zeros((100, 10), np.float32)) - assert not looks_like_point_cloud(_cloud()[:8]) # too few rows - assert not looks_like_point_cloud(np.zeros((100, 20), np.float32)) # too many cols - assert not looks_like_point_cloud(np.zeros((64, 64), np.uint8)) # int image - assert not looks_like_point_cloud(np.zeros((64, 64, 3), np.float32)) # 3D array + assert not looks_like_point_cloud(_cloud()[:8]) # too few rows + assert not looks_like_point_cloud(np.zeros((100, 20), np.float32)) # too many cols + assert not looks_like_point_cloud(np.zeros((64, 64), np.uint8)) # int image + assert not looks_like_point_cloud(np.zeros((64, 64, 3), np.float32)) # 3D array def test_point_distances(): @@ -97,7 +97,7 @@ def test_compute_point_normals_planar(): normals = compute_point_normals(pts, k=12) assert normals.shape == (500, 3) np.testing.assert_allclose(np.linalg.norm(normals, axis=1), 1.0, atol=1e-4) - assert np.abs(normals[:, 2]).mean() > 0.95 # mostly aligned with z + assert np.abs(normals[:, 2]).mean() > 0.95 # mostly aligned with z def test_voxel_downsample_reduces_points(): @@ -106,12 +106,12 @@ def test_voxel_downsample_reduces_points(): out = voxel_downsample(pts, voxel_size=0.25) assert out.shape[1] == 4 assert out.shape[0] < pts.shape[0] - assert out.shape[0] <= 4 ** 3 # at most one point per 0.25 voxel in the unit cube + assert out.shape[0] <= 4 ** 3 # at most one point per 0.25 voxel in the unit cube def test_colorize_from_image(): image = np.zeros((10, 20, 3), np.uint8) - image[:, :, 0] = 255 # all red + image[:, :, 0] = 255 # all red pts = np.array([[1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], np.float32) def project(p): @@ -119,15 +119,15 @@ def project(p): return uv, np.array([True, False]) rgb = colorize_from_image(pts, image, project) - np.testing.assert_allclose(rgb[0], [1.0, 0.0, 0.0], atol=1e-5) # sampled red - np.testing.assert_allclose(rgb[1], [0.5, 0.5, 0.5], atol=1e-5) # invalid -> grey + np.testing.assert_allclose(rgb[0], [1.0, 0.0, 0.0], atol=1e-5) # sampled red + np.testing.assert_allclose(rgb[1], [0.5, 0.5, 0.5], atol=1e-5) # invalid -> grey def test_range_image_shape(): img = point_cloud_to_range_image(_cloud(2000), image_height=48, image_width=256) assert img.size == (256, 48) arr = np.asarray(img) - assert (arr != arr[0, 0]).any() # some points were projected + assert (arr != arr[0, 0]).any() # some points were projected def test_get_point_feature_names_from_dataset_and_default(): @@ -151,7 +151,7 @@ def my_thumb(points): img = render_thumbnail_2d_for_dataset(object(), _cloud()) assert marker["called"] and np.asarray(img)[0, 0, 0] == 9 finally: - register_thumbnail_fn(None) # reset global state + register_thumbnail_fn(None) # reset global state def my_boxes(boxes): return np.zeros((len(boxes), 6), np.float32) @@ -167,7 +167,7 @@ def my_boxes(boxes): def test_filter_valid_points_drops_pads_and_nonfinite(): pts = _cloud(100) - pts[10] = -1000.0 # pad row (all coords at PAD_VALUE) + pts[10] = -1000.0 # pad row (all coords at PAD_VALUE) pts[20, 2] = np.nan out = filter_valid_points(pts) assert out.shape[0] == 98 @@ -234,7 +234,7 @@ def test_project_boxes_min_size_clamp(): def test_project_boxes_2d_rows(): - boxes = np.array([[10.0, 5.0, 2.0, 2.0, 2.0, 0.7]], np.float32) # cx,cy,dx,dy,cls,conf + boxes = np.array([[10.0, 5.0, 2.0, 2.0, 2.0, 0.7]], np.float32) # cx,cy,dx,dy,cls,conf assert boxes_dimensionality(boxes) == 2 bev = project_boxes_to_bev(boxes, PC_RANGE, 0.0) assert bev[0, 4] == 2.0 @@ -311,7 +311,7 @@ def get_items(self, idx, include_metadata=False, include_labels=False, include_i np_img, is_volumetric, shape, pil = load_raw_image_array(PcDataset(), 0) assert not is_volumetric assert pil is not None and pil.mode == "RGB" - assert pil.size[0] == pil.size[1] # square BEV render + assert pil.size[0] == pil.size[1] # square BEV render assert np_img.ndim == 3 and np_img.shape[2] == 3 diff --git a/weightslab/tests/gRPC/test_get_point_cloud.py b/weightslab/tests/gRPC/test_get_point_cloud.py index 1e8ba1ba..e74f6629 100644 --- a/weightslab/tests/gRPC/test_get_point_cloud.py +++ b/weightslab/tests/gRPC/test_get_point_cloud.py @@ -3,7 +3,11 @@ import weightslab.proto.experiment_service_pb2 as pb2 -from weightslab.trainer.services.data_service import DataService +from weightslab.trainer.services.data_service import ( + DataService, + _DEFAULT_POINT_CLOUD_CHUNK_BYTES, + _point_cloud_chunk_bytes, +) PC_RANGE = (0.0, -32.0, -3.0, 64.0, 32.0, 1.0) @@ -84,6 +88,40 @@ def test_get_point_cloud_unknown_sample_fails_gracefully(): assert "not found" in chunks[0].message +def test_point_cloud_chunk_bytes_default(monkeypatch): + monkeypatch.delenv("WL_POINT_CLOUD_CHUNK_BYTES", raising=False) + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES == (1 << 20) + + +def test_point_cloud_chunk_bytes_env_override(monkeypatch): + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "4096") + assert _point_cloud_chunk_bytes() == 4096 + + +def test_point_cloud_chunk_bytes_invalid_falls_back(monkeypatch): + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "not-a-number") + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "0") + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES + monkeypatch.setenv("WL_POINT_CLOUD_CHUNK_BYTES", "-10") + assert _point_cloud_chunk_bytes() == _DEFAULT_POINT_CLOUD_CHUNK_BYTES + + +def test_get_point_cloud_honours_configured_chunk_size(): + """A smaller chunk size splits the same cloud into more (correct) messages.""" + class _SmallChunkService(_StubService): + _POINT_CLOUD_CHUNK_BYTES = 4096 # bytes + + stub = _SmallChunkService(_FakeLidarDataset()) + chunks = _collect(stub, pb2.PointCloudRequest(sample_id="7", origin="train_loader")) + + total_bytes = 50_000 * 4 * 4 + assert len(chunks) > 1 + assert all(len(c.data) <= 4096 for c in chunks) + assert chunks[0].total_chunks == len(chunks) + assert sum(len(c.data) for c in chunks) == total_bytes + + def test_get_point_cloud_non_pointcloud_sample_fails_gracefully(): class ImgDataset(_FakeLidarDataset): def get_items(self, idx, **kwargs): diff --git a/weightslab/tests/gRPC/test_grpc_tag_operations.py b/weightslab/tests/gRPC/test_grpc_tag_operations.py index 0f94643a..b3edc63e 100644 --- a/weightslab/tests/gRPC/test_grpc_tag_operations.py +++ b/weightslab/tests/gRPC/test_grpc_tag_operations.py @@ -164,7 +164,7 @@ def setUpClass(cls): flag="model", device=DEVICE, skip_previous_auto_load=True, - compute_dependencies=False, # dependency analysis is currently disabled + compute_dependencies=False, # dependency analysis is currently disabled ) # Register dataloader @@ -269,7 +269,7 @@ def test_01_add_tags_accumulate(self): response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to add tags: {response.message}") - print(f"✓ Successfully added tag 'test_tag' to 10 samples") + print(f" Successfully added tag 'test_tag' to 10 samples") # Verify tags were added by checking the dataframe df = self.data_service._all_datasets_df @@ -300,7 +300,7 @@ def test_01_add_tags_accumulate(self): self.assertTrue(value, f"Sample {sample_id} should have tag 'test_tag'") - print(f"✓ Verified tag column exists and has correct values") + print(f" Verified tag column exists and has correct values") def test_02_add_multiple_tags(self): """Test adding multiple different tags""" @@ -319,7 +319,7 @@ def test_02_add_multiple_tags(self): response1 = self.data_service.EditDataSample(request1, self.mock_context) self.assertTrue(response1.success) - print(f"✓ Added tag 'difficult' to samples 0-4") + print(f" Added tag 'difficult' to samples 0-4") # Add "outlier" tag to samples 5-9 request2 = pb2.DataEditsRequest( @@ -334,13 +334,13 @@ def test_02_add_multiple_tags(self): response2 = self.data_service.EditDataSample(request2, self.mock_context) self.assertTrue(response2.success) - print(f"✓ Added tag 'outlier' to samples 5-9") + print(f" Added tag 'outlier' to samples 5-9") # Verify both tags exist df = self.data_service._all_datasets_df self.assertIn("tag:difficult", df.columns) self.assertIn("tag:outlier", df.columns) - print(f"✓ Both tag columns exist in dataframe") + print(f" Both tag columns exist in dataframe") def test_03_remove_tag_from_samples(self): """Test removing a tag from specific samples using EDIT_REMOVE""" @@ -359,7 +359,7 @@ def test_03_remove_tag_from_samples(self): response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to remove tag: {response.message}") - print(f"✓ Removed tag 'test_tag' from samples 0-4") + print(f" Removed tag 'test_tag' from samples 0-4") # Verify tag was removed from those samples df = self.data_service._all_datasets_df @@ -408,7 +408,7 @@ def test_03_remove_tag_from_samples(self): self.assertTrue(value, f"Sample {sample_id} should still have tag 'test_tag'") - print(f"✓ Verified tag removal worked correctly") + print(f" Verified tag removal worked correctly") def test_04_delete_entire_tag_column(self): """Test deleting an entire tag column using EDIT_REMOVE with value=-1""" @@ -417,22 +417,22 @@ def test_04_delete_entire_tag_column(self): # Delete the "difficult" tag column completely request = pb2.DataEditsRequest( stat_name="tag:difficult", - float_value=-1, # Signal for column deletion + float_value=-1, # Signal for column deletion string_value="", bool_value=False, type=SampleEditType.EDIT_REMOVE, - samples_ids=["0"], # Just need one sample as reference + samples_ids=["0"], # Just need one sample as reference sample_origins=["test"] ) response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to delete tag column: {response.message}") - print(f"✓ Deleted entire 'difficult' tag column") + print(f" Deleted entire 'difficult' tag column") # Verify column no longer exists df = self.data_service._all_datasets_df self.assertNotIn("tag:difficult", df.columns, "Tag column should be deleted") - print(f"✓ Verified tag column no longer exists in dataframe") + print(f" Verified tag column no longer exists in dataframe") def test_05_deny_listed_operations(self): """Test discarded (discard/restore) operations""" @@ -455,7 +455,7 @@ def test_05_deny_listed_operations(self): stat_name=SampleStatsEx.DISCARDED.value, float_value=0, string_value="", - bool_value=True, # True = discarded + bool_value=True, # True = discarded type=SampleEditType.EDIT_OVERRIDE, samples_ids=sample_ids, sample_origins=origins @@ -463,7 +463,7 @@ def test_05_deny_listed_operations(self): response = self.data_service.EditDataSample(request_discard, self.mock_context) self.assertTrue(response.success, f"Failed to discard samples: {response.message}") - print(f"✓ Marked samples 10-14 as discarded") + print(f" Marked samples 10-14 as discarded") # Verify samples are marked as discarded df = self.data_service._all_datasets_df @@ -488,14 +488,14 @@ def test_05_deny_listed_operations(self): self.assertTrue(value, f"Sample {sample_id} should be discarded") - print(f"✓ Verified samples are discarded") + print(f" Verified samples are discarded") # Now restore samples 10-12 request_restore = pb2.DataEditsRequest( stat_name=SampleStatsEx.DISCARDED.value, float_value=0, string_value="", - bool_value=False, # False = restored + bool_value=False, # False = restored type=SampleEditType.EDIT_OVERRIDE, samples_ids=[str(i) for i in range(10, 13)], sample_origins=["test"] * 3 @@ -503,7 +503,7 @@ def test_05_deny_listed_operations(self): response = self.data_service.EditDataSample(request_restore, self.mock_context) self.assertTrue(response.success, f"Failed to restore samples: {response.message}") - print(f"✓ Restored samples 10-12") + print(f" Restored samples 10-12") # Verify restoration df = self.data_service._all_datasets_df @@ -550,7 +550,7 @@ def test_05_deny_listed_operations(self): self.assertTrue(value, f"Sample {sample_id} should still be discarded") - print(f"✓ Verified restoration worked correctly") + print(f" Verified restoration worked correctly") def test_06_batch_tag_operations(self): """Test batch operations on many samples at once""" @@ -583,7 +583,7 @@ def test_06_batch_tag_operations(self): response = self.data_service.EditDataSample(request, self.mock_context) self.assertTrue(response.success, f"Failed to add batch tag: {response.message}") - print(f"✓ Added 'batch_tag' to 50 samples in one operation") + print(f" Added 'batch_tag' to 50 samples in one operation") # Verify all 50 samples have the tag df = self.data_service._all_datasets_df @@ -614,7 +614,7 @@ def test_06_batch_tag_operations(self): success_count += 1 self.assertGreaterEqual(success_count, 45, f"Expected at least 45 samples to have batch_tag, got {success_count}") - print(f"✓ Verified {success_count}/50 samples have the batch tag") + print(f" Verified {success_count}/50 samples have the batch tag") def test_07_tag_persistence(self): """Test that tags persist and can be queried""" @@ -624,14 +624,14 @@ def test_07_tag_persistence(self): # Count tag columns tag_columns = [col for col in df.columns if col.startswith(f"{SampleStatsEx.TAG.value}:")] - print(f"✓ Found {len(tag_columns)} tag columns: {tag_columns}") + print(f" Found {len(tag_columns)} tag columns: {tag_columns}") self.assertGreater(len(tag_columns), 0, "Should have at least one tag column") # Verify we can query tagged samples for tag_col in tag_columns: tagged_samples = df[df[tag_col] == True] - print(f" - {tag_col}: {len(tagged_samples)} samples") + print(f" - {tag_col}: {len(tagged_samples)} samples") self.assertGreaterEqual(len(tagged_samples), 0) @@ -661,14 +661,14 @@ def run_tests(): print("="*80 + "\n") if result.wasSuccessful(): - print("✅ ALL TESTS PASSED!") + print(" ALL TESTS PASSED!") else: - print("❌ SOME TESTS FAILED") + print(" SOME TESTS FAILED") if result.failures: print("\nFailures:") for test, traceback in result.failures: - print(f" - {test}: {traceback}") + print(f" - {test}: {traceback}") if result.errors: print("\nErrors:") for test, traceback in result.errors: - print(f" - {test}: {traceback}") + print(f" - {test}: {traceback}") diff --git a/weightslab/tests/gRPC/test_grpc_user_actions.py b/weightslab/tests/gRPC/test_grpc_user_actions.py index 05bfc639..1b71da74 100644 --- a/weightslab/tests/gRPC/test_grpc_user_actions.py +++ b/weightslab/tests/gRPC/test_grpc_user_actions.py @@ -280,11 +280,15 @@ def _make_real_data_service(self): # the first call proceeds (mirrors DataService.__init__). ds._update_done = threading.Event() ds._update_done.set() + ds._refresh_in_flight = threading.Lock() # mirrors __init__: bg view-refresh guard ds._df_manager = df_manager ds._all_datasets_df = df.copy() ds._compute_natural_sort = False ds._is_filtered = False ds._last_internals_update_time = 0 + # Thread pool used by GetDataSamples' per-sample path (GetMetaData uses the + # vectorized path and doesn't need it, but GetDataSamples does). + ds._data_executor = ThreadPoolExecutor(max_workers=2) ds._agent = MagicMock() ds._agent.is_ollama_available.return_value = True ds.audit_logger = MagicMock() @@ -426,9 +430,9 @@ def test_grpc_apply_data_query_direct_filter_reduces_view(self): kept = [idx[1] for idx in data_service._all_datasets_df.index.tolist()] self.assertEqual(sorted(kept), ["2", "3"]) - def test_grpc_get_data_samples_returns_scalar_stats(self): - """GetDataSamples must stream per-sample records with requested scalar stats - (no raw images needed for this path).""" + def test_grpc_get_data_samples_excludes_metadata(self): + """GetDataSamples returns image / label / prediction data only — metadata + columns (e.g. 'loss') are now served exclusively by GetMetaData.""" data_service, _ = self._make_real_data_service() servicer = self._make_servicer_with_real_data_service(data_service) @@ -444,10 +448,42 @@ def test_grpc_get_data_samples_returns_scalar_stats(self): self.assertEqual(len(response.data_records), 3) returned_ids = {r.sample_id for r in response.data_records} self.assertEqual(returned_ids, {"1", "2", "3"}) - # The requested 'loss' stat is present on each record. for rec in response.data_records: + names = {s.name for s in rec.data_stats} + # The 'loss' metadata stat must NOT leak through GetDataSamples anymore. + self.assertNotIn("loss", names) + # Rendering flags (origin/task_type/discarded) still travel with image data. + self.assertIn("origin", names) + self.assertIn("discarded", names) + + def test_grpc_get_metadata_returns_names_records_and_modal(self): + """GetMetaData returns whole-dataset column names, per-sample grid metadata + for the slice, and the open modal sample's metadata.""" + data_service, _ = self._make_real_data_service() + servicer = self._make_servicer_with_real_data_service(data_service) + + request = pb2.GetMetaDataRequest( + start_index=0, + records_cnt=10, + modal_sample_id="2", + ) + response = servicer.GetMetaData(request, _MockContext()) + + self.assertTrue(response.success) + # All metadata column names for the whole dataset include 'loss'. + self.assertIn("loss", list(response.all_metadata_names)) + # Grid records cover the slice and carry the 'loss' metadata stat. + self.assertEqual(len(response.grid_records), 3) + returned_ids = {r.sample_id for r in response.grid_records} + self.assertEqual(returned_ids, {"1", "2", "3"}) + for rec in response.grid_records: names = {s.name for s in rec.data_stats} self.assertIn("loss", names) + # Modal record resolves the requested sample_id with its metadata. + self.assertTrue(response.HasField("modal_record")) + self.assertEqual(response.modal_record.sample_id, "2") + modal_names = {s.name for s in response.modal_record.data_stats} + self.assertIn("loss", modal_names) class TestGRPCLoggerOutputIntegration(_TimeoutMixin, unittest.TestCase): @@ -496,16 +532,22 @@ def test_break_by_slices_from_tags_filters_expected_sample(self): {"tag:hard": [True, False]}, index=[11, 12], ) - # break-by-slices reads compact (sample_id, step, value, hash) tuples via - # query_per_sample (filtered by the tag-derived sample_ids), then aggregates - # the matching samples into a single MEAN curve per experiment_hash. + # break-by-slices aggregates the tag-derived sample_ids into a single MEAN + # curve per experiment_hash via aggregate_per_sample_by_step. _pts = [("11", 5, 0.2, "exp-1"), ("12", 5, 0.8, "exp-1")] - def _qps(graph_name, sample_ids=None, exp_hash=None): + def _agg(graph_name, sample_ids=None, exp_hash=None): wanted = {str(s) for s in sample_ids} if sample_ids is not None else None - return [t for t in _pts if wanted is None or str(t[0]) in wanted] + rows = [t for t in _pts if wanted is None or str(t[0]) in wanted] + by_hash: dict = {} + for sid, step, val, h in rows: + by_hash.setdefault(h, {}).setdefault(step, []).append(val) + return { + h: sorted((s, sum(v) / len(v)) for s, v in steps.items()) + for h, steps in by_hash.items() + } - signal_logger.query_per_sample.side_effect = _qps + signal_logger.aggregate_per_sample_by_step.side_effect = _agg signal_logger.get_evaluation_marker_hashes.return_value = [] servicer = ExperimentServiceServicer(exp_service=exp_service) @@ -519,7 +561,7 @@ def _qps(graph_name, sample_ids=None, exp_hash=None): # Only sample 11 is 'hard'-tagged → mean curve over {11} = one aggregated point. self.assertEqual(len(response.points), 1) - self.assertEqual(response.points[0].sample_id, "") # aggregated mean curve + self.assertEqual(response.points[0].sample_id, "") # aggregated mean curve self.assertEqual(response.points[0].metric_name, "test/loss") self.assertAlmostEqual(response.points[0].metric_value, 0.2, places=5) diff --git a/weightslab/tests/general/test_cli.py b/weightslab/tests/general/test_cli.py index 0d34322c..c1d2a9dc 100644 --- a/weightslab/tests/general/test_cli.py +++ b/weightslab/tests/general/test_cli.py @@ -112,7 +112,7 @@ def test_empty_command(self): """Test that empty command returns ok.""" result = _handle_command('') self.assertTrue(result['ok']) - result = _handle_command(' ') + result = _handle_command(' ') self.assertTrue(result['ok']) def test_unknown_command(self): @@ -183,7 +183,7 @@ def test_plot_model_with_model(self): """Test plot_model with registered model.""" # Create a mock model with __str__ method mock_model = MagicMock() - mock_model.__str__ = MagicMock(return_value="Model(\n Layer1\n Layer2\n)") + mock_model.__str__ = MagicMock(return_value="Model(\n Layer1\n Layer2\n)") GLOBAL_LEDGER.register_model(mock_model, name='test_model') @@ -339,90 +339,90 @@ def test_cli_serve_port_binding(self): # TODO (GP): Fix CLI initialization takes too long for integration tests - need to ensure server is fully ready before client tests run, and possibly optimize server startup time for testing purposes # Not working yet - needs check first initialization and teardown of server between tests, and some tweaks to client connection logic to ensure it waits for server to be ready before connecting # class TestCLIIntegration(unittest.TestCase): -# """Integration tests for CLI server-client communication.""" - -# @classmethod -# def setUpClass(cls): -# """Start CLI server for integration tests.""" -# cls.server_info = cli_serve(cli_host='127.0.0.1', cli_port=0, spawn_client=False) -# if not cls.server_info['ok']: -# raise RuntimeError("Failed to start CLI server for integration tests") -# time.sleep(0.2) # Give server time to fully start - -# @classmethod -# def tearDownClass(cls): -# """Stop CLI server after integration tests.""" -# global _server_sock -# if _server_sock: -# try: -# _server_sock.close() -# except Exception: -# pass - -# def _send_command(self, cmd: str) -> dict: -# """Helper to send command to server and get response.""" -# sock = socket.create_connection( -# (self.server_info['host'], self.server_info['port']), -# timeout=5 -# ) -# f = sock.makefile('rwb') - -# # Send command -# f.write((cmd + '\n').encode('utf8')) -# f.flush() - -# # Read response -# response_line = f.readline() -# response = json.loads(response_line.decode('utf8')) - -# f.close() -# sock.close() - -# return response - -# def test_integration_help(self): -# """Test help command through server.""" -# response = self._send_command('help') -# self.assertTrue(response['ok']) -# self.assertIn('commands', response) - -# def test_integration_status(self): -# """Test status command through server.""" -# response = self._send_command('status') -# self.assertTrue(response['ok']) -# self.assertIn('snapshot', response) - -# def test_integration_list_models(self): -# """Test list_models through server.""" -# response = self._send_command('list_models') -# self.assertTrue(response['ok']) -# self.assertIn('models', response) - -# def test_integration_unknown_command(self): -# """Test unknown command through server.""" -# response = self._send_command('invalid_command_xyz') -# self.assertFalse(response['ok']) -# self.assertIn('error', response) - -# def test_integration_quit(self): -# """Test quit command closes connection.""" -# sock = socket.create_connection( -# (self.server_info['host'], self.server_info['port']), -# timeout=5 -# ) -# f = sock.makefile('rwb') - -# # Send quit -# f.write(b'quit\n') -# f.flush() - -# # Read goodbye -# response = json.loads(f.readline().decode('utf8')) -# self.assertTrue(response['ok']) -# self.assertTrue(response.get('bye')) - -# f.close() -# sock.close() +# """Integration tests for CLI server-client communication.""" + +# @classmethod +# def setUpClass(cls): +# """Start CLI server for integration tests.""" +# cls.server_info = cli_serve(cli_host='127.0.0.1', cli_port=0, spawn_client=False) +# if not cls.server_info['ok']: +# raise RuntimeError("Failed to start CLI server for integration tests") +# time.sleep(0.2) # Give server time to fully start + +# @classmethod +# def tearDownClass(cls): +# """Stop CLI server after integration tests.""" +# global _server_sock +# if _server_sock: +# try: +# _server_sock.close() +# except Exception: +# pass + +# def _send_command(self, cmd: str) -> dict: +# """Helper to send command to server and get response.""" +# sock = socket.create_connection( +# (self.server_info['host'], self.server_info['port']), +# timeout=5 +# ) +# f = sock.makefile('rwb') + +# # Send command +# f.write((cmd + '\n').encode('utf8')) +# f.flush() + +# # Read response +# response_line = f.readline() +# response = json.loads(response_line.decode('utf8')) + +# f.close() +# sock.close() + +# return response + +# def test_integration_help(self): +# """Test help command through server.""" +# response = self._send_command('help') +# self.assertTrue(response['ok']) +# self.assertIn('commands', response) + +# def test_integration_status(self): +# """Test status command through server.""" +# response = self._send_command('status') +# self.assertTrue(response['ok']) +# self.assertIn('snapshot', response) + +# def test_integration_list_models(self): +# """Test list_models through server.""" +# response = self._send_command('list_models') +# self.assertTrue(response['ok']) +# self.assertIn('models', response) + +# def test_integration_unknown_command(self): +# """Test unknown command through server.""" +# response = self._send_command('invalid_command_xyz') +# self.assertFalse(response['ok']) +# self.assertIn('error', response) + +# def test_integration_quit(self): +# """Test quit command closes connection.""" +# sock = socket.create_connection( +# (self.server_info['host'], self.server_info['port']), +# timeout=5 +# ) +# f = sock.makefile('rwb') + +# # Send quit +# f.write(b'quit\n') +# f.flush() + +# # Read goodbye +# response = json.loads(f.readline().decode('utf8')) +# self.assertTrue(response['ok']) +# self.assertTrue(response.get('bye')) + +# f.close() +# sock.close() def run_tests(): diff --git a/weightslab/tests/general/test_signals.py b/weightslab/tests/general/test_signals.py index 4c6e5893..0d32127d 100644 --- a/weightslab/tests/general/test_signals.py +++ b/weightslab/tests/general/test_signals.py @@ -57,7 +57,7 @@ def test_signals_with_list_batch_ids(self): mock_gm.return_value = mock_model with patch("weightslab.src.DATAFRAME_M", mock_df): - batch_ids = [20, 21, 22] # list instead of tensor + batch_ids = [20, 21, 22] # list instead of tensor signals = {"loss": 0.3} wl.save_signals( @@ -84,7 +84,7 @@ def test_signals_with_scalar_values(self): with patch("weightslab.src.DATAFRAME_M", mock_df): batch_ids = torch.tensor([30, 31]) signals = { - "loss": 0.25, # scalar float + "loss": 0.25, # scalar float "accuracy": 0.95, "f1": np.float32(0.92) } @@ -462,8 +462,8 @@ def test_save_signals_batch_processing(self, mock_gm, mock_get_dataframe): wl.save_signals( signals={"det_loss": torch.tensor(0.2)}, batch_ids=torch.tensor([6, 7]), - preds=torch.rand((2, 5, 4)), # 5 boxes, 4 coords - targets=torch.rand((2, 4, 4)), # 4 boxes, 4 coords + preds=torch.rand((2, 5, 4)), # 5 boxes, 4 coords + targets=torch.rand((2, 4, 4)), # 4 boxes, 4 coords log=True ) @@ -514,7 +514,7 @@ def test_signals_with_none_batch_ids(self): wl.save_signals( signals=signals, batch_ids=None, - log=False # Don't log without IDs + log=False # Don't log without IDs ) def test_signal_with_mixed_data_types(self): @@ -680,11 +680,11 @@ def test_detection_signals_with_variable_boxes(self, mock_gm, mock_get_dataframe # Variable number of boxes: img1 has 3, img2 has 1, img3 has 5, img4 has 2 preds = [ - torch.tensor([[10, 20, 110, 120], [50, 60, 150, 160], [200, 210, 300, 310]]), # 3 boxes - torch.tensor([[15, 25, 115, 125]]), # 1 box + torch.tensor([[10, 20, 110, 120], [50, 60, 150, 160], [200, 210, 300, 310]]), # 3 boxes + torch.tensor([[15, 25, 115, 125]]), # 1 box torch.tensor([[30, 40, 130, 140], [70, 80, 170, 180], [250, 260, 350, 360], - [100, 110, 200, 210], [180, 190, 280, 290]]), # 5 boxes - torch.tensor([[45, 55, 145, 155], [220, 230, 320, 330]]) # 2 boxes + [100, 110, 200, 210], [180, 190, 280, 290]]), # 5 boxes + torch.tensor([[45, 55, 145, 155], [220, 230, 320, 330]]) # 2 boxes ] targets = [ @@ -800,7 +800,7 @@ def test_signals_for_binary_classification(self, mock_gm, mock_get_dataframe): self.assertTrue(mock_df.enqueue_batch.called) call_kwargs = mock_df.enqueue_batch.call_args[1] losses = call_kwargs['losses'] - self.assertEqual(len(losses), 5) # All signals should be saved + self.assertEqual(len(losses), 5) # All signals should be saved if __name__ == "__main__": unittest.main() diff --git a/weightslab/tests/general/test_signals_wrapping.py b/weightslab/tests/general/test_signals_wrapping.py index 3a307b97..aeac7f48 100644 --- a/weightslab/tests/general/test_signals_wrapping.py +++ b/weightslab/tests/general/test_signals_wrapping.py @@ -58,7 +58,7 @@ def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 32, 3, padding=1) self.pool = nn.AdaptiveAvgPool2d((8, 8)) - self.fc = nn.Linear(32 * 8 * 8, 100) # 100 outputs for bbox/conf + self.fc = nn.Linear(32 * 8 * 8, 100) # 100 outputs for bbox/conf self.task_type = "detection" def forward(self, x): @@ -270,9 +270,9 @@ def test_save_detection_signals_with_variable_boxes(self, mock_gm, mock_get_df): # Variable boxes preds = [ - torch.tensor([[10, 20, 100, 150], [200, 250, 400, 450]]), # 2 boxes - torch.tensor([[15, 25, 110, 160]]), # 1 box - torch.tensor([[30, 40, 130, 140], [70, 80, 170, 180], [250, 260, 350, 360]]) # 3 boxes + torch.tensor([[10, 20, 100, 150], [200, 250, 400, 450]]), # 2 boxes + torch.tensor([[15, 25, 110, 160]]), # 1 box + torch.tensor([[30, 40, 130, 140], [70, 80, 170, 180], [250, 260, 350, 360]]) # 3 boxes ] targets = [ diff --git a/weightslab/tests/integrations/test_explore_mode.py b/weightslab/tests/integrations/test_explore_mode.py new file mode 100644 index 00000000..36daf3ab --- /dev/null +++ b/weightslab/tests/integrations/test_explore_mode.py @@ -0,0 +1,210 @@ +"""Integration test for read-only "explore" mode (``weightslab --logdir``). + +Simulates the real workflow: train a small experiment with weightslab (writing +checkpoints + logger snapshots + the H5 data store to a ``root_log_dir``), then +"kill" the training (clear the ledger, like a fresh process), and finally load +the experiment purely from disk via ``wl.load_experiment_for_explore`` and serve +it read-only. + +Asserts that, after loading: +- the logged history is readable through the real gRPC servicer (the "access the + logs through the UI" requirement); +- the data splits are browsable; +- every mutating action a user must NOT be able to do — start training, change + hyperparameters, load/restore/save weights — is refused, while reads and data + management still work. +""" + +import os +import tempfile +import shutil +import unittest +import warnings + +warnings.filterwarnings("ignore") + +import torch as th +import torch.nn as nn + +import weightslab as wl +import weightslab.proto.experiment_service_pb2 as pb2 +from weightslab.backend import ledgers +from weightslab.backend import explore_mode +from weightslab.components.global_monitoring import ( + guard_training_context, + pause_controller, + start_hp_sync_thread_event, +) +from weightslab.trainer.experiment_context import ExperimentContext +from weightslab.trainer.services.experiment_service import ExperimentService +from weightslab.utils.tools import seed_everything + + +start_hp_sync_thread_event() + + +class _TinyDataset: + """Minimal (data, uid, target) dataset — no downloads, fully synthetic.""" + + def __init__(self, n=8, dim=4, num_classes=3): + g = th.Generator().manual_seed(0) + self._x = th.randn(n, dim, generator=g) + self._y = th.randint(0, num_classes, (n,), generator=g) + + def __len__(self): + return len(self._x) + + def __getitem__(self, idx): + return self._x[idx], th.tensor(idx, dtype=th.long), self._y[idx] + + +class _TinyNet(nn.Module): + def __init__(self, dim=4, num_classes=3): + super().__init__() + self.input_shape = (1, dim) + self.fc = nn.Linear(dim, num_classes) + + def forward(self, x): + return self.fc(x) + + +class ExploreModeTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + seed_everything() + cls.temp_dir = tempfile.mkdtemp(prefix="wl_explore_test_") + cls.root_log_dir = os.path.join(cls.temp_dir, "experiments") + + cls.config = { + "experiment_name": "explore_test", + "device": "cpu", + "root_log_dir": cls.root_log_dir, + "experiment_dump_to_train_steps_ratio": 2, + "data": {"train_loader": {"batch_size": 2, "shuffle": False}}, + "checkpoint_manager": {"dump_model_architecture": True}, + "ledger_enable_flushing_threads": True, + "ledger_enable_h5_persistence": True, + "ledger_flush_max_rows": 2, + "ledger_flush_interval": 1.0, + "serving_grpc": False, + "serving_cli": False, + "optimizer": {"lr": 0.01}, + } + + # ---- Train a small experiment (produces on-disk artifacts) ----------- + pause_controller.pause() + cls.dataset = _TinyDataset() + cls.logger = __import__( + "weightslab.backend.logger", fromlist=["LoggerQueue"] + ).LoggerQueue(register=True) + + cls.config = wl.watch_or_edit( + cls.config, flag="hyperparameters", defaults=cls.config, poll_interval=1.0 + ) + model = wl.watch_or_edit( + _TinyNet(), flag="model", device="cpu", + skip_previous_auto_load=True, compute_dependencies=False, + ) + wl.watch_or_edit( + cls.dataset, flag="data", compute_hash=False, is_training=True, + batch_size=2, shuffle=False, + ) + wl.watch_or_edit( + th.optim.Adam(model.parameters(), lr=0.01), flag="optimizer" + ) + wl.watch_or_edit( + nn.CrossEntropyLoss(reduction="none"), flag="signal", + log=True, name="train/loss", + ) + + cls.chkpt = ledgers.get_checkpoint_manager() + cls.chkpt.update_experiment_hash(first_time=True) + + loader = ledgers.get_dataloader() + optimizer = ledgers.get_optimizer() + criterion = ledgers.get_signal(name="train/loss") + + pause_controller.resume() + for _ in range(6): + with guard_training_context: + inputs, ids, labels = next(loader) + optimizer.zero_grad() + preds_raw = model(inputs) + preds = preds_raw.argmax(dim=1, keepdim=True) + loss = criterion(preds_raw, labels, batch_ids=ids, preds=preds) + loss.mean().backward() + optimizer.step() + pause_controller.pause() + + # Ensure everything is flushed to disk (checkpoints, logger, data). + cls.chkpt.save_model_checkpoint() + cls.chkpt.save_logger_snapshot() + cls.chkpt.save_pending_changes(force=True) + cls.trained_hash = cls.chkpt.get_current_experiment_hash() + + @classmethod + def tearDownClass(cls): + explore_mode.set_explore_mode(False) + ledgers.clear_all() + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + def setUp(self): + # Each test starts from a freshly loaded, read-only explorer (simulates a + # new `weightslab --logdir` process attaching to the killed run). + explore_mode.set_explore_mode(False) + self.summary = wl.load_experiment_for_explore(self.root_log_dir) + self.ctx = ExperimentContext() + self.service = ExperimentService(self.ctx) + + def test_explore_mode_is_enabled_and_experiment_loaded(self): + self.assertTrue(explore_mode.is_explore_mode()) + self.assertTrue(self.summary["has_logger"]) + self.assertIsNotNone(self.summary["experiment_hash"]) + + def test_logger_history_is_readable_through_servicer(self): + resp = self.service.GetLatestLoggerData( + pb2.GetLatestLoggerDataRequest( + request_full_history=True, max_points=1000, break_by_slices=False + ), + None, + ) + # The training above logged "train/loss" each step; it must survive the + # save→fresh-process→load round trip and be visible in the UI. + self.assertGreater(len(resp.points), 0) + + def test_data_is_rehydrated_from_disk(self): + # The persisted H5 data store is rebuilt into the ledger so the sample + # grid is browsable without the original Dataset object. (The split name + # is auto-derived from the dataset, so we don't assert a specific name.) + self.assertTrue(self.summary["origins"], "expected at least one data split") + dfm = ledgers.get_dataframe() + self.assertIsNotNone(dfm) + self.assertEqual(len(dfm.get_df_view()), len(self.dataset)) + + def test_blocks_training_start(self): + resp = self.service.ExperimentCommand( + pb2.TrainerCommand( + hyper_parameter_change=pb2.HyperParameterCommand( + hyper_parameters=pb2.HyperParameters(is_training=True) + ) + ), + None, + ) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_weight_restore(self): + resp = self.service.RestoreCheckpoint( + pb2.RestoreCheckpointRequest(experiment_hash=self.trained_hash), None + ) + self.assertFalse(resp.success) + + def test_reads_still_work(self): + resp = self.service.ExperimentCommand( + pb2.TrainerCommand(get_hyper_parameters=True), None + ) + self.assertTrue(resp.success) + + +if __name__ == "__main__": + unittest.main() diff --git a/weightslab/tests/integrations/test_pytorch_lightning_integration.py b/weightslab/tests/integrations/test_pytorch_lightning_integration.py index 704eee47..6942b1e4 100644 --- a/weightslab/tests/integrations/test_pytorch_lightning_integration.py +++ b/weightslab/tests/integrations/test_pytorch_lightning_integration.py @@ -154,127 +154,127 @@ def tearDown(self): # These 3 next tests were removed as they are covering disabled feature. We can re-enable them once the feature is re-enabled. # def test_proxy_hashable_in_lightning(self): - # """Test that Proxy objects are hashable and work with Lightning's module system.""" - # model = SimpleCNN() - # print(wl.__file__) - # model_wl = wl.watch_or_edit(model, flag="model", device=self.device) + # """Test that Proxy objects are hashable and work with Lightning's module system.""" + # model = SimpleCNN() + # print(wl.__file__) + # model_wl = wl.watch_or_edit(model, flag="model", device=self.device) - # # Test that proxy can be used in sets (requires __hash__) - # proxy_set = {model_wl} - # self.assertIn(model_wl, proxy_set) + # # Test that proxy can be used in sets (requires __hash__) + # proxy_set = {model_wl} + # self.assertIn(model_wl, proxy_set) - # # Test that proxy can be used as dict key - # proxy_dict = {model_wl: "test_value"} - # self.assertEqual(proxy_dict[model_wl], "test_value") + # # Test that proxy can be used as dict key + # proxy_dict = {model_wl: "test_value"} + # self.assertEqual(proxy_dict[model_wl], "test_value") # def test_lightning_module_with_weightslab_tracking(self): - # """Test that Lightning module can be created with WeightsLab tracked objects.""" - # pause_controller.resume(force=True) # Ensure not pausedv - # # Create model and wrap with WeightsLab - # _model = SimpleCNN().to(self.device) - # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) - - # # Create tracked loss and metrics - # criterion = wl.watch_or_edit( - # nn.CrossEntropyLoss(reduction="none"), - # flag="loss", signal_name="loss-CE", log=True - # ) - - # metric = wl.watch_or_edit( - # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), - # flag="metric", signal_name="metric-ACC", log=True - # ) - - # # Create optimizer - # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) - # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") - - # # Create Lightning module with tracked objects - # lit_model = LitTestModel( - # model=model_wl, - # optimizer=optimizer_wl, - # criterion_wl=criterion, - # metric_wl=metric - # ) - - # # Verify Lightning module was created successfully - # self.assertIsInstance(lit_model, pl.LightningModule) - # self.assertIsInstance(lit_model.model, Proxy) + # """Test that Lightning module can be created with WeightsLab tracked objects.""" + # pause_controller.resume(force=True) # Ensure not pausedv + # # Create model and wrap with WeightsLab + # _model = SimpleCNN().to(self.device) + # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) + + # # Create tracked loss and metrics + # criterion = wl.watch_or_edit( + # nn.CrossEntropyLoss(reduction="none"), + # flag="loss", signal_name="loss-CE", log=True + # ) + + # metric = wl.watch_or_edit( + # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), + # flag="metric", signal_name="metric-ACC", log=True + # ) + + # # Create optimizer + # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) + # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") + + # # Create Lightning module with tracked objects + # lit_model = LitTestModel( + # model=model_wl, + # optimizer=optimizer_wl, + # criterion_wl=criterion, + # metric_wl=metric + # ) + + # # Verify Lightning module was created successfully + # self.assertIsInstance(lit_model, pl.LightningModule) + # self.assertIsInstance(lit_model.model, Proxy) # def test_lightning_training_with_weightslab_loaders(self): - # """Test full training loop with WeightsLab tracked data loaders.""" - # pause_controller.resume(force=True) # Ensure not paused - - # # Create tracked loaders - # train_loader = wl.watch_or_edit( - # self.train_dataset, - # flag="data", - # loader_name="train_loader", - # batch_size=16, - # shuffle=True, - # is_training=True, - # compute_hash=False, - # enable_h5_persistence=False - # ) - - # val_loader = wl.watch_or_edit( - # self.val_dataset, - # flag="data", - # loader_name="val_loader", - # batch_size=16, - # shuffle=False, - # is_training=False, - # compute_hash=False, - # enable_h5_persistence=False - # ) - - # # Create model with tracked components - # _model = SimpleCNN().to(self.device) - # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) - - # criterion = wl.watch_or_edit( - # nn.CrossEntropyLoss(reduction="none"), - # flag="loss", signal_name="loss-CE", log=True - # ) - - # metric = wl.watch_or_edit( - # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), - # flag="metric", signal_name="metric-ACC", log=True - # ) - - # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) - # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") - - # lit_model = LitTestModel( - # model=model_wl, - # optimizer=optimizer_wl, - # criterion_wl=criterion, - # metric_wl=metric - # ) - - # # Create Lightning trainer with minimal configuration - # trainer = pl.Trainer( - # max_epochs=2, - # accelerator=self.device if self.device in ["cpu", "cuda"] else "auto", - # devices=1, - # enable_checkpointing=False, - # logger=False, - # enable_progress_bar=False, - # ) - - # # Train the model - this should complete without errors - # try: - # trainer.fit(lit_model, train_loader, val_loader) - # training_succeeded = True - # except Exception as e: - # training_succeeded = False - # self.fail(f"Training failed with error: {e}") - - # self.assertTrue(training_succeeded, "Training should complete successfully") + # """Test full training loop with WeightsLab tracked data loaders.""" + # pause_controller.resume(force=True) # Ensure not paused + + # # Create tracked loaders + # train_loader = wl.watch_or_edit( + # self.train_dataset, + # flag="data", + # loader_name="train_loader", + # batch_size=16, + # shuffle=True, + # is_training=True, + # compute_hash=False, + # enable_h5_persistence=False + # ) + + # val_loader = wl.watch_or_edit( + # self.val_dataset, + # flag="data", + # loader_name="val_loader", + # batch_size=16, + # shuffle=False, + # is_training=False, + # compute_hash=False, + # enable_h5_persistence=False + # ) + + # # Create model with tracked components + # _model = SimpleCNN().to(self.device) + # model_wl = wl.watch_or_edit(_model, flag="model", device=self.device) + + # criterion = wl.watch_or_edit( + # nn.CrossEntropyLoss(reduction="none"), + # flag="loss", signal_name="loss-CE", log=True + # ) + + # metric = wl.watch_or_edit( + # Accuracy(task="multiclass", num_classes=self.n_classes).to(self.device), + # flag="metric", signal_name="metric-ACC", log=True + # ) + + # optimizer = torch.optim.Adam(model_wl.parameters(), lr=0.001) + # optimizer_wl = wl.watch_or_edit(optimizer, flag="optimizer") + + # lit_model = LitTestModel( + # model=model_wl, + # optimizer=optimizer_wl, + # criterion_wl=criterion, + # metric_wl=metric + # ) + + # # Create Lightning trainer with minimal configuration + # trainer = pl.Trainer( + # max_epochs=2, + # accelerator=self.device if self.device in ["cpu", "cuda"] else "auto", + # devices=1, + # enable_checkpointing=False, + # logger=False, + # enable_progress_bar=False, + # ) + + # # Train the model - this should complete without errors + # try: + # trainer.fit(lit_model, train_loader, val_loader) + # training_succeeded = True + # except Exception as e: + # training_succeeded = False + # self.fail(f"Training failed with error: {e}") + + # self.assertTrue(training_succeeded, "Training should complete successfully") def test_weightslab_context_guards_in_lightning(self): """Test that WeightsLab context guards work correctly in Lightning steps.""" - pause_controller.resume(force=True) # Ensure not paused + pause_controller.resume(force=True) # Ensure not paused context_log = [] class ContextTestModule(pl.LightningModule): diff --git a/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py b/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py index 74c45d72..3c268e84 100644 --- a/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py +++ b/weightslab/tests/integrations/ultralytics/ddp/ddp_ablation.py @@ -3,9 +3,9 @@ The fair baseline isn't "no logging" — anyone wanting per-sample signals must decode preds, compute per-sample loss/metrics, and store them. So we compare two modes with identical model / batch / imgsz / data: - ulmanual — ultralytics + a HAND-ROLLED minimal per-sample logger: decode + per-sample + ulmanual — ultralytics + a HAND-ROLLED minimal per-sample logger: decode + per-sample loss/IoU + append the scalars to a plain list. The "classic" way. - wl — full WL pipeline: wrapped model/loss/loader, save_signals, anchor + wl — full WL pipeline: wrapped model/loss/loader, save_signals, anchor (reconcile DOWN + flush UP), decode-for-logging. (wl - ulmanual) = WL's internal machinery (dataframe upserts + ledger/H5 + the DDP @@ -14,7 +14,7 @@ ms/step + rank-0 RSS; `wl` also prints the global dataframe RAM + H5 store sizes. WL_ABLATE=ulmanual WL_DDP_CUDA=1 python ddp_ablation.py - WL_ABLATE=wl WL_DDP_CUDA=1 python ddp_ablation.py + WL_ABLATE=wl WL_DDP_CUDA=1 python ddp_ablation.py """ import os os.environ.setdefault("WEIGHTSLAB_SKIP_SECURE_INIT", "true") @@ -54,7 +54,7 @@ def _rss_mb(): with open("/proc/self/status") as f: for ln in f: if ln.startswith("VmRSS:"): - return int(ln.split()[1]) / 1024.0 # KB -> MB + return int(ln.split()[1]) / 1024.0 # KB -> MB except Exception: pass return -1.0 @@ -134,7 +134,7 @@ def _worker(rank, world, master_port): batch_size = int(os.environ.get("WL_DDP_BATCH", "16")) num_workers = int(os.environ.get("WL_DDP_WORKERS", "0")) - is_wl = MODE == "wl" # else: ulmanual (the hand-rolled classic baseline) + is_wl = MODE == "wl" # else: ulmanual (the hand-rolled classic baseline) if is_wl: import yolo_pipeline cfg["compute_natural_sort"] = False @@ -154,7 +154,7 @@ def _worker(rank, world, master_port): else: model, loader, crit, iou, optimizer = _build_ul(cfg, device, batch_size, num_workers) from yolo_pipeline import _decode_preds_to_6col as decode - _manual_store = [] # the "classic" sink: a plain in-memory list + _manual_store = [] # the "classic" sink: a plain in-memory list # identical initial weights on every rank (flattened broadcast) with torch.no_grad(): @@ -191,7 +191,7 @@ def _inf(ld): for step in range(_WARMUP + _STEPS): timed = step >= _WARMUP if step == _WARMUP: - io0 = _proc_io() # I/O counters at the start of the timed window + io0 = _proc_io() # I/O counters at the start of the timed window t0 = time.perf_counter() inputs = next(batches) if is_wl: @@ -282,21 +282,21 @@ def _inf(ld): # Each rank prints its OWN per-rank line (no gather collective — it was flaky on # gloo+CUDA; per-process I/O reads are independent anyway). io = io_d - print(f"[mode={MODE} rank {rank}] RSS={rss:7.0f}MB anchor={t.ms('anchor(WL)'):6.1f}ms " + print(f"[mode={MODE} rank {rank}] RSS={rss:7.0f}MB anchor={t.ms('anchor(WL)'):6.1f}ms " f"IO(MB): rchar={io.get('rchar',0)/1e6:7.1f} wchar={io.get('wchar',0)/1e6:7.1f} " f"read_dsk={io.get('read_bytes',0)/1e6:6.1f} write_dsk={io.get('write_bytes',0)/1e6:6.1f}", flush=True) if rank == 0: total = sum(t.ms(k) for k in order) print("\n" + "=" * 74) - print(f"ABLATION mode={MODE} device={device} world={world} batch={batch_size} steps={_STEPS}") + print(f"ABLATION mode={MODE} device={device} world={world} batch={batch_size} steps={_STEPS}") print("=" * 74) for k in order: - print(f" {k:18s} {t.ms(k):8.1f} ms/step") - print(f" {'STEP TOTAL':18s} {total:8.1f} ms/step") - print(f" {'grad on the wire':18s} {grad_bytes/1e6:8.1f} MB/step") + print(f" {k:18s} {t.ms(k):8.1f} ms/step") + print(f" {'STEP TOTAL':18s} {total:8.1f} ms/step") + print(f" {'grad on the wire':18s} {grad_bytes/1e6:8.1f} MB/step") if is_wl: - print(f" WL df RAM {df_mb:.1f} MB | WL H5 {h5_mb:.1f} MB disk | " + print(f" WL df RAM {df_mb:.1f} MB | WL H5 {h5_mb:.1f} MB disk | " f"H5 cfg: persist={cfg.get('ledger_enable_h5_persistence')} " f"max_rows={cfg.get('ledger_flush_max_rows')} " f"interval={cfg.get('ledger_flush_interval')}s " diff --git a/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py b/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py index fb8efb42..c9608081 100644 --- a/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py +++ b/weightslab/tests/integrations/ultralytics/ddp/ddp_test_suite.py @@ -7,8 +7,8 @@ measure time/processes, and later extract the reusable bits into the SDK. Layout: - * _train_worker(rank, world, ...) -- the distributed SERVER (spawned, rank 0 serves). - * Client + scenarios -- run in the PARENT process, talk gRPC to rank 0. + * _train_worker(rank, world, ...) -- the distributed SERVER (spawned, rank 0 serves). + * Client + scenarios -- run in the PARENT process, talk gRPC to rank 0. First scenario (scenario_epoch_then_pause): 1. spawn `world` ranks (rank 0 serves gRPC plaintext); parent = client. @@ -19,13 +19,13 @@ 5. wait ~50% of the epoch's wall time (training is paused) and assert the last_seen map is byte-identical => pause truly froze training. -Run: python ddp_test_suite.py (WL_DDP_WORLD_SIZE=2, imgsz 96, num_workers 0) +Run: python ddp_test_suite.py (WL_DDP_WORLD_SIZE=2, imgsz 96, num_workers 0) """ import os # --- test-mode env (must be set before importing weightslab) --------------- -os.environ["WEIGHTSLAB_SKIP_SECURE_INIT"] = "true" # plaintext gRPC for the client +os.environ["WEIGHTSLAB_SKIP_SECURE_INIT"] = "true" # plaintext gRPC for the client os.environ["GRPC_TLS_ENABLED"] = "0" -os.environ.setdefault("WL_DDP_IMGSZ", "96") # small images for speed +os.environ.setdefault("WL_DDP_IMGSZ", "96") # small images for speed os.environ.setdefault("WL_DDP_COLLECTIVE_LOG", "/tmp/wl_collective_log.txt") os.environ.setdefault("WL_PRELOAD_IMAGE_OVERVIEW", "0") os.environ.setdefault("WEIGHTSLAB_LOG_LEVEL", "WARNING") @@ -44,7 +44,7 @@ # usecase modules (yolo_pipeline, utils.*) and its config/data/ddp_run resolve. sys.path.insert(0, os.path.abspath(os.path.join( os.path.dirname(__file__), "../../../../examples/PyTorch/ws-detection/src"))) -import yolo_pipeline # reuse _build_pipeline / _decode_preds_to_6col / _HERE / _LOSS_PARTS +import yolo_pipeline # reuse _build_pipeline / _decode_preds_to_6col / _HERE / _LOSS_PARTS import weightslab.proto.experiment_service_pb2 as pb2 import weightslab.proto.experiment_service_pb2_grpc as pb2_grpc @@ -54,7 +54,7 @@ # =========================================================================== -# SERVER (spawned ranks; rank 0 serves gRPC) +# SERVER (spawned ranks; rank 0 serves gRPC) # =========================================================================== def _train_worker(rank, world, master_port, grpc_port): """Spawned per rank. Delegates to main_ddp.train_worker — the clean @@ -80,11 +80,11 @@ def _train_worker(rank, world, master_port, grpc_port): # =========================================================================== -# CLIENT (parent process) +# CLIENT (parent process) # =========================================================================== class Client: def __init__(self, port): - self._port = int(port) # exposed for topology-style scenarios + self._port = int(port) # exposed for topology-style scenarios self.channel = grpc.insecure_channel( f"{_HOST}:{port}", options=[("grpc.max_receive_message_length", 256 * 1024 * 1024)], @@ -298,7 +298,7 @@ def _wait_until_paused(client, n, min_step, timeout=600.0, poll=5.0): server is paused, not just mid-step). drop_last means the per-rank epoch is floor(shard/batch), so we don't require an exact step count.""" _t0 = time.time() - _last_change = _t0 # wall-time of the most recent last_seen-max change + _last_change = _t0 # wall-time of the most recent last_seen-max change deadline = time.time() + timeout prev = None stable = 0 @@ -310,9 +310,9 @@ def _wait_until_paused(client, n, min_step, timeout=600.0, poll=5.0): if cur >= min_step and stable >= 2: if _SCN_TIMING: tot = time.time() - _t0 - active = _last_change - _t0 # last_seen advancing = training/observed work - settle = time.time() - _last_change # stable-confirm + snapshot-lag = observability - print(f"[scn_timing] wait_until_paused total={tot:6.1f}s " + active = _last_change - _t0 # last_seen advancing = training/observed work + settle = time.time() - _last_change # stable-confirm + snapshot-lag = observability + print(f"[scn_timing] wait_until_paused total={tot:6.1f}s " f"active(train)={active:6.1f}s settle(obs)={settle:6.1f}s", flush=True) return cur prev = cur @@ -354,8 +354,8 @@ def scenario_epoch_then_pause(client, world, batch): # auto-pause (pause_at_step) fire at the epoch boundary WITHOUT crossing into # epoch 2 (which would force a sampler re-iteration mid-test). epoch_steps = (n // world) // batch - epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) # fast-debug override - print(f"[client] universe N={n} world={world} batch={batch} -> epoch_steps={epoch_steps}") + epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) # fast-debug override + print(f"[client] universe N={n} world={world} batch={batch} -> epoch_steps={epoch_steps}") t0 = time.time() client.train_steps(epoch_steps) @@ -366,7 +366,7 @@ def scenario_epoch_then_pause(client, world, batch): epoch_secs = time.time() - t0 print(f"[client] epoch done: max last_seen={reached} in {epoch_secs:.1f}s") - s1 = _settled_last_seen(client, n) # wait out the DataService snapshot throttle + s1 = _settled_last_seen(client, n) # wait out the DataService snapshot throttle populated = {k: v for k, v in s1.items() if v is not None and v >= 0} # With the children->rank0 gather (fired on pause), rank 0 sees what ALL ranks # trained: ~reached*batch*world distinct samples (capped at the universe N, @@ -391,8 +391,8 @@ def scenario_epoch_then_pause(client, world, batch): print(f"[client] FROZEN CHECK FAILED, {len(diff)} changed e.g. {list(diff.items())[:5]}") ok = a1 and a1b and a2 - print(f"[1] EPOCH COVERAGE populated>0={a1} populated~=shard={a1b} -> {'PASS' if (a1 and a1b) else 'FAIL'}") - print(f"[2] PAUSE FREEZES last_seen identical after wait={a2} -> {'PASS' if a2 else 'FAIL'}") + print(f"[1] EPOCH COVERAGE populated>0={a1} populated~=shard={a1b} -> {'PASS' if (a1 and a1b) else 'FAIL'}") + print(f"[2] PAUSE FREEZES last_seen identical after wait={a2} -> {'PASS' if a2 else 'FAIL'}") return ok @@ -443,11 +443,11 @@ def scenario_discard_subset_freezes(client, world, batch, n_discard=5): most_advanced = advanced >= int(0.8 * non_discarded_pop1) ok = a0 and all_frozen and most_advanced and (m2 > m1) - print(f"[1] DISCARD REGISTERED exactly {n_discard} added={a0}") - print(f"[2] SUBSET FROZEN all {n_discard} unchanged={all_frozen} values {L} -> {frozen}") - print(f"[3] MOST ADVANCED {advanced}/{non_discarded_pop1} non-discarded advanced " - f"(>=80%)={most_advanced} (epoch max {m1}->{m2})") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f"[1] DISCARD REGISTERED exactly {n_discard} added={a0}") + print(f"[2] SUBSET FROZEN all {n_discard} unchanged={all_frozen} values {L} -> {frozen}") + print(f"[3] MOST ADVANCED {advanced}/{non_discarded_pop1} non-discarded advanced " + f"(>=80%)={most_advanced} (epoch max {m1}->{m2})") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -458,7 +458,7 @@ def scenario_break_by_slice(client, world, batch): epoch_steps = (n // world) // batch epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) origin = client.train_origin() - graph = "train/bbxs" # a per-sample loss component logged by the criterions + graph = "train/bbxs" # a per-sample loss component logged by the criterions client.train_steps(epoch_steps) _wait_until_paused(client, n, min_step=max(1, epoch_steps - batch)) @@ -472,25 +472,25 @@ def scenario_break_by_slice(client, world, batch): # trained set. Measure it directly with a 'uni' slice over all trained samples # (this is the GLOBAL loss universe on rank 0 thanks to the per-sample gather), # then check the 'even' slice returns exactly the even members that have loss. - even = set(trained[::2]) # every other trained sample — spans both ranks' shards + even = set(trained[::2]) # every other trained sample — spans both ranks' shards client.tag(trained, "uni", origin) client.tag(even, "even", origin) print(f"[client] trained={len(trained)} tagged uni + even={len(even)} (origin={origin})") uni_sids = {p[0] for p in client.break_by_slice(graph, ["uni"])} even_sids = {p[0] for p in client.break_by_slice(graph, ["even"])} - expected_even = even & uni_sids # even-tagged samples that actually have loss + expected_even = even & uni_sids # even-tagged samples that actually have loss a1 = len(even_sids) > 0 - a2 = (even_sids == expected_even) # break-by-slice slices correctly + a2 = (even_sids == expected_even) # break-by-slice slices correctly ok = a1 and a2 - print(f"[1] BREAK-BY-SLICE even returned {len(even_sids)} samples (graph={graph})={a1}") - print(f"[2] SLICE CORRECT even == even-with-loss ({len(expected_even)})={a2}") + print(f"[1] BREAK-BY-SLICE even returned {len(even_sids)} samples (graph={graph})={a1}") + print(f"[2] SLICE CORRECT even == even-with-loss ({len(expected_even)})={a2}") # Cross-rank is evidenced by the server-side [siggather] log (rank 0 receives the # children's triples). It's not cleanly black-box-assertable without a per-rank # baseline, and becomes STRUCTURAL once writes go through sync_to_rank0 on rank 0. print(f"[i] loss universe on rank 0 = {len(uni_sids)} samples (spans both ranks via the gather)") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -501,7 +501,7 @@ def scenario_lr_batch_propagate(client, world, batch): epoch_steps = (n // world) // batch epoch_steps = int(os.environ.get("WL_DDP_TEST_STEPS", epoch_steps)) new_batch = batch * 2 - phase2 = 4 # short, so trained-count doesn't wrap the universe + phase2 = 4 # short, so trained-count doesn't wrap the universe # phase 1 at the original batch client.train_steps(epoch_steps) @@ -521,8 +521,8 @@ def scenario_lr_batch_propagate(client, world, batch): steps2 = a1 - a0 trained2 = sum(1 for v in s1.values() if v is not None and v > a0) rate = trained2 / steps2 if steps2 > 0 else 0.0 - expected = new_batch * world # both ranks switched - rank0_only = new_batch + batch # only rank 0 switched (the bug we fixed) + expected = new_batch * world # both ranks switched + rank0_only = new_batch + batch # only rank 0 switched (the bug we fixed) # Threshold: must be clearly ABOVE the rank0-only failure mode. We don't # require hitting the full `expected` because under drop_last=False the # DistributedSampler pads the per-rank shard with re-yields of samples @@ -531,10 +531,10 @@ def scenario_lr_batch_propagate(client, world, batch): # rank0_only + 1 cleanly distinguishes "both ranks doubled" (rate ≈ 13–16) # from "only rank-0 doubled" (rate ≈ 12). a1ok = steps2 > 0 and rate >= rank0_only + 1 - print(f"[1] BATCH PROPAGATED {trained2} samples / {steps2} steps = {rate:.1f}/step " + print(f"[1] BATCH PROPAGATED {trained2} samples / {steps2} steps = {rate:.1f}/step " f"(expect ~{expected} all-ranks vs ~{rank0_only} rank0-only)={a1ok}") print(f"[i] lr=0.05 rode the same hparam broadcast that carried batch (proven above)") - print(f" -> {'PASS' if a1ok else 'FAIL'}") + print(f" -> {'PASS' if a1ok else 'FAIL'}") return a1ok @@ -556,11 +556,11 @@ def scenario_checkpoint_data_roundtrip(client, world, batch): client.discard([A], origin) # 2) short resume -> save_pending_changes writes a FULL checkpoint (model+config+ - # data{A}) with non-null weights; then read its combined hash from the manifest. + # data{A}) with non-null weights; then read its combined hash from the manifest. client.train_steps(2) _wait_until_paused(client, n, min_step=a0 + 1) saved_hash = client.latest_full_checkpoint_hash() - time.sleep(12) # clear the DataService snapshot throttle before reading + time.sleep(12) # clear the DataService snapshot throttle before reading disc_save = client.discarded_set(n) print(f"[client] discarded A={A}; full-ckpt hash={saved_hash}; discarded@save={sorted(disc_save)}") @@ -580,13 +580,13 @@ def scenario_checkpoint_data_roundtrip(client, world, batch): f"msg={getattr(resp, 'message', '')[:70]}; discarded@post={sorted(disc_post)}") restore_ok = bool(getattr(resp, "success", False)) - a0c = (C in disc_change) # the divergent discard registered - a1 = (C not in disc_post) # restore undid it - a2 = (A in disc_post) # the saved discard survived the roundtrip + a0c = (C in disc_change) # the divergent discard registered + a1 = (C not in disc_post) # restore undid it + a2 = (A in disc_post) # the saved discard survived the roundtrip ok = restore_ok and a0c and a1 and a2 - print(f"[1] DATA ROUNDTRIP restore_ok={restore_ok} C-registered={a0c} " + print(f"[1] DATA ROUNDTRIP restore_ok={restore_ok} C-registered={a0c} " f"C-reverted={a1} A-intact={a2}") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -618,7 +618,7 @@ def scenario_signal_coverage_all_graphs(client, world, batch): if len(trained) < 10: print(f"[client] too few trained ({len(trained)})"); return False client.tag(trained, "uni", origin) - print(f"[client] trained={len(trained)} tagged 'uni' steps={epoch_steps}") + print(f"[client] trained={len(trained)} tagged 'uni' steps={epoch_steps}") graphs = ["train/bbxs", "train/clsf", "train/dfl", "miou/train"] per_sample_min = max(1, int(0.3 * len(trained))) @@ -631,10 +631,10 @@ def scenario_signal_coverage_all_graphs(client, world, batch): plot_ok = len(plot_points) >= plot_min ok = ps_ok and plot_ok all_ok &= ok - print(f"[1] {g:<18s} per-sample={len(per_sample_sids)}/{len(trained)} " - f"≥{per_sample_min}={ps_ok} plot={len(plot_points)} " - f"≥{plot_min}={plot_ok} both={ok}") - print(f" -> {'PASS' if all_ok else 'FAIL'}") + print(f"[1] {g:<18s} per-sample={len(per_sample_sids)}/{len(trained)} " + f"≥{per_sample_min}={ps_ok} plot={len(plot_points)} " + f"≥{plot_min}={plot_ok} both={ok}") + print(f" -> {'PASS' if all_ok else 'FAIL'}") return all_ok @@ -693,12 +693,12 @@ def scenario_resume_continues_curve(client, world, batch): _wait_until_paused(client, n, min_step=age_diverged + 1) post_train_plot = client.scalar_plot("train/bbxs") a3 = len(post_train_plot) > len(pre_restore_plot) - print(f"[3] PLOT GROWS pre={len(pre_restore_plot)} post={len(post_train_plot)} → {a3}") + print(f"[3] PLOT GROWS pre={len(pre_restore_plot)} post={len(post_train_plot)} → {a3}") ok = a1 and a2 and a3 - print(f"[1] RESTORE OK success={a1}") - print(f"[2] SERVER ALIVE universe={n_after}/{n} → {a2}") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f"[1] RESTORE OK success={a1}") + print(f"[2] SERVER ALIVE universe={n_after}/{n} → {a2}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -711,7 +711,7 @@ def scenario_process_topology(client, world, batch): import re import subprocess - _ = client.universe_size() # confirm we're connected; ranks are alive + _ = client.universe_size() # confirm we're connected; ranks are alive grpc_port = getattr(client, "_port", None) # Walk the descendant tree of the suite process to find spawned ranks. @@ -770,9 +770,9 @@ def listeners_of(pid): # Sanity: at least one PID does listen (otherwise the gRPC server is dead). a2 = len(listening) >= 1 ok = a1 and a2 - print(f"[1] gRPC OWNER PIDs owning port {grpc_port}: {grpc_owners} (==1) → {a1}") - print(f"[2] HAS LISTENERS {len(listening)} PID(s) with TCP sockets (≥1) → {a2}") - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f"[1] gRPC OWNER PIDs owning port {grpc_port}: {grpc_owners} (==1) → {a1}") + print(f"[2] HAS LISTENERS {len(listening)} PID(s) with TCP sockets (≥1) → {a2}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -807,17 +807,17 @@ def scenario_multi_epoch_stability(client, world, batch): # Per-graph: no duplicate (sid, age) entries dedup_ok = True for g in ["train/bbxs", "train/clsf", "train/dfl", "miou/train"]: - entries = client.break_by_slice(g, ["uni"]) # [(sid, age, val), ...] + entries = client.break_by_slice(g, ["uni"]) # [(sid, age, val), ...] keys = [(sid, age) for sid, age, _ in entries] unique, total = len(set(keys)), len(keys) ok = (unique == total) dedup_ok &= ok - print(f"[1] {g:<18s} {total} entries, {unique} unique (sid,age) → {ok}") + print(f"[1] {g:<18s} {total} entries, {unique} unique (sid,age) → {ok}") age_mono = ages[0] < ages[1] < ages[2] - print(f"[2] AGE MONOTONIC ages={ages} (strictly increasing) → {age_mono}") + print(f"[2] AGE MONOTONIC ages={ages} (strictly increasing) → {age_mono}") ok = dedup_ok and age_mono - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -825,17 +825,17 @@ def scenario_curate_lifecycle(client, world, batch): """End-to-end UI curation workflow under DDP — multiple composing edits and the loss trajectory tells the story: - epoch 1 (warm up: all populated samples accumulate train/bbxs entries) + epoch 1 (warm up: all populated samples accumulate train/bbxs entries) → tag 3 samples 'suspect' → discard those 3 - epoch 2 (the 3 suspects must produce NO new train/bbxs entries — + epoch 2 (the 3 suspects must produce NO new train/bbxs entries — their slot in the loss trajectory has a gap) → un-discard the 3 - → tag them additionally 'verified' (so each carries BOTH tags) - epoch 3 (the 3 resume; new entries appear beyond the discard age) + → tag them additionally 'verified' (so each carries BOTH tags) + epoch 3 (the 3 resume; new entries appear beyond the discard age) Assertions: - [1] LIFECYCLE — for each suspect: pre-discard entries exist AND + [1] LIFECYCLE — for each suspect: pre-discard entries exist AND no entries in the (discard_age, undiscard_age] window AND post-resume entries exist. The gap is the proof that discard reached the worker fast-path (the shm @@ -895,15 +895,15 @@ def scenario_curate_lifecycle(client, world, batch): ages_by_sid.setdefault(sid, []).append(age) # Per-suspect trajectory: - # pre — every suspect must have ≥1 entry before discard (proves we're - # tracking a sample that was actually trained on); - # gap — NO suspect may have an entry in (discard, undiscard] (proves the - # discard reached the sampler/worker fast-path); - # post — AT LEAST ONE suspect must have a post-undiscard entry (proves - # un-discard reaches the sampler). The shuffled sampler in a - # short 20-step epoch won't yield every sample, so requiring ALL - # suspects to resume would be a shuffle-luck check, not a - # correctness check. + # pre — every suspect must have ≥1 entry before discard (proves we're + # tracking a sample that was actually trained on); + # gap — NO suspect may have an entry in (discard, undiscard] (proves the + # discard reached the sampler/worker fast-path); + # post — AT LEAST ONE suspect must have a post-undiscard entry (proves + # un-discard reaches the sampler). The shuffled sampler in a + # short 20-step epoch won't yield every sample, so requiring ALL + # suspects to resume would be a shuffle-luck check, not a + # correctness check. pre_ok, gap_ok = True, True any_post = False for sid in suspects: @@ -912,26 +912,26 @@ def scenario_curate_lifecycle(client, world, batch): gap = [a for a in ages if age_at_discard < a <= age_at_undiscard] post = [a for a in ages if a > age_at_undiscard] if not pre: pre_ok = False - if gap: gap_ok = False - if post: any_post = True - print(f" sid={sid}: pre={pre[-3:]} gap={gap} post={post[:3]}") + if gap: gap_ok = False + if post: any_post = True + print(f" sid={sid}: pre={pre[-3:]} gap={gap} post={post[:3]}") post_ok = any_post verified_sids = {p[0] for p in client.break_by_slice("train/bbxs", ["verified"])} tag_compose = set(suspects).issubset(verified_sids) plot = client.scalar_plot("train/bbxs") - plot_ok = len(plot) >= 3 * 1 # at least one point per epoch (loose) + plot_ok = len(plot) >= 3 * 1 # at least one point per epoch (loose) a1ok = pre_ok and gap_ok and post_ok a2ok = tag_compose a3ok = plot_ok - print(f"[1] LIFECYCLE pre={pre_ok} gap-empty={gap_ok} any-post={post_ok} → {a1ok}") - print(f"[2] TAG COMPOSE verified⊇suspects ({len(verified_sids)} verified, " + print(f"[1] LIFECYCLE pre={pre_ok} gap-empty={gap_ok} any-post={post_ok} → {a1ok}") + print(f"[2] TAG COMPOSE verified⊇suspects ({len(verified_sids)} verified, " f"{len(set(suspects) & verified_sids)}/3 suspects tagged) → {a2ok}") - print(f"[3] PLOT METRICS scalar_plot has {len(plot)} entries → {a3ok}") + print(f"[3] PLOT METRICS scalar_plot has {len(plot)} entries → {a3ok}") ok = a1ok and a2ok and a3ok - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -948,7 +948,7 @@ def scenario_collective_budget(client, world, batch): log_path = os.environ.get("WL_DDP_COLLECTIVE_LOG") if not log_path: print(f"[client] WL_DDP_COLLECTIVE_LOG not set; skipping"); return False - open(log_path, "w").close() # truncate before this scenario's window + open(log_path, "w").close() # truncate before this scenario's window n = client.universe_size() epoch_steps = (n // world) // batch @@ -966,19 +966,19 @@ def scenario_collective_budget(client, world, batch): # First few entries can include pause-spin reconciles (many per "step" while # the trainer is waiting for the resume signal). Take a slice from the tail # corresponding to clearly-in-the-body steps. - body = [c for c in counts if c <= 5] # drop the spin-inflated outliers + body = [c for c in counts if c <= 5] # drop the spin-inflated outliers spin = [c for c in counts if c > 5] avg_body = (sum(body) / len(body)) if body else float("inf") max_body = max(body) if body else 0 a1 = max_body <= 2 a2 = avg_body <= 2.0 - print(f"[1] BUDGET PER STEP body samples={len(body)}, max={max_body}, " + print(f"[1] BUDGET PER STEP body samples={len(body)}, max={max_body}, " f"avg={avg_body:.2f}, spin samples={len(spin)} (excluded) " f"max-over-budget→{a1}") - print(f"[2] AVG ≤ 2 {avg_body:.2f} → {a2}") + print(f"[2] AVG ≤ 2 {avg_body:.2f} → {a2}") ok = a1 and a2 - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -1013,13 +1013,13 @@ def scenario_seed_determinism(client, world, batch): a1 = len(pull1) > 0 and len(pull1) == len(pull2) a2 = all(p1 == p2 for p1, p2 in zip(pull1, pull2)) - print(f"[1] STABLE LEN pull1={len(pull1)} == pull2={len(pull2)} → {a1}") - print(f"[2] BIT-IDENTICAL every (sid, age, val) matches → {a2}") + print(f"[1] STABLE LEN pull1={len(pull1)} == pull2={len(pull2)} → {a1}") + print(f"[2] BIT-IDENTICAL every (sid, age, val) matches → {a2}") # Spot-check first 3 entries for i in range(min(3, len(pull1))): - print(f" p1[{i}]={pull1[i]} p2[{i}]={pull2[i]}") + print(f" p1[{i}]={pull1[i]} p2[{i}]={pull2[i]}") ok = a1 and a2 - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -1060,7 +1060,7 @@ def scenario_empty_shard_starvation(client, world, batch): else: to_discard = populated[:-keep] client.discard(to_discard, origin) - print(f"[client] discarded {len(to_discard)} (keep={keep}, force={bool(_force)})") + print(f"[client] discarded {len(to_discard)} (keep={keep}, force={bool(_force)})") # Short post-discard train. Bounded timeout: if it hangs, assertion fires. K = min(epoch_steps, 8) @@ -1072,13 +1072,13 @@ def scenario_empty_shard_starvation(client, world, batch): timeout=180.0, poll=3.0) except TimeoutError: elapsed = time.time() - t0 - print(f"[client] HUNG no model_age advance in {elapsed:.0f}s (a0={a0})") - print(f" -> FAIL") + print(f"[client] HUNG no model_age advance in {elapsed:.0f}s (a0={a0})") + print(f" -> FAIL") return False elapsed = time.time() - t0 advanced = a1 > a0 - print(f"[1] NO HANG age advanced {a0}→{a1} in {elapsed:.1f}s → {advanced}") - print(f" -> {'PASS' if advanced else 'FAIL'}") + print(f"[1] NO HANG age advanced {a0}→{a1} in {elapsed:.1f}s → {advanced}") + print(f" -> {'PASS' if advanced else 'FAIL'}") return advanced @@ -1098,7 +1098,7 @@ def scenario_progressive_resample(client, world, batch): def _epoch_steps(live): return max(1, (live // world) // batch) - def _ls(d, sid): # None-safe last_seen ( -1 == never seen ) + def _ls(d, sid): # None-safe last_seen ( -1 == never seen ) v = d.get(sid) return v if v is not None else -1 @@ -1114,8 +1114,8 @@ def _run_epochs(n_ep, steps_each, label, m_start): timeout=180.0, poll=3.0) dt = time.perf_counter() - t0 total = n_ep * steps_each - print(f"[time] {label:24s} {n_ep}ep x {steps_each:>3}st = {total:>4} steps " - f"{dt:6.1f}s ({dt/max(1,total):.2f}s/step)") + print(f"[time] {label:24s} {n_ep}ep x {steps_each:>3}st = {total:>4} steps " + f"{dt:6.1f}s ({dt/max(1,total):.2f}s/step)") return m, dt # --- warm-up: 1 full epoch (100% live) --- @@ -1124,7 +1124,7 @@ def _run_epochs(n_ep, steps_each, label, m_start): m0 = _wait_until_paused(client, n, min_step=max(1, full_epoch_steps - batch)) t_warm = time.perf_counter() - t0 print(f"[time] {'warmup (1 x 100%)':24s} 1ep x {full_epoch_steps:>3}st = {full_epoch_steps:>4} " - f"steps {t_warm:6.1f}s ({t_warm/full_epoch_steps:.2f}s/step)") + f"steps {t_warm:6.1f}s ({t_warm/full_epoch_steps:.2f}s/step)") s0 = _settled_last_seen(client, n) all_ids = sorted(s0.keys(), key=lambda k: int(k)) if sum(1 for k in all_ids if _ls(s0, k) >= 0) < 40: @@ -1146,7 +1146,7 @@ def _run_epochs(n_ep, steps_each, label, m_start): disc_frozen = sum(1 for sid in discard_ids if _ls(s1, sid) == _ls(s0, sid)) a1 = (kept_adv >= int(0.8 * len(keep)) and disc_frozen >= int(0.95 * len(discard_ids))) - print(f"[1] DISCARD SHIFT kept advanced {kept_adv}/{len(keep)} (>=80%), " + print(f"[1] DISCARD SHIFT kept advanced {kept_adv}/{len(keep)} (>=80%), " f"discarded frozen {disc_frozen}/{len(discard_ids)} (>=95%) -> {a1}") # --- grow: un-discard up to ~50% live --- @@ -1165,14 +1165,14 @@ def _run_epochs(n_ep, steps_each, label, m_start): still_frozen = sum(1 for sid in still_disc if _ls(s2, sid) == _ls(s1, sid)) a2 = ((not re_add or readd_adv >= int(0.8 * len(re_add))) and (not still_disc or still_frozen >= int(0.95 * len(still_disc)))) - print(f"[2] GROWTH HANDLED re-added advanced {readd_adv}/{len(re_add)} (>=80%), " + print(f"[2] GROWTH HANDLED re-added advanced {readd_adv}/{len(re_add)} (>=80%), " f"still-discarded frozen {still_frozen}/{len(still_disc)} (>=95%) -> {a2}") - print(f"[time] SUMMARY warmup={t_warm:.0f}s post-discard(2x10%)={t_lo:.0f}s " - f"post-readd(2x50%)={t_hi:.0f}s (warmup/post-discard ~= " + print(f"[time] SUMMARY warmup={t_warm:.0f}s post-discard(2x10%)={t_lo:.0f}s " + f"post-readd(2x50%)={t_hi:.0f}s (warmup/post-discard ~= " f"{t_warm/max(0.1,t_lo):.1f}x, expect ~5x if per-step cost is flat)") ok = a1 and a2 - print(f" -> {'PASS' if ok else 'FAIL'}") + print(f" -> {'PASS' if ok else 'FAIL'}") return ok @@ -1203,7 +1203,7 @@ def _free_port(): def _run_one(scn, batch): """Spawn a FRESH server (isolation), run one scenario, tear the server down.""" master_port, grpc_port = _free_port(), _free_port() - print(f"\n[suite] === {scn.__name__} === spawning {_WORLD} ranks, gRPC :{grpc_port}, " + print(f"\n[suite] === {scn.__name__} === spawning {_WORLD} ranks, gRPC :{grpc_port}, " f"imgsz={os.environ['WL_DDP_IMGSZ']}") ctx = mp.spawn(_train_worker, args=(_WORLD, master_port, grpc_port), nprocs=_WORLD, join=False) client = Client(grpc_port) @@ -1235,7 +1235,7 @@ def main(): _cfg_batch = yaml.safe_load(open(os.path.join(yolo_pipeline._HERE, "config.yaml")) )["data"]["train_loader"]["batch_size"] batch = int(os.environ.get("WL_DDP_BATCH", _cfg_batch)) - only = os.environ.get("WL_DDP_ONLY") # substring filter to run a single scenario + only = os.environ.get("WL_DDP_ONLY") # substring filter to run a single scenario # WL_DDP_SKIP: comma-separated substrings to EXCLUDE — lets a killed run resume # by skipping the scenarios that already passed (the suite has no checkpoint). skip = [s.strip() for s in os.environ.get("WL_DDP_SKIP", "").split(",") if s.strip()] @@ -1246,9 +1246,9 @@ def main(): print("\n" + "=" * 64) for name, ok in results.items(): - print(f" {name:42s} -> {'PASS' if ok else 'FAIL'}") + print(f" {name:42s} -> {'PASS' if ok else 'FAIL'}") allok = bool(results) and all(results.values()) - print(f" RESULT: {'ALL PASS' if allok else 'FAILURES ABOVE'}") + print(f" RESULT: {'ALL PASS' if allok else 'FAILURES ABOVE'}") print("=" * 64) raise SystemExit(0 if allok else 1) diff --git a/weightslab/tests/model/test_constraint_generation.py b/weightslab/tests/model/test_constraint_generation.py index 0b627a22..ddd00a94 100644 --- a/weightslab/tests/model/test_constraint_generation.py +++ b/weightslab/tests/model/test_constraint_generation.py @@ -42,7 +42,7 @@ def __init__(self): self.grouped_conv = nn.Conv2d(8, 16, kernel_size=3, padding=1, groups=2) self.bn = nn.BatchNorm2d(16) self.relu = nn.ReLU() - self.regular_conv = nn.Conv2d(16, 32, kernel_size=3, padding=1) # No groups + self.regular_conv = nn.Conv2d(16, 32, kernel_size=3, padding=1) # No groups def forward(self, x): x = self.grouped_conv(x) @@ -62,8 +62,8 @@ class DepthwisePointwiseModel(nn.Module): """ def __init__(self): super().__init__() - self.dw = nn.Conv2d(16, 16, kernel_size=3, padding=1, groups=16) # Depthwise - self.pw = nn.Conv2d(16, 32, kernel_size=1) # Pointwise + self.dw = nn.Conv2d(16, 16, kernel_size=3, padding=1, groups=16) # Depthwise + self.pw = nn.Conv2d(16, 32, kernel_size=1) # Pointwise def forward(self, x): x = self.dw(x) @@ -227,7 +227,7 @@ def test_constraint_no_hardcoding(self): """Constraints are detected via introspection, not hardcoding on names""" # Create a custom conv with groups but no special name conv_with_groups = nn.Conv2d(4, 8, 3, padding=1, groups=2) - conv_with_groups.__class__.__name__ = "CustomConv" # Change class name + conv_with_groups.__class__.__name__ = "CustomConv" # Change class name constraints = _detect_layer_constraints(conv_with_groups) diff --git a/weightslab/tests/model/test_dependency_patterns.py b/weightslab/tests/model/test_dependency_patterns.py index f1955282..7d62cd31 100644 --- a/weightslab/tests/model/test_dependency_patterns.py +++ b/weightslab/tests/model/test_dependency_patterns.py @@ -114,7 +114,7 @@ def forward(self, x): out = self.relu(out) out = self.conv2(out) out = self.bn2(out) - out = out + out1 # Residual connection (REC dependency) + out = out + out1 # Residual connection (REC dependency) return out @@ -142,7 +142,7 @@ def forward(self, x): branch2 = self.conv2(x) branch2 = self.relu2(branch2) - merged = torch.cat([branch1, branch2], dim=1) # REC: both branches constrained + merged = torch.cat([branch1, branch2], dim=1) # REC: both branches constrained out = self.conv_merged(merged) return out @@ -341,10 +341,10 @@ def forward(self, x): out_id = out out = self.conv2(out) out = self.bn2(out) - out = out + out_id # inner residual + out = out + out_id # inner residual out = self.conv3(out) out = self.bn3(out) - out = out + out1 # outer residual + out = out + out1 # outer residual return out @@ -965,7 +965,7 @@ def setUp(self): self.model = MinimalConv1DChain() self.model.eval() self._model = self.model - self.dummy_input = torch.randn(1, 4, 64) # N, C, L + self.dummy_input = torch.randn(1, 4, 64) # N, C, L def test_conv1d_onnx(self): self.model = self.get_dependencies_onnx(self.model, self.dummy_input) @@ -998,7 +998,7 @@ def setUp(self): self.model = MinimalConv3DChain() self.model.eval() self._model = self.model - self.dummy_input = torch.randn(1, 2, 8, 16, 16) # N, C, D, H, W + self.dummy_input = torch.randn(1, 2, 8, 16, 16) # N, C, D, H, W def test_conv3d_onnx(self): self.model = self.get_dependencies_onnx(self.model, self.dummy_input) diff --git a/weightslab/tests/model/test_model_with_ops.py b/weightslab/tests/model/test_model_with_ops.py index 84a9a2a9..f381c608 100644 --- a/weightslab/tests/model/test_model_with_ops.py +++ b/weightslab/tests/model/test_model_with_ops.py @@ -21,7 +21,7 @@ # Set Global Default Settings -th.manual_seed(42) # Set SEED +th.manual_seed(42) # Set SEED TMP_DIR = '/tmp/utests/'; os.makedirs('/tmp/utests/', exist_ok=True) diff --git a/weightslab/tests/model/test_tracking.py b/weightslab/tests/model/test_tracking.py index 10ca164d..411730f5 100644 --- a/weightslab/tests/model/test_tracking.py +++ b/weightslab/tests/model/test_tracking.py @@ -17,7 +17,7 @@ # Set Global Default Settings DEVICE = 'cpu' if not th.cuda.is_available() else 'cuda' -th.manual_seed(42) # Set SEED +th.manual_seed(42) # Set SEED @unittest.skip("Constraint detection and propagation tests are currently skipped due to ongoing refactor and potential changes in the underlying implementation. Will be re-enabled once the new system is in place more modeling.") diff --git a/weightslab/tests/modules/test_modules_with_ops.py b/weightslab/tests/modules/test_modules_with_ops.py index c1ce3067..8ceab6a5 100644 --- a/weightslab/tests/modules/test_modules_with_ops.py +++ b/weightslab/tests/modules/test_modules_with_ops.py @@ -13,7 +13,7 @@ # Set Global Default Settings -th.manual_seed(42) # Set SEED +th.manual_seed(42) # Set SEED class LayerWiseOperationsTest(unittest.TestCase): @@ -85,7 +85,7 @@ def _test_operation_core( self._create_layers(device=device) layer_instance = self.all_layers.get(layer_key) - layer_instance.to(device) # Update tracker device + layer_instance.to(device) # Update tracker device if layer_instance == None: self.fail(f"Layer key '{layer_key}' not found in setup.") @@ -134,7 +134,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to increase parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # ADD must strictly increase the count + ) # ADD must strictly increase the count self.assertEqual( layer_instance.get_neurons(attr_name='out_neurons'), initial_nb_out_neurons + len(neuron_indices), @@ -142,7 +142,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # ADD 2 neurons must increase the count by 2 + ) # ADD 2 neurons must increase the count by 2 # --- Incoming --- if len(layer_instance.weight.shape) > 1: @@ -166,7 +166,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to increase parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # ADD must strictly increase the count + ) # ADD must strictly increase the count self.assertEqual( layer_instance.get_neurons(attr_name='in_neurons'), initial_nb_in_neurons + len(neuron_indices), @@ -174,7 +174,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # ADD 2 neurons must increase the count by 2 + ) # ADD 2 neurons must increase the count by 2 elif op == ArchitectureNeuronsOpType.PRUNE: # --- Not Incoming --- @@ -198,7 +198,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # PRUNE must strictly decrease the count + ) # PRUNE must strictly decrease the count self.assertEqual( layer_instance.get_neurons(attr_name='out_neurons'), initial_nb_out_neurons - len(neuron_indices), @@ -206,7 +206,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # PRUNE 2 neurons must decrease the count by 2 + ) # PRUNE 2 neurons must decrease the count by 2 # --- Incoming --- if len(layer_instance.weight.shape) > 1: @@ -230,7 +230,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # PRUNE must strictly decrease the count + ) # PRUNE must strictly decrease the count self.assertEqual( layer_instance.get_neurons(attr_name='in_neurons'), initial_nb_in_neurons - len(neuron_indices), @@ -238,7 +238,7 @@ def _test_operation_core( "by 2." + f"Init:{initial_nb_out_neurons}," + f"Final:{layer_instance.get_neurons(attr_name='out_neurons')}" - ) # PRUNE 2 neurons must decrease the count by 2 + ) # PRUNE 2 neurons must decrease the count by 2 elif op == ArchitectureNeuronsOpType.FREEZE: # --- Not Incoming --- @@ -261,7 +261,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE must strictly decrease the count + ) # FREEZE must strictly decrease the count # for tensor_name in layer_instance.learnable_tensors_name: # reverse neuron index @@ -300,7 +300,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to unfreeze parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE must match initial count + ) # UNFREEZE must match initial count # # FREEZE & UNFREEZE every neurons # # FREEZE @@ -319,7 +319,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to freeze every params." + f"Init:{layer_instance.get_neurons(attr_name='in_neurons')}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE every out neurons + ) # FREEZE every out neurons # # # UNFREEZE layer_instance.operate( @@ -337,7 +337,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to unfreeze every params." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE every out neurons + ) # UNFREEZE every out neurons # --- Incoming --- if len(layer_instance.weight.shape) > 1: @@ -360,7 +360,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to decrease parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE must strictly decrease the count + ) # FREEZE must strictly decrease the count # for tensor_name in layer_instance.learnable_tensors_name: # reverse neuron index @@ -400,7 +400,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to unfreeze parameters." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE must match initial count + ) # UNFREEZE must match initial count # # FREEZE & UNFREEZE every neurons # # FREEZE @@ -419,7 +419,7 @@ def _test_operation_core( f"[{layer_key}/{op_name}] failed to freeze every params." + f"Init:{layer_instance.get_neurons(attr_name='in_neurons')}," + f"Final:{final_nb_trainable_parameters}" - ) # FREEZE every out neurons + ) # FREEZE every out neurons # # # UNFREEZE layer_instance.operate( @@ -438,7 +438,7 @@ def _test_operation_core( "params." + f"Init:{initial_nb_trainable_parameters}," + f"Final:{final_nb_trainable_parameters}" - ) # UNFREEZE every out neurons + ) # UNFREEZE every out neurons elif op == ArchitectureNeuronsOpType.RESET: # RESET must preserve the number of parameters diff --git a/weightslab/tests/test_secure_docker.py b/weightslab/tests/test_secure_docker.py index cee59f9b..d7b9bd30 100644 --- a/weightslab/tests/test_secure_docker.py +++ b/weightslab/tests/test_secure_docker.py @@ -153,7 +153,7 @@ def test_backend_connection_timeout(self): """Test backend connection with timeout.""" result = _test_backend_connection( host='127.0.0.1', - port=59999, # Likely not listening + port=59999, # Likely not listening timeout=0.5 ) assert result is False diff --git a/weightslab/tests/trainer/services/test_agent_prompt_unit.py b/weightslab/tests/trainer/services/test_agent_prompt_unit.py index 96179db8..52941897 100644 --- a/weightslab/tests/trainer/services/test_agent_prompt_unit.py +++ b/weightslab/tests/trainer/services/test_agent_prompt_unit.py @@ -455,7 +455,7 @@ def test_tag_outliers_by_stddev(self): ctx = SimpleNamespace( _all_datasets_df=agent_mod.pd.DataFrame( { - "signals//train_loss": [0.1, 0.15, 0.12, 0.14, 1.5], # Last one is outlier + "signals//train_loss": [0.1, 0.15, 0.12, 0.14, 1.5], # Last one is outlier }, index=agent_mod.pd.MultiIndex.from_tuples( [(f"train", i) for i in range(5)], diff --git a/weightslab/tests/trainer/services/test_trainer_services_server.py b/weightslab/tests/trainer/services/test_trainer_services_server.py index 288ba402..61e38134 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_server.py +++ b/weightslab/tests/trainer/services/test_trainer_services_server.py @@ -4,7 +4,7 @@ import weightslab.trainer.trainer_services as trainer_services -# Default per-test timeout in seconds. Override with WL_TEST_TIMEOUT env var. +# Default per-test timeout in seconds. Override with WL_TEST_TIMEOUT env var. import os _TEST_TIMEOUT = int(os.getenv("WL_TEST_TIMEOUT", "30")) diff --git a/weightslab/tests/trainer/services/test_trainer_services_unit.py b/weightslab/tests/trainer/services/test_trainer_services_unit.py index d1454521..1e6b0a6f 100644 --- a/weightslab/tests/trainer/services/test_trainer_services_unit.py +++ b/weightslab/tests/trainer/services/test_trainer_services_unit.py @@ -77,16 +77,23 @@ def test_get_latest_logger_data_queue_mode(self): ) def test_get_latest_logger_data_break_by_slices(self): signal_logger = MagicMock() - # break-by-slices reads compact (sample_id, step, value, hash) tuples via - # query_per_sample (filtered by the tag-derived sample_ids), then aggregates - # the matching samples into a single MEAN curve per experiment_hash. + # break-by-slices aggregates the tag-derived sample_ids into a single MEAN + # curve per experiment_hash via aggregate_per_sample_by_step (DuckDB does the + # GROUP BY step / AVG natively). _pts = [("11", 3, 0.3, "exp"), ("12", 3, 0.6, "exp")] - def _qps(graph_name, sample_ids=None, exp_hash=None): + def _agg(graph_name, sample_ids=None, exp_hash=None): wanted = {str(s) for s in sample_ids} if sample_ids is not None else None - return [t for t in _pts if wanted is None or str(t[0]) in wanted] + rows = [t for t in _pts if wanted is None or str(t[0]) in wanted] + by_hash: dict = {} + for sid, step, val, h in rows: + by_hash.setdefault(h, {}).setdefault(step, []).append(val) + return { + h: sorted((s, sum(v) / len(v)) for s, v in steps.items()) + for h, steps in by_hash.items() + } - signal_logger.query_per_sample.side_effect = _qps + signal_logger.aggregate_per_sample_by_step.side_effect = _agg signal_logger.get_evaluation_marker_hashes.return_value = [] df_manager = MagicMock() df_manager.get_df_view.return_value = pd.DataFrame( @@ -108,7 +115,7 @@ def _qps(graph_name, sample_ids=None, exp_hash=None): # Only sample 11 is 'hard'-tagged → mean curve over {11} = one aggregated point. self.assertEqual(len(response.points), 1) - self.assertEqual(response.points[0].sample_id, "") # aggregated, not a single sample + self.assertEqual(response.points[0].sample_id, "") # aggregated, not a single sample self.assertEqual(response.points[0].metric_name, "test/loss") self.assertAlmostEqual(response.points[0].metric_value, 0.3, places=5) @@ -148,6 +155,107 @@ def test_restore_checkpoint_weights_step_mode(self): _, kwargs = checkpoint_manager.load_state.call_args self.assertEqual(kwargs.get("target_step"), 5) + def _make_save_service(self, components): + ctx = _DummyCtx(components=components) + with patch("weightslab.trainer.services.experiment_service.DataService"): + return ExperimentService(ctx) + + def test_save_checkpoint_with_optimizer_and_architecture(self): + """Manual save with both optimizer + architecture: pauses, then dumps all three.""" + trainer = MagicMock() + model = MagicMock() + checkpoint_manager = MagicMock() + checkpoint_manager.save_model_checkpoint.return_value = "/tmp/exp/weights_step_000010.pt" + checkpoint_manager.save_model_architecture.return_value = "/tmp/exp/arch.pkl" + hp = {"is_training": True} + + service = self._make_save_service({ + "trainer": trainer, + "model": model, + "checkpoint_manager": checkpoint_manager, + "hyperparams": hp, + }) + + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation( + save_architecture=True, save_optimizer=True + ) + ) + response = service.ExperimentCommand(request, None) + + self.assertTrue(response.success) + # Training is paused BEFORE the dump, and is_training is cleared. + trainer.pause.assert_called_once() + self.assertFalse(hp["is_training"]) + # Weights dumped with optimizer; architecture dumped too. + checkpoint_manager.save_model_checkpoint.assert_called_once() + _, kwargs = checkpoint_manager.save_model_checkpoint.call_args + self.assertTrue(kwargs.get("save_optimizer")) + self.assertIs(kwargs.get("model"), model) + checkpoint_manager.save_model_architecture.assert_called_once_with(model) + self.assertIn("optimizer", response.message) + self.assertIn("architecture", response.message) + + def test_save_checkpoint_weights_only(self): + """Manual save without optimizer/architecture: only weights are dumped.""" + trainer = MagicMock() + model = MagicMock() + checkpoint_manager = MagicMock() + checkpoint_manager.save_model_checkpoint.return_value = "/tmp/exp/weights_step_000010.pt" + hp = {"is_training": True} + + service = self._make_save_service({ + "trainer": trainer, + "model": model, + "checkpoint_manager": checkpoint_manager, + "hyperparams": hp, + }) + + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation( + save_architecture=False, save_optimizer=False + ) + ) + response = service.ExperimentCommand(request, None) + + self.assertTrue(response.success) + trainer.pause.assert_called_once() + checkpoint_manager.save_model_checkpoint.assert_called_once() + _, kwargs = checkpoint_manager.save_model_checkpoint.call_args + self.assertFalse(kwargs.get("save_optimizer")) + # No architecture dump requested. + checkpoint_manager.save_model_architecture.assert_not_called() + self.assertNotIn("optimizer", response.message) + self.assertNotIn("architecture", response.message) + + def test_save_checkpoint_no_model_registered(self): + """No registered model: fails clearly and does NOT pause or dump.""" + trainer = MagicMock() + checkpoint_manager = MagicMock() + + service = self._make_save_service({ + "trainer": trainer, + "model": None, + "checkpoint_manager": checkpoint_manager, + "hyperparams": {}, + }) + + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation( + save_architecture=True, save_optimizer=True + ) + ) + # The ledger fallback must also report no model. + with patch("weightslab.trainer.services.experiment_service.ledgers.get_model", return_value=None): + response = service.ExperimentCommand(request, None) + + self.assertFalse(response.success) + self.assertIn("No model registered", response.message) + # Nothing was dumped, and a running experiment is left untouched. + trainer.pause.assert_not_called() + checkpoint_manager.save_model_checkpoint.assert_not_called() + checkpoint_manager.save_model_architecture.assert_not_called() + class TestModelServiceUnit(unittest.TestCase): def test_get_weights_success(self): @@ -365,15 +473,10 @@ def test_metadata_only_response_uses_dataframe_columns(self): "quality": [0.1, 0.9], } ) - request = type( - "Req", - (), - { - "stats_to_retrieve": ["quality"], - }, - )() - response = service._build_metadata_only_response(df_slice, request) + # _build_metadata_only_response now takes an explicit requested_cols list + # (it is the building block reused by the GetMetaData RPC). + response = service._build_metadata_only_response(df_slice, ["quality"]) self.assertTrue(response.success) self.assertEqual(len(response.data_records), 2) @@ -408,5 +511,108 @@ def test_manual_save_data_state_force_enables_h5_and_flushes(self): self.assertTrue(checkpoint_manager.save_data_snapshot.called) +class TestExploreModeGuards(unittest.TestCase): + """Read-only explore mode: training/HP/weight mutations are refused; data + management and reads stay allowed.""" + + def setUp(self): + from weightslab.backend import explore_mode + self._explore_mode = explore_mode + explore_mode.set_explore_mode(True) + self.addCleanup(lambda: explore_mode.set_explore_mode(False)) + + def _experiment_service(self, components=None): + ctx = _DummyCtx(components=components or {}) + with patch("weightslab.trainer.services.experiment_service.DataService"): + return ExperimentService(ctx) + + def test_explore_mode_off_by_default(self): + # Sanity: the cleanup of other tests must restore the disabled state. + self._explore_mode.set_explore_mode(False) + self.assertFalse(self._explore_mode.is_explore_mode()) + self._explore_mode.set_explore_mode(True) + self.assertTrue(self._explore_mode.is_explore_mode()) + + def test_blocks_hyperparameter_change(self): + service = self._experiment_service() + request = pb2.TrainerCommand( + hyper_parameter_change=pb2.HyperParameterCommand( + hyper_parameters=pb2.HyperParameters(is_training=True) + ) + ) + resp = service.ExperimentCommand(request, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_save_checkpoint(self): + service = self._experiment_service() + request = pb2.TrainerCommand( + save_checkpoint_operation=pb2.SaveCheckpointOperation(save_optimizer=True) + ) + resp = service.ExperimentCommand(request, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_load_checkpoint(self): + service = self._experiment_service() + request = pb2.TrainerCommand( + load_checkpoint_operation=pb2.LoadCheckpointOperation(checkpoint_id=3) + ) + resp = service.ExperimentCommand(request, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_restore_checkpoint(self): + service = self._experiment_service() + resp = service.RestoreCheckpoint( + pb2.RestoreCheckpointRequest(experiment_hash="abc"), None + ) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_trigger_evaluation(self): + service = self._experiment_service() + resp = service.TriggerEvaluation( + pb2.TriggerEvaluationRequest(split_name="test"), None + ) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_blocks_manipulate_weights(self): + ctx = _DummyCtx(components={"model": MagicMock()}) + service = ModelService(ctx) + req = pb2.WeightsOperationRequest() + req.weight_operation.op_type = pb2.WeightOperationType.FREEZE + resp = service.ManipulateWeights(req, None) + self.assertFalse(resp.success) + self.assertIn("explore mode", resp.message) + + def test_allows_plot_note_data_annotation(self): + # Data/annotation writes (plot notes) are NOT blocked in explore mode. + signal_logger = MagicMock() + signal_logger.set_point_note.return_value = True + service = self._experiment_service(components={"signal_logger": signal_logger}) + + request = pb2.TrainerCommand( + plot_note_operation=pb2.PlotNoteOperation( + metric_name="train/loss", + experiment_hash="abcdef", + model_age=5, + note="checkpoint of interest", + ) + ) + resp = service.ExperimentCommand(request, None) + self.assertTrue(resp.success) + signal_logger.set_point_note.assert_called_once() + + def test_allows_reads(self): + # Read requests are unaffected by explore mode. + service = self._experiment_service() + resp = service.ExperimentCommand( + pb2.TrainerCommand(get_hyper_parameters=True), None + ) + self.assertTrue(resp.success) + + if __name__ == "__main__": unittest.main() diff --git a/weightslab/tests/watchdog/test_lock_monitor.py b/weightslab/tests/watchdog/test_lock_monitor.py index 29a07bec..279cfcc5 100644 --- a/weightslab/tests/watchdog/test_lock_monitor.py +++ b/weightslab/tests/watchdog/test_lock_monitor.py @@ -90,7 +90,7 @@ class TestMonitoredRLockReentrant(unittest.TestCase): def test_same_thread_can_reacquire(self): lock = MonitoredRLock() lock.acquire() - lock.acquire() # reentrant — must not deadlock + lock.acquire() # reentrant — must not deadlock try: self.assertTrue(lock.is_held()) finally: diff --git a/weightslab/tests/watchdog/test_watchdog.py b/weightslab/tests/watchdog/test_watchdog.py index 1601a3cc..cb915736 100644 --- a/weightslab/tests/watchdog/test_watchdog.py +++ b/weightslab/tests/watchdog/test_watchdog.py @@ -46,7 +46,7 @@ def emit(self, record): log.addHandler(handler) log.setLevel(logging.DEBUG) try: - log.watchdog("hello %s", "world") # type: ignore[attr-defined] + log.watchdog("hello %s", "world") # type: ignore[attr-defined] finally: log.removeHandler(handler) @@ -104,11 +104,11 @@ def test_healthy_lock_not_interrupted(self): def quick_holder(): lock.acquire() - time.sleep(0.02) # well below threshold + time.sleep(0.02) # well below threshold lock.release() watchdog = WeighlabsWatchdog( - stuck_threshold_s=5.0, # high threshold — should not fire + stuck_threshold_s=5.0, # high threshold — should not fire poll_interval_s=0.05, ) watchdog.register_lock("safe_lock", lock) @@ -147,16 +147,16 @@ def test_stuck_rpc_triggers_restart_request(self): def test_healthy_rpc_does_not_trigger_restart(self): watchdog = WeighlabsWatchdog( - stuck_threshold_s=5.0, # high threshold + stuck_threshold_s=5.0, # high threshold poll_interval_s=0.05, restart_threshold=1, ) watchdog.start() rpc_id = watchdog.rpc_state.begin("/test/FastMethod") - time.sleep(0.02) # much less than threshold + time.sleep(0.02) # much less than threshold watchdog.rpc_state.end(rpc_id) - time.sleep(0.1) # let watchdog tick + time.sleep(0.1) # let watchdog tick watchdog.stop() self.assertFalse(watchdog.server_manager.should_restart()) @@ -165,14 +165,14 @@ def test_unhealthy_count_resets_on_recovery(self): watchdog = WeighlabsWatchdog( stuck_threshold_s=0.02, poll_interval_s=0.02, - restart_threshold=10, # high — won't restart + restart_threshold=10, # high — won't restart ) watchdog.start() rpc_id = watchdog.rpc_state.begin("/test/SlowThenFast") - time.sleep(0.12) # trigger unhealthy + time.sleep(0.12) # trigger unhealthy watchdog.rpc_state.end(rpc_id) - time.sleep(0.15) # let watchdog see healthy state + time.sleep(0.15) # let watchdog see healthy state watchdog.stop() self.assertEqual(watchdog._unhealthy_count, 0, "unhealthy_count must reset to 0 on recovery") @@ -194,7 +194,7 @@ class TestWatchdogConfigurability(unittest.TestCase): def test_per_lock_timeout_overrides_global_threshold(self): """A per-lock set_timeout() must take precedence over the global threshold.""" lock = MonitoredRLock() - lock.set_timeout(0.05) # this lock is allowed only 50ms, regardless of global + lock.set_timeout(0.05) # this lock is allowed only 50ms, regardless of global released = threading.Event() started = threading.Event() @@ -230,7 +230,7 @@ def test_restart_threshold_requires_n_consecutive_unhealthy(self): watchdog = WeighlabsWatchdog(stuck_threshold_s=0.01, poll_interval_s=0.03, restart_threshold=3) watchdog.start() rpc_id = watchdog.rpc_state.begin("/test/SlowMethod") - time.sleep(0.3) # many poll cycles → unhealthy count climbs past 3 + time.sleep(0.3) # many poll cycles → unhealthy count climbs past 3 watchdog.stop() watchdog.rpc_state.end(rpc_id) @@ -250,7 +250,7 @@ def holder(): started.set() try: for _ in range(40): - time.sleep(0.02) # ~0.8s, far over the threshold + time.sleep(0.02) # ~0.8s, far over the threshold except _WatchdogInterrupt: interrupted.set() finally: @@ -263,7 +263,7 @@ def holder(): watchdog = WeighlabsWatchdog(stuck_threshold_s=0.05, poll_interval_s=0.05) watchdog.register_lock("eval_lock", lock) watchdog.start() - time.sleep(0.4) # let several poll cycles run + time.sleep(0.4) # let several poll cycles run watchdog.stop() t.join(timeout=2.0) @@ -294,7 +294,7 @@ def test_dead_worker_with_running_controller_is_marked_error(self): controller = self._FakeController() dead_thread = threading.Thread(target=lambda: None) dead_thread.start() - dead_thread.join() # now not alive + dead_thread.join() # now not alive watchdog = WeighlabsWatchdog(poll_interval_s=0.05) watchdog.register_eval_monitor(lambda: controller, lambda: dead_thread) diff --git a/weightslab/trainer/experiment_context.py b/weightslab/trainer/experiment_context.py index 13a5f83e..52899f06 100644 --- a/weightslab/trainer/experiment_context.py +++ b/weightslab/trainer/experiment_context.py @@ -83,7 +83,7 @@ def ensure_components(self, force: bool = False): try: dnames = list_dataloaders() for dname in dnames: - data_loaders[dname] = get_dataloader(dname) # pre-load to catch errors early + data_loaders[dname] = get_dataloader(dname) # pre-load to catch errors early except Exception: logger.error("Error while listing/resolving dataloaders", exc_info=True) pass @@ -154,7 +154,7 @@ def ensure_components(self, force: bool = False): "checkpoint_manager": checkpoint_manager, "df_manager": df_manager } - self._components.update(data_loaders) # add all dataloaders found + self._components.update(data_loaders) # add all dataloaders found self._last_resolve_time = now # Build hyper-parameter descriptors used by the protocol. Use diff --git a/weightslab/trainer/services/agent/agent.py b/weightslab/trainer/services/agent/agent.py index 715e52f4..05a80621 100644 --- a/weightslab/trainer/services/agent/agent.py +++ b/weightslab/trainer/services/agent/agent.py @@ -194,9 +194,9 @@ def build_op(self, step: AtomicIntent, context: Intent) -> Optional[dict]: pattern = r"(\[\s*['\"])(.*?)(['\"]\s*\])" def replace_col(match): - prefix = match.group(1) # e.g. [' + prefix = match.group(1) # e.g. [' content = match.group(2) # e.g. signals//train_loss - suffix = match.group(3) # e.g. '] + suffix = match.group(3) # e.g. '] resolved = self.agent._resolve_column(content) # Try to resolve the content to a real column @@ -373,7 +373,7 @@ def _setup_schema(self): self._build_column_index() def _load_config(self): - self.preferred_provider = os.environ.get("PREFERRED_PROVIDER", "openrouter") # Default to OpenRouter if API key is provided, otherwise fallback to local Ollama. This can be overridden by config file or env variable. + self.preferred_provider = os.environ.get("PREFERRED_PROVIDER", "openrouter") # Default to OpenRouter if API key is provided, otherwise fallback to local Ollama. This can be overridden by config file or env variable. # Cloud provider settings with sensible defaults. OpenRouter is the default cloud provider if API key is provided. self.openrouter_model = os.environ.get("OPENROUTER_MODEL", "meta-llama/llama-3.3-70b-instruct") @@ -382,12 +382,12 @@ def _load_config(self): self.openrouter_request_timeout = float(os.environ.get("OPENROUTER_REQUEST_TIMEOUT", "15.0")) # Local fallback if no cloud (OpenRouter) is available or if the user prefers it. Ollama is the default local provider. - self.fallback_to_local = True # Default to allowing fallback to local Ollama if OpenRouter fails + self.fallback_to_local = True # Default to allowing fallback to local Ollama if OpenRouter fails self.ollama_host = "localhost" self.ollama_port = "11435" self.ollama_model = "llama3.2:3b" - repo_root = Path(__file__).resolve().parents[4] # weightslab/ root + repo_root = Path(__file__).resolve().parents[4] # weightslab/ root inner_pkg = Path(__file__).resolve().parents[3] env_paths = [repo_root / ".env", inner_pkg / ".env"] @@ -574,9 +574,9 @@ def initialize_with_cloud_key(self, api_key: str, provider: str, model: Optional Initialize (or reinitialize) the OpenRouter cloud provider. Args: - api_key: The API key obtained from the provider's website. + api_key: The API key obtained from the provider's website. provider: Must be ``"openrouter"``. - model: OpenRouter model identifier chosen by the user. + model: OpenRouter model identifier chosen by the user. Returns: ``(True, success_message)`` or ``(False, error_message)``. @@ -688,7 +688,7 @@ def _resolve_column(self, user_name: str) -> Optional[str]: # Normalize Input: lowercase, replace spaces AND SLASHES with underscores user_lower = user_name.strip().lower() - user_clean = re.sub(r"[ /_]+", "_", user_lower) # "signals//train_loss" -> "signals_train_loss" + user_clean = re.sub(r"[ /_]+", "_", user_lower) # "signals//train_loss" -> "signals_train_loss" # 1. Exact Match (Fast path) if user_name in self._cols: return user_name @@ -735,7 +735,7 @@ def _build_python_mask(self, conditions: List[Condition], n: Optional[int] = Non # 2. Normalize Operator op = cond.op.lower() - if op == "=" or op == "equals": op = "==" # Fix "equals" + if op == "=" or op == "equals": op = "==" # Fix "equals" val = cond.value diff --git a/weightslab/trainer/services/agent_service.py b/weightslab/trainer/services/agent_service.py index 29ad164e..e39998b0 100644 --- a/weightslab/trainer/services/agent_service.py +++ b/weightslab/trainer/services/agent_service.py @@ -4,16 +4,16 @@ gRPC surface for AI-agent lifecycle management. Responsibilities: - - CheckAgentHealth : report whether any LLM provider is ready. - - InitializeAgent : wire up a cloud provider from a user-supplied API key. + - CheckAgentHealth : report whether any LLM provider is ready. + - InitializeAgent : wire up a cloud provider from a user-supplied API key. The actual ``DataManipulationAgent`` instance lives inside ``DataService`` because it requires the live dataframe context (schema, column index, etc.) -that ``DataService`` owns. ``AgentService`` receives a reference to +that ``DataService`` owns. ``AgentService`` receives a reference to ``DataService`` at construction time and delegates to its agent. Wire-up (in ExperimentService): - data_service = DataService(ctx) + data_service = DataService(ctx) agent_service = AgentService(data_service) """ @@ -62,7 +62,7 @@ def CheckAgentHealth(self, request, context): Returns: AgentHealthResponse { available: bool, message: str } - - available=True → "Ready to help you." + - available=True → "Ready to help you." - available=False → "Agent not configured. Type /init to set up." """ available = self._is_available() diff --git a/weightslab/trainer/services/data_image_utils.py b/weightslab/trainer/services/data_image_utils.py index 872e8f3d..81358b32 100644 --- a/weightslab/trainer/services/data_image_utils.py +++ b/weightslab/trainer/services/data_image_utils.py @@ -2,7 +2,7 @@ data_image_utils — Image encoding, mask compression, and proto helpers for gRPC data serving. Extracted from data_service.py to keep image-specific logic separate from the -DataService orchestration class. All functions here are pure (stateless) and +DataService orchestration class. All functions here are pure (stateless) and safe to call from any thread. """ @@ -47,8 +47,8 @@ def rle_encode_mask(mask_flat: np.ndarray) -> bytes: ends = np.empty_like(starts) ends[:-1] = starts[1:] ends[-1] = mask_flat.size - lengths = ends - starts # numpy int array - values = mask_flat[starts] # numpy uint8 array + lengths = ends - starts # numpy int array + values = mask_flat[starts] # numpy uint8 array # Split any runs > 65535 into multiple segments out_vals: list[int] = [] diff --git a/weightslab/trainer/services/data_service.py b/weightslab/trainer/services/data_service.py index b12495c3..f30f8831 100755 --- a/weightslab/trainer/services/data_service.py +++ b/weightslab/trainer/services/data_service.py @@ -52,6 +52,34 @@ logger = logging.getLogger(__name__) +# Streamed chunk size for GetPointCloud (raw float32 bytes per gRPC message). +# Larger chunks mean fewer messages but more memory per message. Override with +# the WL_POINT_CLOUD_CHUNK_BYTES env variable (see docs/configuration.rst). +_DEFAULT_POINT_CLOUD_CHUNK_BYTES = 1 << 20 # 1 MiB + + +def _point_cloud_chunk_bytes() -> int: + """Read WL_POINT_CLOUD_CHUNK_BYTES; non-positive/invalid falls back to the default.""" + raw = os.getenv("WL_POINT_CLOUD_CHUNK_BYTES") + if raw is None or raw == "": + return _DEFAULT_POINT_CLOUD_CHUNK_BYTES + try: + val = int(raw) + except (TypeError, ValueError): + logger.warning( + "WL_POINT_CLOUD_CHUNK_BYTES=%r is not an integer — using default %d", + raw, _DEFAULT_POINT_CLOUD_CHUNK_BYTES, + ) + return _DEFAULT_POINT_CLOUD_CHUNK_BYTES + if val <= 0: + logger.warning( + "WL_POINT_CLOUD_CHUNK_BYTES=%r must be > 0 — using default %d", + raw, _DEFAULT_POINT_CLOUD_CHUNK_BYTES, + ) + return _DEFAULT_POINT_CLOUD_CHUNK_BYTES + return val + + def normalize_metadata_copy_source_name(source_name: str, experiment_hash: str = None) -> str: """Normalize a source metadata name for deterministic copied-column naming.""" name = str(source_name or "").strip() @@ -231,12 +259,16 @@ def __init__(self, ctx): # rather than queuing to redo the same work. # # Protocol: - # 1. try_acquire(_update_lock, blocking=False) - # → won: clear _update_done, do the update, release, set _update_done - # → lost: _update_done.wait() then return (result already fresh) + # 1. try_acquire(_update_lock, blocking=False) + # → won: clear _update_done, do the update, release, set _update_done + # → lost: _update_done.wait() then return (result already fresh) self._update_lock = threading.Lock() self._update_done = threading.Event() - self._update_done.set() # "done" initially so the very first call proceeds + self._update_done.set() # "done" initially so the very first call proceeds + # Guard so a non-force (reader-triggered) view refresh runs in the BACKGROUND + # at most once at a time — readers never pay the rebuild cost (they read the + # current snapshot; the bg thread swaps in fresh data when ready). + self._refresh_in_flight = threading.Lock() self._df_manager = get_dataframe() # init references to the context components @@ -258,7 +290,7 @@ def __init__(self, ctx): # Check hyperparameters for compute_natural_sort flag (default: False) # Users can enable it by setting compute_natural_sort=True in their hyperparameters. hp = self._ctx.components.get("hyperparams") if self._ctx and self._ctx.components else None - hp_dict = hp.get() if Proxy.is_proxy(hp) else (hp if isinstance(hp, dict) else {}) # is it already a proxy ? + hp_dict = hp.get() if Proxy.is_proxy(hp) else (hp if isinstance(hp, dict) else {}) # is it already a proxy ? self._compute_natural_sort = bool((hp_dict or {}).get("compute_natural_sort", False)) # How per-instance (per-annotation) numeric columns are folded to a single @@ -291,14 +323,14 @@ def __init__(self, ctx): max_workers=8 ) - self._is_filtered = False # Track if the current view is filtered/modified by user + self._is_filtered = False # Track if the current view is filtered/modified by user # logger.info("[DataService] Skipping expensive startup computations (aspect ratio, natural sort, signals).") # These should be triggered on-demand or run in background to avoid blocking training start. if self._compute_natural_sort: self._compute_natural_sort_stats() - self._is_filtered = False # Track if the current view is filtered/modified by user + self._is_filtered = False # Track if the current view is filtered/modified by user # ===================================================================== # Preview cache: pre-generate 64×64 or less WebP thumbnails + RLE masks for @@ -321,7 +353,7 @@ def __init__(self, ctx): daemon=True, ).start() else: - self._preview_cache_ready.set() # No preload → mark immediately ready + self._preview_cache_ready.set() # No preload → mark immediately ready logger.info("DataService initialized.") @@ -346,8 +378,8 @@ def _build_preview_cache(self) -> None: """Pre-generate 64×64 or less or less thumbnail + RLE mask for every row in the DF. Each entry is a lightweight ``DataRecord`` containing only: - • raw_data (bytes) — 64×64 or less or less WebP thumbnail - • target (rle_mask) — RLE-encoded GT mask resized to 64×64 or less or less + • raw_data (bytes) — 64×64 or less or less WebP thumbnail + • target (rle_mask) — RLE-encoded GT mask resized to 64×64 or less or less • pred_mask (rle_mask) — RLE-encoded prediction mask resized to 64×64 or less or less • origin, task_type, num_classes, class_names (metadata) Respects ``_preview_cache_max`` to cap memory usage. @@ -363,7 +395,7 @@ def _build_preview_cache(self) -> None: logger.info("[PreviewCache] Building 64×64 or less or less preview cache for %d samples …", total) t0 = time.time() - PREVIEW_SIZE = 64 # fixed low-res dimension + PREVIEW_SIZE = 64 # fixed low-res dimension built = 0 index_names = list(getattr(df.index, "names", []) or []) @@ -762,49 +794,49 @@ def _is_training_active(self) -> bool: return True def _pull_into_all_data_view_df(self): - """Stream stats from the global in-memory dataframe (ledger manager). + """Stream stats from the global in-memory dataframe (ledger manager). - Uses the shared dataframe manager instead of the H5 store and avoids - blocking on IO. Falls back to last snapshot if retrieval fails. - """ - try: - # Load dataframe from the shared dataframe manager with arrays autoloaded from h5 storage - df = self._df_manager.get_combined_df() if self._df_manager is not None else pd.DataFrame() - if df.empty: - logger.debug(f"[DataService] Pull returned empty dataframe (manager: {self._df_manager is not None})") - return df - - # The manager now expands samples into one row per (sample_id, annotation_id) - # instance. Collapse back to one row per sample for the sample-centric UI/agent - # view, nesting per-instance signals into a dict column. - df = self._df_manager.get_collapse_annotations_to_samples_df() - - # Ensure sample_id is a column if it was the index - df = safe_reset_index(df) - - # Ensure we have a unique index across all origins by using a MultiIndex (origin, sample_id) - # This is CRITICAL for correctly applying reindex() in _slowUpdateInternals without - # exploding the dataframe size due to duplicate sample_id index labels. - if SampleStatsEx.ORIGIN.value in df.columns: - # Use drop=True to ensure origin is NOT in both index and columns (avoids ambiguity) - # GetDataSamples calls reset_index() before processing rows, which restores them as columns - df = df.set_index([SampleStatsEx.ORIGIN.value, SampleStatsEx.SAMPLE_ID.value], drop=True) - else: - # Fallback to single index if origin is missing, though manager should provide it - df = df.set_index([SampleStatsEx.SAMPLE_ID.value], drop=True) + Uses the shared dataframe manager instead of the H5 store and avoids + blocking on IO. Falls back to last snapshot if retrieval fails. + """ + try: + # Load dataframe from the shared dataframe manager with arrays autoloaded from h5 storage + df = self._df_manager.get_combined_df() if self._df_manager is not None else pd.DataFrame() + if df.empty: + logger.debug(f"[DataService] Pull returned empty dataframe (manager: {self._df_manager is not None})") + return df - # DEDUPLICATE: Ensure index is unique before returning. - # If duplicates exist, reindex() will fail later. - if df.index.has_duplicates: - logger.debug(f"[DataService] Dropping {df.index.duplicated().sum()} duplicate index labels from data view.") - df = df[~df.index.duplicated(keep='last')] + # The manager now expands samples into one row per (sample_id, annotation_id) + # instance. Collapse back to one row per sample for the sample-centric UI/agent + # view, nesting per-instance signals into a dict column. + df = self._df_manager.get_collapse_annotations_to_samples_df() + + # Ensure sample_id is a column if it was the index + df = safe_reset_index(df) + + # Ensure we have a unique index across all origins by using a MultiIndex (origin, sample_id) + # This is CRITICAL for correctly applying reindex() in _slowUpdateInternals without + # exploding the dataframe size due to duplicate sample_id index labels. + if SampleStatsEx.ORIGIN.value in df.columns: + # Use drop=True to ensure origin is NOT in both index and columns (avoids ambiguity) + # GetDataSamples calls reset_index() before processing rows, which restores them as columns + df = df.set_index([SampleStatsEx.ORIGIN.value, SampleStatsEx.SAMPLE_ID.value], drop=True) + else: + # Fallback to single index if origin is missing, though manager should provide it + df = df.set_index([SampleStatsEx.SAMPLE_ID.value], drop=True) - return df - except Exception as e: - logger.debug(f"[DataService] Error pulling data view: {e}") - # Use getattr to safely check for attribute during __init__ - current_df = getattr(self, "_all_datasets_df", None) - return current_df if current_df is not None else pd.DataFrame() + # DEDUPLICATE: Ensure index is unique before returning. + # If duplicates exist, reindex() will fail later. + if df.index.has_duplicates: + logger.debug(f"[DataService] Dropping {df.index.duplicated().sum()} duplicate index labels from data view.") + df = df[~df.index.duplicated(keep='last')] + + return df + except Exception as e: + logger.debug(f"[DataService] Error pulling data view: {e}") + # Use getattr to safely check for attribute during __init__ + current_df = getattr(self, "_all_datasets_df", None) + return current_df if current_df is not None else pd.DataFrame() def _get_origin_filter(self, request): """Extract requested origins if present on request (backward compatible).""" @@ -815,7 +847,7 @@ def _get_origin_filter(self, request): if val: # Normalize to list if isinstance(val, str): - origins = [val] if val.strip() else [] # Filter empty strings + origins = [val] if val.strip() else [] # Filter empty strings else: # Filter out empty strings from list origins = [o for o in list(val) if o and str(o).strip()] @@ -877,7 +909,6 @@ def _is_nan_value(self, value): except (TypeError, ValueError): return False - def _compute_natural_sort_stats(self): """ Compute hardcoded natural sort statistics (brightness, hue, saturation, entropy) for all samples @@ -894,33 +925,30 @@ def _compute_natural_sort_stats(self): # 4. "Grouped" (Pseudo-primary key): Brightness=5.0, Entropy=1.0 (Forces clustering by light) SORT_WEIGHTS = { - "brightness": 0.7, # Primary cue: Lighting conditions - "entropy": 0.3, # Secondary cue: Texture/Scene complexity - "hue": 0.0 # Optional: Color tint + "brightness": 0.7, # Primary cue: Lighting conditions + "entropy": 0.3, # Secondary cue: Texture/Scene complexity + "hue": 0.0 # Optional: Color tint } logger.info(f"[DataService] Starting natural sort stats computation with weights: {SORT_WEIGHTS}") - try: - import cv2 - except ImportError: - logger.warning("[DataService] OpenCV not found. Skipping natural sort computation.") - return "OpenCV not installed" - if self._all_datasets_df is None or self._all_datasets_df.empty: return "No data to process" - # Helper: Calculate Shannon Entropy (Complexity) + # Helper: Calculate Shannon Entropy (Complexity). + # 256-bin histogram of the 8-bit grayscale image via numpy (no OpenCV). def calc_entropy(img_gray): try: - # Calculate histogram (256 bins for 8-bit) - hist = cv2.calcHist([img_gray], [0], None, [256], [0, 256]) - # Normalize histogram to get probabilities - p = hist.ravel() / hist.sum() - # Filter out zero probabilities to avoid log(0) + gray_u8 = np.clip(np.rint(img_gray), 0, 255).astype(np.uint8) + counts = np.bincount(gray_u8.ravel(), minlength=256).astype(np.float64) + total = counts.sum() + if total <= 0: + return 0.0 + # Normalize to probabilities, dropping zeros to avoid log(0) + p = counts / total p = p[p > 0] # Shannon Entropy in bits - return -np.sum(p * np.log2(p)) + return float(-np.sum(p * np.log2(p))) except Exception: return 0.0 @@ -961,27 +989,28 @@ def process_sample(args): # Convert to numpy (RGB) img_np = np.array(pil_img) - # Brightness (mean pixel intensity) - # If RGB, convert to Gray, else just mean - if img_np.ndim == 3: - # OpenCV expects BGR usually, but PIL gives RGB. - # cvtColor RGB2GRAY is correct. - try: - gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) - except Exception: - gray = img_np + # Brightness (mean pixel intensity). For color images, reduce to + # luma using the ITU-R 601-2 transform — the same weights PIL's + # "L" mode uses — with plain numpy. + if img_np.ndim == 3 and img_np.shape[2] >= 3: + luma = np.array([0.299, 0.587, 0.114], dtype=np.float32) + gray = img_np[..., :3].astype(np.float32) @ luma + elif img_np.ndim == 3: + gray = img_np[..., 0].astype(np.float32) else: - gray = img_np + gray = img_np.astype(np.float32) - brightness = np.mean(gray) + brightness = float(np.mean(gray)) entropy = calc_entropy(gray) - # HSV Stats - if img_np.ndim == 3: + # HSV stats (hue/saturation). Computed with Pillow's "HSV" mode + # (H, S, V each in 0-255) to avoid an OpenCV dependency. + if img_np.ndim == 3 and img_np.shape[2] >= 3: try: - hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV) - hue = np.mean(hsv[:, :, 0]) - saturation = np.mean(hsv[:, :, 1]) + rgb_img = pil_img if pil_img.mode == "RGB" else pil_img.convert("RGB") + hsv = np.asarray(rgb_img.convert("HSV")) + hue = float(np.mean(hsv[:, :, 0])) + saturation = float(np.mean(hsv[:, :, 1])) except Exception: hue = 0.0 saturation = 0.0 @@ -997,8 +1026,8 @@ def process_sample(args): # Entropy: 0-8 bits typical for 8-bit image norm_entropy = min(max(entropy / 8.0, 0.0), 1.0) - # Hue: 0-179 in OpenCV - norm_hue = min(max(hue / 179.0, 0.0), 1.0) + # Hue: 0-255 in Pillow's HSV space + norm_hue = min(max(hue / 255.0, 0.0), 1.0) score = ( SORT_WEIGHTS.get("brightness", 0) * norm_brightness + @@ -1121,7 +1150,7 @@ def _process_sample_row(self, args): try: origin = row.get(SampleStatsEx.ORIGIN.value, 'unknown') sample_id = row.get(SampleStatsEx.SAMPLE_ID.value, 0) - # logger.debug(f"Processing sample_id={sample_id} from origin={origin} with request: {request}") + logger.debug(f"Processing sample_id={sample_id} from origin={origin} with request: {request}") # ===== Timing accumulators ===== t_image_load = 0.0 @@ -1179,7 +1208,7 @@ def _process_sample_row(self, args): task_type = "unknown" else: # 4. Safe Heuristic evaluation - task_type = "classification" # Default fallback + task_type = "classification" # Default fallback if label is not None: if isinstance(label, dict): if ('boxes' in label or 'bboxes' in label): @@ -1204,58 +1233,14 @@ def _process_sample_row(self, args): except Exception: pass - # ====== Step 5a: Process stats ====== - stats_to_retrieve = list(request.stats_to_retrieve) - - # These columns are handled explicitly later in the pipeline - exclude_cols = { - SampleStatsEx.SAMPLE_ID.value, - SampleStatsEx.ORIGIN.value, - SampleStatsEx.TARGET.value if not skip_label_for_request else None, - SampleStatsEx.PREDICTION.value, - SampleStatsEx.TASK_TYPE.value, - '_instance_signals', # Special handling for multi-instance signals - 'annotation_id', # Internal multi-index tracking - } - - if not stats_to_retrieve: - stats_to_retrieve = [col for col in df_columns if col not in exclude_cols] - - # Optimized bulk processing of stats - for stat_name in stats_to_retrieve: - # Never re-process core fields generically (prevents duplicates/bad db state overwriting calculated state) - if stat_name in exclude_cols: - continue - - value = row.get(stat_name) + # ====== Step 5a: Metadata stats — moved to GetMetaData ====== + # Generic dataframe metadata columns (signals, tags, custom fields, etc.) + # are no longer returned by GetDataSamples; the dedicated GetMetaData RPC + # serves them. GetDataSamples returns only the rendering flags + # origin / task_type / discarded (needed for the split border, overlay + # mode and gray-out) plus image / label / prediction data below. - # Skip prediction raw array - if (isinstance(value, np.ndarray) and value.ndim > 1) or (isinstance(value, (list, tuple, np.ndarray)) and len(value) == 0): - continue - elif isinstance(value, float): - value = round(value, 7) - elif isinstance(value, bool): - value = int(value) - - # Check if it s a tag column here and handle it as a string stat with the tag name as value - value_string = str(value) - if stat_name.startswith(f"{SampleStatsEx.TAG.value}"): - tag_name = stat_name[len(f"{SampleStatsEx.TAG.value}:"):] # Remove "tags_" prefix to get tag name - data_stats.append( - create_data_stat( - f"{SampleStatsEx.TAG.value}:{tag_name}", - "string", - shape=[1], - value_string=value_string, - thumbnail=b"" - ) - ) - else: - data_stats.append( - create_data_stat(stat_name, "string", shape=[1], value_string=value_string[:512], thumbnail=b"") - ) - - # ====== Step 6: Add origin and task_type stats ====== + # ====== Step 6: Add origin, task_type and discarded rendering flags ====== data_stats.append( create_data_stat( "origin", 'string', shape=[1], value_string=origin, thumbnail=b"" @@ -1266,6 +1251,18 @@ def _process_sample_row(self, args): "task_type", 'string', shape=[1], value_string=str(task_type), thumbnail=b"" ) ) + # 'discarded' drives the grayed-out cell rendering, so it rides with the + # image data as "1"/"0" (not treated as analytical metadata). This keeps + # the gray-out reliable on every grid (re)fetch / scroll. + try: + _discarded_str = "1" if bool(row.get(SampleStatsEx.DISCARDED.value)) else "0" + except Exception: + _discarded_str = "0" + data_stats.append( + create_data_stat( + SampleStatsEx.DISCARDED.value, 'string', shape=[1], value_string=_discarded_str, thumbnail=b"" + ) + ) target_mask_stat_index = None pred_mask_stat_index = None @@ -1403,7 +1400,7 @@ def _process_sample_row(self, args): max_id = int(label_arr.max()) num_classes = max(1, max_id) + 1 else: - num_classes = 2 # Always at least 2 classes for segmentation (foreground/background) + num_classes = 2 # Always at least 2 classes for segmentation (foreground/background) data_stats.append( create_data_stat( @@ -1468,7 +1465,7 @@ def _process_sample_row(self, args): else: # Check if label is NaN (handle both scalars and arrays) if self._is_nan_value(label): - pass # Skip NaN labels + pass # Skip NaN labels # Handle scalar labels try: @@ -1617,7 +1614,7 @@ def _process_sample_row(self, args): else: # Classification: get prediction from row or dataset if pred is None: - pass # No prediction to process + pass # No prediction to process else: # Handle scalar predictions (int, float, or unwrapped from H5) @@ -1699,7 +1696,7 @@ def _process_sample_row(self, args): target_width = w_limit target_height = int(target_width / aspect_ratio) elif request.resize_width == 0 and request.resize_height == 0: - target_height = int(os.environ.get("WL_DEFAULT_THUMBNAIL_SIZE", 180)) # Default full resolution image is 360p on the longest side, but can be overridden by env var + target_height = int(os.environ.get("WL_DEFAULT_THUMBNAIL_SIZE", 180)) # Default full resolution image is 360p on the longest side, but can be overridden by env var target_width = int(target_height * aspect_ratio) if is_full_resolution: @@ -1877,7 +1874,7 @@ def _build_success_response( """ total_count = len(df) discarded_count = ( - len(df[df.get(SampleStatsEx.DISCARDED.value, False) == True]) # noqa: E712 + len(df[df.get(SampleStatsEx.DISCARDED.value, False) == True]) # noqa: E712 if df is not None and SampleStatsEx.DISCARDED.value in df.columns else 0 ) @@ -2222,7 +2219,7 @@ def _apply_agent_operation(self, df, func: str, params: dict) -> str: # We must split the updates by origin and upsert them to the manager if self._df_manager is not None: # Create a minimal update dataframe with just the modified column - update_payload = df[[col]] # .copy() # Remove copy because memory waste and slowdown + update_payload = df[[col]] # .copy() # Remove copy because memory waste and slowdown # Ensure origin is available for grouping if isinstance(df.index, pd.MultiIndex) and "origin" in df.index.names: @@ -2271,7 +2268,7 @@ def _apply_agent_operation(self, df, func: str, params: dict) -> str: if start < len(df): logger.debug(f"[sort_view_slice] Sorting slice {start}:{end}") # Extract and sort slice - sub_df = df.iloc[start:end] # .copy() # Remove copy because memory waste and slowdown + sub_df = df.iloc[start:end] # .copy() # Remove copy because memory waste and slowdown # Apply sort to slice # Filter params for sort_values @@ -2331,12 +2328,12 @@ def _apply_agent_operation(self, df, func: str, params: dict) -> str: # otherwise Sample ID X will point to data from Sample ID Y (corruption). try: if isinstance(df.index, pd.MultiIndex): - new_index_values = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown + new_index_values = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown new_index_values[start:end] = sub_df.index.to_numpy() df.index = pd.MultiIndex.from_tuples(new_index_values, names=df.index.names) else: idx_name = df.index.name - new_index = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown + new_index = df.index.to_numpy() # .copy() # Remove copy because memory waste and slowdown new_index[start:end] = sub_df.index.to_numpy() df.index = pd.Index(new_index, name=idx_name) except Exception as e: @@ -2466,7 +2463,7 @@ def _restore_index(): # Lock watchdog helpers # ------------------------------------------------------------------ # ------------------------------------------------------------------ - # Lock watchdog helpers (build on MonitoredRLock from watchdog/) + # Lock watchdog helpers (build on MonitoredRLock from watchdog/) # ------------------------------------------------------------------ @staticmethod def _lock_caller() -> str: @@ -2503,7 +2500,7 @@ def _watched_lock(self, lock_name: str = "_lock"): self._lock.acquire() waited_ms = (time.time() - t0) * 1000 logger.debug( - "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", + "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", lock_name, thread, caller, waited_ms, ) t_held = time.time() @@ -2514,18 +2511,29 @@ def _watched_lock(self, lock_name: str = "_lock"): self._lock.release() if held_ms > 1000: logger.warning( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", lock_name, thread, held_ms, ) else: logger.debug( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", lock_name, thread, held_ms, ) # ------------------------------------------------------------------ # Main update method # ------------------------------------------------------------------ + def _bg_view_refresh(self) -> None: + """Background view rebuild for reader-triggered (non-force) refreshes. Runs the + real rebuild+swap via force=True OFF the request path, then releases the guard so + a later stale read can trigger another. Never raises into a request.""" + try: + self._slowUpdateInternals(force=True) + except Exception: + logger.exception("[ViewRefresh] background view refresh failed") + finally: + self._refresh_in_flight.release() + def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> None: """Update the internal dataframe view with the latest data from the manager. @@ -2555,12 +2563,29 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> current_time - self._last_internals_update_time <= 10: return - # --- Try to become the single updater --- + # --- Non-force (reader-triggered) refresh: run it in the BACKGROUND --- + # The view is stale, but a reader (grid/histogram/periodic fetch) must NOT block + # on the multi-second collapse+combine rebuild. Kick a single background refresh + # (the WL-ViewRefresh thread calls force=True, which does the real rebuild+atomic + # swap below) and return immediately — the caller reads the current (last-completed) + # snapshot. If a refresh is already running, just return; the next fetch sees the swap. + if not force: + if self._refresh_in_flight.acquire(blocking=False): + try: + threading.Thread( + target=self._bg_view_refresh, name="WL-ViewRefresh", daemon=True + ).start() + except Exception: + self._refresh_in_flight.release() # never leak the guard + logger.exception("[ViewRefresh] failed to start background refresh") + return + + # --- Try to become the single updater (force path: rebuild inline) --- t_wait_start = time.time() acquired = self._update_lock.acquire(blocking=False) if not acquired: - # Another worker is already updating. Wait for it to finish (bounded), + # Another worker is already updating. Wait for it to finish (bounded), # then return — the caller will read the already-refreshed view. thread = threading.current_thread().name logger.debug( @@ -2576,7 +2601,7 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> thread = threading.current_thread().name caller = self._lock_caller() logger.debug( - "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", + "[LockWatchdog] %-36s ACQUIRED by %-30s caller=%s (waited %.1f ms)", "_update_lock[_slowUpdateInternals]", thread, caller, waited_ms, ) # Signal to latecomers that an update is now in progress. @@ -2685,12 +2710,12 @@ def _slowUpdateInternals(self, force: bool = False, reset_view: bool = False) -> self._update_lock.release() if held_ms > 1000: logger.warning( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms ← SLOW", "_update_lock[_slowUpdateInternals]", threading.current_thread().name, held_ms, ) else: logger.debug( - "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", + "[LockWatchdog] %-36s RELEASED by %-30s held %.1f ms", "_update_lock[_slowUpdateInternals]", threading.current_thread().name, held_ms, ) # Unblock all workers that were waiting on this update. @@ -2716,12 +2741,13 @@ def _is_metadata_only_request(self, request) -> bool: except Exception: return False - def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): - """Build DataSamplesResponse from dataframe columns only (no dataset/image traversal). + def _build_metadata_only_response(self, df_slice: pd.DataFrame, requested_cols=None): + """Build a DataSamplesResponse of metadata DataRecords from dataframe columns only. - This is a single-job vectorized path: the entire df_slice is processed - at once using pandas operations rather than dispatching per-sample_id - work to the thread pool. + No dataset/image traversal: the entire df_slice is processed at once using + vectorized pandas operations rather than dispatching per-sample_id work to + the thread pool. ``requested_cols`` restricts the columns; when None/empty + all columns are returned except heavy per-sample blobs. Used by GetMetaData. """ if df_slice is None or df_slice.empty: return pb2.DataSamplesResponse( @@ -2730,7 +2756,7 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): data_records=[], ) - requested_cols = list(getattr(request, 'stats_to_retrieve', []) or []) + requested_cols = list(requested_cols or []) # NOTE: ORIGIN is intentionally NOT excluded. The histogram (and any caller # that needs per-sample split coloring) requests 'origin' explicitly and # relies on this fast vectorized path to return it — without this, the client @@ -2769,7 +2795,7 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): # -- Vectorized pre-processing: build string matrices via pandas ------ # Separate tag columns from regular metadata columns for different - # handling. All heavy conversion is done once on the full column + # handling. All heavy conversion is done once on the full column # vectors, not per-row. tag_cols = [c for c in metadata_cols if c.startswith(tag_prefix)] meta_cols = [c for c in metadata_cols if not c.startswith(tag_prefix)] @@ -2783,15 +2809,20 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): # -- Column-wise DataStat construction -------------------------------- # Build all DataStat objects for one column at a time using list # comprehensions (CPython fast-path) and inline pb2.DataStat() to - # eliminate the create_data_stat wrapper overhead. Then scatter - # them into the per-row bins. At 1M rows × 10 cols this avoids + # eliminate the create_data_stat wrapper overhead. Then scatter + # them into the per-row bins. At 1M rows × 10 cols this avoids # a 10M-iteration nested Python loop. - _DataStat = pb2.DataStat # local ref – avoids repeated attr lookup + _DataStat = pb2.DataStat # local ref – avoids repeated attr lookup for col in meta_cols: series = df_slice[col] if series.dtype.kind == 'f': str_vals = series.round(7).astype(str).str[:512].tolist() + elif series.dtype.kind == 'b': + # Booleans (e.g. 'discarded') → "1"/"0" so the UI's boolean/discarded + # handling — which expects the legacy per-sample "1"/"0" form — keeps + # working now that metadata is served exclusively by GetMetaData. + str_vals = series.astype(int).astype(str).tolist() else: str_vals = series.astype(str).str[:512].tolist() # NaN → None @@ -2841,6 +2872,132 @@ def _build_metadata_only_response(self, df_slice: pd.DataFrame, request): data_records=data_records, ) + def _get_all_metadata_column_names(self) -> list: + """Return every metadata column name available across the WHOLE dataset. + + Excludes heavy per-sample blob columns (pred/target) and internal + bookkeeping columns, matching the column set _build_metadata_only_response + emits. Order follows the dataframe columns (de-duplicated) so the UI column + picker stays stable across refreshes. + """ + try: + df = self._all_datasets_df + if df is None or df.empty: + return [] + _HEAVY_BLOB_COLS = {"pred", "prediction", "prediction_raw", "target"} + _INTERNAL_COLS = { + SampleStatsEx.SAMPLE_ID.value, + SampleStatsEx.TASK_TYPE.value, + "annotation_id", + "_instance_signals", + } + seen, names = set(), [] + for col in df.columns: + name = str(col) + if col in _HEAVY_BLOB_COLS or col in _INTERNAL_COLS: + continue + if name in seen: + continue + seen.add(name) + names.append(name) + # Include index-level names too (e.g. 'origin' in the (origin, sample_id) + # multi-index), excluding internal levels like sample_id/annotation_id. + if isinstance(df.index, pd.MultiIndex): + index_names = [n for n in (df.index.names or []) if n] + elif df.index.name: + index_names = [df.index.name] + else: + index_names = [] + for n in index_names: + name = str(n) + if n in _HEAVY_BLOB_COLS or n in _INTERNAL_COLS or name in seen: + continue + seen.add(name) + names.append(name) + return names + except Exception as e: + logger.warning("Error enumerating metadata column names: %s", e) + return [] + + def GetMetaData(self, request, context): + """Metadata-only retrieval, separated from GetDataSamples. + + Returns: + - all_metadata_names: every metadata column for the WHOLE dataset + - grid_records: per-sample metadata for the requested grid slice + - modal_record: metadata for the open modal sample (by sample_id), if any + """ + try: + # Read the current view directly (kept fresh by the same mechanisms + # GetDataSamples relies on); no forced refresh on the 15s metadata poll. + all_names = self._get_all_metadata_column_names() + df = self._all_datasets_df + + if df is None or df.empty: + return pb2.GetMetaDataResponse( + success=False, + message="Internal dataframe is empty or not initialized.", + all_metadata_names=all_names, + grid_records=[], + ) + + # ---- Grid slice metadata (current view order) ---- + grid_records = [] + start = max(0, int(getattr(request, "start_index", 0))) + count = int(getattr(request, "records_cnt", 0)) + if count > 0: + try: + df_slice = safe_reset_index(df.iloc[start:start + count]) + except IndexError: + df_slice = None + if df_slice is not None and not df_slice.empty: + df_slice, _ = self._merge_multi_instance_signals(df_slice) + grid_resp = self._build_metadata_only_response(df_slice) + if grid_resp.success: + grid_records = list(grid_resp.data_records) + + # ---- Modal sample metadata (optional, by sample_id) ---- + modal_record = None + modal_id = str(getattr(request, "modal_sample_id", "") or "").strip() + if modal_id: + try: + sid_col = SampleStatsEx.SAMPLE_ID.value + matches = None + if sid_col in df.columns: + matches = df[df[sid_col].astype(str) == modal_id] + elif isinstance(df.index, pd.MultiIndex) and sid_col in (df.index.names or []): + # sample_id is a multi-index level (origin, sample_id). + level_vals = df.index.get_level_values(sid_col).astype(str) + matches = df[level_vals == modal_id] + else: + matches = df[df.index.astype(str) == modal_id] + if matches is not None and not matches.empty: + modal_df = safe_reset_index(matches.iloc[[0]]) + modal_df, _ = self._merge_multi_instance_signals(modal_df) + modal_resp = self._build_metadata_only_response(modal_df) + if modal_resp.success and modal_resp.data_records: + modal_record = modal_resp.data_records[0] + except Exception as e: + logger.warning("GetMetaData modal lookup failed for %s: %s", modal_id, e) + + resp = pb2.GetMetaDataResponse( + success=True, + message=f"Retrieved {len(grid_records)} metadata records, {len(all_names)} columns", + all_metadata_names=all_names, + grid_records=grid_records, + ) + if modal_record is not None: + resp.modal_record.CopyFrom(modal_record) + return resp + except Exception as e: + logger.error("Error in GetMetaData: %s", str(e), exc_info=True) + return pb2.GetMetaDataResponse( + success=False, + message=f"Failed to retrieve metadata: {str(e)}", + all_metadata_names=[], + grid_records=[], + ) + def _merge_multi_instance_signals(self, df_slice): """Merge per-instance signals into dictionaries for multi-index dataframes. @@ -2922,14 +3079,14 @@ def _process_get_data_samples(self, request, context): resolution (both dims ≤ ``_PREVIEW_CACHE_THRESHOLD``), serve from the cache instantly without touching the file system. 2. **Parallel batch processing** – All samples are submitted to the - thread pool at once so all 8 workers stay busy. The chunk-size + thread pool at once so all 8 workers stay busy. The chunk-size env-var ``WL_BATCH_CHUNK_SIZE`` is kept for backward compat but the default is now the full request size (all at once). """ - _PREVIEW_CACHE_THRESHOLD = 80 # max px to consider a "preview" request + _PREVIEW_CACHE_THRESHOLD = 80 # max px to consider a "preview" request # Default: process ALL rows at once in the thread pool (workers = 8). # Override with WL_BATCH_CHUNK_SIZE to throttle concurrency. - _BATCH_CHUNK_SIZE = int(os.environ.get("WL_BATCH_CHUNK_SIZE", "0")) # 0 = all at once + _BATCH_CHUNK_SIZE = int(os.environ.get("WL_BATCH_CHUNK_SIZE", "0")) # 0 = all at once try: start_time = time.time() @@ -2988,8 +3145,9 @@ def _process_get_data_samples(self, request, context): logger.debug( "Retrieving samples from %s to %s", request.start_index, request.start_index + request.records_cnt) - if self._is_metadata_only_request(request): - return self._build_metadata_only_response(df_slice, request) + # NOTE: metadata-only requests are no longer served here. GetDataSamples + # returns image / label / prediction data only; metadata columns are + # served by the dedicated GetMetaData RPC. # ---- Preview-cache tolerant path ------------------------------- # For preview-tier requests, serve what is available from cache @@ -3095,7 +3253,7 @@ def _process_get_data_samples(self, request, context): # ---- Parallel batch processing --------------------------------- # Submit ALL rows to the thread pool at once so all 8 workers - # stay busy. This avoids the old sequential-chunk bottleneck + # stay busy. This avoids the old sequential-chunk bottleneck # where each sub-batch had to finish before the next started. data_records: list = [] rows_list = list(df_slice.iterrows()) @@ -3161,7 +3319,7 @@ def _calculate_tag_column_updates(self, sample_id: int, origin: str, new_tag_nam new_tag_name = f'{SampleStatsEx.TAG.value}:{stripped_tag_name}' # Get current tags from the in-memory dataframe or df_manager - existing_tag_value = True # Default to True for new tags + existing_tag_value = True # Default to True for new tags try: if self._all_datasets_df is not None: # Read current tag columns from in-memory dataframe @@ -3177,7 +3335,7 @@ def _calculate_tag_column_updates(self, sample_id: int, origin: str, new_tag_nam if row is not None: for col in row.index: - if col == new_tag_name and row[col]: # If existing, revert the value + if col == new_tag_name and row[col]: # If existing, revert the value existing_tag_value = bool(1 - row[col]) except (KeyError, AttributeError) as e: @@ -3185,7 +3343,7 @@ def _calculate_tag_column_updates(self, sample_id: int, origin: str, new_tag_nam # Calculate target tags based on edit type if edit_type == SampleEditType.EDIT_REMOVE: - existing_tag_value = False # For removal, we set the tag to False + existing_tag_value = False # For removal, we set the tag to False target_tags_set = self._parse_tags(new_tag_name) else: # Override: replace all tags with the new value @@ -3227,8 +3385,8 @@ def ApplyDataQuery(self, request, context): Apply a query on the in-memory dataframe. Modes: - - request.query == "" -> just return counts, do not modify df - - request.query != "" -> always handled by the agent (natural language path) + - request.query == "" -> just return counts, do not modify df + - request.query != "" -> always handled by the agent (natural language path) Counts returned: - number_of_all_samples: all rows currently in the dataframe @@ -3262,14 +3420,20 @@ def ApplyDataQuery(self, request, context): # Apply operations with lock with self._watched_lock("_lock[ApplyDataQuery/ops]"): - # If this is JUST a view-sort for a specific page, DO NOT force an internal refresh - # as that wipes out the existing slice/sort state before we try to modify the next slice. - is_only_view_sort = len(operations) == 1 and operations[0].get("function") == "df.sort_view_slice" - if not is_only_view_sort: - self._slowUpdateInternals(force=True) # Refresh internals before applying Agent operations + # Skip the forced full-view rebuild for SORT-ONLY operations. Sorting just + # re-orders the existing snapshot, so a fresh collapse+combine (hundreds of + # ms on large views, and — being lock-held — contends with the training + # thread for multi-second stalls) is unnecessary. Filters/edits still refresh + # so they operate on the latest data. The view is frozen on direct queries + # anyway (_is_filtered=True), so it wasn't auto-refreshing mid-sort regardless. + _SORT_FUNCS = {"df.sort_values", "df.sort_index", "df.sort_view_slice"} + is_sort_only = bool(operations) and all( + op.get("function") in _SORT_FUNCS for op in operations) + if not is_sort_only: + self._slowUpdateInternals(force=True) # Refresh internals before applying non-sort operations # Work on a copy to allow concurrent readers to see a consistent state - df = self._all_datasets_df # Remove copy because memory waste and slowdown + df = self._all_datasets_df # Remove copy because memory waste and slowdown messages = [] for op in operations: @@ -3300,7 +3464,7 @@ def ApplyDataQuery(self, request, context): logger.info(f"[ApplyDataQuery] BYPASSING AGENT - Direct DataFrame operation: {request.query[:100]}...") with self._lock: - self._all_datasets_df, message = execute_df_operation(self._all_datasets_df, request.query) # in-place operation, or replace previous dataframe + self._all_datasets_df, message = execute_df_operation(self._all_datasets_df, request.query) # in-place operation, or replace previous dataframe logger.info(f"[ApplyDataQuery] Executed direct DataFrame operation. Message: {message}") if operations: @@ -3316,8 +3480,8 @@ def ApplyDataQuery(self, request, context): logger.info(f"[ApplyDataQuery] BYPASSING AGENT - Direct reset/clear operation: {request.query[:100]}...") # Force view reset with self._lock: - self._is_filtered = False # Unfreeze view first - self._slowUpdateInternals(force=True) # Force update to ensure we have the latest data + self._is_filtered = False # Unfreeze view first + self._slowUpdateInternals(force=True) # Force update to ensure we have the latest data logger.info(f"[ApplyDataQuery] Force view reset and unfrozen.") return pb2.DataQueryResponse( @@ -3367,7 +3531,7 @@ def status_cb(msg: str): if self._all_datasets_df is None: self._all_datasets_df = self._pull_into_all_data_view_df() or pd.DataFrame() - df = self._all_datasets_df # .copy() # Remove copy because memory waste and slowdown + df = self._all_datasets_df # .copy() # Remove copy because memory waste and slowdown messages = [] intent_type = pb2.INTENT_FILTER analysis_result = "" @@ -3440,8 +3604,74 @@ def GetDataSamples(self, request, context): data_records=[] ) - # Streamed chunk size for GetPointCloud (raw float32 bytes per message). - _POINT_CLOUD_CHUNK_BYTES = 1 << 20 # 1 MiB + def GetHistogram(self, request, context): + """Server-side histogram binning of one column (typed RPC). + + Bins the current all-data view by ROW ORDER into <= max_bins equal- + population bins; each bin carries {min,max,avg,count} over its finite + values plus a per-(origin,discarded) sub-bar breakdown. Returns typed + HistogramBin messages (no DataStat name-encoding). Empty bins are emitted + with count=0 and NaN stats so the client's positional bars stay aligned. + """ + try: + column = request.column or "" + max_bins = int(request.max_bins) if request.max_bins > 0 else 512 + df = getattr(self, "_all_datasets_df", None) + if df is None or df.empty: + return pb2.HistogramResponse( + success=False, message="empty dataframe view", total_rows=0, bins=[]) + df = safe_reset_index(df) + n = len(df) + if column not in df.columns: + return pb2.HistogramResponse( + success=False, message=f"column '{column}' not in view", + total_rows=n, bins=[]) + + bars = max(1, min(n, max_bins)) + vals = pd.to_numeric(df[column], errors="coerce").to_numpy() + origin = (df["origin"].astype(str).to_numpy() if "origin" in df.columns + else np.full(n, "")) + disc = (df["discarded"].astype(bool).to_numpy() if "discarded" in df.columns + else np.zeros(n, bool)) + edges = (np.arange(bars + 1) * n) // bars + bin_of_row = np.searchsorted(edges, np.arange(n), side="right") - 1 + fin = np.isfinite(vals) + gf = pd.DataFrame({"b": bin_of_row[fin], "v": vals[fin], + "o": origin[fin], "d": disc[fin]}) + stats = gf.groupby("b")["v"].agg(["min", "max", "mean", "count"]) + sub_by_bin = {} + for (b, d, o), c in gf.groupby(["b", "d", "o"]).size().items(): + sub_by_bin.setdefault(int(b), []).append( + pb2.HistogramSubBar(origin=str(o), discarded=bool(d), count=int(c))) + have = stats.index.to_numpy() + mn, mx, av, cn = (stats["min"].to_numpy(), stats["max"].to_numpy(), + stats["mean"].to_numpy(), stats["count"].to_numpy()) + pos = {int(b): i for i, b in enumerate(have)} + _nan = float("nan") + bins = [] + for b in range(bars): + i = pos.get(b) + if i is None: + bins.append(pb2.HistogramBin( + min=_nan, max=_nan, avg=_nan, count=0, sub_bars=[])) + else: + bins.append(pb2.HistogramBin( + min=float(mn[i]), max=float(mx[i]), avg=float(av[i]), + count=int(cn[i]), sub_bars=sub_by_bin.get(b, []))) + logger.info("[HistBin] column=%s rows=%d bins=%d", column, n, len(bins)) + return pb2.HistogramResponse( + success=True, + message=f"histogram {column}: {len(bins)} bins from {n} rows", + total_rows=n, bins=bins) + except Exception as e: + logger.error("Error in GetHistogram: %s", str(e), exc_info=True) + return pb2.HistogramResponse( + success=False, message=f"histogram failed: {str(e)}", + total_rows=0, bins=[]) + + # Streamed chunk size for GetPointCloud (raw float32 bytes per message), + # configurable via the WL_POINT_CLOUD_CHUNK_BYTES env var (default 1 MiB). + _POINT_CLOUD_CHUNK_BYTES = _point_cloud_chunk_bytes() def GetPointCloud(self, request, context): """Stream one sample's raw point cloud as binary float32 chunks. diff --git a/weightslab/trainer/services/experiment_service.py b/weightslab/trainer/services/experiment_service.py index 436997b5..1ad86ff4 100644 --- a/weightslab/trainer/services/experiment_service.py +++ b/weightslab/trainer/services/experiment_service.py @@ -14,6 +14,7 @@ from weightslab.backend.ledgers import set_hyperparam, list_hyperparams, resolve_hp_name, get_hyperparams from weightslab.backend import ledgers from weightslab.backend.audit_logger import AuditLogger +from weightslab.backend.explore_mode import is_explore_mode, EXPLORE_BLOCKED_MESSAGE from weightslab.trainer.services.model_service import ModelService from weightslab.trainer.services.data_service import DataService from weightslab.trainer.services.agent_service import AgentService @@ -176,7 +177,7 @@ def _get_latest_logger_data_impl(self, request, context): self._ctx.ensure_components() components = self._ctx.components signal_logger = components.get("signal_logger") - if signal_logger == None: + if signal_logger == None: return pb2.GetLatestLoggerDataResponse(points=[]) # Drop the request early if the client already disconnected @@ -238,7 +239,7 @@ def _get_latest_logger_data_impl(self, request, context): # WL_MAX_POINTS_PER_SAMPLE bounds points per returned curve (endpoints kept). max_points = _max_points_per_sample() for exp_hash, series in per_hash.items(): - series.sort(key=lambda sv: sv[0]) # order by step (model age) + series.sort(key=lambda sv: sv[0]) # order by step (model age) series = _downsample_uniform(series, max_points) audit = exp_hash in eval_hashes for step, mean_val in series: @@ -249,7 +250,7 @@ def _get_latest_logger_data_impl(self, request, context): metric_value=mean_val, experiment_hash=exp_hash, timestamp=now, - sample_id="", # aggregated mean curve — not a single sample + sample_id="", # aggregated mean curve — not a single sample audit_mode=audit, ) ) @@ -320,7 +321,7 @@ def _get_latest_logger_data_impl(self, request, context): metric_value=s.get("metric_value", 0.0), experiment_hash=s.get("experiment_hash", "N.A."), timestamp=int(s.get("timestamp", time.time())), - sample_id="", # No sample_id in aggregated mode + sample_id="", # No sample_id in aggregated mode is_evaluation_marker=bool(s.get("is_evaluation_marker", False)), split_name=str(s.get("split_name", "")), evaluation_tags=[str(tag) for tag in s.get("evaluation_tags", []) or []], @@ -346,7 +347,7 @@ def _get_latest_logger_data_impl(self, request, context): metric_value=s.get("metric_value", 0.0), experiment_hash=s.get("experiment_hash", "N.A."), timestamp=int(s.get("timestamp", time.time())), - sample_id="", # No sample_id in queue mode + sample_id="", # No sample_id in queue mode is_evaluation_marker=bool(s.get("is_evaluation_marker", False)), split_name=str(s.get("split_name", "")), evaluation_tags=[str(tag) for tag in s.get("evaluation_tags", []) or []], @@ -364,6 +365,8 @@ def RestoreCheckpoint(self, request, context): - Calls checkpoint manager to load the state - Returns success flag and message """ + if is_explore_mode(): + return pb2.RestoreCheckpointResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) try: raw_experiment_hash = request.experiment_hash experiment_hash = raw_experiment_hash @@ -417,11 +420,11 @@ def RestoreCheckpoint(self, request, context): load_weights=True, load_config=True, load_data=True, - load_logger=False, # Don't load logger for weights-only restore to avoid overwriting signals, + load_logger=False, # Don't load logger for weights-only restore to avoid overwriting signals, target_step=target_step, ) else: - success = checkpoint_manager.load_state(experiment_hash, load_logger=False) # Don't load logger for full restore to avoid overwriting signals already in memory + success = checkpoint_manager.load_state(experiment_hash, load_logger=False) # Don't load logger for full restore to avoid overwriting signals already in memory # Reply if success: @@ -477,6 +480,10 @@ def TriggerEvaluation(self, request, context): This stores the evaluation request in the global eval_controller. The actual pass runs in the training thread via ``run_pending_evaluation()``. """ + # No training thread runs in explore mode, so evaluation can't execute. + if is_explore_mode(): + return pb2.TriggerEvaluationResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) + split_name = request.split_name or "" tags = list(request.tags) if request.tags else [] use_full_set = bool(request.use_full_set) @@ -596,13 +603,158 @@ def _get_live_hyper_parameter_descs(self, components): ) return hyper_parameter_descs + def _handle_save_checkpoint(self, save_op, components, context): + """Pause training and force-dump the current model weights. + + Implements the manual "Save weights" button. Training is paused *first* + (so the weights are captured at a consistent point, between training + steps) and only then are the latest weights written to a checkpoint — + optionally with the optimizer state and/or a fresh architecture dump, + per the ``SaveCheckpointOperation`` flags. + + If no model is registered there is nothing to dump: we return a clear + failure *without* disrupting the run (training is left untouched). + + Returns a ``pb2.CommandResponse``. + """ + save_optimizer = bool(getattr(save_op, "save_optimizer", False)) + save_architecture = bool(getattr(save_op, "save_architecture", False)) + audit_details = { + "save_optimizer": save_optimizer, + "save_architecture": save_architecture, + } + + # 1) Resolve the model first. Without a registered model there is nothing + # to dump — fail early so we don't needlessly pause a running experiment. + model = components.get("model") if components else None + if model is None: + model = ledgers.get_model() + if model is None: + msg = ( + "No model registered — nothing to dump. Register one with " + "watch_or_edit(model, flag='model')." + ) + logger.warning("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=msg) + return pb2.CommandResponse(success=False, message=msg) + + # 2) Resolve the checkpoint manager (component cache, then ledger fallback). + checkpoint_manager = components.get("checkpoint_manager") if components else None + if checkpoint_manager is None: + checkpoint_manager = ledgers.get_checkpoint_manager() + if checkpoint_manager is None: + msg = "Checkpoint manager not initialized" + logger.warning("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=msg) + return pb2.CommandResponse(success=False, message=msg) + + # 3) Pause training before dumping. Acquire the global lock so any + # in-flight training step has finished; clearing is_training keeps the + # loop parked at its next pause point, so the subsequent save reads a + # consistent model state. + if not try_acquire_rlock(): + logger.error( + "[SaveCheckpoint] weightslab_rlock timed out after %.0fs", + _GRPC_LOCK_TIMEOUT_S, + ) + if context is not None: + context.abort( + grpc.StatusCode.RESOURCE_EXHAUSTED, + f"Training lock not acquired within {_GRPC_LOCK_TIMEOUT_S:.0f}s", + ) + return pb2.CommandResponse(success=False, message="Lock timeout") + try: + trainer = components.get("trainer") if components else None + if trainer is not None: + logger.info("[SaveCheckpoint] Pausing training before weights dump...") + trainer.pause() + hp = components.get("hyperparams") if components else None + if hp is not None: + try: + hp["is_training"] = False + except Exception: + logger.debug("[SaveCheckpoint] Could not set is_training=False", exc_info=True) + finally: + weightslab_rlock.release() + + # 4) Ensure an experiment hash exists so save_model_checkpoint has a + # target directory (a brand-new experiment may not have one yet). + try: + if getattr(checkpoint_manager, "current_exp_hash", None) is None: + if hasattr(checkpoint_manager, "get_current_experiment_hash"): + checkpoint_manager.get_current_experiment_hash() + if ( + getattr(checkpoint_manager, "current_exp_hash", None) is None + and hasattr(checkpoint_manager, "update_experiment_hash") + ): + checkpoint_manager.update_experiment_hash(first_time=True) + except Exception: + logger.debug("[SaveCheckpoint] Could not ensure experiment hash", exc_info=True) + + # 5) Dump the weights (force_dump_pending flushes any pending changes). + try: + checkpoint_path = checkpoint_manager.save_model_checkpoint( + model=model, + save_optimizer=save_optimizer, + force_dump_pending=True, + ) + except Exception as e: + msg = f"Failed to save model weights: {e}" + logger.error("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=str(e)) + return pb2.CommandResponse(success=False, message=msg) + + if checkpoint_path is None: + msg = "Failed to save model weights (no checkpoint produced)." + logger.warning("[SaveCheckpoint] %s", msg) + self._log_audit("checkpoint_save", "failed", audit_details, error=msg) + return pb2.CommandResponse(success=False, message=msg) + + # 6) Optionally dump the architecture as well. + arch_saved = False + if save_architecture: + try: + arch_path = checkpoint_manager.save_model_architecture(model) + arch_saved = arch_path is not None + except Exception: + logger.warning("[SaveCheckpoint] Could not save model architecture", exc_info=True) + + saved = ["weights"] + if save_optimizer: + saved.append("optimizer") + if arch_saved: + saved.append("architecture") + msg = f"Saved {', '.join(saved)} (training paused)." + logger.info("[SaveCheckpoint] %s -> %s", msg, checkpoint_path) + audit_details["path"] = str(checkpoint_path) + audit_details["architecture_saved"] = arch_saved + self._log_audit("checkpoint_save", "success", audit_details) + return pb2.CommandResponse(success=True, message=msg) + # Training & hyperparameter commands # ------------------------------------------------------------------------- def ExperimentCommand(self, request, context): self._ctx.ensure_components() components = self._ctx.components + # Read-only explore mode: refuse the mutating commands that would change + # the model or the training run. Data management (deny/tag operations, + # plot notes) and all read requests below stay allowed. + if is_explore_mode(): + for forbidden in ( + "hyper_parameter_change", + "save_checkpoint_operation", + "load_checkpoint_operation", + ): + if request.HasField(forbidden): + return pb2.CommandResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) + # Write requests + if request.HasField("save_checkpoint_operation"): + return self._handle_save_checkpoint( + request.save_checkpoint_operation, components, context + ) + if request.HasField("plot_note_operation"): note_op = request.plot_note_operation metric_name = str(note_op.metric_name or "") @@ -664,6 +816,105 @@ def ExperimentCommand(self, request, context): ), ) + if request.HasField("save_checkpoint_operation"): + op = request.save_checkpoint_operation + + checkpoint_manager = components.get("checkpoint_manager") if isinstance(components, dict) else None + if checkpoint_manager is None: + try: + checkpoint_manager = ledgers.get_checkpoint_manager() + except Exception: + checkpoint_manager = None + if checkpoint_manager is None: + return pb2.CommandResponse(success=False, message="Checkpoint manager not initialized") + + # Resolve the live model. The ledger may hand back a proxy, so unwrap it + # the same way checkpoint_manager._save_changes does before dumping. + model = components.get("model") if isinstance(components, dict) else None + if model is None: + try: + model = ledgers.get_model() + except Exception: + model = None + if model is not None and callable(getattr(model, "get", None)): + try: + inner = model.get() + if inner is not None: + model = inner + except Exception: + pass + if model is None: + return pb2.CommandResponse(success=False, message="No model available to checkpoint") + + if getattr(checkpoint_manager, "current_exp_hash", None) is None: + return pb2.CommandResponse( + success=False, + message="No experiment hash set yet; run at least one training step before saving.", + ) + + # Snapshot model state under the training lock so we never persist a + # half-applied optimizer step (mirrors the load-checkpoint hygiene). + if not try_acquire_rlock(): + logger.error( + "[ExperimentCommand] save_checkpoint: weightslab_rlock timed out after %.0fs", + _GRPC_LOCK_TIMEOUT_S, + ) + return pb2.CommandResponse( + success=False, + message=f"Training busy: lock not acquired within {_GRPC_LOCK_TIMEOUT_S:.0f}s. Try again.", + ) + + ckpt_path = None + try: + # save_model_architecture is a no-op when the .pkl already exists, + # so delete it first when a forced architecture re-dump is requested. + if op.save_architecture: + try: + h = checkpoint_manager.current_exp_hash[8:-8] + arch_file = checkpoint_manager.models_dir / h / f"{h}_architecture.pkl" + if arch_file.exists(): + arch_file.unlink() + except Exception as e: + logger.warning(f"Could not remove existing architecture file for forced re-dump: {e}") + checkpoint_manager.save_model_architecture(model) + + ckpt_path = checkpoint_manager.save_model_checkpoint( + model=model, + save_optimizer=bool(op.save_optimizer), + save_model_checkpoint=True, + force_dump_pending=True, + ) + except Exception as e: + logger.error(f"Error during manual checkpoint save: {e}") + self._log_audit( + "checkpoint_save", + "failed", + {"save_architecture": bool(op.save_architecture), "save_optimizer": bool(op.save_optimizer)}, + error=str(e), + ) + return pb2.CommandResponse(success=False, message=f"Failed to save checkpoint: {e}") + finally: + weightslab_rlock.release() + + self._log_audit( + "checkpoint_save", + "success" if ckpt_path is not None else "failed", + { + "experiment_hash": checkpoint_manager.current_exp_hash, + "save_architecture": bool(op.save_architecture), + "save_optimizer": bool(op.save_optimizer), + "checkpoint_file": str(ckpt_path) if ckpt_path else None, + }, + ) + return pb2.CommandResponse( + success=ckpt_path is not None, + message=( + f"Saved model weights{' and architecture' if op.save_architecture else ''} for {checkpoint_manager.current_exp_hash}" + if ckpt_path is not None + else "Checkpoint save produced no weights file (weight dumping may be disabled in config)." + ), + ) + if request.HasField("hyper_parameter_change"): hyper_parameters = request.hyper_parameter_change.hyper_parameters @@ -699,12 +950,12 @@ def ExperimentCommand(self, request, context): ) # TODO (GP): Disabled with modelling for now. # if hyper_parameters.HasField("learning_rate"): - # hp_changes["learning_rate"] = hyper_parameters.learning_rate - # set_hyperparam( - # name=hp_name, - # key_path="optimizer.lr", - # value=hyper_parameters.learning_rate - # ) + # hp_changes["learning_rate"] = hyper_parameters.learning_rate + # set_hyperparam( + # name=hp_name, + # key_path="optimizer.lr", + # value=hyper_parameters.learning_rate + # ) if hyper_parameters.HasField("batch_size"): hp_changes["batch_size"] = hyper_parameters.batch_size set_hyperparam( @@ -866,7 +1117,7 @@ def ExperimentCommand(self, request, context): # Set number of steps desired to run before next pause if provided, based on current model age + requested nb_steps if hyper_parameters.HasField("nb_steps"): - m = components.get("model") # Get model + m = components.get("model") # Get model m_age = m.get_age() logger.info(f"\n[WeightsLab] UI Command: Define number of steps at {hyper_parameters.nb_steps}") if hyper_parameters.nb_steps > 0: diff --git a/weightslab/trainer/services/instance_merger.py b/weightslab/trainer/services/instance_merger.py index 66f7efa1..bc582936 100644 --- a/weightslab/trainer/services/instance_merger.py +++ b/weightslab/trainer/services/instance_merger.py @@ -102,7 +102,7 @@ def merge_segmentation_instances(instance_values: List[Any], task_type: str = "s mask1: [[0, 1], [0, 1]] mask2: [[1, 0], [0, 1]] Output: np.max([mask0, mask1, mask2], axis=0) - = [[1, 1], [1, 1]] [MAX aggregated!] + = [[1, 1], [1, 1]] [MAX aggregated!] - Input: [mask0] (single mask) Output: mask0 as-is (512, 512) @@ -133,7 +133,7 @@ def merge_segmentation_instances(instance_values: List[Any], task_type: str = "s # Multiple masks: aggregate using max at each pixel # Stack temporarily for max operation, then return result stacked = np.stack(masks_np, axis=0) - return np.max(stacked, axis=0) # Take max across instances → (H, W) + return np.max(stacked, axis=0) # Take max across instances → (H, W) def merge_classification_instances(instance_values: List[Any], task_type: str = "classification") -> Union[list, None]: @@ -152,7 +152,7 @@ def merge_classification_instances(instance_values: List[Any], task_type: str = - All None: Return None Example: - - Input: ['cat', None, None] → Output: ['cat'] [LIST!] + - Input: ['cat', None, None] → Output: ['cat'] [LIST!] - Input: ['cat', 'dog', 'animal'] → Output: ['cat', 'dog', 'animal'] """ labels = [] @@ -212,9 +212,9 @@ def group_instances_by_sample(df_slice, target_column: str, task_type: str): Returns: Dict mapping sample_id to merged value Example: { - 'sample_0': [bbox0, bbox1, bbox2], # Detection: list of bboxes - 'sample_1': [mask0, mask1], # Segmentation: list of masks - 'sample_2': 'cat', # Classification: single label + 'sample_0': [bbox0, bbox1, bbox2], # Detection: list of bboxes + 'sample_1': [mask0, mask1], # Segmentation: list of masks + 'sample_2': 'cat', # Classification: single label } """ if df_slice.empty or target_column not in df_slice.columns: diff --git a/weightslab/trainer/services/model_service.py b/weightslab/trainer/services/model_service.py index bca63e66..5b9aac9c 100644 --- a/weightslab/trainer/services/model_service.py +++ b/weightslab/trainer/services/model_service.py @@ -10,6 +10,7 @@ from weightslab.trainer.trainer_tools import process_sample, _get_input_tensor_for_sample from weightslab.modules.neuron_ops import ArchitectureNeuronsOpType from weightslab.components.global_monitoring import weightslab_rlock, try_acquire_rlock, _GRPC_LOCK_TIMEOUT_S +from weightslab.backend.explore_mode import is_explore_mode, EXPLORE_BLOCKED_MESSAGE logger = logging.getLogger(__name__) @@ -311,6 +312,10 @@ def hook(mod, inp, out): # Weight manipulation (architecture operations) # ------------------------------------------------------------------------- def ManipulateWeights(self, request, context): + # Read-only explore mode: architecture/weight edits are disabled. + if is_explore_mode(): + return pb2.WeightsOperationResponse(success=False, message=EXPLORE_BLOCKED_MESSAGE) + self._ctx.ensure_components() components = self._ctx.components diff --git a/weightslab/trainer/services/utils/tools.py b/weightslab/trainer/services/utils/tools.py index 3654fabd..11148009 100644 --- a/weightslab/trainer/services/utils/tools.py +++ b/weightslab/trainer/services/utils/tools.py @@ -3,7 +3,7 @@ ======================= Shared utility helpers for the trainer service layer. -Keep this file free of heavy domain logic. It is the right place for: +Keep this file free of heavy domain logic. It is the right place for: - Small, stateless helper functions used by two or more services. - Shared constants / lookup tables (e.g. provider maps). - Thin wrappers that reduce boilerplate inside service methods. diff --git a/weightslab/trainer/trainer_services.py b/weightslab/trainer/trainer_services.py index dfe1ad42..0400c434 100644 --- a/weightslab/trainer/trainer_services.py +++ b/weightslab/trainer/trainer_services.py @@ -282,7 +282,7 @@ def intercept_service(self, continuation, handler_call_details): # --------------------------------------------------------------------------- # Backward-compat note: RpcWatchdogState, RpcTimingAndWatchdogInterceptor and # GrpcServerManager are now defined in weightslab.watchdog.grpc_watchdog and -# re-exported above. External code that imported them from trainer_services +# re-exported above. External code that imported them from trainer_services # continues to work unchanged. # --------------------------------------------------------------------------- @@ -334,6 +334,13 @@ def GetDataSamples(self, request, context): logger.debug(f"\nExperimentServiceServicer.GetDataSamples({request})") return self._exp_service.data_service.GetDataSamples(request, context) + def GetHistogram(self, request, context): + logger.debug(f"\nExperimentServiceServicer.GetHistogram({request})") + return self._exp_service.data_service.GetHistogram(request, context) + def GetMetaData(self, request, context): + logger.debug(f"\nExperimentServiceServicer.GetMetaData({request})") + return self._exp_service.data_service.GetMetaData(request, context) + def GetPointCloud(self, request, context): logger.debug(f"\nExperimentServiceServicer.GetPointCloud({request})") # Server-streaming RPC: delegate the generator directly. @@ -461,12 +468,12 @@ def grpc_serve( grpc_host = os.getenv("GRPC_BACKEND_HOST", "0.0.0.0") if not force_parameters or grpc_host is None else grpc_host grpc_port = int(os.getenv("GRPC_BACKEND_PORT", 50051)) if not force_parameters or grpc_port is None else grpc_port - watchdog_threshold_s = float(os.getenv("GRPC_WATCHDOG_STUCK_SECONDS", "180")) # 3 minutes default stuck threshold + watchdog_threshold_s = float(os.getenv("GRPC_WATCHDOG_STUCK_SECONDS", "180")) # 3 minutes default stuck threshold watchdog_interval_s = float(os.getenv("GRPC_WATCHDOG_INTERVAL_SECONDS", "5")) watchdog_exit_on_stuck = str(os.getenv("GRPC_WATCHDOG_EXIT_ON_STUCK", "0")).strip().lower() in {"1", "true", "yes", "on"} - watchdog_restart_threshold = int(os.getenv("GRPC_WATCHDOG_RESTART_THRESHOLD", "3")) # Restart after 3 unhealthy checks + watchdog_restart_threshold = int(os.getenv("GRPC_WATCHDOG_RESTART_THRESHOLD", "3")) # Restart after 3 unhealthy checks watchdog_details_limit = int(os.getenv("GRPC_WATCHDOG_INFLIGHT_DETAILS_LIMIT", "10")) - watchdog_disabled = str(os.getenv("WEIGHTSLAB_DISABLE_WATCHDOGS", "1")).strip().lower() in {"1", "true", "yes", "on"} # Default state: disabled + watchdog_disabled = str(os.getenv("WEIGHTSLAB_DISABLE_WATCHDOGS", "1")).strip().lower() in {"1", "true", "yes", "on"} # Default state: disabled config = get_hyperparams() grpc_tls_enabled = _resolve_bool_setting(config, "grpc_tls_enabled", "GRPC_TLS_ENABLED", "0") grpc_tls_key_file = _resolve_grpc_tls_path( @@ -521,7 +528,7 @@ def grpc_serve( ) watchdog.register_lock("weightslab_rlock", weightslab_rlock) - # Eval thread monitor — no timeout, just liveness. Lazy imports avoid + # Eval thread monitor — no timeout, just liveness. Lazy imports avoid # circular dependencies since weightslab.src imports trainer code. def _get_eval_controller(): from weightslab.components.evaluation_controller import eval_controller as _ec @@ -535,8 +542,8 @@ def _get_eval_thread(): get_controller=_get_eval_controller, get_thread=_get_eval_thread, ) - watchdog_state = watchdog.rpc_state # shared with RpcTimingAndWatchdogInterceptor - server_manager = watchdog.server_manager # shared with serving_thread_callback + watchdog_state = watchdog.rpc_state # shared with RpcTimingAndWatchdogInterceptor + server_manager = watchdog.server_manager # shared with serving_thread_callback logger.debug( f"grpc_serve called with parameters: n_workers_grpc={n_workers_grpc}, grpc_host={grpc_host}, grpc_port={grpc_port}, " f"watchdog_threshold_s={watchdog_threshold_s}, watchdog_interval_s={watchdog_interval_s}, watchdog_exit_on_stuck={watchdog_exit_on_stuck}, watchdog_restart_threshold={watchdog_restart_threshold}, " @@ -546,13 +553,13 @@ def _get_eval_thread(): def serving_thread_callback(): logger.info("[gRPC] Thread callback started") try: - while True: # Loop to allow restarts + while True: # Loop to allow restarts _effective_workers = n_workers_grpc or min(32, (os.cpu_count() or 1) + 4) logger.info( "[gRPC] Creating ThreadPoolExecutor with %d worker threads (n_workers_grpc=%s, max_concurrent_rpcs=%s)", _effective_workers, n_workers_grpc, max_concurrent_rpcs, ) - _max_msg = int(os.getenv("GRPC_MAX_MESSAGE_BYTES", 256 * 1024 * 1024)) # 256 MB + _max_msg = int(os.getenv("GRPC_MAX_MESSAGE_BYTES", 256 * 1024 * 1024)) # 256 MB server = grpc.server( futures.ThreadPoolExecutor( thread_name_prefix="WL-gRPC-Worker", @@ -629,16 +636,16 @@ def serving_thread_callback(): while not server_manager.should_restart(): time.sleep(0.5) - logger.watchdog("[gRPC] Restart requested. Gracefully shutting down (5s grace)...") # type: ignore[attr-defined] + logger.watchdog("[gRPC] Restart requested. Gracefully shutting down (5s grace)...") # type: ignore[attr-defined] stop_event = server.stop(grace=5) stopped = stop_event.wait(timeout=6.0) if not stopped: - logger.watchdog("[gRPC] Graceful stop timed out; forcing immediate stop.") # type: ignore[attr-defined] + logger.watchdog("[gRPC] Graceful stop timed out; forcing immediate stop.") # type: ignore[attr-defined] server.stop(grace=0).wait(timeout=1.0) cleared = watchdog_state.clear_for_restart() if cleared: - logger.watchdog("[gRPC] Cleared %d stale in-flight RPC records after restart.", cleared) # type: ignore[attr-defined] + logger.watchdog("[gRPC] Cleared %d stale in-flight RPC records after restart.", cleared) # type: ignore[attr-defined] server_manager.clear_restart_request() logger.info("[gRPC] Server stopped. Restarting in 2s...") time.sleep(2) diff --git a/weightslab/trainer/trainer_tools.py b/weightslab/trainer/trainer_tools.py index e017a279..f81e7f9b 100644 --- a/weightslab/trainer/trainer_tools.py +++ b/weightslab/trainer/trainer_tools.py @@ -59,7 +59,7 @@ def execute_df_operation(df, operation_str): except Exception as e: error_msg = f"Error executing DataFrame operation '{operation_str}': {e}" logger.error(error_msg, exc_info=True) - return df, error_msg # Return original df on error + return df, error_msg # Return original df on error def get_hyper_parameters_pb( @@ -82,7 +82,7 @@ def get_hyper_parameters_pb( # For numerical values, ensure we pass a float to gRPC to avoid "must be real number" errors try: if hasattr(value, "get"): - value = value.get() # unwrap if it's a wrapper object + value = value.get() # unwrap if it's a wrapper object if value is None or value == None: num_val = 'null' type_ = 'string' @@ -241,7 +241,7 @@ def _labels_from_mask_path_histogram(path, num_classes=None, ignore_index=255): with Image.open(path) as im: if im.mode not in ("P", "L"): im = im.convert("L") - hist = im.histogram() # length 256 + hist = im.histogram() # length 256 ub = 256 if num_classes is None else int(num_classes) ids = [i for i, cnt in enumerate(hist[:ub]) if cnt > 0] if ignore_index is not None: @@ -313,7 +313,7 @@ def _safe_dataset_length(ds): sample_stats.task_type = task_type ignore_index = getattr(dataset, "ignore_index", 255) - num_classes = getattr(dataset, "num_classes", getattr(experiment, "num_classes", None)) + num_classes = getattr(dataset, "num_classes", getattr(experiment, "num_classes", None)) # Safely iterate dataset records; if as_records isn't available or dataset is a placeholder # fall back to an empty iterator. @@ -343,9 +343,9 @@ def _safe_dataset_length(ds): pred_list = _class_ids(row.get("prediction_raw"), num_classes, ignore_index) else: target = row.get("label", row.get("target", -1)) - pred = row.get("prediction_raw", -1) + pred = row.get("prediction_raw", -1) target_list = [int(target)] if not isinstance(target, (list, np.ndarray)) else [int(np.array(target).item())] - pred_list = [int(pred)] if not isinstance(pred, (list, np.ndarray)) else [int(np.array(pred).item())] + pred_list = [int(pred)] if not isinstance(pred, (list, np.ndarray)) else [int(np.array(pred).item())] record.sample_label.extend(target_list) record.sample_prediction.extend(pred_list) @@ -526,18 +526,18 @@ def encode_image_to_raw_bytes( - All other cases (thumbnails, 2D): PIL image compressed to WebP (JPEG fallback). Args: - np_img: Numpy array of the image (required for the volumetric path). - middle_pil: PIL Image (required for the 2D / thumbnail path). - original_shape: Original tensor shape, used to derive [Z, H, W, C] for volumetric. - is_volumetric: True when the image has a depth (Z) dimension. + np_img: Numpy array of the image (required for the volumetric path). + middle_pil: PIL Image (required for the 2D / thumbnail path). + original_shape: Original tensor shape, used to derive [Z, H, W, C] for volumetric. + is_volumetric: True when the image has a depth (Z) dimension. is_full_resolution: True when sending the full modal view, False for grid thumbnails. - target_width: Width of the (possibly resized) output image. - target_height: Height of the (possibly resized) output image. + target_width: Width of the (possibly resized) output image. + target_height: Height of the (possibly resized) output image. Returns: raw_data_bytes: Encoded bytes ready for gRPC transfer. - raw_shape: [Z, H, W, C] or [H, W, C] shape of the encoded data. - encode_time_s: Seconds spent encoding (0.0 for the raw float32 path). + raw_shape: [Z, H, W, C] or [H, W, C] shape of the encoded data. + encode_time_s: Seconds spent encoding (0.0 for the raw float32 path). """ raw_data_bytes: bytes = b"" raw_shape: list = [] @@ -549,16 +549,16 @@ def encode_image_to_raw_bytes( if not np_img_f32.flags['C_CONTIGUOUS']: np_img_f32 = np.ascontiguousarray(np_img_f32) raw_data_bytes = np_img_f32.tobytes() - del np_img_f32 # release float32 copy immediately + del np_img_f32 # release float32 copy immediately # Normalise shape to [Z, H, W, C] from the original 4-D tensor. if len(original_shape) == 4: if original_shape[1] > original_shape[-1]: - raw_shape = list(original_shape) # already [Z, H, W, C] + raw_shape = list(original_shape) # already [Z, H, W, C] elif original_shape[1] < original_shape[-1]: - raw_shape = [original_shape[0], original_shape[2], original_shape[3], original_shape[1]] # [Z, C, H, W] -> [Z, H, W, C] + raw_shape = [original_shape[0], original_shape[2], original_shape[3], original_shape[1]] # [Z, C, H, W] -> [Z, H, W, C] else: - raw_shape = [original_shape[0], original_shape[1], original_shape[2], 1] # ambiguous: assume single channel + raw_shape = [original_shape[0], original_shape[1], original_shape[2], 1] # ambiguous: assume single channel logger.info( "[Volumetric] Sending full res: np_img.shape=%s, original_shape=%s, raw_shape=%s, bytes=%d", np_img.shape, original_shape, raw_shape, len(raw_data_bytes), @@ -567,7 +567,7 @@ def encode_image_to_raw_bytes( # Thumbnail (grid) or non-volumetric: compress with WebP, fall back to JPEG. # WebP is ~40-50 % smaller than JPEG at equivalent visual quality. _quality = 80 if is_full_resolution else 65 - _webp_method = 4 if is_full_resolution else 2 # 0 = fastest … 6 = smallest + _webp_method = 4 if is_full_resolution else 2 # 0 = fastest … 6 = smallest raw_buf = io.BytesIO() t0_enc = time.time() try: diff --git a/weightslab/ui_docker_bridge.py b/weightslab/ui_docker_bridge.py index 343b55c3..028a1d2c 100644 --- a/weightslab/ui_docker_bridge.py +++ b/weightslab/ui_docker_bridge.py @@ -56,7 +56,7 @@ def _persist_certs_dir(certs_dir_str: str) -> None: """Persist WEIGHTSLAB_CERTS_DIR so future terminals and the training backend find it. - Windows — runs `setx` (permanent user env) and prints the PS one-liner for + Windows — runs `setx` (permanent user env) and prints the PS one-liner for the current session. Linux/macOS — appends an export line to ~/.bashrc (idempotent) and prints the source command for the current session. @@ -68,10 +68,10 @@ def _persist_certs_dir(certs_dir_str: str) -> None: stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) if result.returncode == 0: - logger.info("✓ WEIGHTSLAB_CERTS_DIR saved permanently via setx (new terminals will have it)") + logger.info(" WEIGHTSLAB_CERTS_DIR saved permanently via setx (new terminals will have it)") else: logger.warning(f"setx failed — set it manually: setx WEIGHTSLAB_CERTS_DIR \"{certs_dir_str}\"") - logger.info(f" Current terminal (PowerShell): $env:WEIGHTSLAB_CERTS_DIR = \"{certs_dir_str}\"") + logger.info(f" Current terminal (PowerShell): $env:WEIGHTSLAB_CERTS_DIR = \"{certs_dir_str}\"") else: bashrc = Path.home() / ".bashrc" try: @@ -79,13 +79,13 @@ def _persist_certs_dir(certs_dir_str: str) -> None: if export_line not in existing: with open(bashrc, "a", encoding="utf-8") as f: f.write(f"\n# Added by weightslab\n{export_line}\n") - logger.info(f"✓ WEIGHTSLAB_CERTS_DIR appended to {bashrc} (new terminals will have it)") + logger.info(f" WEIGHTSLAB_CERTS_DIR appended to {bashrc} (new terminals will have it)") else: - logger.info(f"✓ WEIGHTSLAB_CERTS_DIR already in {bashrc}") + logger.info(f" WEIGHTSLAB_CERTS_DIR already in {bashrc}") except OSError as e: logger.warning(f"Could not write to {bashrc}: {e}") - logger.info(f" Add manually: {export_line}") - logger.info(f" Current terminal: source ~/.bashrc (or open a new terminal)") + logger.info(f" Add manually: {export_line}") + logger.info(f" Current terminal: source ~/.bashrc (or open a new terminal)") def _strip_derived_deploy_env() -> None: @@ -113,40 +113,40 @@ def _banner() -> str: _EPILOG = """\ commands: - se Set up the secure environment: generate TLS + se Set up the secure environment: generate TLS certificates + a gRPC auth token in ~/.weightslab-certs. Then set WEIGHTSLAB_CERTS_DIR (the single source of truth) so the backend + new shells find them. - --force-certs regenerate even if certs exist + --force-certs regenerate even if certs exist - ui launch Purge stale weightslab/weights_studio Docker + ui launch Purge stale weightslab/weights_studio Docker resources, then build & start the UI stack. UNSECURED (HTTP) by default — no certs generated. - --certs generate (if missing) + use TLS + --certs generate (if missing) + use TLS certs + gRPC auth (HTTPS) - start example Run a bundled PyTorch example (foreground; stop with + start example Run a bundled PyTorch example (foreground; stop with Ctrl+C). Installs the example's requirements first, without prompting. Defaults to classification: - --cls classification example (default) - --seg segmentation example - --det detection example - --clus clustering example - --gen generation example - --3d_det 3D LiDAR point-cloud detection example - --2d_det 2D LiDAR point-cloud detection example + --cls classification example (default) + --seg segmentation example + --det detection example + --clus clustering example + --gen generation example + --3d_det 3D LiDAR point-cloud detection example + --2d_det 2D LiDAR point-cloud detection example examples: - weightslab se # one-time secure setup (then export WEIGHTSLAB_CERTS_DIR) - weightslab se --force-certs # regenerate the certs - weightslab ui launch # clean + launch (unsecured HTTP, default) - weightslab ui launch --certs # secured launch (HTTPS + gRPC auth) - weightslab start example # run the classification demo (default) - weightslab start example --seg # run the segmentation demo - weightslab start example --det # run the detection demo - weightslab start example --3d_det # run the 3D LiDAR detection demo - weightslab start example --2d_det # run the 2D LiDAR detection demo + weightslab se # one-time secure setup (then export WEIGHTSLAB_CERTS_DIR) + weightslab se --force-certs # regenerate the certs + weightslab ui launch # clean + launch (unsecured HTTP, default) + weightslab ui launch --certs # secured launch (HTTPS + gRPC auth) + weightslab start example # run the classification demo (default) + weightslab start example --seg # run the segmentation demo + weightslab start example --det # run the detection demo + weightslab start example --3d_det # run the 3D LiDAR detection demo + weightslab start example --2d_det # run the 2D LiDAR detection demo """ @@ -329,7 +329,7 @@ def _compose_cmd(compose_file, envoy_config, action): # locally), then `up` without the flag. v2 supports it inline, so leave it. if base == ["docker-compose"] and action and action[0] == "up" and "--pull" in action: i = action.index("--pull") - del action[i:i + 2] # drop '--pull' and its policy value (e.g. 'always') + del action[i:i + 2] # drop '--pull' and its policy value (e.g. 'always') logger.info("Docker Compose v1 detected — pulling images before 'up'...") pull_result = subprocess.run( base + ["-f", str(compose_file), "pull"], @@ -452,7 +452,7 @@ def _run_shell_script(script_path: str, args: list = None, env_vars: dict = None # Build bash command - pass Windows path directly, script will handle conversion # # Process path to ensure it's compatible with bash, especially on Windows if _is_windows() and '\\' in script_path: - script_path = script_path.replace("\\", "/") # Ensure path is Unix-style for bash + script_path = script_path.replace("\\", "/") # Ensure path is Unix-style for bash script_path = _convert_to_git_bash_path(script_path) logger.info(f"Converted script path for bash: {script_path}") logger.info(f"Running shell script: {script_path} with args: {args} and env_vars: {env_vars}") @@ -550,8 +550,8 @@ def _install_ca_trust(ca_file: Path) -> None: Idempotent and safe to call on every launch. Platform behavior: * Windows — adds to the CurrentUser\\Root store via the .NET X509Store API (silent, no prompt). - * macOS — adds to the login keychain (may show a one-time auth prompt). - * Linux — installs into the system trust store via sudo (one-time prompt) + * macOS — adds to the login keychain (may show a one-time auth prompt). + * Linux — installs into the system trust store via sudo (one-time prompt) and, best-effort, the user's NSS DB so Chrome/Firefox trust it too. A failure here is non-fatal: TLS still works, the browser just shows a @@ -584,7 +584,7 @@ def _install_ca_trust(ca_file: Path) -> None: stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if result.returncode == 0: - logger.info("✓ Dev CA trusted in Windows CurrentUser\\Root store (restart browser to apply)") + logger.info(" Dev CA trusted in Windows CurrentUser\\Root store (restart browser to apply)") else: logger.warning(f"Could not auto-trust dev CA: {result.stderr.strip()}") return @@ -595,7 +595,7 @@ def _install_ca_trust(ca_file: Path) -> None: stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) if check.returncode == 0: - logger.info("✓ Dev CA already trusted (macOS keychain)") + logger.info(" Dev CA already trusted (macOS keychain)") return logger.info("Installing dev CA into macOS login keychain (may prompt)...") subprocess.run( @@ -615,7 +615,7 @@ def _install_ca_trust(ca_file: Path) -> None: subprocess.run(["sudo", "cp", str(ca_file), str(system_ca)]) subprocess.run(["sudo", "update-ca-certificates"]) else: - logger.info("✓ Dev CA already in Linux system trust store") + logger.info(" Dev CA already in Linux system trust store") # Browsers use their own NSS DB; add it there too if certutil is available. if shutil.which("certutil"): @@ -636,7 +636,7 @@ def _ensure_certificates(manager: CertAuthManager, force_certs: bool = False) -> truth). Returns True if certs are present afterwards, False otherwise. """ if manager.has_any_credentials() and not force_certs: - logger.info(f"✓ Using existing credentials in {manager.certs_dir}") + logger.info(f" Using existing credentials in {manager.certs_dir}") manager.get_or_create_auth_token() # Ensure the CA is trusted even when reusing certs from a prior run that # was generated via bash (which does not install OS trust). @@ -656,7 +656,7 @@ def _ensure_certificates(manager: CertAuthManager, force_certs: bool = False) -> manager.get_or_create_auth_token() _install_ca_trust(manager.ca_file) - logger.info(f"✓ Certificates ready in {manager.certs_dir}") + logger.info(f" Certificates ready in {manager.certs_dir}") return manager.has_valid_certs() @@ -724,10 +724,10 @@ def ui_launch(args): (file presence is the single source of truth) and are never deleted here. Flags (all optional, read defensively so legacy callers still work): - --certs generate (if missing) and use TLS certs + gRPC auth (HTTPS) - --force-certs with --certs, regenerate certificates even if they exist - --no-clean skip the stale Docker resource cleanup step - --dev use the dev compose overlay + --certs generate (if missing) and use TLS certs + gRPC auth (HTTPS) + --force-certs with --certs, regenerate certificates even if they exist + --no-clean skip the stale Docker resource cleanup step + --dev use the dev compose overlay """ _check_docker() # pip installs the bundled .sh scripts without the execute bit; make them @@ -869,10 +869,10 @@ def ui_launch(args): # The backend and any new shell must point at the same certs dir, or # they'll mismatch the UI's TLS/auth. Keep this the last thing printed. logger.warning("") - logger.warning("⚠ ACTION REQUIRED — TLS is ON. Set WEIGHTSLAB_CERTS_DIR so the " + logger.warning(" ACTION REQUIRED — TLS is ON. Set WEIGHTSLAB_CERTS_DIR so the " "training backend and new terminals use the same certificates:") - logger.warning(f" (bash) export WEIGHTSLAB_CERTS_DIR=\"{certs_dir_str}\"") - logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{certs_dir_str}\"") + logger.warning(f" (bash) export WEIGHTSLAB_CERTS_DIR=\"{certs_dir_str}\"") + logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{certs_dir_str}\"") else: logger.info("UI is running UNSECURED (HTTP, no gRPC auth). " "Re-run with `weightslab ui launch --certs` for TLS.") @@ -913,17 +913,17 @@ def ui_secure_environment(args): # Export ONLY the single source of truth for this process. os.environ["WEIGHTSLAB_CERTS_DIR"] = str(manager.certs_dir) - logger.info("✓ Certificates generated successfully") - logger.info("✓ gRPC auth token created") - logger.info(f"✓ Certs and token stored in: {manager.certs_dir}") - logger.info(f"✓ WEIGHTSLAB_CERTS_DIR exported for this process: {manager.certs_dir}") + logger.info(" Certificates generated successfully") + logger.info(" gRPC auth token created") + logger.info(f" Certs and token stored in: {manager.certs_dir}") + logger.info(f" WEIGHTSLAB_CERTS_DIR exported for this process: {manager.certs_dir}") logger.info("Then launch the secured UI with: weightslab ui launch --certs") # Keep this the FINAL output so the user can't miss the action they must take. logger.warning("") - logger.warning("⚠ ACTION REQUIRED — set WEIGHTSLAB_CERTS_DIR globally so new shells " + logger.warning(" ACTION REQUIRED — set WEIGHTSLAB_CERTS_DIR globally so new shells " "and the training backend find these certs (single source of truth):") - logger.warning(f" (bash) echo 'export WEIGHTSLAB_CERTS_DIR=\"{manager.certs_dir}\"' >> ~/.bashrc && source ~/.bashrc") - logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{manager.certs_dir}\"") + logger.warning(f" (bash) echo 'export WEIGHTSLAB_CERTS_DIR=\"{manager.certs_dir}\"' >> ~/.bashrc && source ~/.bashrc") + logger.warning(f" (Windows) setx WEIGHTSLAB_CERTS_DIR \"{manager.certs_dir}\"") # Bundled PyTorch examples, keyed by the CLI flag (e.g. --cls -> ws-classification). @@ -970,7 +970,7 @@ def _install_example_requirements(example_dir: Path) -> None: f"Failed to install requirements ({req}): {exc}. " "Continuing — the example may still run if deps are already installed." ) - return # only the first matching requirements file is used + return # only the first matching requirements file is used def example_start(args): @@ -994,7 +994,7 @@ def example_start(args): _install_example_requirements(example_dir) logger.info(f"Starting the WeightsLab {label} ({kind}) example...") - logger.info(f" {main_py}") + logger.info(f" {main_py}") logger.info("In another terminal, launch the UI with: weightslab ui launch") logger.info(f"Then open http://localhost:5173 — stop the example with Ctrl+C.") if not _CERTS_DIR_IN_ORIGINAL_ENV: @@ -1017,6 +1017,71 @@ def example_start(args): sys.exit(result.returncode) +def logdir_explore(args): + """`weightslab logdir [--no-ui]`: offline explore mode for a downloaded log dir. + + Loads the experiment from disk into a read-only ledger (no training script, + GPU, or original dataset required), starts the gRPC backend server, and — + unless ``--no-ui`` is given — also brings up the Weights Studio Docker UI + stack so the experiment can be browsed immediately. + + Intended workflow:: + + # On your dev machine, after rsync-ing the cluster run: + weightslab logdir ./root_log_dir + + Once running, open http://localhost:5173 (or the URL printed on startup). + Press Ctrl+C to stop the gRPC server (Docker UI keeps running in the + background; stop it separately with ``weightslab ui launch`` or + ``docker compose down`` if needed). + """ + root_log_dir = args.root_log_dir + + if not getattr(args, "no_ui", False): + # Launch the Weights Studio UI stack (Docker + Envoy) the same way + # `weightslab ui launch` does. Build an args namespace that is + # compatible with ui_launch, inheriting relevant flags. + ui_args = argparse.Namespace( + certs=getattr(args, "certs", False), + force_certs=False, + no_clean=False, + no_auth=False, + dev=False, + certs_dir=getattr(args, "certs_dir", None), + ) + ui_launch(ui_args) + + logger.info("Loading experiment from disk: %s", root_log_dir) + # Lazy import: pulls in torch and the full weightslab stack only when this + # command is actually invoked, keeping other commands fast. + try: + from weightslab.src import load_experiment_for_explore, serve, keep_serving + except ImportError as exc: + logger.error("Failed to import weightslab core: %s", exc) + sys.exit(1) + + try: + summary = load_experiment_for_explore( + root_log_dir, + exp_hash=getattr(args, "exp_hash", None), + ) + except FileNotFoundError as exc: + logger.error(str(exc)) + sys.exit(1) + + logger.info("Experiment loaded: hash=%s, origins=%s", + summary.get("experiment_hash"), summary.get("origins")) + + grpc_port = getattr(args, "grpc_port", None) or int(os.getenv("GRPC_BACKEND_PORT", 50051)) + os.environ["GRPC_BACKEND_PORT"] = str(grpc_port) + logger.info("Starting WeightsLab gRPC server on port %d (read-only explore mode)...", grpc_port) + serve(serving_grpc=True) + + logger.info("WeightsLab is running in read-only explore mode.") + logger.info("Open the UI, then press Ctrl+C to stop.") + keep_serving(release_gpu=False) + + def _add_example_kind_flags(p: argparse.ArgumentParser) -> None: """Attach the mutually-exclusive example-kind flags (default: classification).""" group = p.add_mutually_exclusive_group() @@ -1045,6 +1110,7 @@ def _build_parser() -> argparse.ArgumentParser: weightslab se [--force-certs] weightslab ui launch [--certs] weightslab start example [--cls|--seg|--det|--clus|--gen|--3d_det|--2d_det] + weightslab logdir [--no-ui] [--certs] [--grpc-port PORT] """ parser = argparse.ArgumentParser( prog="weightslab", @@ -1054,7 +1120,7 @@ def _build_parser() -> argparse.ArgumentParser: ) # metavar lists only the documented commands; the `example` alias is accepted # but intentionally omitted here (and help=SUPPRESS'd below) so it stays hidden. - sub = parser.add_subparsers(dest="command", metavar="{se,ui,start,help}") + sub = parser.add_subparsers(dest="command", metavar="{se,ui,start,logdir,help}") # weightslab se [--force-certs] [certs_dir] se_parser = sub.add_parser("se", help="Set up the secure environment (TLS certs + gRPC auth token)") @@ -1090,6 +1156,47 @@ def _build_parser() -> argparse.ArgumentParser: sub.add_parser("help", help="Show this help message") + # weightslab logdir [--no-ui] [--certs] [--grpc-port PORT] + logdir_parser = sub.add_parser( + "logdir", + help="Open a finished experiment from disk in read-only explore mode, " + "then serve it through Weights Studio", + ) + logdir_parser.add_argument( + "root_log_dir", + help="Path to the root_log_dir produced by a previous training run", + ) + logdir_parser.add_argument( + "--no-ui", + action="store_true", + help="Skip launching the Weights Studio Docker UI stack " + "(useful when the UI is already running)", + ) + logdir_parser.add_argument( + "--certs", + action="store_true", + help="Generate TLS certs + gRPC auth token if missing, then launch UI secured", + ) + logdir_parser.add_argument( + "--grpc-port", + type=int, + default=None, + metavar="PORT", + help=f"gRPC backend port (default: $GRPC_BACKEND_PORT or 50051)", + ) + logdir_parser.add_argument( + "--exp-hash", + default=None, + metavar="HASH", + help="Specific experiment hash to open (default: latest)", + ) + logdir_parser.add_argument( + "certs_dir", + nargs="?", + default=None, + help="Custom directory for certs/token (default: $WEIGHTSLAB_CERTS_DIR or ~/.weightslab-certs)", + ) + return parser, ui_parser, start_parser @@ -1115,6 +1222,8 @@ def main(): # Alias for `start example` — tolerate the swapped subcommand order # (`weightslab example start [flags]`) and the bare `weightslab example`. example_start(args) + elif args.command == "logdir": + logdir_explore(args) else: parser.print_help() diff --git a/weightslab/utils/computational_graph.py b/weightslab/utils/computational_graph.py index f6754658..34d1699f 100644 --- a/weightslab/utils/computational_graph.py +++ b/weightslab/utils/computational_graph.py @@ -144,11 +144,11 @@ def _generate_mappings( # Case 2: Many-to-one (src > dst) # A "batch" of source neurons maps to a single dstination neuron. # if len(src_channels) % len(dst_channels) != 0: - # raise ValueError( - # f"Source channels ({src_channels}) must be perfectly \ - # divisible by dstination channels ({dst_channels}) \ - # for many-to-one mapping." - # ) + # raise ValueError( + # f"Source channels ({src_channels}) must be perfectly \ + # divisible by dstination channels ({dst_channels}) \ + # for many-to-one mapping." + # ) # 1. Calculate the block size. # This determines how many linear layer neurons map to one convolution channel. @@ -184,7 +184,7 @@ def _generate_mappings( # We map the individual code back to the original index dst_to_src_mapping[code] = [index] - else: # src_channels < dst_channels + else: # src_channels < dst_channels # 1. Calculate the block size. # This determines how many linear layer neurons map to one convolution channel. # We use integer division to ensure a clean split. @@ -359,7 +359,7 @@ def _propagate_constraints_through_dependencies( propagated_constraints[module_id].get('outgoing', {})[cname] = cval # BFS to propagate OUTGOING constraints downstream - queue = [(module_id, native_constraints.copy(), {module_id})] # (current_id, outgoing_constraints, visited) + queue = [(module_id, native_constraints.copy(), {module_id})] # (current_id, outgoing_constraints, visited) while queue: current_id, current_constraints, visited_set = queue.pop(0) @@ -531,7 +531,7 @@ def _alias_from_tensor_name(tensor_name: str) -> Optional[str]: next_part = module_parts[i + 1] # If next part starts with current part + '.', it's redundant if next_part.startswith(part + '.'): - continue # Skip this redundant part + continue # Skip this redundant part deduplicated.append(part) return '.'.join(deduplicated) @@ -712,23 +712,23 @@ def generate_graph_dependencies_from_torchfx( # # SEED NEURONS: Use FX metadata to seed neurons if possible # for mod in make_safelist(current_module): - # if 'tensor_meta' in node.meta: - # meta = node.meta['tensor_meta'] - # if hasattr(meta, 'shape') and len(meta.shape) >= 2: - # out_ch = meta.shape[1] - # if out_ch is not None and out_ch > 0: - # mod.set_neurons('out_neurons', out_ch) - # if getattr(mod, 'wl_same_flag', False): - # mod.set_neurons('in_neurons', out_ch) - - # # Also check inputs to seed in_neurons - # for arg in node.args: - # if isinstance(arg, th.fx.Node) and 'tensor_meta' in arg.meta: - # meta_in = arg.meta['tensor_meta'] - # if hasattr(meta_in, 'shape') and len(meta_in.shape) >= 2: - # in_ch = meta_in.shape[1] - # if in_ch is not None and in_ch > 0: - # mod.set_neurons('in_neurons', in_ch) + # if 'tensor_meta' in node.meta: + # meta = node.meta['tensor_meta'] + # if hasattr(meta, 'shape') and len(meta.shape) >= 2: + # out_ch = meta.shape[1] + # if out_ch is not None and out_ch > 0: + # mod.set_neurons('out_neurons', out_ch) + # if getattr(mod, 'wl_same_flag', False): + # mod.set_neurons('in_neurons', out_ch) + + # # Also check inputs to seed in_neurons + # for arg in node.args: + # if isinstance(arg, th.fx.Node) and 'tensor_meta' in arg.meta: + # meta_in = arg.meta['tensor_meta'] + # if hasattr(meta_in, 'shape') and len(meta_in.shape) >= 2: + # in_ch = meta_in.shape[1] + # if in_ch is not None and in_ch > 0: + # mod.set_neurons('in_neurons', in_ch) # --- Handle General Merge Operations (Any call_function with multiple # module inputs) --- @@ -744,8 +744,8 @@ def generate_graph_dependencies_from_torchfx( # TODO (GP): cat of cat of cat, should be nested list also ? # TODO (GP): e.g., cat([conv1, conv2, cat([conv3, cat([conv4, # TODO (GP): conv5])])])]) - source_modules_ = [] # Collect modules to check for single input - source_nodes = [] # Collect nodes to check for single input + source_modules_ = [] # Collect modules to check for single input + source_nodes = [] # Collect nodes to check for single input for arg in node.args: if not isinstance(arg, list): arg = make_safelist(arg) @@ -795,7 +795,7 @@ def generate_graph_dependencies_from_torchfx( # dependent on the first module in the merge node_to_module[node] = distinct_source_modules else: - node_to_module[node] = None # Placeholder or constant input + node_to_module[node] = None # Placeholder or constant input # Clean dependencies (remove duplicates and self-loops) dependencies = _clean_dependencies(dependencies) @@ -835,9 +835,9 @@ def generate_layer_dependencies_from_onnx( Returns: [ - ('conv1', 'bn1', DepType.SAME), # Conv output = BN input/output - ('bn1', 'relu', DepType.SAME), # BN output = ReLU input/output - ('relu', 'conv2', DepType.INCOMING), # ReLU output = Conv2 input only + ('conv1', 'bn1', DepType.SAME), # Conv output = BN input/output + ('bn1', 'relu', DepType.SAME), # BN output = ReLU input/output + ('relu', 'conv2', DepType.INCOMING), # ReLU output = Conv2 input only ] Note: @@ -916,7 +916,7 @@ def get_channel_count(tensor_name: str) -> Optional[int]: if onnx_shapes_map: shape = onnx_shapes_map.get(tensor_name) if shape and len(shape) >= 2: - return shape[1] # NCHW format, C is dimension 1 + return shape[1] # NCHW format, C is dimension 1 # Fallback to node attributes producer = producer_for_tensor.get(tensor_name) @@ -925,19 +925,19 @@ def get_channel_count(tensor_name: str) -> Optional[int]: weight_name = producer.input[1] for init in graph.initializer: if init.name == weight_name and len(init.dims) >= 1: - return init.dims[0] # out_channels + return init.dims[0] # out_channels elif producer.op_type == 'Gemm' and len(producer.input) >= 2: weight_name = producer.input[1] for init in graph.initializer: if init.name == weight_name and len(init.dims) >= 1: - return init.dims[0] # out_features + return init.dims[0] # out_features elif producer.op_type == 'BatchNormalization' and len(producer.input) >= 2: weight_name = producer.input[1] for init in graph.initializer: if init.name == weight_name and len(init.dims) >= 1: - return init.dims[0] # num_features + return init.dims[0] # num_features return None @@ -1018,8 +1018,8 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: for node in graph.node: logger.debug(f"\nNode: {node.op_type} | name: {node.name}") - logger.debug(f" Inputs: {node.input[:3]}") # Show first 3 inputs - logger.debug(f" Outputs: {node.output}") + logger.debug(f" Inputs: {node.input[:3]}") # Show first 3 inputs + logger.debug(f" Outputs: {node.output}") # Find source modules from inputs src_modules = [] @@ -1028,7 +1028,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: for inp in node.input: # Skip constant/initializer inputs (weights, biases, etc.) if any(param in inp for param in ['.weight', '.bias', '.running_mean', '.running_var', '.num_batches_tracked']): - logger.debug(f" Skipping parameter input: {inp[:50]}") + logger.debug(f" Skipping parameter input: {inp[:50]}") continue src_mods = module_for_tensor(inp) @@ -1037,12 +1037,12 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: src_modules.append(src_mod) src_tensors.append(inp) src_name = module_to_name.get(src_mod, "") - logger.debug(f" Found source: {src_name} (from tensor: {inp[:50]})") + logger.debug(f" Found source: {src_name} (from tensor: {inp[:50]})") else: - logger.debug(f" Could not find source module for input: {inp[:50]}") + logger.debug(f" Could not find source module for input: {inp[:50]}") if not src_modules: - logger.debug(f" -> No source modules found, skipping") + logger.debug(f" -> No source modules found, skipping") continue # Handle merge operations (Add, Sub, Sum, Concat, Mul, Div) - create REC dependencies between branches @@ -1050,8 +1050,8 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: is_concat = "Concat" in node.op_type.capitalize() or "Cat" in node.op_type.capitalize() if is_merge: - logger.debug(f" Detected merge operation: {node.op_type} with {len(src_modules)} source modules") - logger.debug(f" Source modules: {[module_to_name.get(m, '?') for m in src_modules]}") + logger.debug(f" Detected merge operation: {node.op_type} with {len(src_modules)} source modules") + logger.debug(f" Source modules: {[module_to_name.get(m, '?') for m in src_modules]}") if is_merge and len(src_modules) >= 2: # Create REC dependencies between all pairs of source modules @@ -1073,15 +1073,15 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: channels_a = get_channel_count(tensor_a) channels_b = get_channel_count(tensor_b) - logger.debug(f" Checking REC: {name_a} (ch={channels_a}) <-> {name_b} (ch={channels_b})") + logger.debug(f" Checking REC: {name_a} (ch={channels_a}) <-> {name_b} (ch={channels_b})") # For Add/Sub/Mul/Div, channels must match # For Concat, channels can differ (concatenated along channel dim) create_rec = False if is_concat: - create_rec = True # Always create REC for concat + create_rec = True # Always create REC for concat elif channels_a is not None and channels_b is not None and channels_a == channels_b: - create_rec = True # Channels match for Add/Sub/etc + create_rec = True # Channels match for Add/Sub/etc elif channels_a is None or channels_b is None: # Can't verify channels, but merge operation requires compatibility create_rec = True @@ -1092,16 +1092,16 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: edge_key_ba = (name_b, name_a) if edge_key_ab not in seen_edges: - logger.debug(f" ✓ Adding REC dependency: {name_a} <-> {name_b}") + logger.debug(f" Adding REC dependency: {name_a} <-> {name_b}") dependencies.append((mod_a, mod_b, DepType.REC)) seen_edges.add(edge_key_ab) if edge_key_ba not in seen_edges: - logger.debug(f" ✓ Adding REC dependency: {name_b} <-> {name_a}") + logger.debug(f" Adding REC dependency: {name_b} <-> {name_a}") dependencies.append((mod_b, mod_a, DepType.REC)) seen_edges.add(edge_key_ba) else: - logger.debug(f" ✗ Skipping REC dependency (channel mismatch: {channels_a} vs {channels_b})") + logger.debug(f" Skipping REC dependency (channel mismatch: {channels_a} vs {channels_b})") # Find destination module from outputs # First try: direct alias from output tensor name @@ -1113,7 +1113,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: if alias and alias in name_to_module: dst_mod = name_to_module[alias] dst_tensor = out_name - logger.debug(f" Found dest (method 1 - alias): {alias}") + logger.debug(f" Found dest (method 1 - alias): {alias}") break # Second try: For ops that correspond to nn.Module, extract module from node itself @@ -1125,15 +1125,15 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: if len(parts) >= 2: # The module name is typically everything except the last part (op type) potential_names = [ - '.'.join(parts[:-1]), # e.g., 'model.bn1' from '/model/bn1/BatchNormalization' - parts[-2] if len(parts) >= 2 else parts[0], # e.g., 'bn1' from above + '.'.join(parts[:-1]), # e.g., 'model.bn1' from '/model/bn1/BatchNormalization' + parts[-2] if len(parts) >= 2 else parts[0], # e.g., 'bn1' from above ] for pname in potential_names: if pname in name_to_module: dst_mod = name_to_module[pname] if node.output: dst_tensor = node.output[0] - logger.debug(f" Found dest (method 2 - node name): {pname}") + logger.debug(f" Found dest (method 2 - node name): {pname}") break # Third try: For weight-based ops (Conv, Gemm, BatchNorm), check input parameters @@ -1148,7 +1148,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: dst_mod = name_to_module[potential_name] if node.output: dst_tensor = node.output[0] - logger.debug(f" Found dest (method 3 - params): {potential_name} (from {inp})") + logger.debug(f" Found dest (method 3 - params): {potential_name} (from {inp})") break if dst_mod: break @@ -1159,17 +1159,17 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: # if not hasattr(src_mod, 'bypass'): # src_mod.bypass = 0 bypassed.extend(make_safelist(list(node.output))) - logger.debug(f" Setting bypass=0 for module after Concat: {module_to_name.get(src_mod, '?')}") + logger.debug(f" Setting bypass=0 for module after Concat: {module_to_name.get(src_mod, '?')}") for k in make_safelist(list(node.input)): if k in bypassed: if dst_mod is not None: dst_mod.bypass = 0 - logger.debug(f" Setting bypass=0 for destination module: {module_to_name.get(dst_mod, '?')}") + logger.debug(f" Setting bypass=0 for destination module: {module_to_name.get(dst_mod, '?')}") break # If no destination module found, skip as it s the end if dst_mod is None: - logger.debug(f" -> No destination module found, skipping") + logger.debug(f" -> No destination module found, skipping") continue # Determine dependency type for each source -> destination connection @@ -1196,17 +1196,17 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: # # SEED NEURONS: Use ONNX metadata to seed neurons if possible # if src_channels is not None and src_channels > 0: - # src_mod.set_neurons('out_neurons', src_channels) - # if getattr(src_mod, 'wl_same_flag', False): - # src_mod.set_neurons('in_neurons', src_channels) + # src_mod.set_neurons('out_neurons', src_channels) + # if getattr(src_mod, 'wl_same_flag', False): + # src_mod.set_neurons('in_neurons', src_channels) # if dst_channels is not None and dst_channels > 0: - # dst_mod.set_neurons('in_neurons', dst_channels) - # if getattr(dst_mod, 'wl_same_flag', False): - # dst_mod.set_neurons('out_neurons', dst_channels) + # dst_mod.set_neurons('in_neurons', dst_channels) + # if getattr(dst_mod, 'wl_same_flag', False): + # dst_mod.set_neurons('out_neurons', dst_channels) logger.debug(f"Analyzing dependency {src_name} -> {dst_name}") - logger.debug(f" Source channels: {src_channels}, Destination channels: {dst_channels}") + logger.debug(f" Source channels: {src_channels}, Destination channels: {dst_channels}") # Use helper function to infer dependency type dep_type = _infer_dependency_type(dst_mod) @@ -1215,7 +1215,7 @@ def module_for_tensor(tname: str) -> Optional[nn.Module]: if dep_type == DepType.SAME: dst_mod.wl_same_flag = True - logger.debug(f" ✓ Adding dependency: {src_name} -> {dst_name} [{dep_type.name}]") + logger.debug(f" Adding dependency: {src_name} -> {dst_name} [{dep_type.name}]") dependencies.append((src_mod, dst_mod, dep_type)) seen_edges.add(edge_key) @@ -1247,7 +1247,7 @@ def generate_index_maps( for edge in dependencies: # Get src and dst modules and type src_mod, dst_mod, edge_label = edge[0], edge[1], edge[2] - recursive_dep = edge_label == DepType.REC # A recursive dependency ? + recursive_dep = edge_label == DepType.REC # A recursive dependency ? # 1.1. Determine the number of neurons in each direction # # Src - First will always be is not None and int @@ -1264,7 +1264,7 @@ def generate_index_maps( dst_mod.set_neurons( 'in_neurons' if not recursive_dep and not hasattr(dst_mod, 'wl_transposed') else 'out_neurons', dst_nb_neurons - ) # So next will have neurons + ) # So next will have neurons dst_mod_out_neurons = dst_mod.get_neurons( 'in_neurons' if not (not recursive_dep and not hasattr(dst_mod, 'wl_transposed')) else 'out_neurons' ) @@ -1279,7 +1279,7 @@ def generate_index_maps( src_mod.set_neurons( 'out_neurons' if not hasattr(src_mod, 'wl_transposed') else 'in_neurons', src_nb_neurons - ) # So next will have neurons + ) # So next will have neurons src_mod_out_neurons = src_mod.get_neurons( 'out_neurons' if not hasattr(src_mod, 'wl_transposed') else 'in_neurons' ) @@ -1360,7 +1360,7 @@ def extract_group_size(mod: nn.Module, incoming: bool) -> Optional[int]: dst_mod.get_name_wi_id(): deepcopy(src_to_dst_mapping_tnsr) } if not hasattr(dst_mod, 'bypass') else {} - ) # Child equivalent here + ) # Child equivalent here dst_mod.src_to_dst_mapping_tnsrs = normalize_dicts(dst_mod.src_to_dst_mapping_tnsrs) dst_mod.related_dst_to_src_mapping_tnsrs = normalize_dicts(dst_mod.related_dst_to_src_mapping_tnsrs) @@ -1392,7 +1392,7 @@ def extract_group_size(mod: nn.Module, incoming: bool) -> Optional[int]: """ # Enable debug logging logging.basicConfig( - level=logging.DEBUG, # Set to DEBUG to see detailed merge operation detection + level=logging.DEBUG, # Set to DEBUG to see detailed merge operation detection format='%(levelname)s - %(message)s' ) logger.setLevel(logging.DEBUG) @@ -1523,14 +1523,14 @@ def __init__(self, in_channels=3, num_classes=10): self.enc_relu1 = nn.ReLU() self.enc_residual1 = ResidualBlockWithUpsampling(64, 64) - self.enc_pool1 = nn.MaxPool2d(2, 2) # 32x32 -> 16x16 + self.enc_pool1 = nn.MaxPool2d(2, 2) # 32x32 -> 16x16 self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.enc_bn2 = nn.BatchNorm2d(128) self.enc_relu2 = nn.ReLU() self.enc_residual2 = ResidualBlockWithUpsampling(128, 128) - self.enc_pool2 = nn.MaxPool2d(2, 2) # 16x16 -> 8x8 + self.enc_pool2 = nn.MaxPool2d(2, 2) # 16x16 -> 8x8 # Bottleneck self.bottleneck_conv = nn.Conv2d(128, 256, kernel_size=3, padding=1) @@ -1539,8 +1539,8 @@ def __init__(self, in_channels=3, num_classes=10): self.bottleneck_residual = ResidualBlockWithUpsampling(256, 256) # Decoder: Upsampling path - self.dec_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 8x8 -> 16x16 - self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features + self.dec_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 8x8 -> 16x16 + self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features self.dec_bn1 = nn.BatchNorm2d(128) self.dec_residual1 = ResidualBlockWithUpsampling(128, 128) @@ -1569,7 +1569,7 @@ def forward(self, x): # Decoder with skip connections dec1 = self.dec_upsample1(bottleneck) - dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation + dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation dec1 = self.dec_conv1(dec1) dec1 = self.dec_bn1(dec1) dec1 = self.dec_residual1(dec1) @@ -1595,14 +1595,14 @@ def __init__(self, in_channels=3, num_classes=10): self.enc_relu1 = nn.ReLU() self.enc_residual1 = ResidualBlockWithUpsampling(64, 64) - self.enc_pool1 = nn.MaxPool2d(3, 3) # 27x27 -> 9x9 + self.enc_pool1 = nn.MaxPool2d(3, 3) # 27x27 -> 9x9 self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.enc_bn2 = nn.BatchNorm2d(128) self.enc_relu2 = nn.ReLU() self.enc_residual2 = ResidualBlockWithUpsampling(128, 128) - self.enc_pool2 = nn.MaxPool2d(3, 3) # 9x9 -> 3x3 + self.enc_pool2 = nn.MaxPool2d(3, 3) # 9x9 -> 3x3 # Bottleneck self.bottleneck_conv = nn.Conv2d(128, 256, kernel_size=3, padding=1) @@ -1611,13 +1611,13 @@ def __init__(self, in_channels=3, num_classes=10): self.bottleneck_residual = ResidualBlockWithUpsampling(256, 256) # Decoder: Mixed upsampling (3x and 2x) - self.dec_upsample1 = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False) # 3x3 -> 9x9 - self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features + self.dec_upsample1 = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False) # 3x3 -> 9x9 + self.dec_conv1 = nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1) # Concatenate with encoder features self.dec_bn1 = nn.BatchNorm2d(128) self.dec_residual1 = ResidualBlockWithUpsampling(128, 128) - self.dec_upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 9x9 -> 18x18 - self.dec_conv2 = nn.Conv2d(128 + 64, 64, kernel_size=3, padding=1) # Concatenate with encoder features + self.dec_upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 9x9 -> 18x18 + self.dec_conv2 = nn.Conv2d(128 + 64, 64, kernel_size=3, padding=1) # Concatenate with encoder features self.dec_bn2 = nn.BatchNorm2d(64) self.dec_residual2 = ResidualBlockWithUpsampling(64, 64) @@ -1646,13 +1646,13 @@ def forward(self, x): # Decoder with skip connections dec1 = self.dec_upsample1(bottleneck) - dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation + dec1 = th.cat([dec1, enc2], dim=1) # Skip connection via concatenation dec1 = self.dec_conv1(dec1) dec1 = self.dec_bn1(dec1) dec1 = self.dec_residual1(dec1) dec2 = self.dec_upsample2(dec1) - dec2 = th.cat([dec2, enc1], dim=1) # Skip connection via concatenation + dec2 = th.cat([dec2, enc1], dim=1) # Skip connection via concatenation dec2 = self.dec_conv2(dec2) dec2 = self.dec_bn2(dec2) dec2 = self.dec_residual2(dec2) @@ -1694,23 +1694,23 @@ def forward(self, x): incoming_deps = [(s, d) for s, d, t in dependencies3 if t == DepType.INCOMING] rec_deps = [(s, d) for s, d, t in dependencies3 if t == DepType.REC] - print(f"\n SAME Dependencies ({len(same_deps)}):") - for src, dst in same_deps: # Show first 5 + print(f"\n SAME Dependencies ({len(same_deps)}):") + for src, dst in same_deps: # Show first 5 src_name = next((name for name, mod in model.named_modules() if mod is src), "") dst_name = next((name for name, mod in model.named_modules() if mod is dst), "") - print(f" [{src_name:30s}] --SAME-----> [{dst_name:30s}]") + print(f" [{src_name:30s}] --SAME-----> [{dst_name:30s}]") - print(f"\n INCOMING Dependencies ({len(incoming_deps)}):") - for src, dst in incoming_deps: # Show first 5 + print(f"\n INCOMING Dependencies ({len(incoming_deps)}):") + for src, dst in incoming_deps: # Show first 5 src_name = next((name for name, mod in model.named_modules() if mod is src), "") dst_name = next((name for name, mod in model.named_modules() if mod is dst), "") - print(f" [{src_name:30s}] --INCOMING-> [{dst_name:30s}]") + print(f" [{src_name:30s}] --INCOMING-> [{dst_name:30s}]") - print(f"\n REC Dependencies ({len(rec_deps)}):") - for src, dst in rec_deps: # Show first 5 + print(f"\n REC Dependencies ({len(rec_deps)}):") + for src, dst in rec_deps: # Show first 5 src_name = next((name for name, mod in model.named_modules() if mod is src), "") dst_name = next((name for name, mod in model.named_modules() if mod is dst), "") - print(f" [{src_name:30s}] <--REC----> [{dst_name:30s}]") + print(f" [{src_name:30s}] <--REC----> [{dst_name:30s}]") print(""" Model Architecture Notes: diff --git a/weightslab/utils/logs.py b/weightslab/utils/logs.py index c329710e..dd75872f 100644 --- a/weightslab/utils/logs.py +++ b/weightslab/utils/logs.py @@ -140,7 +140,7 @@ def setup_logging(level, log_to_file=True): _LOG_FILE_PATH = os.path.join(log_dir, f'weightslab_{timestamp}.log') _FILE_HANDLER = logging.FileHandler(_LOG_FILE_PATH, mode='w', encoding='utf-8') - _FILE_HANDLER.setLevel(logging.DEBUG) # Always log DEBUG+ to file + _FILE_HANDLER.setLevel(logging.DEBUG) # Always log DEBUG+ to file _FILE_HANDLER.setFormatter(formatter) root_logger.addHandler(_FILE_HANDLER) diff --git a/weightslab/utils/tools.py b/weightslab/utils/tools.py index 4e8fa025..8ae77b8f 100644 --- a/weightslab/utils/tools.py +++ b/weightslab/utils/tools.py @@ -29,7 +29,7 @@ def safe_reset_index(df: "pd.DataFrame") -> "pd.DataFrame": Plain ``df.reset_index()`` raises ``ValueError: cannot insert X, already exists`` when a MultiIndex level name (e.g. ``sample_id`` or ``annotation_id``) has already been materialised as a column — which - happens after ``_normalize_for_read`` in the H5 store. This helper only + happens after ``_normalize_for_read`` in the H5 store. This helper only promotes the levels that are actually missing from the column namespace. """ import pandas as _pd @@ -89,7 +89,7 @@ def normalize_config(obj: Any) -> Any: elif isinstance(obj, list): return [normalize_config(v) for v in obj] elif isinstance(obj, torch.device): - return str(obj) # e.g. "cuda" or "cuda:0" + return str(obj) # e.g. "cuda" or "cuda:0" elif isinstance(obj, pathlib.Path): return obj.as_posix() elif isinstance(obj, (bool, int, float, str)) or obj is None: @@ -176,7 +176,7 @@ def restore_rng_state(rng_state): # Restore Python random state if 'python_random' in rng_state: try: - random.setstate(tuple(tuple(i) if i is not None and not isinstance(i, (int, float)) else i for i in rng_state['python_random'])) # Conver to tuple of tuples + random.setstate(tuple(tuple(i) if i is not None and not isinstance(i, (int, float)) else i for i in rng_state['python_random'])) # Conver to tuple of tuples logger.debug("Restored Python random state") except Exception as e: logger.warning(f"Failed to restore Python random state: {e}") @@ -402,7 +402,7 @@ def model_op_neurons(model, layer_id=None, dummy_input=None, op=None, rand=False Test function to iteratively update neurons for each layer, then test inference. Everything match ? """ - seed_everything(42) if rand else None # Set seed for reproducibility + seed_everything(42) if rand else None # Set seed for reproducibility n_layers = len(model.layers) for n in range(n_layers-1, 0, -1): if rand and th.rand(1) > 0.5 and layer_id is None and dummy_input is None: @@ -412,7 +412,7 @@ def model_op_neurons(model, layer_id=None, dummy_input=None, op=None, rand=False if n != layer_id: continue else: - if n != n_layers + layer_id: # - -layer_id != + -layer_id + if n != n_layers + layer_id: # - -layer_id != + -layer_id continue logger.debug(f'\nOperate on neurons at layer {n}') if op is None: @@ -631,7 +631,7 @@ def array_id_2bytes( h = xxhash.xxh64() h.update(data) - digest8 = h.digest() # 8 bytes + digest8 = h.digest() # 8 bytes if return_hex: hexs = digest8.hex() @@ -648,10 +648,10 @@ def detach_to_cpu(obj: Any) -> Any: """Recursively detach tensors from the compute graph and move them to CPU. Handles: - - ``torch.Tensor`` → ``.detach().cpu()`` - - ``dict`` → recurse into values, preserve keys - - ``list`` / ``tuple`` → recurse element-wise, preserve type - - anything else → returned as-is + - ``torch.Tensor`` → ``.detach().cpu()`` + - ``dict`` → recurse into values, preserve keys + - ``list`` / ``tuple`` → recurse element-wise, preserve type + - anything else → returned as-is """ if isinstance(obj, th.Tensor): return obj.detach().cpu() @@ -681,7 +681,7 @@ def filter_kwargs_for_callable(func, kwargs): Examples: >>> def my_func(a, b, c=10): - ... return a + b + c + ... return a + b + c >>> all_kwargs = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5} >>> filtered = filter_kwargs_for_callable(my_func, all_kwargs) >>> filtered @@ -743,7 +743,7 @@ def safe_call_with_kwargs(func, *args, **kwargs): Examples: >>> def my_func(a, b, c=10): - ... return a + b + c + ... return a + b + c >>> safe_call_with_kwargs(my_func, 1, 2, c=3, d=4, e=5) 6 """ diff --git a/weightslab/watchdog/__init__.py b/weightslab/watchdog/__init__.py index 50c9e849..56e5e5a2 100644 --- a/weightslab/watchdog/__init__.py +++ b/weightslab/watchdog/__init__.py @@ -2,29 +2,29 @@ Public API ---------- -WATCHDOG : int — custom log level (35, between WARNING and ERROR) -MonitoredRLock : class — RLock with holder-thread tracking -raise_in_thread : func — deliver _WatchdogInterrupt to a thread by id -_WatchdogInterrupt : class — BaseException raised in stuck threads -RpcWatchdogState : class — tracks in-flight gRPC RPCs +WATCHDOG : int — custom log level (35, between WARNING and ERROR) +MonitoredRLock : class — RLock with holder-thread tracking +raise_in_thread : func — deliver _WatchdogInterrupt to a thread by id +_WatchdogInterrupt : class — BaseException raised in stuck threads +RpcWatchdogState : class — tracks in-flight gRPC RPCs RpcTimingAndWatchdogInterceptor : class — gRPC ServerInterceptor -GrpcServerManager : class — controls gRPC server lifecycle / restarts -WeighlabsWatchdog : class — unified watchdog (locks + gRPC) +GrpcServerManager : class — controls gRPC server lifecycle / restarts +WeighlabsWatchdog : class — unified watchdog (locks + gRPC) """ # Register WATCHDOG log level and logger.watchdog() method -from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 +from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 -from weightslab.watchdog.lock_monitor import ( # noqa: F401 +from weightslab.watchdog.lock_monitor import ( # noqa: F401 MonitoredRLock, _WatchdogInterrupt, raise_in_thread, ) -from weightslab.watchdog.grpc_watchdog import ( # noqa: F401 +from weightslab.watchdog.grpc_watchdog import ( # noqa: F401 RpcWatchdogState, RpcTimingAndWatchdogInterceptor, GrpcServerManager, ) -from weightslab.watchdog.watchdog import WeighlabsWatchdog # noqa: F401 +from weightslab.watchdog.watchdog import WeighlabsWatchdog # noqa: F401 diff --git a/weightslab/watchdog/grpc_watchdog.py b/weightslab/watchdog/grpc_watchdog.py index 94713d6c..979b90ba 100644 --- a/weightslab/watchdog/grpc_watchdog.py +++ b/weightslab/watchdog/grpc_watchdog.py @@ -12,7 +12,7 @@ from threading import Lock, Event -from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level +from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level logger = logging.getLogger(__name__) @@ -149,7 +149,7 @@ def set_server(self, server) -> None: def stop(self, grace: float = 5.0) -> None: with self._lock: if self._server: - logger.watchdog("[gRPC] Requesting graceful shutdown with %.1fs grace", grace) # type: ignore[attr-defined] + logger.watchdog("[gRPC] Requesting graceful shutdown with %.1fs grace", grace) # type: ignore[attr-defined] self._server.stop(grace=grace) self._server = None diff --git a/weightslab/watchdog/lock_monitor.py b/weightslab/watchdog/lock_monitor.py index 299ecc9b..3b00693a 100644 --- a/weightslab/watchdog/lock_monitor.py +++ b/weightslab/watchdog/lock_monitor.py @@ -1,11 +1,11 @@ """Lock monitoring for weightslab watchdog. Provides: - - MonitoredRLock : drop-in RLock replacement that tracks the holder thread + - MonitoredRLock : drop-in RLock replacement that tracks the holder thread and how long it has been held, so the watchdog can detect and recover from stuck locks. - _WatchdogInterrupt : BaseException raised asynchronously in stuck threads. - - raise_in_thread : deliver _WatchdogInterrupt to any thread by id. + - raise_in_thread : deliver _WatchdogInterrupt to any thread by id. When the watchdog raises _WatchdogInterrupt in a thread that holds a MonitoredRLock via ``with`` or a ``try/finally: release()``, Python's @@ -36,7 +36,7 @@ def raise_in_thread(tid: int, exc_type: type = _WatchdogInterrupt) -> bool: """Raise *exc_type* asynchronously in the thread identified by *tid*. Uses ``ctypes.pythonapi.PyThreadState_SetAsyncExc`` which delivers the - exception at the next Python bytecode boundary. Any active ``finally:`` + exception at the next Python bytecode boundary. Any active ``finally:`` or ``with`` block in the target thread will execute before the exception propagates, so held locks are released cleanly. @@ -48,7 +48,7 @@ def raise_in_thread(tid: int, exc_type: type = _WatchdogInterrupt) -> bool: ctypes.py_object(exc_type), ) if res == 0: - return False # thread not found + return False # thread not found if res > 1: # More than one state was modified — undo to be safe (shouldn't happen) ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None) @@ -72,17 +72,17 @@ class MonitoredRLock: whether to kill the holder via ``raise_in_thread``. Re-entrancy is fully supported: the same thread can acquire multiple - times. ``_acquired_at`` records the time of the *first* acquisition and + times. ``_acquired_at`` records the time of the *first* acquisition and is cleared only when the lock becomes fully free (count reaches 0). """ def __init__(self) -> None: self._lock = threading.RLock() - self._meta = threading.Lock() # guards the three fields below + self._meta = threading.Lock() # guards the three fields below self._holder_tid: Optional[int] = None self._acquired_at: Optional[float] = None self._count: int = 0 - self._timeout: Optional[float] = None # Optional per-lock timeout for watchdog (None to use global default) + self._timeout: Optional[float] = None # Optional per-lock timeout for watchdog (None to use global default) # ------------------------------------------------------------------ # Core acquire / release diff --git a/weightslab/watchdog/log_level.py b/weightslab/watchdog/log_level.py index 540e252f..48f4862f 100644 --- a/weightslab/watchdog/log_level.py +++ b/weightslab/watchdog/log_level.py @@ -22,4 +22,4 @@ def _watchdog(self: logging.Logger, message: str, *args, **kwargs) -> None: # Patch Logger class once so every logger instance gets the method -logging.Logger.watchdog = _watchdog # type: ignore[attr-defined] +logging.Logger.watchdog = _watchdog # type: ignore[attr-defined] diff --git a/weightslab/watchdog/watchdog.py b/weightslab/watchdog/watchdog.py index 136785af..8832471c 100644 --- a/weightslab/watchdog/watchdog.py +++ b/weightslab/watchdog/watchdog.py @@ -1,14 +1,14 @@ """WeighlabsWatchdog — unified watchdog for locks and gRPC threads. Combines: - 1. Lock monitoring — polls MonitoredRLock instances, raises _WatchdogInterrupt + 1. Lock monitoring — polls MonitoredRLock instances, raises _WatchdogInterrupt in the holder thread when the lock is held too long. - 2. gRPC monitoring — detects stuck in-flight RPCs via RpcWatchdogState and + 2. gRPC monitoring — detects stuck in-flight RPCs via RpcWatchdogState and requests a server restart when the threshold is exceeded. 3. Eval thread monitoring — checks that the evaluation worker thread is still alive whenever eval_controller reports is_running() or - is_pending(). If the thread is dead the controller is - transitioned to error state automatically. No timeout is + is_pending(). If the thread is dead the controller is + transitioned to error state automatically. No timeout is applied — evaluation may run for an arbitrarily long time. Typical usage (inside grpc_serve): @@ -25,7 +25,7 @@ get_thread=lambda: _EVAL_WORKER_THREAD, ) watchdog.start() - # watchdog.rpc_state → pass to RpcTimingAndWatchdogInterceptor + # watchdog.rpc_state → pass to RpcTimingAndWatchdogInterceptor # watchdog.server_manager → used by serving_thread_callback """ @@ -34,7 +34,7 @@ import threading from typing import Callable, Dict, List, Optional, Tuple -from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level +from weightslab.watchdog.log_level import WATCHDOG # noqa: F401 — registers level from weightslab.watchdog.lock_monitor import MonitoredRLock, raise_in_thread from weightslab.watchdog.grpc_watchdog import RpcWatchdogState, GrpcServerManager @@ -91,12 +91,12 @@ def register_eval_monitor( The watchdog will call ``mark_error()`` on the controller when it reports ``is_running()`` or ``is_pending()`` but the worker thread is no longer - alive. **No timeout is applied** — evaluation is allowed to run for as + alive. **No timeout is applied** — evaluation is allowed to run for as long as needed. Args: get_controller: Zero-arg callable that returns the EvaluationController. - get_thread: Zero-arg callable that returns the current worker + get_thread: Zero-arg callable that returns the current worker ``threading.Thread`` (or ``None`` if not started yet). """ self._eval_monitors.append((get_controller, get_thread)) @@ -114,7 +114,7 @@ def start(self) -> None: daemon=True, ) self._thread.start() - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Started (threshold=%.1fs poll=%.1fs restart_after=%d exit_on_stuck=%s locks=%s)", self._stuck_threshold_s, self._poll_interval_s, @@ -152,7 +152,7 @@ def _loop(self) -> None: # Lock monitoring # ------------------------------------------------------------------ - def _check_locks(self) -> None: # noqa: C901 + def _check_locks(self) -> None: # noqa: C901 # Import here to avoid circular imports try: from weightslab.components.global_monitoring import is_in_evaluation @@ -185,19 +185,19 @@ def _check_locks(self) -> None: # noqa: C901 continue if duration >= effective_threshold: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Lock '%s' held for %.1fs by tid=%s — sending interrupt", name, duration, tid, ) if tid is not None: killed = raise_in_thread(tid) if killed: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Interrupt delivered to tid=%s (lock '%s' will be released by finally/with)", tid, name, ) else: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Could not deliver interrupt to tid=%s — thread may have already exited", tid, ) @@ -217,14 +217,14 @@ def _check_eval_threads(self) -> None: continue if not (controller.is_running() or controller.is_pending()): - continue # nothing active — nothing to check + continue # nothing active — nothing to check if thread is not None and thread.is_alive(): - continue # worker is alive — all good + continue # worker is alive — all good # Controller believes eval is active but the thread is dead or missing. status = controller.get_status() if hasattr(controller, "get_status") else "unknown" - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Eval controller is '%s' but worker thread is dead — marking error", status, ) @@ -253,7 +253,7 @@ def _check_grpc(self) -> None: self._unhealthy_count += 1 self.rpc_state.record_unhealthy() - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] gRPC unhealthy #%d: in_flight=%d oldest=%.1fs method=%s threshold=%.1fs | %s", self._unhealthy_count, snap["in_flight"], @@ -264,13 +264,13 @@ def _check_grpc(self) -> None: ) if self._exit_on_stuck: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] GRPC_WATCHDOG_EXIT_ON_STUCK=1 — calling os._exit(1)" ) os._exit(1) if self._unhealthy_count >= self._restart_threshold: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] Restart threshold reached (%d/%d) — requesting server restart", self._unhealthy_count, self._restart_threshold, ) @@ -278,7 +278,7 @@ def _check_grpc(self) -> None: else: if self._unhealthy_count > 0: - logger.watchdog( # type: ignore[attr-defined] + logger.watchdog( # type: ignore[attr-defined] "[Watchdog] gRPC recovered after %d unhealthy checks", self._unhealthy_count ) self._unhealthy_count = 0