(WIP) Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU#434
(WIP) Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU#434amepas wants to merge 9 commits into
Conversation
|
|
||
| class GenerateFlux2KleinE2ETest(unittest.TestCase): | ||
|
|
||
| def test_end_to_end_parity_and_offloading(self): |
There was a problem hiding this comment.
this test is very likely to fail in the github runner with the hardcoded values. We usually don't run e2e tests on the github runner, you can mark it so it doesn't run in the runner.
| every single stage against the golden PyTorch reference. | ||
| """ | ||
| # Set highest precision for strict mathematical parity checks | ||
| jax.config.update("jax_default_matmul_precision", "highest") |
There was a problem hiding this comment.
what is the reason for using highest here?
There was a problem hiding this comment.
Do we need a separate generate file for 9B model? Can this be combined with the generate_flux2klein.py or even the generate_flux.py?
There was a problem hiding this comment.
We don't need a benchmark code to be added to tests/ folder. We could benchmark the results and mention it in the PR or a doc but no need to add it to the repo
| extra_one_step: bool = False, | ||
| reverse_sigmas: bool = False, | ||
| use_dynamic_shifting: bool = False, | ||
| time_shift_type: str = "linear", |
There was a problem hiding this comment.
What are the possible time_shift_types here? Is moving to a Enum data type better?
There was a problem hiding this comment.
I suspect you can refactor this file and move some of the functions to files under different folders like pipeline, models/flux, max_utils.py
Refer to WAN model related files
(WIP) - this will be updated with multi-chip latency and support for Flux2 Klein 9B!
Draft PR for the Flux2 Klein model. Includes a custom implementation of the Qwen3-4B model for getting text embeddings. VAE Decoder, RoPE positional embedder, flow-matching step schedule are all re-used. Light modifications to transformer/attention blocks are used.
Latency for batch-size 4 of 1024 by 1024 images (bfloat16):
57.67 ms(1.58% of total)3,181.20 ms(87.09% of total)795.30 ms413.77 ms(11.33% of total)PR includes code for verifying accuracy of implementation. Sharding model is implemented but not tested.
Image generation is only supported so far.