[MAX] Align FLUX.2-Klein prompt masking with diffusers by pei0033 · Pull Request #6153 · modular/modular
Summary
This PR adds an actual tokenizer-mask path for the FLUX.2-Klein text encoder.
Problems
MAX prompt embedding behavior is currently not aligned with diffusers across the FLUX family, and the mismatch is especially visible for FLUX.2-Klein.
Current behavior by model family:
-
flux1(t5,clip)- Tokens are padded to
max_lengthwith pad tokens. - The text encoders do not consume an attention mask.
- Tokens are padded to
-
flux2-dev(mistral)- Tokens are left-padded to
max_length. - Diffusers passes both padded tokens and the corresponding attention mask to the text encoder.
- Because text-encoder masking was not implemented, PR [Flux2] Align prompt embedding prep with attention mask and padded sequence length #6013 worked around this by compacting tokens before encoding and then left-zero-padding the returned hidden states.
- Tokens are left-padded to
-
flux2-klein(qwen3)- Tokens are right-padded to
max_length. - Diffusers passes both padded tokens and the corresponding attention mask to the text encoder.
- The
flux2-devworkaround does not apply here: with right padding, the hidden states at pad-token positions are not zero, so compacting and then padding hidden states back does not reproduce diffusers behavior.
- Tokens are right-padded to
As a result, the current MAX behavior for FLUX.2-Klein does not match diffusers prompt conditioning semantics and produces visibly different image quality.
What Chaged
Before this change, MAX did not have an end-to-end path for arbitrary token masks here. In practice, the behavior was effectively pad-token-only:
PixelGenerationTokenizercan produce anattention_mask,PixelContextcan carry that mask,- but
Flux2KleinPipelinedoes not pass it into prompt encoding, Qwen3TextEncoderModel.__call__()rejectsattention_mask,- and the compiled Qwen3 text attention path is causal-only.
The implementation in this PR:
- adds a shared causal + token-mask helper in
max.pipelines.dataprocessing, - builds the additive attention bias from that mask on the host before the
compiled Qwen3 text-encoder call,
Testing
Latency
on H200 Text Encoder latency slow down (54 -> 72 ms)
Ran:
./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "black-forest-labs/FLUX.2-klein-4B" --prompt 'Four elements split into quadrants: top-left shows splashing water forming the word "WATER" on a water background, top-right shows soil forming "EARTH" with planet earth behind, bottom-left shows colorful clouds forming "AIR" at sunset, bottom-right shows fiery lava forming "FIRE" against the sun' --num-inference-steps 4 --guidance-scale 1.0 --seed 42 --profile-timings --num-profile-iterations 5 --num-warmups 1
baseline on main (with PR #6150):
==================== PROFILING REPORT ==
Component Timings:
components calls total avg (ms)
component/transformer 20 1388.121 69.406
component/text_encoder 5 273.029 54.606
Method Timings:
methods calls total avg (ms)
E2E execute 5 2152.649 430.530
component/transformer 20 1388.121 69.406
decode_latents 5 478.375 95.675
prepare_embeddings 5 273.368 54.674
component/text_encoder 5 273.029 54.606
scheduler_step 20 9.151 0.458
preprocess_latents 5 0.858 0.172
prepare_scheduler 5 0.395 0.079
==========================================
Generation complete!
PR result:
==================== PROFILING REPORT ==
Component Timings:
components calls total avg (ms)
component/transformer 20 1389.771 69.489
component/text_encoder 5 360.039 72.008
Method Timings:
methods calls total avg (ms)
E2E execute 5 2239.572 447.914
component/transformer 20 1389.771 69.489
decode_latents 5 480.221 96.044
prepare_embeddings 5 360.398 72.080
component/text_encoder 5 360.039 72.008
scheduler_step 20 5.576 0.279
preprocess_latents 5 0.763 0.153
prepare_scheduler 5 0.416 0.083
==========================================
Generation complete!
Quality
Prompt
Four elements split into quadrants: top-left shows splashing water forming the word \"WATER\" on a water background, top-right shows soil forming \"EARTH\" with planet earth behind, bottom-left shows colorful clouds forming \"AIR\" at sunset, bottom-right shows fiery lava forming \"FIRE\" against the sun
Checklist
- PR is small and focused — consider splitting larger changes into a
sequence of smaller PRs - I ran
./bazelw run formatto format my changes - I added or updated tests to cover my changes
- If AI tools assisted with this contribution, I have included an
Assisted-by:trailer in my commit message or this PR description
(see AI Tool Use Policy)