[Kernels][GPU] Fix CPU eager fallback for flash_attention_gpu no-cache path by bro4all · Pull Request #6317 · modular/modular

@bro4all

BEGIN_PUBLIC
[Kernel][GPU] Fix flash attention CPU fallback

Fix flash_attention_gpu so eager CPU tensors use the CPU no-cache path
instead of dispatching into the GPU-only custom op.

Keep the public Python API stable, add same-device validation, reject
unsupported CPU valid_length usage explicitly, and add focused CPU eager
regression tests. Preserve the existing padded GPU path by keeping the
valid_length fast path unchanged.

Fixes modular#6287
END_PUBLIC

Assisted-by: AI
Signed-off-by: Omar Habra <omarbro4all@gmail.com>

@bro4all

BEGIN_PUBLIC
[Kernel][GPU] Fix flash attention test typing

Adjust the new flash attention eager regression helpers to use graph-side
TensorValue annotations inside @functional wrappers so mypy accepts the
calls to flash_attention_gpu.
END_PUBLIC

Assisted-by: AI
Signed-off-by: Omar Habra <omarbro4all@gmail.com>

@bro4all bro4all changed the title Fix CPU eager fallback for flash_attention_gpu no-cache path [Kernels][GPU] Fix CPU eager fallback for flash_attention_gpu no-cache path

Mar 31, 2026

@bro4all bro4all marked this pull request as ready for review

March 31, 2026 18:38