Skip to content

(WIP) Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU#434

Draft
amepas wants to merge 9 commits into
mainfrom
flux2klein-onboarding
Draft

(WIP) Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU#434
amepas wants to merge 9 commits into
mainfrom
flux2klein-onboarding

Conversation

@amepas

@amepas amepas commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

(WIP) - this will be updated with multi-chip latency and support for Flux2 Klein 9B!

Draft PR for the Flux2 Klein model. Includes a custom implementation of the Qwen3-4B model for getting text embeddings. VAE Decoder, RoPE positional embedder, flow-matching step schedule are all re-used. Light modifications to transformer/attention blocks are used.

Latency for batch-size 4 of 1024 by 1024 images (bfloat16):

  • Prompt Encoding (Qwen3): 57.67 ms (1.58% of total)
  • Denoising Loop (Flux 4 steps): 3,181.20 ms (87.09% of total)
    • Per-Step Transformer Time: 795.30 ms
  • VAE Decoding (VAE): 413.77 ms (11.33% of total)
  • Total: 3.65 seconds

PR includes code for verifying accuracy of implementation. Sharding model is implemented but not tested.

Image generation is only supported so far.

@amepas amepas requested review from chandrasekhard2 and eltsai June 29, 2026 20:49
@amepas amepas requested a review from entrpn as a code owner June 29, 2026 20:49
@amepas amepas marked this pull request as draft June 29, 2026 20:53

class GenerateFlux2KleinE2ETest(unittest.TestCase):

def test_end_to_end_parity_and_offloading(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is very likely to fail in the github runner with the hardcoded values. We usually don't run e2e tests on the github runner, you can mark it so it doesn't run in the runner.

every single stage against the golden PyTorch reference.
"""
# Set highest precision for strict mathematical parity checks
jax.config.update("jax_default_matmul_precision", "highest")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason for using highest here?

@amepas amepas changed the title Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU Jun 30, 2026

@Perseus14 Perseus14 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @amepas, I see this is WIP, added a few comments to help polish the PR

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a separate generate file for 9B model? Can this be combined with the generate_flux2klein.py or even the generate_flux.py?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a benchmark code to be added to tests/ folder. We could benchmark the results and mention it in the PR or a doc but no need to add it to the repo

extra_one_step: bool = False,
reverse_sigmas: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "linear",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the possible time_shift_types here? Is moving to a Enum data type better?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect you can refactor this file and move some of the functions to files under different folders like pipeline, models/flux, max_utils.py

Refer to WAN model related files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants