[Flux2] Align prompt embedding prep with attention mask and padded sequence length by pei0033 · Pull Request #6013 · modular/modular
…0028) [External] [MAX] Align FLUX.2-Klein prompt masking with diffusers <!-- Thanks for submitting a pull request! Your contribution is appreciated. Please fill out the sections below to help reviewers understand your change. For guidance on writing a good PR, see our [contributor guide](../CONTRIBUTING.md). --> ## Summary This PR adds an actual tokenizer-mask path for the FLUX.2-Klein text encoder. ## Problems MAX prompt embedding behavior is currently not aligned with diffusers across the FLUX family, and the mismatch is especially visible for FLUX.2-Klein. Current behavior by model family: 1. `flux1` (`t5`, `clip`) - Tokens are padded to `max_length` with pad tokens. - The text encoders do not consume an attention mask. 2. `flux2-dev` (`mistral`) - Tokens are left-padded to `max_length`. - Diffusers passes both padded tokens and the corresponding attention mask to the text encoder. - Because text-encoder masking was not implemented, PR #6013 worked around this by compacting tokens before encoding and then left-zero-padding the returned hidden states. 3. `flux2-klein` (`qwen3`) - Tokens are right-padded to `max_length`. - Diffusers passes both padded tokens and the corresponding attention mask to the text encoder. - The `flux2-dev` workaround does not apply here: with right padding, the hidden states at pad-token positions are not zero, so compacting and then padding hidden states back does not reproduce diffusers behavior. <img width="2462" height="932" alt="image" src="https://github.com/user-attachments/assets/33748187-14d2-45ee-a748-6f6eefcc1f96" /> As a result, the current MAX behavior for FLUX.2-Klein does not match diffusers prompt conditioning semantics and produces visibly different image quality. ### What Chaged Before this change, MAX did not have an end-to-end path for arbitrary token masks here. In practice, the behavior was effectively pad-token-only: - `PixelGenerationTokenizer` can produce an `attention_mask`, - `PixelContext` can carry that mask, - but `Flux2KleinPipeline` does not pass it into prompt encoding, - `Qwen3TextEncoderModel.__call__()` rejects `attention_mask`, - and the compiled Qwen3 text attention path is causal-only. The implementation in this PR: - adds a shared causal + token-mask helper in`max.pipelines.dataprocessing`, - builds the additive attention bias from that mask on the host before the compiled Qwen3 text-encoder call, ## Testing ### Latency on H200 Text Encoder latency slow down (**54 -> 72 ms**) Ran: ```bash ./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "black-forest-labs/FLUX.2-klein-4B" --prompt 'Four elements split into quadrants: top-left shows splashing water forming the word "WATER" on a water background, top-right shows soil forming "EARTH" with planet earth behind, bottom-left shows colorful clouds forming "AIR" at sunset, bottom-right shows fiery lava forming "FIRE" against the sun' --num-inference-steps 4 --guidance-scale 1.0 --seed 42 --profile-timings --num-profile-iterations 5 --num-warmups 1 ``` baseline on main (with PR #6150): ``` ==================== PROFILING REPORT == Component Timings: components calls total avg (ms) component/transformer 20 1388.121 69.406 component/text_encoder 5 273.029 54.606 Method Timings: methods calls total avg (ms) E2E execute 5 2152.649 430.530 component/transformer 20 1388.121 69.406 decode_latents 5 478.375 95.675 prepare_embeddings 5 273.368 54.674 component/text_encoder 5 273.029 54.606 scheduler_step 20 9.151 0.458 preprocess_latents 5 0.858 0.172 prepare_scheduler 5 0.395 0.079 ========================================== Generation complete! ``` PR result: ``` ==================== PROFILING REPORT == Component Timings: components calls total avg (ms) component/transformer 20 1389.771 69.489 component/text_encoder 5 360.039 72.008 Method Timings: methods calls total avg (ms) E2E execute 5 2239.572 447.914 component/transformer 20 1389.771 69.489 decode_latents 5 480.221 96.044 prepare_embeddings 5 360.398 72.080 component/text_encoder 5 360.039 72.008 scheduler_step 20 5.576 0.279 preprocess_latents 5 0.763 0.153 prepare_scheduler 5 0.416 0.083 ========================================== Generation complete! ``` ### Quality **Prompt** ``` Four elements split into quadrants: top-left shows splashing water forming the word \"WATER\" on a water background, top-right shows soil forming \"EARTH\" with planet earth behind, bottom-left shows colorful clouds forming \"AIR\" at sunset, bottom-right shows fiery lava forming \"FIRE\" against the sun ``` <table> <tr> <td align="center"> <img width="1024" height="1024" alt="diffusers" src="https://github.com/user-attachments/assets/0649854f-bedd-4d9d-94bf-03f74938e55f" /> <br/>diffusers with same latent </td> <td align="center"> <img width="1024" height="1024" alt="max_main" src="https://github.com/user-attachments/assets/828800db-28b3-4421-99e2-2f2a312f7348" /> <br/>max main branch </td> <td align="center"> <img width="1024" height="1024" alt="max_pr" src="https://github.com/user-attachments/assets/4ef02185-4e3f-4a8f-b828-5cf727f931e2" /> <br/>max current PR </td> </tr> </table> ## Checklist - [x] PR is small and focused — consider splitting larger changes into a sequence of smaller PRs - [x] I ran `./bazelw run format` to format my changes - [x] I added or updated tests to cover my changes - [x] If AI tools assisted with this contribution, I have included an `Assisted-by:` trailer in my commit message or this PR description (see [AI Tool Use Policy](../AI_TOOL_POLICY.md)) ``` Assisted-by Codex ``` ORIGINAL_AUTHOR=pei0033 <parkeunik@naver.com> ORIGINAL_USER=@pei0033 Co-authored-by: pei0033 <parkeunik@naver.com> Closes #6153 MODULAR_ORIG_COMMIT_REV_ID: 1dab4a13d1c2dd8d4c73d00026a82c10f4b8b1e7