Broaden GEMV split-K tuning paths by prsabahrami · Pull Request #6337 · modular/modular
def _k_scalar_iter():
var kc = iteration * tile_k + Int(thread_idx.x) * simd_width
if kc < k:
var valid = min(simd_width, k - kc)
comptime for i in range(tile_n): comptime if check_bounds: if i + tile_id_n >= n: continue var bv = SIMD[b_type, simd_width](0) var wr = block_idx.y * tile_n + Int(i) var w_base = wr * k + kc comptime for el in range(simd_width): if Int(el) < valid: bv[el] = weight.ptr[w_base + Int(el)] tile_w.store(i, 0, bv)
comptime for i in range(tile_m): comptime if check_bounds: if i + tile_id_m >= m: continue comptime NativeVecType = SIMD[a_type, simd_width] var act_native = SIMD[a_type, simd_width](0) var a_base = (block_idx.x * tile_m + Int(i)) * k + kc comptime for el in range(simd_width): if Int(el) < valid: act_native[el] = act.ptr[a_base + Int(el)] comptime for j in range(tile_n): var weight_native = rebind[NativeVecType]( tile_w.vectorize[1, simd_width]()[j, 0] )
comptime for i in range(tile_n): comptime if check_bounds: if i + tile_id_n >= n: continue var bv = SIMD[b_type, simd_width](0) var wr = block_idx.y * tile_n + Int(i) var w_base = wr * k + kc comptime for el in range(simd_width): if Int(el) < valid: bv[el] = weight.ptr[w_base + Int(el)] tile_w.store(i, 0, bv)
comptime for i in range(tile_m): comptime if check_bounds: if i + tile_id_m >= m: continue comptime NativeVecType = SIMD[a_type, simd_width] var act_native = SIMD[a_type, simd_width](0) var a_base = (block_idx.x * tile_m + Int(i)) * k + kc comptime for el in range(simd_width): if Int(el) < valid: act_native[el] = act.ptr[a_base + Int(el)] comptime for j in range(tile_n): var weight_native = rebind[NativeVecType]( tile_w.vectorize[1, simd_width]()[j, 0] )