Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from math import prod
from typing import Optional

import torch
Expand Down Expand Up @@ -96,7 +95,7 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
@register_fake("bitsandbytes::int8_vectorwise_quant")
def _(A: torch.Tensor, threshold=0.0):
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
row_stats = torch.empty(A.numel() // A.shape[-1], device=A.device, dtype=torch.float32)

if threshold == 0.0:
return out_row, row_stats, None
Expand Down Expand Up @@ -153,7 +152,7 @@ def _(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8)
out_col = torch.empty_like(A, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
row_stats = torch.empty(A.numel() // A.shape[-1], device=A.device, dtype=torch.float32)
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
outlier_n = torch.library.get_ctx().new_dynamic_size()
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
Expand All @@ -175,7 +174,13 @@ def _(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}")
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
torch._check(
dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
return torch.empty(shape, dtype=dtype, device=A.device)


Expand All @@ -195,7 +200,13 @@ def _(
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}")
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
torch._check(
dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
Expand All @@ -211,7 +222,12 @@ def _(
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}")
torch._check(
A.dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)

n = A.numel()
blocks = -(n // -blocksize)
Expand Down Expand Up @@ -250,7 +266,7 @@ def _(
B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32),
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(blocksize in [32, 64, 128, 256, 512, 1024, 2048, 4096], lambda: f"invalid blocksize {blocksize}")
torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}")
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
if absmax_8bit is not None:
Expand Down Expand Up @@ -287,8 +303,12 @@ def _(

@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(blocksize > 0, lambda: f"blocksize must be positive, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"Blockwise dequantization only supports 16/32-bit floats, but got {dtype}",
)
return torch.empty_like(A, dtype=dtype)


Expand All @@ -302,8 +322,12 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
):
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(blocksize > 0, lambda: f"blocksize must be positive, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"Blockwise dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
Expand All @@ -314,7 +338,11 @@ def _(

@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(blocksize > 0, lambda: f"blocksize must be positive, got {blocksize}")
torch._check(
A.dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
Expand All @@ -332,14 +360,13 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
A.dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32),
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
shape = (*A.shape[:-1], shapeB[0])
Expand All @@ -362,14 +389,13 @@ def _(
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
A.dtype in (torch.float16, torch.bfloat16, torch.float32),
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32),
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class MatMul4Bit(torch.autograd.Function):
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
if A.numel() == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
Expand Down
12 changes: 0 additions & 12 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def _(A: torch.Tensor, B: torch.Tensor):

@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")

n = A.numel()
blocks = -(n // -blocksize)

Expand Down Expand Up @@ -94,9 +92,6 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

out = torch.empty_like(A, dtype=dtype)
if dtype == torch.float32:
lib.cdequantize_blockwise_cpu_fp32(
Expand Down Expand Up @@ -146,13 +141,6 @@ def _(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)

# Fallback as AVX512 implementation has accuracy issues with blocksize >= 2048.
# Note: this is not a common use case.
avx512_fallback = _has_avx512 and blocksize >= 2048
Expand Down
Loading
Loading