diff --git a/AdEMAMix-Shampoo.py b/AdEMAMix-Shampoo.py index 2a73931..d2e5dfb 100644 --- a/AdEMAMix-Shampoo.py +++ b/AdEMAMix-Shampoo.py @@ -284,7 +284,7 @@ def _update_adamemix_distributed_shampoo( denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) # Compute step size - step_size = lr / bias_correction1 + step_size = lr / (bias_correction1 if bias_correction1 > 0 else 0.01) # Apply weight decay if weight_decay != 0: