Using BitBLAS from DSL
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.base.arch import CUDA from bitblas.base.utils import apply_and_build @tvm.script.ir_module class MatmulNT: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, [M, K], dtype=in_dtype) B = T.match_buffer(b, [N, K], dtype=in_dtype) C = T.match_buffer(c, [M, N], dtype=out_dtype) for i, j, k in T.grid(M, N, K): with T.block("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = tvm.tir.const(0, out_dtype) C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ vj, vk ].astype(out_dtype) ir_module = MatmulNT func = ir_module["main"] target = tvm.target.Target("nvidia/nvidia-a100") arch = CUDA(target)
Get tuning policy and candidates:
# Tune with SIMT Cuda Core policy = DefaultPolicy(func=func, arch=arch) try: tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) except Exception: tags = None # Tune with Tensor Core if possible if tags: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) configs = policy.emit_config(topk=20) ''' [BitBLAS] Evaluation with config {'block': [64, 64], 'warp': [32, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.032 ms [BitBLAS] Evaluation with config {'block': [32, 128], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.021 ms [BitBLAS] Evaluation with config {'block': [128, 32], 'warp': [64, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.023 ms [BitBLAS] Evaluation with config {'block': [32, 32], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.023 ms [BitBLAS] Evaluation with config {'block': [32, 64], 'warp': [16, 32], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.027 ms [BitBLAS] Evaluation with config {'block': [64, 32], 'warp': [32, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.025 ms [BitBLAS] Evaluation with config {'block': [64, 128], 'warp': [32, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.023 ms [BitBLAS] Evaluation with config {'block': [128, 64], 'warp': [64, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.025 ms [BitBLAS] Evaluation with config {'block': [16, 64], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.037 ms [BitBLAS] Evaluation with config {'block': [64, 16], 'warp': [16, 16], 'rstep': [128], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.037 ms [BitBLAS] Evaluation with config {'block': [128, 128], 'warp': [64, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.026 ms [BitBLAS] Evaluation with config {'block': [16, 128], 'warp': [16, 32], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.043 ms [BitBLAS] Evaluation with config {'block': [128, 16], 'warp': [32, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.042 ms [BitBLAS] Evaluation with config {'block': [32, 256], 'warp': [16, 128], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.025 ms [BitBLAS] Evaluation with config {'block': [256, 32], 'warp': [128, 16], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.029 ms [BitBLAS] Evaluation with config {'block': [64, 256], 'warp': [32, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.028 ms [BitBLAS] Evaluation with config {'block': [256, 64], 'warp': [128, 32], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.027 ms [BitBLAS] Evaluation with config {'block': [128, 256], 'warp': [64, 128], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.044 ms [BitBLAS] Evaluation with config {'block': [256, 128], 'warp': [128, 64], 'rstep': [32], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.040 ms [BitBLAS] Evaluation with config {'block': [16, 256], 'warp': [16, 64], 'rstep': [64], 'use_tc': True, 'vectorize': {'A_reindex': 8, 'B_reindex': 8}} [BitBLAS] Time cost of this config: 0.047 ms '''
Apply and build and get best code generation result:
cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) # get the best code generation result. print(best.code) ''' extern "C" __global__ void __launch_bounds__(128) default_function_kernel(half* __restrict__ A, half* __restrict__ B, half* __restrict__ C) { ... } '''
we also provide something interesting with DSL.
Auto Tensorization
Say we currently have two policies, one is for SIMT Cuda Core, another is for TensorCore. The decision to utilize a TensorCore policy over a SIMT Cuda Core policy can be enhanced by the integration of an auto-tensorization strategy, it allows BitBLAS to automatically select if the DSL Expression can uitlize TensorCore.
# Assume func is conv2d, after this invocation, the tensorized_func is the tensorized version of the conv2d, otherwise, the tags is None. tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target)
Tune with dynamic symbolic
As in LLM Serving, the input shape is dynamic, we can use the dynamic symbolic to generate high performance kernel with dynamic shape.
@tvm.script.ir_module class MatmulNT: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "main", "tir.noalias": True}) m = T.int32() A = T.match_buffer(a, [m, K], dtype=in_dtype) B = T.match_buffer(b, [N, K], dtype=in_dtype) C = T.match_buffer(c, [m, N], dtype=out_dtype) for i, j, k in T.grid(m, N, K): with T.block("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = tvm.tir.const(0, out_dtype) C[vi, vj] = C[vi, vj] + A[vi, vk].astype(out_dtype) * B[ vj, vk ].astype(out_dtype) from bitblas import fast_tune_with_dynamic_range # Tune with dynamic symbolic scheduled_ir_module = fast_tune_with_dynamic_range( func, target, topk=topk, parallel_build=True, dynamic_range={ "M": [1, 1024] } ) # finally, we will generate a dispatch func to dispatch the kernel with dynamic symbolic. ''' @IRModule class MatmulNT: def matmul_nt_opt_m_1(A: Tensor, T_reshape: Tensor, m: int): ... def matmul_nt_opt_m_256(A: Tensor, T_reshape: Tensor, m: int): ... def dispatcher(args): if m <= 1: matmul_nt_opt_m_1(A.data, T_reshape.data, m) if m > 1 and m <= 256: matmul_nt_opt_m_256(A.data, T_reshape.data, m) if m > 256: matmul_nt_m_256(A.data, T_reshape.data, m) '''
You can find some example dsl implementation in python/bitblas/ops/impl and benchmark/dsl, see more examples and tutorials in apache/tvm
