[Kernels][GPU] Fix CPU eager fallback for flash_attention_gpu no-cache path by bro4all · Pull Request #6317 · modular/modular
BEGIN_PUBLIC [Kernel][GPU] Fix flash attention CPU fallback Fix flash_attention_gpu so eager CPU tensors use the CPU no-cache path instead of dispatching into the GPU-only custom op. Keep the public Python API stable, add same-device validation, reject unsupported CPU valid_length usage explicitly, and add focused CPU eager regression tests. Preserve the existing padded GPU path by keeping the valid_length fast path unchanged. Fixes modular#6287 END_PUBLIC Assisted-by: AI Signed-off-by: Omar Habra <omarbro4all@gmail.com>
BEGIN_PUBLIC [Kernel][GPU] Fix flash attention test typing Adjust the new flash attention eager regression helpers to use graph-side TensorValue annotations inside @functional wrappers so mypy accepts the calls to flash_attention_gpu. END_PUBLIC Assisted-by: AI Signed-off-by: Omar Habra <omarbro4all@gmail.com>
bro4all
changed the title
Fix CPU eager fallback for flash_attention_gpu no-cache path
[Kernels][GPU] Fix CPU eager fallback for flash_attention_gpu no-cache path
bro4all
marked this pull request as ready for review
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters