[MAX] Align FLUX.2-Klein prompt masking with diffusers by pei0033 · Pull Request #6153 · modular/modular

Summary

This PR adds an actual tokenizer-mask path for the FLUX.2-Klein text encoder.

Problems

MAX prompt embedding behavior is currently not aligned with diffusers across the FLUX family, and the mismatch is especially visible for FLUX.2-Klein.

Current behavior by model family:

  1. flux1 (t5, clip)

    • Tokens are padded to max_length with pad tokens.
    • The text encoders do not consume an attention mask.
  2. flux2-dev (mistral)

  3. flux2-klein (qwen3)

    • Tokens are right-padded to max_length.
    • Diffusers passes both padded tokens and the corresponding attention mask to the text encoder.
    • The flux2-dev workaround does not apply here: with right padding, the hidden states at pad-token positions are not zero, so compacting and then padding hidden states back does not reproduce diffusers behavior.
    image

As a result, the current MAX behavior for FLUX.2-Klein does not match diffusers prompt conditioning semantics and produces visibly different image quality.

What Chaged

Before this change, MAX did not have an end-to-end path for arbitrary token masks here. In practice, the behavior was effectively pad-token-only:

  • PixelGenerationTokenizer can produce an attention_mask,
  • PixelContext can carry that mask,
  • but Flux2KleinPipeline does not pass it into prompt encoding,
  • Qwen3TextEncoderModel.__call__() rejects attention_mask,
  • and the compiled Qwen3 text attention path is causal-only.

The implementation in this PR:

  • adds a shared causal + token-mask helper inmax.pipelines.dataprocessing,
  • builds the additive attention bias from that mask on the host before the
    compiled Qwen3 text-encoder call,

Testing

Latency

on H200 Text Encoder latency slow down (54 -> 72 ms)
Ran:

./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "black-forest-labs/FLUX.2-klein-4B" --prompt 'Four elements split into quadrants: top-left shows splashing water forming the word "WATER" on a water background, top-right shows soil forming "EARTH" with planet earth behind, bottom-left shows colorful clouds forming "AIR" at sunset, bottom-right shows fiery lava forming "FIRE" against the sun' --num-inference-steps 4 --guidance-scale 1.0 --seed 42 --profile-timings --num-profile-iterations 5 --num-warmups 1

baseline on main (with PR #6150):

==================== PROFILING REPORT ==
Component Timings:
components                       calls        total          avg (ms)
component/transformer               20     1388.121       69.406
component/text_encoder               5      273.029       54.606

Method Timings:
methods                          calls        total          avg (ms)
E2E execute                          5     2152.649      430.530
component/transformer               20     1388.121       69.406
decode_latents                       5      478.375       95.675
prepare_embeddings                   5      273.368       54.674
component/text_encoder               5      273.029       54.606
scheduler_step                      20        9.151        0.458
preprocess_latents                   5        0.858        0.172
prepare_scheduler                    5        0.395        0.079
==========================================
Generation complete!

PR result:

==================== PROFILING REPORT ==
Component Timings:
components                       calls        total          avg (ms)
component/transformer               20     1389.771       69.489
component/text_encoder               5      360.039       72.008

Method Timings:
methods                          calls        total          avg (ms)
E2E execute                          5     2239.572      447.914
component/transformer               20     1389.771       69.489
decode_latents                       5      480.221       96.044
prepare_embeddings                   5      360.398       72.080
component/text_encoder               5      360.039       72.008
scheduler_step                      20        5.576        0.279
preprocess_latents                   5        0.763        0.153
prepare_scheduler                    5        0.416        0.083
==========================================
Generation complete!

Quality

Prompt

Four elements split into quadrants: top-left shows splashing water forming the word \"WATER\" on a water background, top-right shows soil forming \"EARTH\" with planet earth behind, bottom-left shows colorful clouds forming \"AIR\" at sunset, bottom-right shows fiery lava forming \"FIRE\" against the sun

Checklist

  • PR is small and focused — consider splitting larger changes into a
    sequence of smaller PRs
  • I ran ./bazelw run format to format my changes
  • I added or updated tests to cover my changes
  • If AI tools assisted with this contribution, I have included an
    Assisted-by: trailer in my commit message or this PR description
    (see AI Tool Use Policy)