diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 9b97480aa..43efd8609 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from math import prod from typing import Optional import torch @@ -96,7 +95,7 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): @register_fake("bitsandbytes::int8_vectorwise_quant") def _(A: torch.Tensor, threshold=0.0): out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32) + row_stats = torch.empty(A.numel() // A.shape[-1], device=A.device, dtype=torch.float32) if threshold == 0.0: return out_row, row_stats, None @@ -153,7 +152,7 @@ def _( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: out_row = torch.empty_like(A, dtype=torch.int8) out_col = torch.empty_like(A, dtype=torch.int8) - row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32) + row_stats = torch.empty(A.numel() // A.shape[-1], device=A.device, dtype=torch.float32) col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32) outlier_n = torch.library.get_ctx().new_dynamic_size() outlier_cols = A.new_empty(outlier_n, dtype=torch.int64) @@ -175,7 +174,13 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") + torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}") + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}") + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check( + dtype in (torch.float16, torch.bfloat16, torch.float32), + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) return torch.empty(shape, dtype=dtype, device=A.device) @@ -195,7 +200,13 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") + torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}") + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}") + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check( + dtype in (torch.float16, torch.bfloat16, torch.float32), + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") @@ -211,7 +222,12 @@ def _( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") + torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}") + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}") + torch._check( + A.dtype in (torch.float16, torch.bfloat16, torch.float32), + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) n = A.numel() blocks = -(n // -blocksize) @@ -250,7 +266,7 @@ def _( B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32), lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", ) - torch._check(blocksize in [32, 64, 128, 256, 512, 1024, 2048, 4096], lambda: f"invalid blocksize {blocksize}") + torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}") torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}") torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") if absmax_8bit is not None: @@ -287,8 +303,12 @@ def _( @register_fake("bitsandbytes::dequantize_blockwise") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") + torch._check(blocksize > 0, lambda: f"blocksize must be positive, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in (torch.float16, torch.bfloat16, torch.float32), + lambda: f"Blockwise dequantization only supports 16/32-bit floats, but got {dtype}", + ) return torch.empty_like(A, dtype=dtype) @@ -302,8 +322,12 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ): - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") + torch._check(blocksize > 0, lambda: f"blocksize must be positive, got {blocksize}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in (torch.float16, torch.bfloat16, torch.float32), + lambda: f"Blockwise dequantization only supports 16/32-bit floats, but got {dtype}", + ) torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") @@ -314,7 +338,11 @@ def _( @register_fake("bitsandbytes::quantize_blockwise") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") + torch._check(blocksize > 0, lambda: f"blocksize must be positive, got {blocksize}") + torch._check( + A.dtype in (torch.float16, torch.bfloat16, torch.float32), + lambda: f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}", + ) n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) @@ -332,14 +360,13 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor def _( A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") + torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}") torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], + A.dtype in (torch.float16, torch.bfloat16, torch.float32), lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", ) torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32), lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", ) shape = (*A.shape[:-1], shapeB[0]) @@ -362,14 +389,13 @@ def _( blocksize: int, out: torch.Tensor, ) -> None: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") + torch._check(blocksize in (32, 64, 128, 256, 512, 1024, 2048, 4096), lambda: f"invalid blocksize {blocksize}") torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], + A.dtype in (torch.float16, torch.bfloat16, torch.float32), lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", ) torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + B.dtype in (torch.uint8, torch.bfloat16, torch.float16, torch.float32), lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", ) torch._check( diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e254f63df..8a069bd10 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -304,7 +304,7 @@ class MatMul4Bit(torch.autograd.Function): def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): # default of pytorch behavior if inputs are empty ctx.is_empty = False - if prod(A.shape) == 0: + if A.numel() == 0: ctx.is_empty = True ctx.A = A ctx.B = B diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index a6277e5cf..ed6803eda 100755 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -35,8 +35,6 @@ def _(A: torch.Tensor, B: torch.Tensor): @register_kernel("bitsandbytes::quantize_blockwise", "cpu") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - n = A.numel() blocks = -(n // -blocksize) @@ -94,9 +92,6 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - out = torch.empty_like(A, dtype=dtype) if dtype == torch.float32: lib.cdequantize_blockwise_cpu_fp32( @@ -146,13 +141,6 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - # Fallback as AVX512 implementation has accuracy issues with blocksize >= 2048. # Note: this is not a common use case. avx512_fallback = _has_avx512 and blocksize >= 2048 diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 835febd99..7825d6585 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -7,7 +7,7 @@ import torch -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, get_ptr from ..._ops import register_kernel from ...cextension import lib @@ -15,6 +15,62 @@ from ..utils import _get_4bit_code +def _setup_ctypes(names, argtypes, restype=None): + for name in names: + fn = getattr(lib, name) + fn.argtypes = argtypes + fn.restype = restype + + +# 4-bit/8-bit dequantize: (code, A, absmax, out, blocksize, numel, stream) +_setup_ctypes( + [f"cdequantize_blockwise_{d}_{q}" for d in ("fp32", "bf16", "fp16") for q in ("nf4", "fp4")] + + [f"cdequantize_blockwise_{d}" for d in ("fp32", "bf16", "fp16")], + [ct.c_void_p] * 4 + [ct.c_int32, ct.c_int32, ct.c_void_p], +) + +# 4-bit GEMM: (A, B, absmax, absmax_8bit, absmax_code, absmax_offset, out, bias, M, N, K, blocksize, quant_type, stream) +_setup_ctypes( + [f"cgemm_4bit_{d}" for d in ("bf16", "fp16", "fp32")], + [ct.c_void_p] * 8 + [ct.c_int32, ct.c_int32, ct.c_int32, ct.c_int32, ct.c_int32, ct.c_void_p], +) + +# 4-bit GEMV: (m, n, k, A, B, absmax, code, out, lda, ldb, ldc, blocksize, stream) +_setup_ctypes( + [f"cgemm_4bit_inference_naive_{d}" for d in ("bf16", "fp16", "fp32")], + [ct.c_int32] * 3 + [ct.c_void_p] * 5 + [ct.c_int32] * 3 + [ct.c_int32, ct.c_void_p], +) + +# int8 igemm: (ctx, m, n, k, A, B, C, rowscale, lda, ldb, ldc, stream) -> int32 +_setup_ctypes( + ["cigemmlt_32"], + [ct.c_void_p] + [ct.c_int32] * 3 + [ct.c_void_p] * 4 + [ct.c_int32] * 3 + [ct.c_void_p], + restype=ct.c_int32, +) + +# int8 mm dequant: (A, row_stats, col_stats, out, bias, numRows, numCols, stream) +_setup_ctypes( + ["cdequant_mm_int32_fp16"], + [ct.c_void_p] * 5 + [ct.c_int32, ct.c_int32, ct.c_void_p], +) + +# int8 vectorwise quant: (A, out, row_stats, threshold, rows, cols, stream) +_setup_ctypes( + ["cint8_vector_quant"], + [ct.c_void_p] * 3 + [ct.c_float, ct.c_int32, ct.c_int32, ct.c_void_p], +) + +# 4-bit/8-bit blockwise quantize: (code, A, absmax, out, blocksize, n) +_setup_ctypes( + [f"cquantize_blockwise_{d}_{q}" for d in ("fp32", "bf16", "fp16") for q in ("nf4", "fp4")] + + [f"cquantize_blockwise_{d}" for d in ("fp32", "bf16", "fp16")], + [ct.c_void_p] * 4 + [ct.c_int32, ct.c_int32], +) + + +_get_raw_stream = torch._C._cuda_getCurrentRawStream + + @functools.cache def _gpu_dispatch_props(device_index): props = torch.cuda.get_device_properties(device_index) @@ -38,15 +94,22 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor shapeA = A.shape shapeB = B.shape - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) + if A.dtype != torch.int8: + raise ValueError("B must be int8") + if B.dtype != torch.int8: + raise ValueError("A must be int8") + if A.ndim != 2: + raise ValueError("Only two dimensional matrices are supported for argument B") + if B.ndim not in (2, 3): + raise ValueError("Only two or three dimensional matrices are supported for argument A") + if prod(shapeB) <= 0: + raise ValueError(f"Input tensor dimensions need to be > 0: {shapeB}") + if out.dtype != torch.int32: + raise ValueError(f"out must be int32, got {out.dtype}") shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + if out.shape != shapeC: + raise ValueError(f"Output shape {out.shape} does not match expected shape {shapeC}") k, m = shapeA n = prod(shapeB[:-1]) @@ -54,10 +117,10 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor ldb = shapeB[-1] # Activations (batch, tokens, inputs) ldc = shapeC[-1] # Output (batch, tokens, outputs) - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) + if lda != ldb: + raise ValueError( + f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" + ) # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. # We'll fall back to a slower fp32 calculation in this circumstance. @@ -68,19 +131,20 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + has_error = lib.cigemmlt_32( + ctx, + m, + n, + k, + A.data_ptr(), + B.data_ptr(), + out.data_ptr(), + None, + lda, + ldb, + ldc, + _get_raw_stream(A.device.index), + ) if has_error: if has_error == 100: @@ -114,28 +178,31 @@ def _( dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + if A.dtype != torch.int32: + raise ValueError(f"A must be int32, got {A.dtype}") + if row_stats.dtype != torch.float32: + raise ValueError(f"row_stats must be float32, got {row_stats.dtype}") + if col_stats.dtype != torch.float32: + raise ValueError(f"col_stats must be float32, got {col_stats.dtype}") # Note: cuda kernel only currently supports fp16 output. # We'll later cast to desired dtype if needed. out = torch.empty_like(A, dtype=torch.float16) - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - # Note: fused bias in the kernel is only supported for fp16 # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + bias_ptr = bias.data_ptr() if bias is not None and bias.dtype == torch.float16 else None with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + A.data_ptr(), + row_stats.data_ptr(), + col_stats.data_ptr(), + out.data_ptr(), + bias_ptr, + A.numel() // A.shape[-1], + A.shape[-1], + _get_raw_stream(A.device.index), ) # Add bias separately if not fused in kernel @@ -147,10 +214,12 @@ def _( @register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + if A.dtype != torch.float16: + raise ValueError(f"A must be float16, got {A.dtype}") + if threshold < 0.0: + raise ValueError("threshold must be non-negative") - rows = prod(A.shape[:-1]) + rows = A.numel() // A.shape[-1] cols = A.shape[-1] row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) @@ -170,13 +239,13 @@ def _(A: torch.Tensor, threshold=0.0): with _cuda_device_of(A): lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), + A.data_ptr(), + out_row.data_ptr(), + row_stats.data_ptr(), + threshold, + rows, + cols, + _get_raw_stream(A.device.index), ) # Zero out values from outlier columns across all rows. @@ -211,7 +280,8 @@ def _get_col_absmax( A: torch.Tensor, threshold=0.0, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) + if not A.is_floating_point(): + raise ValueError(f"A must be a floating point tensor, got {A.dtype}") outlier_mask = None @@ -231,36 +301,36 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: A = A.contiguous() - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + if code.dtype != torch.float32: + raise ValueError(f"code must be float32, got {code.dtype}") + if blocksize not in (64, 128, 256, 512, 1024, 2048, 4096): + raise ValueError(f"invalid blocksize {blocksize}") n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) + if A.dtype == torch.float32: + fn = lib.cquantize_blockwise_fp32 + elif A.dtype == torch.float16: + fn = lib.cquantize_blockwise_fp16 + elif A.dtype == torch.bfloat16: + fn = lib.cquantize_blockwise_bf16 + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), + fn( + code.data_ptr(), + A.data_ptr(), + absmax.data_ptr(), + out.data_ptr(), + blocksize, + n, ) - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - return out, absmax @@ -280,8 +350,10 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + if out.dtype != dtype: + raise ValueError(f"Expected out.dtype == {dtype}, got {out.dtype}") + if out.shape != A.shape: + raise ValueError(f"Expected out.shape == {A.shape}, got {out.shape}") _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) @@ -289,77 +361,64 @@ def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: A = A.contiguous() - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) + if dtype == torch.float32: + fn = lib.cdequantize_blockwise_fp32 + elif dtype == torch.float16: + fn = lib.cdequantize_blockwise_fp16 + elif dtype == torch.bfloat16: + fn = lib.cdequantize_blockwise_bf16 + else: + raise ValueError(f"Blockwise dequantization only supports 16/32-bit floats, but got {dtype}") with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), + fn( + code.data_ptr(), + A.data_ptr(), + absmax.data_ptr(), + out.data_ptr(), + blocksize, + A.numel(), + _get_raw_stream(A.device.index), ) - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - @register_kernel("bitsandbytes::quantize_4bit", "cuda") def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: A = A.contiguous() - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + fn = lib.cquantize_blockwise_bf16_fp4 + else: + fn = lib.cquantize_blockwise_bf16_nf4 + elif A.dtype == torch.float16: + if quant_type == "fp4": + fn = lib.cquantize_blockwise_fp16_fp4 + else: + fn = lib.cquantize_blockwise_fp16_nf4 + elif A.dtype == torch.float32: + if quant_type == "fp4": + fn = lib.cquantize_blockwise_fp32_fp4 + else: + fn = lib.cquantize_blockwise_fp32_nf4 + with _cuda_device_of(A): - args = ( + fn( None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int32(n), + A.data_ptr(), + absmax.data_ptr(), + out.data_ptr(), + blocksize, + n, ) - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - return out, absmax @@ -387,8 +446,10 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + if out.shape != tuple(shape): + raise ValueError(f"Expected out.shape == {shape}, got {out.shape}") + if out.dtype != dtype: + raise ValueError(f"Expected out.dtype == {dtype}, got {out.dtype}") _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) @@ -401,41 +462,36 @@ def _dequantize_4bit_impl( out: torch.Tensor, ) -> None: A = A.contiguous() - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) + if dtype == torch.bfloat16: + if quant_type == "fp4": + fn = lib.cdequantize_blockwise_bf16_fp4 + else: + fn = lib.cdequantize_blockwise_bf16_nf4 + elif dtype == torch.float16: + if quant_type == "fp4": + fn = lib.cdequantize_blockwise_fp16_fp4 + else: + fn = lib.cdequantize_blockwise_fp16_nf4 + elif dtype == torch.float32: + if quant_type == "fp4": + fn = lib.cdequantize_blockwise_fp32_fp4 + else: + fn = lib.cdequantize_blockwise_fp32_nf4 + else: + raise ValueError(f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}") with _cuda_device_of(A): - args = ( + fn( None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int32(out.numel()), - _get_tensor_stream(A), + A.data_ptr(), + absmax.data_ptr(), + out.data_ptr(), + blocksize, + out.numel(), + _get_raw_stream(A.device.index), ) - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - @register_kernel("bitsandbytes::gemv_4bit", "cuda") def _( @@ -457,11 +513,11 @@ def _( blocksize: int, out: torch.Tensor, ) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + expected_shape = (*A.shape[:-1], shapeB[0]) + if out.shape != expected_shape: + raise ValueError(f"Expected out.shape == {expected_shape}, got {out.shape}") + if out.dtype != A.dtype: + raise ValueError(f"Expected out.dtype == {A.dtype}, got {out.dtype}") _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) @@ -474,7 +530,8 @@ def _gemv_4bit_impl( blocksize: int, out: torch.Tensor, ) -> None: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") + if blocksize not in (32, 64, 128, 256, 512, 1024, 2048, 4096): + raise ValueError(f"invalid blocksize {blocksize}") # Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now. # torch._check( @@ -492,65 +549,37 @@ def _gemv_4bit_impl( # torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") # torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) + m = shapeB[0] + n = 1 + k = shapeB[1] lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldb = (A.shape[-1] + 1) // 2 ldc = m - stream = _get_tensor_stream(A) + if A.dtype == torch.float16: + fn = lib.cgemm_4bit_inference_naive_fp16 + elif A.dtype == torch.bfloat16: + fn = lib.cgemm_4bit_inference_naive_bf16 + elif A.dtype == torch.float32: + fn = lib.cgemm_4bit_inference_naive_fp32 with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) + fn( + m, + n, + k, + A.data_ptr(), + B.data_ptr(), + absmax.data_ptr(), + code.data_ptr(), + out.data_ptr(), + lda, + ldb, + ldc, + blocksize, + _get_raw_stream(A.device.index), + ) @functools.cache @@ -806,9 +835,13 @@ def _( use_custom = M <= 4 or _gemm_4bit_use_custom(A.device.index, A.dtype, M, N, K) if not use_custom: - return _gemm_4bit_default_impl( - A, B, shapeB, absmax, blocksize, quant_type, bias, absmax_8bit, absmax_code, absmax_offset - ) + if absmax_8bit is not None: + absmax_dq = torch.empty_like(absmax_8bit, dtype=torch.float32) + _dequantize_blockwise_impl(absmax_8bit, absmax, absmax_code, 256, torch.float32, out=absmax_dq) + absmax = absmax_dq + absmax_offset + B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device) + _dequantize_4bit_impl(B, absmax, blocksize, quant_type, A.dtype, out=B_dq) + return torch.nn.functional.linear(A, B_dq, bias) if K != shapeB[1]: raise RuntimeError(f"A inner dim ({K}) does not match weight ({shapeB[1]})") @@ -823,7 +856,7 @@ def _( quant_type_int = 1 if quant_type == "fp4" else 2 out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device) - stream = torch._C._cuda_getCurrentRawStream(A.device.index) + stream = _get_raw_stream(A.device.index) if A.dtype == torch.bfloat16: fn = lib.cgemm_4bit_bf16 diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 2f276edc6..80d86321f 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from functools import wraps -from math import prod, sqrt +from math import sqrt from typing import Optional import torch @@ -43,9 +43,12 @@ def _( dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + if A.dtype != torch.int32: + raise ValueError(f"A must be int32, got {A.dtype}") + if row_stats.dtype != torch.float32: + raise ValueError(f"row_stats must be float32, got {row_stats.dtype}") + if col_stats.dtype != torch.float32: + raise ValueError(f"col_stats must be float32, got {col_stats.dtype}") A_calc = A.view(-1, A.shape[-1]) row_stats = row_stats.reshape(-1).unsqueeze(-1) @@ -123,7 +126,8 @@ def _(A: torch.Tensor, B: torch.Tensor): @register_kernel("bitsandbytes::int8_linear_matmul.out", "default") def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - torch._check(out.dtype == torch.int32) + if out.dtype != torch.int32: + raise ValueError(f"out must be int32, got {out.dtype}") _int8_linear_matmul_impl(A, B, out) @@ -137,7 +141,7 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor @register_kernel("bitsandbytes::int8_vectorwise_quant", "default") def _(A: torch.Tensor, threshold=0.0): - rows = prod(A.shape[:-1]) + rows = A.numel() // A.shape[-1] outlier_cols = None outlier_restore = None @@ -175,8 +179,6 @@ def _(A: torch.Tensor, threshold=0.0): @register_kernel("bitsandbytes::quantize_blockwise", "default") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - n = A.numel() rem = n % blocksize has_rem = rem > 0 @@ -201,9 +203,6 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor @register_kernel("bitsandbytes::dequantize_blockwise", "default") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - out = code[A.reshape(-1).int()] blocks = out.shape[-1] // blocksize res = out.shape[-1] % blocksize @@ -220,13 +219,6 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - n = A.numel() full_blocks = n // blocksize rem = n % blocksize @@ -317,13 +309,6 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py index 2844df731..645687598 100644 --- a/bitsandbytes/backends/hpu/ops.py +++ b/bitsandbytes/backends/hpu/ops.py @@ -25,12 +25,10 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}") - torch._check( - A.dtype in [torch.bfloat16, torch.uint8], - lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}", - ) + if quant_type != "nf4": + raise ValueError(f"HPU backend only supports quant_type 'nf4', got {quant_type!r}") + if A.dtype not in (torch.bfloat16, torch.uint8): + raise ValueError(f"HPU backend only supports uint8 or bfloat16 storage, got {A.dtype}") # Enable non uint8 dtype if A.dtype != torch.uint8: diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py index bbf2fed5e..0e3186663 100644 --- a/bitsandbytes/backends/mps/ops.py +++ b/bitsandbytes/backends/mps/ops.py @@ -42,8 +42,8 @@ def _( quant_type: str, quant_storage: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [64, 128, 256, 512]) - torch._check(quant_type in ("fp4", "nf4")) + if blocksize not in (64, 128, 256, 512): + raise ValueError(f"MPS backend only supports blocksize in (64, 128, 256, 512), got {blocksize}") k = _get_kernel() packed, absmax = k.quantize_4bit(A.contiguous(), blocksize, _QUANT_MAP[quant_type]) @@ -82,8 +82,8 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check(blocksize in [64, 128, 256, 512]) - torch._check(quant_type in ("fp4", "nf4")) + if blocksize not in (64, 128, 256, 512): + raise ValueError(f"MPS backend only supports blocksize in (64, 128, 256, 512), got {blocksize}") return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 6b1a2904b..2dfe7e758 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -15,7 +15,6 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") with torch_accelerator_module.device(A.device): out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A.contiguous(), code, blocksize) @@ -25,8 +24,8 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t def dequantize_blockwise( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + if A.dtype != torch.uint8: + raise ValueError(f"A must be uint8, got {A.dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") with torch_accelerator_module.device(A.device): out = kernels_8bit_quant.dequant_8bit_blockwise( @@ -47,11 +46,14 @@ def dequantize_blockwise_inplace( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + if A.dtype != torch.uint8: + raise ValueError(f"A must be uint8, got {A.dtype}") + if out.shape != A.shape: + raise ValueError(f"Expected out.shape == {A.shape}, got {out.shape}") + if out.device != A.device: + raise ValueError(f"Expected out.device == {A.device}, got {out.device}") + if out.dtype != dtype: + raise ValueError(f"Expected out.dtype == {dtype}, got {out.dtype}") with torch_accelerator_module.device(A.device): kernels_8bit_quant.dequant_8bit_blockwise( @@ -67,12 +69,9 @@ def dequantize_blockwise_inplace( def quantize_4bit( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) + if A.dtype not in (torch.bfloat16, torch.float16, torch.float32): + raise ValueError(f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}") n = A.numel() @@ -109,12 +108,9 @@ def dequantize_4bit( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}") # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}") - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) + if dtype not in (torch.bfloat16, torch.float16, torch.float32): + raise ValueError(f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}") # torch._check( # A.dtype == torch.uint8, # lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}", @@ -139,8 +135,10 @@ def dequantize_4bit_inplace( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + if out.shape != tuple(shape): + raise ValueError(f"Expected out.shape == {shape}, got {out.shape}") + if out.dtype != dtype: + raise ValueError(f"Expected out.dtype == {dtype}, got {out.dtype}") with torch_accelerator_module.device(A.device): kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 58629f2a8..731200c53 100644 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -251,8 +251,10 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + if out.dtype != dtype: + raise ValueError(f"Expected out.dtype == {dtype}, got {out.dtype}") + if out.shape != A.shape: + raise ValueError(f"Expected out.shape == {A.shape}, got {out.shape}") _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) @register_kernel("bitsandbytes::gemv_4bit", "xpu") @@ -279,11 +281,11 @@ def _( blocksize: int, out: torch.Tensor, ) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + expected_shape = (*A.shape[:-1], shapeB[0]) + if out.shape != expected_shape: + raise ValueError(f"Expected out.shape == {expected_shape}, got {out.shape}") + if out.dtype != A.dtype: + raise ValueError(f"Expected out.dtype == {A.dtype}, got {out.dtype}") _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) elif triton_available: diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 70df070d7..7796a8e84 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -114,29 +114,6 @@ def __init__(self, lib: ct.CDLL): lib.get_context.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p - # argtypes for the 4-bit GEMM entry points. - _gemm4bit_argtypes = [ - ct.c_void_p, # A - ct.c_void_p, # B - ct.c_void_p, # absmax - ct.c_void_p, # absmax_8bit - ct.c_void_p, # absmax_code - ct.c_void_p, # absmax_offset - ct.c_void_p, # out - ct.c_void_p, # bias - ct.c_int32, # M - ct.c_int32, # N - ct.c_int32, # K - ct.c_int32, # blocksize - ct.c_int32, # quant_type - ct.c_void_p, # stream - ] - for _fn_name in ("cgemm_4bit_bf16", "cgemm_4bit_fp16", "cgemm_4bit_fp32"): - _fn = getattr(lib, _fn_name, None) - if _fn is not None: - _fn.argtypes = _gemm4bit_argtypes - _fn.restype = None - class XpuBNBNativeLibrary(BNBNativeLibrary): """XPU native library with SYCL USM paged memory support.""" diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d4ee98652..bb56e9d50 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -391,10 +391,10 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. - if tensor.device.type == "xpu": - return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index)) if tensor.device.type == "cuda": return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) + if tensor.device.type == "xpu": + return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index)) # For CPU tensors (e.g. paged optimizer states), use current device's stream. if hasattr(torch, "xpu") and torch.xpu.is_available(): return ct.c_void_p(torch._C._xpu_getCurrentRawStream(torch.xpu.current_device())) @@ -644,6 +644,11 @@ def quantize_blockwise( - [`QuantState`]: The state object used to undo the quantization. """ + if blocksize <= 0: + raise ValueError(f"blocksize must be positive, got {blocksize}") + if A.dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -721,7 +726,10 @@ def dequantize_blockwise( The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. """ - assert quant_state is not None or absmax is not None + if quant_state is None and absmax is None: + raise ValueError("dequantize_blockwise requires either quant_state or absmax") + if A.dtype != torch.uint8: + raise ValueError(f"A must be uint8, got {A.dtype}") if code is None and quant_state is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -730,6 +738,9 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + if quant_state.blocksize <= 0: + raise ValueError(f"blocksize must be positive, got {quant_state.blocksize}") + absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) @@ -905,6 +916,13 @@ def quantize_4bit( if blocksize is None: blocksize = 64 + if blocksize not in (32, 64, 128, 256, 512, 1024, 2048, 4096): + raise ValueError(f"invalid blocksize {blocksize}") + if quant_type not in ("nf4", "fp4"): + raise ValueError(f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}") + if A.dtype not in (torch.bfloat16, torch.float16, torch.float32): + raise ValueError(f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}") + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1009,7 +1027,8 @@ def dequantize_4bit( blocksize = 64 if quant_state is None: - assert absmax is not None and out is not None + if absmax is None or out is None: + raise ValueError("dequantize_4bit requires both absmax and out when quant_state is not provided") quant_state = QuantState( absmax=absmax, @@ -1022,6 +1041,13 @@ def dequantize_4bit( else: absmax = quant_state.absmax + if quant_state.blocksize not in (32, 64, 128, 256, 512, 1024, 2048, 4096): + raise ValueError(f"invalid blocksize {quant_state.blocksize}") + if quant_state.quant_type not in ("nf4", "fp4"): + raise ValueError(f"quant_type must be 'nf4' or 'fp4', got {quant_state.quant_type!r}") + if quant_state.dtype not in (torch.bfloat16, torch.float16, torch.float32): + raise ValueError(f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {quant_state.dtype}") + if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset