Skip to content

ThunderFX in BF16 mode shows a larger divergence from eager FP64 results compared to the divergence between eager BF16 and eager FP64 #2153

@kiya00

Description

@kiya00

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

While adding the coverage test for HF Diffusers (PR #2141), I noticed that ThunderFX in BF16 showed a significant divergence compared to Eager BF16. To investigate further, I tested various dtypes (FP64, FP32, FP16, BF16) across different backends (Eager, ThunderFX, and torch.compile) and compared their outputs to Eager FP64. The results showed that ThunderFX generally stays within an acceptable range of divergence—except for gradient outputs in FP16 and BF16, where the discrepancies were noticeably larger.

The results of runwayml/stable-diffusion-v1-5:
BF16 vs eager FP64:

ThunderFX with default executors
Forward pass comparison: FAILED
Error: Tensor-likes are not close!

Mismatched elements: 1023 / 1024 (99.9%)
Greatest absolute difference: 0.02505786082606219 at index (0, 3, 9, 3) (up to 1e-05 allowed)
Greatest relative difference: 15.79343976974047 at index (0, 0, 6, 2) (up to 1e-05 allowed)
Computing gradients for torch eager float64...
Computing gradients for thunder torch.bfloat16...
Gradient comparison: FAILED
Error: Tensor-likes are not close!

Mismatched elements: 11520 / 11520 (100.0%)
Greatest absolute difference: **52.038735867250054** at index (145, 0, 1, 0) (up to 1e-05 allowed)
Greatest relative difference: **118164.35679032684** at index (162, 0, 0, 2) (up to 1e-05 allowed)

The failure occurred for item [0]

eager

Mismatched elements: 1022 / 1024 (99.8%)
Greatest absolute difference: 0.031462447813257244 at index (0, 3, 12, 14) (up to 1e-05 allowed)
Greatest relative difference: 5.43559016117814 at index (0, 0, 6, 4) (up to 1e-05 allowed)
Computing gradients for torch eager float64...
Computing gradients for eager torch.bfloat16...
Gradient comparison: FAILED
Error: Tensor-likes are not close!

Mismatched elements: 11518 / 11520 (100.0%)
Greatest absolute difference: **0.23972053767860058** at index (127, 3, 1, 1) (up to 1e-05 allowed)
Greatest relative difference: **532.5203075118578** at index (162, 0, 0, 2) (up to 1e-05 allowed)

torch.compile

Mismatched elements: 1023 / 1024 (99.9%)
Greatest absolute difference: 0.02224970804979243 at index (0, 3, 1, 14) (up to 1e-05 allowed)
Greatest relative difference: 7.853684232346166 at index (0, 0, 6, 2) (up to 1e-05 allowed)
Computing gradients for torch eager float64...
Computing gradients for inductor torch.bfloat16...
Gradient comparison: FAILED
Error: Tensor-likes are not close!

Mismatched elements: 11515 / 11520 (100.0%)
Greatest absolute difference: **0.1986014106462043** at index (202, 0, 0, 1) (up to 1e-05 allowed)
Greatest relative difference: **328.6249694895744** at index (239, 3, 0, 2) (up to 1e-05 allowed)

ThunderFX with executors=[thunder.pytorch_executor]

Mismatched elements: 1024 / 1024 (100.0%)
Greatest absolute difference: 0.02503035742302913 at index (0, 1, 11, 1) (up to 1e-05 allowed)
Greatest relative difference: 14.88029205458141 at index (0, 0, 6, 2) (up to 1e-05 allowed)
Computing gradients for torch eager float64...
Computing gradients for thunder torch.bfloat16...
Gradient comparison: FAILED
Error: Tensor-likes are not close!

Mismatched elements: 11517 / 11520 (100.0%)
Greatest absolute difference: **0.25748365664500517** at index (122, 0, 1, 2) (up to 1e-05 allowed)
Greatest relative difference: **387.0515917837159** at index (162, 0, 0, 2) (up to 1e-05 allowed)

To Reproduce

  1. Uses container pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.6.3-cudnn-fe1.10.0-py3.10-pt_2.7.0-dev
  2. Modifies the bitsandbytes version and adds diffusers, then installs it:
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -17,7 +17,9 @@ absl-py # thunder/benchmarks/test_benchmark_litgpt.py
 pandas # thunder/benchmarks/test_benchmark_litgpt.py
 xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py
 jsonargparse # thunder/benchmarks/benchmark_litgpt.py
-bitsandbytes==0.42.0  # fixed version!
+# bitsandbytes==0.42.0  # fixed version!
+bitsandbytes>=0.42,<0.43; sys_platform=='darwin'
+bitsandbytes>=0.45.2,<0.45.5; sys_platform!='darwin'
 transformers==4.50.3 # for test_networks.py
 diffusers==0.33.0 # for test_networks.py
  1. python compare_acc64.py --backend thunder --dtype bfloat16:
import torch
from thunder.dynamo import thunderfx
from diffusers import UNet2DConditionModel

# torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.set_float32_matmul_precision('high')
torch.manual_seed(0)
import argparse

# Parse command line arguments for backend selection
parser = argparse.ArgumentParser(description='UNet2D accuracy test with different backends')
parser.add_argument('--backend', type=str, default='thunder', choices=['thunder', 'inductor', 'eager'],
                    help='Backend to use for model compilation (default: thunder)')
parser.add_argument('--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16', 'float32', 'float64'],
                    help='Data type to use for model (default: bfloat16)')
parser.add_argument('--model-id', type=str, default='runwayml/stable-diffusion-v1-5',
                    help='Model ID to use for testing (default: runwayml/stable-diffusion-v1-5)')
args = parser.parse_args()
backend = args.backend
if args.dtype == "float16":
    dtype = torch.float16
elif args.dtype == "bfloat16":
    dtype = torch.bfloat16
elif args.dtype == "float32":
    dtype = torch.float32
elif args.dtype == "float64":
    dtype = torch.float64
else:
    raise ValueError(f"Invalid dtype: {args.dtype}")
model_id = args.model_id

print(f"Using backend: {backend}")
print(f"Using dtype: {dtype}")
print(f"Using model ID: {model_id}")

# model_id = "runwayml/stable-diffusion-v1-5" #"stabilityai/stable-diffusion-2-1"

try:
    unet_bf16 = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
    unet_f64 = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float64)
except OSError:
    unet_bf16 = UNet2DConditionModel.from_pretrained(
        model_id, subfolder="unet", use_safetensors=False, torch_dtype=dtype
    )
    unet_f64 = UNet2DConditionModel.from_pretrained(
        model_id, subfolder="unet", use_safetensors=False, torch_dtype=torch.float64
    )

config = unet_bf16.config
in_channels = config.in_channels
sample_size = 16 # config.sample_size
cross_attention_dim = config.cross_attention_dim
addition_embed_type = config.addition_embed_type

batch_size = 1
seq_length = 8

if "xl" in model_id:
    time_ids_dim = 6
    text_embeds_dim = 1280
    if "refiner" in model_id:
        time_ids_dim = 2
        text_embeds_dim = 2048
else:
    time_ids_dim = None
    text_embeds_dim = None

input_shape = (batch_size, in_channels, sample_size, sample_size)
hidden_states_shape = (batch_size, seq_length, cross_attention_dim)

# Setup models
unet_f64 = unet_f64.to("cuda", dtype=torch.float64).requires_grad_(True)
unet_bf16 = unet_bf16.to("cuda", dtype=dtype).requires_grad_(True)
if backend == "thunder":
    import thunder
    compiled_model = thunderfx(unet_bf16)
    # compiled_model = thunderfx(unet_bf16, executors=[thunder.pytorch_executor])
elif backend == "inductor":
    compiled_model = torch.compile(unet_bf16)
elif backend == "eager":
    compiled_model = unet_bf16
else:
    raise ValueError(f"Invalid backend: {backend}")

def make_inputs(dtype=torch.bfloat16):
    added_cond_kwargs = {}
    with torch.device("cuda"):
        input = torch.randn(input_shape, dtype=dtype)
        hidden_states = torch.randn(hidden_states_shape, dtype=dtype)
        timestep = torch.ones(batch_size, dtype=torch.long)
        if addition_embed_type is not None:
            assert text_embeds_dim is not None and time_ids_dim is not None
            time_ids_shape = (batch_size, time_ids_dim)
            text_embeds_shape = (batch_size, text_embeds_dim)
            added_cond_kwargs["time_ids"] = torch.randn(time_ids_shape, device="cuda", dtype=dtype)
            added_cond_kwargs["text_embeds"] = torch.randn(text_embeds_shape, device="cuda", dtype=dtype)
    return (input, timestep, hidden_states), {"added_cond_kwargs": added_cond_kwargs}

# Create identical inputs, then convert to appropriate dtypes
base_args, base_kwargs = make_inputs(torch.float64)

# Use the same inputs for both models, just convert dtype
f64_args = base_args
f64_kwargs = base_kwargs

# Convert to bfloat16 for the compiled model
bf16_args = tuple(arg.to(dtype) for arg in base_args)
bf16_kwargs = {}
for key, value in base_kwargs.items():
    if key == "added_cond_kwargs":
        bf16_kwargs[key] = {k: v.to(dtype) for k, v in value.items()}
    else:
        bf16_kwargs[key] = value

# Run inference
print("Running torch eager float64...")
ref_output = unet_f64(*f64_args, **f64_kwargs)

print(f"Running {backend} {dtype}...")
compiled_output = compiled_model(*bf16_args, **bf16_kwargs)

ref_output = ref_output.sample
compiled_output = compiled_output.sample

# Convert compiled output to float64 for comparison
compiled_output_f64 = compiled_output.to(torch.float64)

# Compare outputs
try:
    torch.testing.assert_close(compiled_output_f64, ref_output, rtol=1e-5, atol=1e-5)
except Exception as e:
    print(f"Forward pass comparison: FAILED")
    print(f"Error: {e}")
else:
    print("Forward pass comparison: PASSED")

# Gradient comparison
# loss_grad_f64 = torch.randn_like(ref_output)
# loss_grad_bf16 = loss_grad_f64.to(dtype)

loss_grad_bf16 = torch.randn_like(compiled_output)
loss_grad_f64 = loss_grad_bf16.to(torch.float64)

print("Computing gradients for torch eager float64...")
grads_ref = torch.autograd.grad(ref_output, unet_f64.parameters(), grad_outputs=loss_grad_f64)

print(f"Computing gradients for {backend} {dtype}...")
grads_compiled = torch.autograd.grad(compiled_output, unet_bf16.parameters(), grad_outputs=loss_grad_bf16)

# Convert compiled gradients to float64 for comparison
grads_compiled_f64 = [g.to(torch.float64) for g in grads_compiled]

try:
    # Compare gradients with relaxed tolerance due to precision differences
    torch.testing.assert_close(grads_compiled_f64, grads_ref, rtol=1e-5, atol=1e-5)
except Exception as e:
    print(f"Gradient comparison: FAILED")
    print(f"Error: {e}")
else:
    print("Gradient comparison: PASSED")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions