From 41375d321d124f18de134c2c393029b8b00a152d Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 7 Mar 2025 21:38:32 +0000 Subject: [PATCH 1/2] Wan Pipeline scaling fix, type hint warning, multi generator fix --- .../pipelines/wan/pipeline_wan_i2v.py | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 863178e7c434..dca8cd5b2fd8 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -19,7 +19,7 @@ import PIL import regex as re import torch -from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput @@ -49,11 +49,11 @@ >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline >>> from diffusers.utils import export_to_video, load_image - >>> from transformers import CLIPVisionModel + >>> from transformers import CLIPVisionModelWithProjection >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" - >>> image_encoder = CLIPVisionModel.from_pretrained( + >>> image_encoder = CLIPVisionModelWithProjection.from_pretrained( ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 ... ) >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) @@ -109,14 +109,30 @@ def prompt_clean(text): def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" + encoder_output: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + generator: Optional[torch.Generator] = None, + sample_mode: str = "sample", ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std + encoder_output.latent_dist.logvar = torch.clamp( + (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 + ) + encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar) + encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar) return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std + encoder_output.latent_dist.logvar = torch.clamp( + (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 + ) + encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar) + encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar) return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): - return encoder_output.latents + return (encoder_output.latents - latents_mean) * latents_std else: raise AttributeError("Could not access latents of provided encoder_output") @@ -155,7 +171,7 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, + image_encoder: CLIPVisionModelWithProjection, image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, @@ -385,13 +401,6 @@ def prepare_latents( ) video_condition = video_condition.to(device=device, dtype=dtype) - if isinstance(generator, list): - latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator] - latents = latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), generator) - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) @@ -401,7 +410,14 @@ def prepare_latents( latents.device, latents.dtype ) - latent_condition = (latent_condition - latents_mean) * latents_std + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator) + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, num_frames))] = 0 From e461b61d4134b2fd8b39bea8aab11d48da24ee15 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 7 Mar 2025 22:50:48 +0000 Subject: [PATCH 2/2] Apply suggestions from code review --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index dca8cd5b2fd8..102f1a5002e1 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -19,7 +19,7 @@ import PIL import regex as re import torch -from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput @@ -49,11 +49,11 @@ >>> import numpy as np >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline >>> from diffusers.utils import export_to_video, load_image - >>> from transformers import CLIPVisionModelWithProjection + >>> from transformers import CLIPVisionModel >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" - >>> image_encoder = CLIPVisionModelWithProjection.from_pretrained( + >>> image_encoder = CLIPVisionModel.from_pretrained( ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 ... ) >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) @@ -171,7 +171,7 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModelWithProjection, + image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan,