Skip to content

Use output_size in repeat_interleave #11030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,10 +741,14 @@ def prepare_attention_mask(

if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
attention_mask = attention_mask.repeat_interleave(
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
attention_mask = attention_mask.repeat_interleave(
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
)

return attention_mask

Expand Down Expand Up @@ -3704,8 +3708,10 @@ def __call__(
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
value = torch.repeat_interleave(
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
)

if attn.norm_q is not None:
query = attn.norm_q(query)
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = F.pixel_shuffle(x, self.factor)

if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1)
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
Expand Down Expand Up @@ -361,7 +361,9 @@ def __init__(

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
if self.down_sample:
identity = hidden_states[:, :, ::2]
elif self.up_sample:
identity = hidden_states.repeat_interleave(2, dim=2)
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
else:
identity = hidden_states

Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]

# Interleaved repeat of input channels to match w
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
h = inputs.repeat_interleave(
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnets/controlnet_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def forward(
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0)
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)

# 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape
Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(

# 3. Concat
pos_embed_spatial = pos_embed_spatial[None, :, :]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
) # [T, H*W, D // 4 * 3]

pos_embed_temporal = pos_embed_temporal[:, None, :]
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
Expand Down Expand Up @@ -1154,8 +1156,8 @@ def get_1d_rotary_pos_embed(
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
Expand Down
18 changes: 12 additions & 6 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,17 @@ def forward(
# Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
)
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])

# Prepare timesteps for spatial and temporal block
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
timestep_spatial = timestep.repeat_interleave(
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
).view(-1, timestep.shape[-1])

# Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate(
Expand Down Expand Up @@ -299,7 +303,9 @@ def forward(
).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])

embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/transformers/prior_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,11 @@ def forward(
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
attention_mask = attention_mask.repeat_interleave(
self.config.num_attention_heads,
dim=0,
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
)

if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/unets/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,10 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)

# 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unets/unet_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def forward(

# 3. time + FPS embeddings.
emb = t_emb + fps_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)

# 4. context embeddings.
# The context embeddings consist of both text embeddings from the input prompt
Expand Down Expand Up @@ -620,7 +620,7 @@ def forward(
image_emb = self.context_embedding(image_embeddings)
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, image_emb], dim=1)
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)

image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2],
Expand Down
7 changes: 5 additions & 2 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,7 +2059,7 @@ def forward(
aug_emb = self.add_embedding(add_embeds)

emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)

if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
Expand All @@ -2068,7 +2068,10 @@ def forward(
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
image_embeds = [
image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
for image_embed in image_embeds
]
encoder_hidden_states = (encoder_hidden_states, image_embeds)

# 2. pre-process
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/unets/unet_spatio_temporal_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,11 @@ def forward(
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)

# 2. pre-process
sample = self.conv_in(sample)
Expand Down
Loading