Moved int8_mm_dequant from CPU to default backend by Egor-Krivov · Pull Request #1626 · bitsandbytes-foundation/bitsandbytes
@@ -1,6 +1,5 @@
from collections.abc import Sequence
import ctypes as ct
from typing import Optional
import torch
Expand All @@ -24,29 +23,6 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0])
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu") def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
A_calc = A.view(-1, A.shape[-1]) row_stats = row_stats.reshape(-1).unsqueeze(-1) col_stats = col_stats.reshape(-1).unsqueeze(0)
out = A_calc * (row_stats * col_stats) * 6.200124e-05 if bias is not None: out += bias
return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::quantize_blockwise", "cpu") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) Expand Down
import torch
Expand All @@ -24,29 +23,6 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0])
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu") def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
A_calc = A.view(-1, A.shape[-1]) row_stats = row_stats.reshape(-1).unsqueeze(-1) col_stats = col_stats.reshape(-1).unsqueeze(0)
out = A_calc * (row_stats * col_stats) * 6.200124e-05 if bias is not None: out += bias
return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::quantize_blockwise", "cpu") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) Expand Down