Skip to content

Bug fix in jvp updater#94

Merged
anagainaru merged 6 commits intomainfrom
jvp-bug
Apr 7, 2026
Merged

Bug fix in jvp updater#94
anagainaru merged 6 commits intomainfrom
jvp-bug

Conversation

@anagainaru
Copy link
Copy Markdown
Collaborator

This is the error happening without this fix:

2026/03/06 22:35:50 INFO:0 mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2026/03/06 22:35:50 INFO:0 mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
INFO:0 | 22:35:51 | step=0 | continuous_monitor | ==== ContinuousMonitor initialized ====
INFO:1 | 22:35:51 | step=0 | continuous_monitor | 	Detector: ADWINDetector
INFO:1 | 22:35:51 | step=0 | continuous_monitor | 	Monitoring metric index: 0
INFO:1 | 22:35:51 | step=0 | continuous_monitor | 	Detection interval: 1 batches
INFO:1 | 22:35:51 | step=0 | continuous_monitor | 	Aggregation method: mean
INFO:1 | 22:35:51 | step=0 | continuous_monitor | 	Max stream updates: 33
INFO:0 | 22:35:51 | step=0 | continuous_monitor | ==== Starting Continuous Monitoring ====
INFO:1 | 22:35:51 | step=0 | continuous_monitor | 	Initializing first data stream...
INFO:1 | 22:35:56 | step=18 | continuous_monitor | 	Stream exhausted. Loading next data buffer. 1/33
INFO:1 | 22:36:02 | step=38 | continuous_monitor | 	Stream exhausted. Loading next data buffer. 2/33
INFO:1 | 22:36:07 | step=58 | continuous_monitor | 	Stream exhausted. Loading next data buffer. 3/33
Processing batches: 1it [00:04,  4.38s/it]INFO:0 | 22:36:13 | step=62 | continuous_monitor | ==== DRIFT DETECTED (Event #1)! ====
INFO:1 | 22:36:13 | step=62 | continuous_monitor | 	Regime: continual_learning
INFO:1 | 22:36:13 | step=62 | continuous_monitor | 	Drift Score: 0.0312
INFO:1 | 22:36:13 | step=62 | continuous_monitor | 	Confidence: 0.8
INFO:2 | 22:36:13 | step=62 | count_flops | ---------------------------------------------------------------------------
INFO:2 | 22:36:13 | step=62 | count_flops | Compute Performance Metrics (Averaged per Update)
INFO:2 | 22:36:13 | step=62 | count_flops | ---------------------------------------------------------------------------
INFO:2 | 22:36:13 | step=62 | count_flops | Operation       FLOPs              Time            Throughput
INFO:2 | 22:36:13 | step=62 | count_flops | ---------------------------------------------------------------------------
INFO:2 | 22:36:13 | step=62 | count_flops | detector        0 FLOPs            41.10 μs        0 FLOP/s
INFO:2 | 22:36:13 | step=62 | count_flops | infer           22.91 MFLOPs       14.68 ms        1.56 GFLOP/s
INFO:2 | 22:36:13 | step=62 | count_flops | ---------------------------------------------------------------------------
INFO:2 | 22:36:13 | step=62 | count_flops | TOTAL           22.91 MFLOPs       14.72 ms        1.56 GFLOP/s
INFO:2 | 22:36:13 | step=62 | count_flops | ---------------------------------------------------------------------------
INFO:0 | 22:36:13 | step=62 | continuous_monitor | -> Dispatching continual learning module...
INFO:0 | 22:36:24 | step=62 | continuous_trainer | ==== Continual Learning ====
INFO:1 | 22:36:24 | step=62 | continuous_trainer | 	Initial test acc: 2.4631996154785156
INFO:1 | 22:36:24 | step=62 | continuous_trainer | 	Initial historical test acc: 0.8635233116149902
CL Updates (drift_event_id=1):   0%|                                                                                                     | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):                                                                                                       | 0/100 [00:00<?, ?it/s]
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/95j/_software/BaseSim_Framework/src/main.py", line 47, in <module>
    sys.exit(main())
             ~~~~^^
  File "/Users/95j/_software/BaseSim_Framework/src/main.py", line 37, in main
    monitor.run()
    ~~~~~~~~~~~^^
  File "/Users/95j/_software/BaseSim_Framework/src/driver/continuous_monitor.py", line 116, in run
    self._process_stream()
    ~~~~~~~~~~~~~~~~~~~~^^
  File "/Users/95j/_software/BaseSim_Framework/src/driver/continuous_monitor.py", line 155, in _process_stream
    self._handle_drift(drift_signal)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/Users/95j/_software/BaseSim_Framework/src/driver/continuous_monitor.py", line 290, in _handle_drift
    self.trainer.outer_cl_training_loop(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        drift_event_id=self.drift_event_count,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/95j/_software/BaseSim_Framework/src/training/continuous_trainer.py", line 102, in outer_cl_training_loop
    generation_loss, forgetting_loss = self.inner_cl_training_loop(
                                       ~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        iter_count=iter_count,
        ^^^^^^^^^^^^^^^^^^^^^^
    ...<3 lines>...
        hist_train_iter=hist_train_iter,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/95j/_software/BaseSim_Framework/src/training/continuous_trainer.py", line 211, in inner_cl_training_loop
    loss += self.cl_updater.fwd_bwd(train_batch_tuple, hist_batch_tuple)
            ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/95j/_software/BaseSim_Framework/src/training/updater/jvp_reg.py", line 81, in fwd_bwd
    grad_dict, loss_curr, loss_mem = self._compute_jvp_gradients(
                                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        self._params, x_curr, y_curr, x_mem, y_mem, deltax
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/95j/_software/BaseSim_Framework/src/training/updater/jvp_reg.py", line 116, in _compute_jvp_gradients
    grad_curr = grad_fn(params, x_curr, y_curr)
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/_functorch/apis.py", line 406, in wrapper
    return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/_functorch/eager_transforms.py", line 1406, in grad_impl
    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/_functorch/eager_transforms.py", line 1364, in grad_and_value_impl
    output = func(*args, **kwargs)
  File "/Users/95j/_software/BaseSim_Framework/src/training/updater/jvp_reg.py", line 109, in loss_fn
    pred = functional_call(self.model, p, (x,))
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/_functorch/functional_call.py", line 153, in functional_call
    return nn.utils.stateless._functional_call(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        module,
        ^^^^^^^
    ...<4 lines>...
        strict=strict,
        ^^^^^^^^^^^^^^





    )
    ^
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/utils/stateless.py", line 282, in _functional_call
    return module(*args, **kwargs)
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/95j/_software/BaseSim_Framework/examples/aeris/model.py", line 42, in forward
    return self.layers(x)
           ~~~~~~~~~~~^^^
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/95j/_penv/torch-penv/lib/python3.13/site-packages/torch/nn/modules/batchnorm.py", line 173, in forward
    self.num_batches_tracked.add_(1)  # type: ignore[has-type]
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.

@anagainaru anagainaru merged commit 36b2a42 into main Apr 7, 2026
3 checks passed
@anagainaru anagainaru deleted the jvp-bug branch April 7, 2026 18:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants