[Flux2] Align prompt embedding prep with attention mask and padded sequence length by pei0033 · Pull Request #6013 · modular/modular

@pei0033 @modularbot

…0028)

[External] [MAX] Align FLUX.2-Klein prompt masking with diffusers

<!--
Thanks for submitting a pull request! Your contribution is appreciated.

Please fill out the sections below to help reviewers understand your
change.
For guidance on writing a good PR, see our
[contributor guide](../CONTRIBUTING.md).
-->

## Summary

This PR adds an actual tokenizer-mask path for the FLUX.2-Klein text
encoder.

## Problems

MAX prompt embedding behavior is currently not aligned with diffusers
across the FLUX family, and the mismatch is especially visible for
FLUX.2-Klein.

Current behavior by model family:

1. `flux1` (`t5`, `clip`)
   - Tokens are padded to `max_length` with pad tokens.
   - The text encoders do not consume an attention mask.

2. `flux2-dev` (`mistral`)
   - Tokens are left-padded to `max_length`.
- Diffusers passes both padded tokens and the corresponding attention
mask to the text encoder.
- Because text-encoder masking was not implemented, PR
#6013 worked around this by compacting tokens before
encoding and then left-zero-padding the returned hidden states.

3. `flux2-klein` (`qwen3`)
   - Tokens are right-padded to `max_length`.
- Diffusers passes both padded tokens and the corresponding attention
mask to the text encoder.
- The `flux2-dev` workaround does not apply here: with right padding,
the hidden states at pad-token positions are not zero, so compacting and
then padding hidden states back does not reproduce diffusers behavior.
<img width="2462" height="932" alt="image"
src="https://github.com/user-attachments/assets/33748187-14d2-45ee-a748-6f6eefcc1f96"
/>
As a result, the current MAX behavior for FLUX.2-Klein does not match
diffusers prompt conditioning semantics and produces visibly different
image quality.

### What Chaged

Before this change, MAX did not have an end-to-end path for arbitrary
token masks here. In practice, the behavior was effectively
pad-token-only:

- `PixelGenerationTokenizer` can produce an `attention_mask`,
- `PixelContext` can carry that mask,
- but `Flux2KleinPipeline` does not pass it into prompt encoding,
- `Qwen3TextEncoderModel.__call__()` rejects `attention_mask`,
- and the compiled Qwen3 text attention path is causal-only.

The implementation in this PR:

- adds a shared causal + token-mask helper
in`max.pipelines.dataprocessing`,
- builds the additive attention bias from that mask on the host before
the
  compiled Qwen3 text-encoder call,

## Testing

### Latency
on H200 Text Encoder latency slow down (**54 -> 72 ms**)
Ran:
```bash
./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "black-forest-labs/FLUX.2-klein-4B" --prompt 'Four elements split into quadrants: top-left shows splashing water forming the word "WATER" on a water background, top-right shows soil forming "EARTH" with planet earth behind, bottom-left shows colorful clouds forming "AIR" at sunset, bottom-right shows fiery lava forming "FIRE" against the sun' --num-inference-steps 4 --guidance-scale 1.0 --seed 42 --profile-timings --num-profile-iterations 5 --num-warmups 1
```

baseline on main (with PR #6150):
```
==================== PROFILING REPORT ==
Component Timings:
components                       calls        total          avg (ms)
component/transformer               20     1388.121       69.406
component/text_encoder               5      273.029       54.606

Method Timings:
methods                          calls        total          avg (ms)
E2E execute                          5     2152.649      430.530
component/transformer               20     1388.121       69.406
decode_latents                       5      478.375       95.675
prepare_embeddings                   5      273.368       54.674
component/text_encoder               5      273.029       54.606
scheduler_step                      20        9.151        0.458
preprocess_latents                   5        0.858        0.172
prepare_scheduler                    5        0.395        0.079
==========================================
Generation complete!
```

PR result:
```
==================== PROFILING REPORT ==
Component Timings:
components                       calls        total          avg (ms)
component/transformer               20     1389.771       69.489
component/text_encoder               5      360.039       72.008

Method Timings:
methods                          calls        total          avg (ms)
E2E execute                          5     2239.572      447.914
component/transformer               20     1389.771       69.489
decode_latents                       5      480.221       96.044
prepare_embeddings                   5      360.398       72.080
component/text_encoder               5      360.039       72.008
scheduler_step                      20        5.576        0.279
preprocess_latents                   5        0.763        0.153
prepare_scheduler                    5        0.416        0.083
==========================================
Generation complete!
```

### Quality
**Prompt**
```
Four elements split into quadrants: top-left shows splashing water forming the word \"WATER\" on a water background, top-right shows soil forming \"EARTH\" with planet earth behind, bottom-left shows colorful clouds forming \"AIR\" at sunset, bottom-right shows fiery lava forming \"FIRE\" against the sun
```
<table>
  <tr>
      <td align="center">
<img width="1024" height="1024" alt="diffusers"
src="https://github.com/user-attachments/assets/0649854f-bedd-4d9d-94bf-03f74938e55f"
/>
      <br/>diffusers with same latent
    </td>
    <td align="center">
<img width="1024" height="1024" alt="max_main"
src="https://github.com/user-attachments/assets/828800db-28b3-4421-99e2-2f2a312f7348"
/>
      <br/>max main branch
    </td>
    <td align="center">
<img width="1024" height="1024" alt="max_pr"
src="https://github.com/user-attachments/assets/4ef02185-4e3f-4a8f-b828-5cf727f931e2"
/>
    <br/>max current PR
    </td>

  </tr>
</table>

## Checklist

- [x] PR is small and focused — consider splitting larger changes into a
      sequence of smaller PRs
- [x] I ran `./bazelw run format` to format my changes
- [x] I added or updated tests to cover my changes
- [x] If AI tools assisted with this contribution, I have included an
      `Assisted-by:` trailer in my commit message or this PR description
      (see [AI Tool Use Policy](../AI_TOOL_POLICY.md))
```
Assisted-by Codex
```
ORIGINAL_AUTHOR=pei0033 <parkeunik@naver.com>
ORIGINAL_USER=@pei0033

Co-authored-by: pei0033 <parkeunik@naver.com>
Closes #6153
MODULAR_ORIG_COMMIT_REV_ID: 1dab4a13d1c2dd8d4c73d00026a82c10f4b8b1e7