Skip to content

Flux with Remote Encode #11091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):

return latents

# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):

return latents

# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor):

return latents

# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
def prepare_latents(
self,
image,
Expand Down Expand Up @@ -614,7 +613,6 @@ def prepare_latents(
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, noise, image_latents, latent_image_ids

# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
def prepare_mask_latents(
self,
mask,
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
Expand Down Expand Up @@ -634,7 +637,10 @@ def prepare_latents(
return latents.to(device=device, dtype=dtype), latent_image_ids

image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,13 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=latent_channels,
vae_latent_channels=self.latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
Expand Down Expand Up @@ -653,7 +655,10 @@ def prepare_latents(
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image

if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
Expand Down Expand Up @@ -710,7 +715,9 @@ def prepare_mask_latents(
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)

masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
masked_image_latents = (
masked_image_latents - self.vae.config.shift_factor
) * self.vae.config.scaling_factor

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/remote_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def prepare_encode(
if shift_factor is not None:
parameters["shift_factor"] = shift_factor
if isinstance(image, torch.Tensor):
data = safetensors.torch._tobytes(image, "tensor")
data = safetensors.torch._tobytes(image.contiguous(), "tensor")
parameters["shape"] = list(image.shape)
parameters["dtype"] = str(image.dtype).split(".")[-1]
else:
Expand Down
Loading