Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the Qwen Image VAE, a crucial component for single-image generation within the MAX framework. It adapts the QwenImage VAE's 3D causal convolution architecture to 2D for efficient T=1 image generation, providing both encoder and decoder modules. The changes also include the necessary configuration and a robust weight transformation mechanism to correctly load and utilize pre-trained 3D weights in a 2D context. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds the VAE implementation for Qwen-Image. The implementation looks solid overall. I've left a few comments for minor improvements: one is a typo in a config name, and two others are suggestions to improve code clarity and reduce duplication in the weight transformation logic. Please take a look.
| def _transform_decoder_weights( | ||
| raw_weights: dict[str, Any], | ||
| target_dtype: DType, | ||
| ) -> dict[str, Any]: |
There was a problem hiding this comment.
The function _transform_decoder_weights is also used to transform encoder weights in load_model (lines 848-852). The name is a bit misleading. Consider renaming it to something more generic, like _transform_3d_vae_weights, to better reflect its usage for both encoder and decoder components. You will need to update the call sites as well.
| def _transform_decoder_weights( | |
| raw_weights: dict[str, Any], | |
| target_dtype: DType, | |
| ) -> dict[str, Any]: | |
| def _transform_3d_vae_weights( | |
| raw_weights: dict[str, Any], | |
| target_dtype: DType, | |
| ) -> dict[str, Any]: |
| if ".to_qkv.weight" in key: | ||
| if data.ndim == 5: | ||
| data = ( | ||
| data[:, :, -1, :, :] | ||
| if data.shape[2] > 1 | ||
| else data[:, :, 0, :, :] | ||
| ) | ||
| C = data.shape[0] // 3 | ||
| prefix = key.replace(".to_qkv.weight", "") | ||
| result[f"{prefix}.to_q.weight"] = _to_weight_data( | ||
| data[:C], f"{prefix}.to_q.weight", target_dtype | ||
| ) | ||
| result[f"{prefix}.to_k.weight"] = _to_weight_data( | ||
| data[C : 2 * C], f"{prefix}.to_k.weight", target_dtype | ||
| ) | ||
| result[f"{prefix}.to_v.weight"] = _to_weight_data( | ||
| data[2 * C :], f"{prefix}.to_v.weight", target_dtype | ||
| ) | ||
| continue | ||
| if ".to_qkv.bias" in key: | ||
| C = data.shape[0] // 3 | ||
| prefix = key.replace(".to_qkv.bias", "") | ||
| result[f"{prefix}.to_q.bias"] = _to_weight_data( | ||
| data[:C], f"{prefix}.to_q.bias", target_dtype | ||
| ) | ||
| result[f"{prefix}.to_k.bias"] = _to_weight_data( | ||
| data[C : 2 * C], f"{prefix}.to_k.bias", target_dtype | ||
| ) | ||
| result[f"{prefix}.to_v.bias"] = _to_weight_data( | ||
| data[2 * C :], f"{prefix}.to_v.bias", target_dtype | ||
| ) | ||
| continue |
There was a problem hiding this comment.
The logic for splitting fused QKV weights and biases is very similar. You can refactor this to combine the handling for weights and biases, which would reduce code duplication and improve maintainability.
if ".to_qkv." in key:
is_weight = ".weight" in key
suffix = "weight" if is_weight else "bias"
if is_weight and data.ndim == 5:
data = (
data[:, :, -1, :, :]
if data.shape[2] > 1
else data[:, :, 0, :, :]
)
C = data.shape[0] // 3
prefix = key.replace(f".to_qkv.{suffix}", "")
result[f"{prefix}.to_q.{suffix}"] = _to_weight_data(
data[:C], f"{prefix}.to_q.{suffix}", target_dtype
)
result[f"{prefix}.to_k.{suffix}"] = _to_weight_data(
data[C : 2 * C], f"{prefix}.to_k.{suffix}", target_dtype
)
result[f"{prefix}.to_v.{suffix}"] = _to_weight_data(
data[2 * C :], f"{prefix}.to_v.{suffix}", target_dtype
)
continue| temperal_downsample: list[bool] = Field( | ||
| default_factory=lambda: [False, True, True] | ||
| ) |
There was a problem hiding this comment.
Summary
Testing
./bazelw run format./bazelw run lintChecklist