Add CUDA kernel support for 4-bit quantization with blocksize=32 by Abdennacer-Badaoui · Pull Request #1854 · bitsandbytes-foundation/bitsandbytes

Description

Implements specialized CUDA kernel to support blocksize=32 for 4-bit quantization (FP4/NF4), addressing feature request in #986.
Smaller block sizes provide better quantization accuracy by computing separate scaling factors for smaller groups of values, reducing quantization error at the cost of slightly increased metadata overhead.

Key Changes

New quantization kernel (kQuantizeBlockwise32):

  • Optimized for blocksize=32, processes 2 blocks per warp (32 threads)
  • Threads 0-15 handle block 0, threads 16-31 handle block 1
  • Each block computes independent scale factor for finer granularity

Dequantization: Reuses existing generic kernel with proper dual-scale lookup

Testing: Extended test suites in test_functional.py, test_linear4bit.py and tests/test_ops.py

Quick comparaison

Test configuration: torch.float16, CUDA, averaged over 1000 runs per shape

FP4 Quantization Error Comparison

Shape Blocksize=64 Blocksize=32 Improvement
1K×1K 0.096540 0.088918 +7.9%
2K×2K 0.096548 0.088919 +7.9%
4K×4K 0.096545 0.088919 +7.9%
8K×4K 0.096545 0.088919 +7.9%
1K×768 (LLaMA-like) 0.096547 0.088918 +7.9%
4K×11K (LLaMA FFN) 0.096546 0.088920 +7.9%

NF4 Quantization Error Comparison

Shape Blocksize=64 Blocksize=32 Improvement
1K×1K 0.072798 0.067750 +6.9%
2K×2K 0.072795 0.067748 +6.9%
4K×4K 0.072795 0.067747 +6.9%
8K×4K 0.072795 0.067748 +6.9%
1K×768 (LLaMA-like) 0.072793 0.067749 +6.9%
4K×11K (LLaMA FFN) 0.072795 0.067748 +6.9%