From 081e68f66a004c058434da70ae10a9e67f9057ac Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Mar 2025 10:12:47 +0000 Subject: [PATCH 01/10] =?UTF-8?q?[hybrid=20inference=20=F0=9F=8D=AF?= =?UTF-8?q?=F0=9F=90=9D]=20Add=20VAE=20encode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/en/hybrid_inference/overview.md | 5 +- docs/source/en/hybrid_inference/vae_encode.md | 322 ++++++++++++++++++ src/diffusers/utils/remote_utils.py | 114 ++++++- tests/remote/test_remote_decode.py | 29 +- tests/remote/test_remote_encode.py | 214 ++++++++++++ 5 files changed, 662 insertions(+), 22 deletions(-) create mode 100644 docs/source/en/hybrid_inference/vae_encode.md create mode 100644 tests/remote/test_remote_encode.py diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index 9bbe245901df..53acfcf6c34f 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Available Models * **VAE Decode šŸ–¼ļø:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed. -* **VAE Encode šŸ”¢ (coming soon):** Efficiently encode images into latent representations for generation and training. +* **VAE Encode šŸ”¢:** Efficiently encode images into latent representations for generation and training. * **Text Encoders šŸ“ƒ (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow. --- @@ -48,7 +48,8 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Contents -The documentation is organized into two sections: +The documentation is organized into three sections: * **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference. +* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference. * **API Reference** Dive into task-specific settings and parameters. diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md new file mode 100644 index 000000000000..d277f260f450 --- /dev/null +++ b/docs/source/en/hybrid_inference/vae_encode.md @@ -0,0 +1,322 @@ +# Getting Started: VAE Encode with Hybrid Inference + +VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations. + +## Memory + +These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs. + +For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality. + +
SD v1.5 + +| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | +| --- | --- | --- | --- | --- | --- | +TODO + +
+ +
SDXL + +| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | +| --- | --- | --- | --- | --- | --- | +TODO + +
+ +## Available VAEs + +| | **Endpoint** | **Model** | +|:-:|:-----------:|:--------:| +| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | +| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | +| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | + + +> [!TIP] +> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). + + +## Code + +> [!TIP] +> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main` + + +A helper method simplifies interacting with Hybrid Inference. + +```python +from diffusers.utils.remote_utils import remote_encode +``` + +### Basic example + +Here, we show how to use the remote VAE on random tensors. + +
Code + +```python +image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16), + scaling_factor=0.18215, +) +``` + +
+ +
+ +
+ +Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`. + +
Code + +```python +image = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=torch.randn([1, 4096, 64], dtype=torch.float16), + height=1024, + width=1024, + scaling_factor=0.3611, + shift_factor=0.1159, +) +``` + +
+ +
+ +
+ +Finally, an example for HunyuanVideo. + +
Code + +```python +video = remote_decode( + endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16), + output_type="mp4", +) +with open("video.mp4", "wb") as f: + f.write(video) +``` + +
+ +
+ +
+ + +### Generation + +But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5. + +
Code + +```python +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + torch_dtype=torch.float16, + variant="fp16", + vae=None, +).to("cuda") + +prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" + +latent = pipe( + prompt=prompt, + output_type="latent", +).images +image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + scaling_factor=0.18215, +) +image.save("test.jpg") +``` + +
+ +
+ +
+ +Here’s another example with Flux. + +
Code + +```python +from diffusers import FluxPipeline + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, + vae=None, +).to("cuda") + +prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" + +latent = pipe( + prompt=prompt, + guidance_scale=0.0, + num_inference_steps=4, + output_type="latent", +).images +image = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + height=1024, + width=1024, + scaling_factor=0.3611, + shift_factor=0.1159, +) +image.save("test.jpg") +``` + +
+ +
+ +
+ +Here’s an example with HunyuanVideo. + +
Code + +```python +from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + +model_id = "hunyuanvideo-community/HunyuanVideo" +transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=torch.bfloat16 +) +pipe = HunyuanVideoPipeline.from_pretrained( + model_id, transformer=transformer, vae=None, torch_dtype=torch.float16 +).to("cuda") + +latent = pipe( + prompt="A cat walks on the grass, realistic", + height=320, + width=512, + num_frames=61, + num_inference_steps=30, + output_type="latent", +).frames + +video = remote_decode( + endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + output_type="mp4", +) + +if isinstance(video, bytes): + with open("video.mp4", "wb") as f: + f.write(video) +``` + +
+ +
+ +
+ + +### Queueing + +One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency. + + +
Code + +```python +import queue +import threading +from IPython.display import display +from diffusers import StableDiffusionPipeline + +def decode_worker(q: queue.Queue): + while True: + item = q.get() + if item is None: + break + image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=item, + scaling_factor=0.18215, + ) + display(image) + q.task_done() + +q = queue.Queue() +thread = threading.Thread(target=decode_worker, args=(q,), daemon=True) +thread.start() + +def decode(latent: torch.Tensor): + q.put(latent) + +prompts = [ + "Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious", + "Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore", + "Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.", + "Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP", + "A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting", + "Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,", +] + +pipe = StableDiffusionPipeline.from_pretrained( + "Lykon/dreamshaper-8", + torch_dtype=torch.float16, + vae=None, +).to("cuda") + +pipe.unet = pipe.unet.to(memory_format=torch.channels_last) +pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +_ = pipe( + prompt=prompts[0], + output_type="latent", +) + +for prompt in prompts: + latent = pipe( + prompt=prompt, + output_type="latent", + ).images + decode(latent) + +q.put(None) +thread.join() +``` + +
+ + +
+ +
+ +## Integrations + +* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. +* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index 12bcc94af74f..68cd39f55265 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -43,6 +43,17 @@ from PIL import Image +DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" + + +ENCODE_ENDPOINT_SD_V1 = "" +ENCODE_ENDPOINT_SD_XL = "" +ENCODE_ENDPOINT_FLUX = "" + + def detect_image_type(data: bytes) -> str: if data.startswith(b"\xff\xd8"): return "jpeg" @@ -55,7 +66,7 @@ def detect_image_type(data: bytes) -> str: return "unknown" -def check_inputs( +def check_inputs_decode( endpoint: str, tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, @@ -89,7 +100,7 @@ def check_inputs( ) -def postprocess( +def postprocess_decode( response: requests.Response, processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", @@ -142,7 +153,7 @@ def postprocess( return output -def prepare( +def prepare_decode( tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, do_scaling: bool = True, @@ -293,7 +304,7 @@ def remote_decode( standard_warn=False, ) output_tensor_type = "binary" - check_inputs( + check_inputs_decode( endpoint, tensor, processor, @@ -309,7 +320,7 @@ def remote_decode( height, width, ) - kwargs = prepare( + kwargs = prepare_decode( tensor=tensor, processor=processor, do_scaling=do_scaling, @@ -324,7 +335,7 @@ def remote_decode( response = requests.post(endpoint, **kwargs) if not response.ok: raise RuntimeError(response.json()) - output = postprocess( + output = postprocess_decode( response=response, processor=processor, output_type=output_type, @@ -332,3 +343,94 @@ def remote_decode( partial_postprocess=partial_postprocess, ) return output + + +def check_inputs_encode( + endpoint: str, + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +): + pass + + +def postprocess_encode( + response: requests.Response, +): + output_tensor = response.content + parameters = response.headers + shape = json.loads(parameters["shape"]) + dtype = parameters["dtype"] + torch_dtype = DTYPE_MAP[dtype] + output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) + return output_tensor + + +def prepare_encode( + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +): + headers = {} + parameters = {} + if scaling_factor is not None: + parameters["scaling_factor"] = scaling_factor + if shift_factor is not None: + parameters["shift_factor"] = shift_factor + if isinstance(image, torch.Tensor): + data = safetensors.torch._tobytes(image, "tensor") + parameters["shape"] = list(image.shape) + parameters["dtype"] = str(image.dtype).split(".")[-1] + else: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + data = buffer.getvalue() + return {"data": data, "params": parameters, "headers": headers} + + +def remote_encode( + endpoint: str, + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +) -> "torch.Tensor": + """ + Hugging Face Hybrid Inference that allow running VAE encode remotely. + + Args: + endpoint (`str`): + Endpoint for Remote Decode. + image (`torch.Tensor` or `PIL.Image.Image`): + Image to be encoded. + scaling_factor (`float`, *optional*): + Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`]. + - SD v1: 0.18215 + - SD XL: 0.13025 + - Flux: 0.3611 + If `None`, input must be passed with scaling applied. + shift_factor (`float`, *optional*): + Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`. + - Flux: 0.1159 + If `None`, input must be passed with scaling applied. + + Returns: + output (`torch.Tensor`). + """ + check_inputs_encode( + endpoint, + image, + scaling_factor, + shift_factor, + ) + kwargs = prepare_encode( + image=image, + scaling_factor=scaling_factor, + shift_factor=shift_factor, + ) + response = requests.post(endpoint, **kwargs) + if not response.ok: + raise RuntimeError(response.json()) + output = postprocess_encode( + response=response, + ) + return output diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index 11f9c24d16f6..7308b1121991 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -21,7 +21,13 @@ import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.utils.remote_utils import remote_decode +from diffusers.utils.remote_utils import ( + DECODE_ENDPOINT_FLUX, + DECODE_ENDPOINT_HUNYUAN_VIDEO, + DECODE_ENDPOINT_SD_V1, + DECODE_ENDPOINT_SD_XL, + remote_decode, +) from diffusers.utils.testing_utils import ( enable_full_determinism, slow, @@ -33,11 +39,6 @@ enable_full_determinism() -ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" - class RemoteAutoencoderKLMixin: shape: Tuple[int, ...] = None @@ -350,7 +351,7 @@ class RemoteAutoencoderKLSDv1Tests( 512, 512, ) - endpoint = ENDPOINT_SD_V1 + endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -374,7 +375,7 @@ class RemoteAutoencoderKLSDXLTests( 1024, 1024, ) - endpoint = ENDPOINT_SD_XL + endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -398,7 +399,7 @@ class RemoteAutoencoderKLFluxTests( 1024, 1024, ) - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -425,7 +426,7 @@ class RemoteAutoencoderKLFluxPackedTests( ) height = 1024 width = 1024 - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -453,7 +454,7 @@ class RemoteAutoencoderKLHunyuanVideoTests( 320, 512, ) - endpoint = ENDPOINT_HUNYUAN_VIDEO + endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO dtype = torch.float16 scaling_factor = 0.476986 processor_cls = VideoProcessor @@ -504,7 +505,7 @@ class RemoteAutoencoderKLSDv1SlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = ENDPOINT_SD_V1 + endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -515,7 +516,7 @@ class RemoteAutoencoderKLSDXLSlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = ENDPOINT_SD_XL + endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -527,7 +528,7 @@ class RemoteAutoencoderKLFluxSlowTests( unittest.TestCase, ): channels = 16 - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py new file mode 100644 index 000000000000..28e71d0b141e --- /dev/null +++ b/tests/remote/test_remote_encode.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import PIL.Image +import torch + +from diffusers.utils import load_image +from diffusers.utils.remote_utils import ( + remote_decode, + remote_encode, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, +) + + +enable_full_determinism() + +IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true" + + +class RemoteAutoencoderKLEncodeMixin: + channels: int = None + endpoint: str = None + decode_endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + image: PIL.Image.Image = None + + def get_dummy_inputs(self): + if self.image is None: + self.image = load_image(IMAGE) + inputs = { + "endpoint": self.endpoint, + "image": self.image, + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + } + return inputs + + def test_image_input(self): + inputs = self.get_dummy_inputs() + height, width = inputs["image"].height, inputs["image"].width + output = remote_encode(**inputs) + self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) + decoded = remote_decode( + tensor=output, + endpoint=self.decode_endpoint, + scaling_factor=self.scaling_factor, + shift_factor=self.shift_factor, + image_format="png", + ) + self.assertEqual(decoded.height, height) + self.assertEqual(decoded.width, width) + # image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten()) + # decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten()) + # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? + + +# class RemoteAutoencoderKLSDv1Tests( +# RemoteAutoencoderKLEncodeMixin, +# unittest.TestCase, +# ): +# channels = 4 +# endpoint = ENCODE_ENDPOINT_SD_V1 +# decode_endpoint = DECODE_ENDPOINT_SD_V1 +# dtype = torch.float16 +# scaling_factor = 0.18215 +# shift_factor = None + + +# class RemoteAutoencoderKLSDXLTests( +# RemoteAutoencoderKLEncodeMixin, +# unittest.TestCase, +# ): +# channels = 4 +# endpoint = ENCODE_ENDPOINT_SD_XL +# decode_endpoint = DECODE_ENDPOINT_SD_XL +# dtype = torch.float16 +# scaling_factor = 0.13025 +# shift_factor = None + + +# class RemoteAutoencoderKLFluxTests( +# RemoteAutoencoderKLEncodeMixin, +# unittest.TestCase, +# ): +# channels = 16 +# endpoint = DECODE_ENDPOINT_FLUX +# decode_endpoint = ENCODE_ENDPOINT_FLUX +# dtype = torch.bfloat16 +# scaling_factor = 0.3611 +# shift_factor = 0.1159 + + +class RemoteAutoencoderKLEncodeSlowTestMixin: + channels: int = 4 + endpoint: str = None + decode_endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + image: PIL.Image.Image = None + + def get_dummy_inputs(self): + if self.image is None: + self.image = load_image(IMAGE) + inputs = { + "endpoint": self.endpoint, + "image": self.image, + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + } + return inputs + + def test_multi_res(self): + inputs = self.get_dummy_inputs() + for height in { + 320, + 512, + 640, + 704, + 896, + 1024, + 1208, + 1384, + 1536, + 1608, + 1864, + 2048, + }: + for width in { + 320, + 512, + 640, + 704, + 896, + 1024, + 1208, + 1384, + 1536, + 1608, + 1864, + 2048, + }: + inputs["image"] = inputs["image"].resize( + ( + width, + height, + ) + ) + output = remote_encode(**inputs) + self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) + decoded = remote_decode( + tensor=output, + endpoint=self.decode_endpoint, + scaling_factor=self.scaling_factor, + shift_factor=self.shift_factor, + image_format="png", + ) + self.assertEqual(decoded.height, height) + self.assertEqual(decoded.width, width) + decoded.save(f"test_multi_res_{height}_{width}.png") + + +# @slow +# class RemoteAutoencoderKLSDv1SlowTests( +# RemoteAutoencoderKLEncodeSlowTestMixin, +# unittest.TestCase, +# ): +# endpoint = ENCODE_ENDPOINT_SD_V1 +# decode_endpoint = DECODE_ENDPOINT_SD_V1 +# dtype = torch.float16 +# scaling_factor = 0.18215 +# shift_factor = None + + +# @slow +# class RemoteAutoencoderKLSDXLSlowTests( +# RemoteAutoencoderKLEncodeSlowTestMixin, +# unittest.TestCase, +# ): +# endpoint = ENCODE_ENDPOINT_SD_XL +# decode_endpoint = DECODE_ENDPOINT_SD_XL +# dtype = torch.float16 +# scaling_factor = 0.13025 +# shift_factor = None + + +# @slow +# class RemoteAutoencoderKLFluxSlowTests( +# RemoteAutoencoderKLEncodeSlowTestMixin, +# unittest.TestCase, +# ): +# channels = 16 +# endpoint = ENCODE_ENDPOINT_FLUX +# decode_endpoint = DECODE_ENDPOINT_FLUX +# dtype = torch.bfloat16 +# scaling_factor = 0.3611 +# shift_factor = 0.1159 From 140e0c21efc664157b9e244de153a01397b95eb1 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Mar 2025 10:25:18 +0000 Subject: [PATCH 02/10] _toctree: add vae encode --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8811fca5f5a2..d1805ff605d8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -81,6 +81,8 @@ title: Overview - local: hybrid_inference/vae_decode title: VAE Decode + - local: hybrid_inference/vae_encode + title: VAE Encode - local: hybrid_inference/api_reference title: API Reference title: Hybrid Inference From e70bdb2234d5acaf40b7712e4df7c49a2b83bcff Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Mar 2025 12:46:56 +0000 Subject: [PATCH 03/10] Add endpoints, tests --- src/diffusers/utils/remote_utils.py | 6 +- tests/remote/test_remote_encode.py | 146 +++++++++++++++------------- 2 files changed, 80 insertions(+), 72 deletions(-) diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index 68cd39f55265..bee641ce5592 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -49,9 +49,9 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" -ENCODE_ENDPOINT_SD_V1 = "" -ENCODE_ENDPOINT_SD_XL = "" -ENCODE_ENDPOINT_FLUX = "" +ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" def detect_image_type(data: bytes) -> str: diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py index 28e71d0b141e..39bea125337e 100644 --- a/tests/remote/test_remote_encode.py +++ b/tests/remote/test_remote_encode.py @@ -13,17 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest import PIL.Image import torch from diffusers.utils import load_image from diffusers.utils.remote_utils import ( + DECODE_ENDPOINT_FLUX, + DECODE_ENDPOINT_SD_V1, + DECODE_ENDPOINT_SD_XL, + ENCODE_ENDPOINT_FLUX, + ENCODE_ENDPOINT_SD_V1, + ENCODE_ENDPOINT_SD_XL, remote_decode, remote_encode, ) from diffusers.utils.testing_utils import ( enable_full_determinism, + slow, ) @@ -71,40 +79,40 @@ def test_image_input(self): # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? -# class RemoteAutoencoderKLSDv1Tests( -# RemoteAutoencoderKLEncodeMixin, -# unittest.TestCase, -# ): -# channels = 4 -# endpoint = ENCODE_ENDPOINT_SD_V1 -# decode_endpoint = DECODE_ENDPOINT_SD_V1 -# dtype = torch.float16 -# scaling_factor = 0.18215 -# shift_factor = None - - -# class RemoteAutoencoderKLSDXLTests( -# RemoteAutoencoderKLEncodeMixin, -# unittest.TestCase, -# ): -# channels = 4 -# endpoint = ENCODE_ENDPOINT_SD_XL -# decode_endpoint = DECODE_ENDPOINT_SD_XL -# dtype = torch.float16 -# scaling_factor = 0.13025 -# shift_factor = None - - -# class RemoteAutoencoderKLFluxTests( -# RemoteAutoencoderKLEncodeMixin, -# unittest.TestCase, -# ): -# channels = 16 -# endpoint = DECODE_ENDPOINT_FLUX -# decode_endpoint = ENCODE_ENDPOINT_FLUX -# dtype = torch.bfloat16 -# scaling_factor = 0.3611 -# shift_factor = 0.1159 +class RemoteAutoencoderKLSDv1Tests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 4 + endpoint = ENCODE_ENDPOINT_SD_V1 + decode_endpoint = DECODE_ENDPOINT_SD_V1 + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + + +class RemoteAutoencoderKLSDXLTests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 4 + endpoint = ENCODE_ENDPOINT_SD_XL + decode_endpoint = DECODE_ENDPOINT_SD_XL + dtype = torch.float16 + scaling_factor = 0.13025 + shift_factor = None + + +class RemoteAutoencoderKLFluxTests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 16 + endpoint = ENCODE_ENDPOINT_FLUX + decode_endpoint = DECODE_ENDPOINT_FLUX + dtype = torch.bfloat16 + scaling_factor = 0.3611 + shift_factor = 0.1159 class RemoteAutoencoderKLEncodeSlowTestMixin: @@ -177,38 +185,38 @@ def test_multi_res(self): decoded.save(f"test_multi_res_{height}_{width}.png") -# @slow -# class RemoteAutoencoderKLSDv1SlowTests( -# RemoteAutoencoderKLEncodeSlowTestMixin, -# unittest.TestCase, -# ): -# endpoint = ENCODE_ENDPOINT_SD_V1 -# decode_endpoint = DECODE_ENDPOINT_SD_V1 -# dtype = torch.float16 -# scaling_factor = 0.18215 -# shift_factor = None - - -# @slow -# class RemoteAutoencoderKLSDXLSlowTests( -# RemoteAutoencoderKLEncodeSlowTestMixin, -# unittest.TestCase, -# ): -# endpoint = ENCODE_ENDPOINT_SD_XL -# decode_endpoint = DECODE_ENDPOINT_SD_XL -# dtype = torch.float16 -# scaling_factor = 0.13025 -# shift_factor = None - - -# @slow -# class RemoteAutoencoderKLFluxSlowTests( -# RemoteAutoencoderKLEncodeSlowTestMixin, -# unittest.TestCase, -# ): -# channels = 16 -# endpoint = ENCODE_ENDPOINT_FLUX -# decode_endpoint = DECODE_ENDPOINT_FLUX -# dtype = torch.bfloat16 -# scaling_factor = 0.3611 -# shift_factor = 0.1159 +@slow +class RemoteAutoencoderKLSDv1SlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + endpoint = ENCODE_ENDPOINT_SD_V1 + decode_endpoint = DECODE_ENDPOINT_SD_V1 + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + + +@slow +class RemoteAutoencoderKLSDXLSlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + endpoint = ENCODE_ENDPOINT_SD_XL + decode_endpoint = DECODE_ENDPOINT_SD_XL + dtype = torch.float16 + scaling_factor = 0.13025 + shift_factor = None + + +@slow +class RemoteAutoencoderKLFluxSlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + channels = 16 + endpoint = ENCODE_ENDPOINT_FLUX + decode_endpoint = DECODE_ENDPOINT_FLUX + dtype = torch.bfloat16 + scaling_factor = 0.3611 + shift_factor = 0.1159 From e5448f2dd7dc5f5f1761e446415d2e994a58aee4 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Mar 2025 12:47:05 +0000 Subject: [PATCH 04/10] vae_encode docs --- docs/source/en/hybrid_inference/vae_encode.md | 258 +++--------------- 1 file changed, 40 insertions(+), 218 deletions(-) diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md index d277f260f450..844f96d4af0e 100644 --- a/docs/source/en/hybrid_inference/vae_encode.md +++ b/docs/source/en/hybrid_inference/vae_encode.md @@ -6,7 +6,7 @@ VAE encode is used for training, image-to-image and image-to-video - turning int These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs. -For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality. +For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
SD v1.5 @@ -28,9 +28,9 @@ TODO | | **Endpoint** | **Model** | |:-:|:-----------:|:--------:| -| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | -| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | -| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | +| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | +| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | +| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | > [!TIP] @@ -51,269 +51,91 @@ from diffusers.utils.remote_utils import remote_encode ### Basic example -Here, we show how to use the remote VAE on random tensors. - -
Code - -```python -image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16), - scaling_factor=0.18215, -) -``` - -
+Let's encode an image, then decode it to demonstrate.
- +
-Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`. -
Code ```python -image = remote_decode( - endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=torch.randn([1, 4096, 64], dtype=torch.float16), - height=1024, - width=1024, +from diffusers.utils import load_image +from diffusers.utils.remote_utils import remote_decode + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true") + +latent = remote_encode( + endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/", scaling_factor=0.3611, shift_factor=0.1159, ) -``` -
- -
- -
- -Finally, an example for HunyuanVideo. - -
Code - -```python -video = remote_decode( - endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16), - output_type="mp4", +decoded = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + scaling_factor=0.3611, + shift_factor=0.1159, ) -with open("video.mp4", "wb") as f: - f.write(video) ```
- +
### Generation -But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5. +Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
Code ```python -from diffusers import StableDiffusionPipeline +import torch +from diffusers import StableDiffusionImg2ImgPipeline +from diffusers.utils import load_image +from diffusers.utils.remote_utils import remote_decode, remote_encode -pipe = StableDiffusionPipeline.from_pretrained( +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", vae=None, ).to("cuda") -prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" +init_image = load_image( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +) +init_image = init_image.resize((768, 512)) -latent = pipe( - prompt=prompt, - output_type="latent", -).images -image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, +init_latent = remote_encode( + endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/", + image=init_image, scaling_factor=0.18215, ) -image.save("test.jpg") -``` - -
- -
- -
- -Here’s another example with Flux. - -
Code - -```python -from diffusers import FluxPipeline - -pipe = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-schnell", - torch_dtype=torch.bfloat16, - vae=None, -).to("cuda") - -prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" +prompt = "A fantasy landscape, trending on artstation" latent = pipe( prompt=prompt, - guidance_scale=0.0, - num_inference_steps=4, + image=init_latent, + strength=0.75, output_type="latent", ).images -image = remote_decode( - endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - height=1024, - width=1024, - scaling_factor=0.3611, - shift_factor=0.1159, -) -image.save("test.jpg") -``` - -
- -
- -
-Here’s an example with HunyuanVideo. - -
Code - -```python -from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel - -model_id = "hunyuanvideo-community/HunyuanVideo" -transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=torch.bfloat16 -) -pipe = HunyuanVideoPipeline.from_pretrained( - model_id, transformer=transformer, vae=None, torch_dtype=torch.float16 -).to("cuda") - -latent = pipe( - prompt="A cat walks on the grass, realistic", - height=320, - width=512, - num_frames=61, - num_inference_steps=30, - output_type="latent", -).frames - -video = remote_decode( - endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", +image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", tensor=latent, - output_type="mp4", -) - -if isinstance(video, bytes): - with open("video.mp4", "wb") as f: - f.write(video) -``` - -
- -
- -
- - -### Queueing - -One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency. - - -
Code - -```python -import queue -import threading -from IPython.display import display -from diffusers import StableDiffusionPipeline - -def decode_worker(q: queue.Queue): - while True: - item = q.get() - if item is None: - break - image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=item, - scaling_factor=0.18215, - ) - display(image) - q.task_done() - -q = queue.Queue() -thread = threading.Thread(target=decode_worker, args=(q,), daemon=True) -thread.start() - -def decode(latent: torch.Tensor): - q.put(latent) - -prompts = [ - "Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious", - "Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore", - "Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.", - "Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP", - "A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting", - "Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,", -] - -pipe = StableDiffusionPipeline.from_pretrained( - "Lykon/dreamshaper-8", - torch_dtype=torch.float16, - vae=None, -).to("cuda") - -pipe.unet = pipe.unet.to(memory_format=torch.channels_last) -pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - -_ = pipe( - prompt=prompts[0], - output_type="latent", + scaling_factor=0.18215, ) - -for prompt in prompts: - latent = pipe( - prompt=prompt, - output_type="latent", - ).images - decode(latent) - -q.put(None) -thread.join() +image.save("fantasy_landscape.jpg") ```
-
- +
## Integrations From 15914a9967e13f863fb29f4c6b771674f67719b5 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Mar 2025 14:51:49 +0000 Subject: [PATCH 05/10] vae encode benchmarks --- docs/source/en/hybrid_inference/vae_encode.md | 51 ++++++++++++++++--- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md index 844f96d4af0e..dd285fa25c03 100644 --- a/docs/source/en/hybrid_inference/vae_encode.md +++ b/docs/source/en/hybrid_inference/vae_encode.md @@ -10,17 +10,56 @@ For the majority of these GPUs the memory usage % dictates other models (text en
SD v1.5 -| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | -| --- | --- | --- | --- | --- | --- | -TODO +| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 | +
SDXL -| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | -| --- | --- | --- | --- | --- | --- | -TODO +| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
From 0a2231a7a3e3c41a8c98c6b33ee4b90ca787a494 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Mar 2025 14:54:00 +0000 Subject: [PATCH 06/10] api reference --- docs/source/en/hybrid_inference/api_reference.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md index aa0a5e5ae58f..865aaba5ebb6 100644 --- a/docs/source/en/hybrid_inference/api_reference.md +++ b/docs/source/en/hybrid_inference/api_reference.md @@ -3,3 +3,7 @@ ## Remote Decode [[autodoc]] utils.remote_utils.remote_decode + +## Remote Encode + +[[autodoc]] utils.remote_utils.remote_encode From 0f5705bd40a9cb818b39e8b62d6d6380eb1c9697 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Mar 2025 14:55:39 +0000 Subject: [PATCH 07/10] changelog --- docs/source/en/hybrid_inference/overview.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index 53acfcf6c34f..498e22fad50e 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -46,6 +46,11 @@ Hybrid Inference offers a fast and simple way to offload local generation requir * **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. * **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. +## Changelog + +- March 10 2025: Added VAE encode +- March 2 2025: Initial release. + ## Contents The documentation is organized into three sections: From c6ac397ebacfda65f2ede6a8f0291382613a6467 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 11 Mar 2025 06:13:56 +0000 Subject: [PATCH 08/10] Update docs/source/en/hybrid_inference/overview.md Co-authored-by: Sayak Paul --- docs/source/en/hybrid_inference/overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index 498e22fad50e..b44393c77cbd 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -49,7 +49,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Changelog - March 10 2025: Added VAE encode -- March 2 2025: Initial release. +- March 2 2025: Initial release with VAE decoding ## Contents From abb3e3bde8674873cd8029bc65b14ddbafc92e68 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 11 Mar 2025 08:23:30 +0000 Subject: [PATCH 09/10] update --- src/diffusers/utils/constants.py | 11 +++++++++++ src/diffusers/utils/remote_utils.py | 11 ----------- tests/remote/test_remote_decode.py | 4 +++- tests/remote/test_remote_encode.py | 4 +++- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 3f88f347710f..fa12318f4714 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -56,3 +56,14 @@ if USE_PEFT_BACKEND and _CHECK_PEFT: dep_version_check("peft") + + +DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" + + +ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index bee641ce5592..fbce33d97f54 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -43,17 +43,6 @@ from PIL import Image -DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" -DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" -DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" -DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" - - -ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" -ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" -ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" - - def detect_image_type(data: bytes) -> str: if data.startswith(b"\xff\xd8"): return "jpeg" diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index 7308b1121991..cec96e729a48 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -21,11 +21,13 @@ import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.utils.remote_utils import ( +from diffusers.utils.constants import ( DECODE_ENDPOINT_FLUX, DECODE_ENDPOINT_HUNYUAN_VIDEO, DECODE_ENDPOINT_SD_V1, DECODE_ENDPOINT_SD_XL, +) +from diffusers.utils.remote_utils import ( remote_decode, ) from diffusers.utils.testing_utils import ( diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py index 39bea125337e..62ed97ee8f49 100644 --- a/tests/remote/test_remote_encode.py +++ b/tests/remote/test_remote_encode.py @@ -19,13 +19,15 @@ import torch from diffusers.utils import load_image -from diffusers.utils.remote_utils import ( +from diffusers.utils.constants import ( DECODE_ENDPOINT_FLUX, DECODE_ENDPOINT_SD_V1, DECODE_ENDPOINT_SD_XL, ENCODE_ENDPOINT_FLUX, ENCODE_ENDPOINT_SD_V1, ENCODE_ENDPOINT_SD_XL, +) +from diffusers.utils.remote_utils import ( remote_decode, remote_encode, ) From 79620bf8f5eddcec827c946830ee298f336a1533 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 11 Mar 2025 13:34:38 +0000 Subject: [PATCH 10/10] =?UTF-8?q?[hybrid=20inference=20=F0=9F=8D=AF?= =?UTF-8?q?=F0=9F=90=9D]=20Wan=202.1=20decode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/en/hybrid_inference/overview.md | 1 + docs/source/en/hybrid_inference/vae_decode.md | 1 + src/diffusers/utils/constants.py | 2 +- src/diffusers/utils/remote_utils.py | 7 --- tests/remote/test_remote_decode.py | 45 ++++++++++++------- 5 files changed, 33 insertions(+), 23 deletions(-) diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index b44393c77cbd..6a5c3313070a 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -48,6 +48,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Changelog +- March 11 2025: Added Wan 2.1 VAE decode - March 10 2025: Added VAE encode - March 2 2025: Initial release with VAE decoding diff --git a/docs/source/en/hybrid_inference/vae_decode.md b/docs/source/en/hybrid_inference/vae_decode.md index 1457090550c7..5320608ddc2d 100644 --- a/docs/source/en/hybrid_inference/vae_decode.md +++ b/docs/source/en/hybrid_inference/vae_decode.md @@ -54,6 +54,7 @@ For the majority of these GPUs the memory usage % dictates other models (text en | **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | | **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | | **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) | +| **Wan2.1** | [https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud](https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud) | [`Wan-AI/Wan2.1-T2V-1.3B-Diffusers`](https://hf.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) | > [!TIP] diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index fa12318f4714..638678ef78d7 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -62,7 +62,7 @@ DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" - +DECODE_ENDPOINT_WAN_2_1 = "https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index fbce33d97f54..2df7a19f68b4 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -80,13 +80,6 @@ def check_inputs_decode( and not isinstance(processor, (VaeImageProcessor, VideoProcessor)) ): raise ValueError("`processor` is required.") - if do_scaling and scaling_factor is None: - deprecate( - "do_scaling", - "1.0.0", - "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", - standard_warn=False, - ) def postprocess_decode( diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index cec96e729a48..e1f3435d33ef 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -26,6 +26,7 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO, DECODE_ENDPOINT_SD_V1, DECODE_ENDPOINT_SD_XL, + DECODE_ENDPOINT_WAN_2_1, ) from diffusers.utils.remote_utils import ( remote_decode, @@ -176,18 +177,6 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self): f"{output_slice}", ) - def test_do_scaling_deprecation(self): - inputs = self.get_dummy_inputs() - inputs.pop("scaling_factor", None) - inputs.pop("shift_factor", None) - with self.assertWarns(FutureWarning) as warning: - _ = remote_decode(output_type="pt", partial_postprocess=True, **inputs) - self.assertEqual( - str(warning.warnings[0].message), - "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", - str(warning.warnings[0].message), - ) - def test_input_tensor_type_base64_deprecation(self): inputs = self.get_dummy_inputs() with self.assertWarns(FutureWarning) as warning: @@ -209,7 +198,7 @@ def test_output_tensor_type_base64_deprecation(self): ) -class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin): +class RemoteAutoencoderKLVideoMixin(RemoteAutoencoderKLMixin): def test_no_scaling(self): inputs = self.get_dummy_inputs() if inputs["scaling_factor"] is not None: @@ -221,7 +210,6 @@ def test_no_scaling(self): processor = self.processor_cls() output = remote_decode( output_type="pt", - # required for now, will be removed in next update do_scaling=False, processor=processor, **inputs, @@ -337,6 +325,8 @@ def test_output_type_mp4(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="mp4", return_type="mp4", **inputs) self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}") + with open("test.mp4", "wb") as f: + f.write(output) class RemoteAutoencoderKLSDv1Tests( @@ -442,7 +432,7 @@ class RemoteAutoencoderKLFluxPackedTests( class RemoteAutoencoderKLHunyuanVideoTests( - RemoteAutoencoderKLHunyuanVideoMixin, + RemoteAutoencoderKLVideoMixin, unittest.TestCase, ): shape = ( @@ -467,6 +457,31 @@ class RemoteAutoencoderKLHunyuanVideoTests( return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708]) +class RemoteAutoencoderKLWanTests( + RemoteAutoencoderKLVideoMixin, + unittest.TestCase, +): + shape = ( + 1, + 16, + 3, + 40, + 64, + ) + out_hw = ( + 320, + 512, + ) + endpoint = DECODE_ENDPOINT_WAN_2_1 + dtype = torch.float16 + processor_cls = VideoProcessor + output_pt_slice = torch.tensor([203, 174, 178, 204, 171, 177, 209, 183, 182], dtype=torch.uint8) + partial_postprocess_return_pt_slice = torch.tensor( + [206, 209, 221, 202, 199, 222, 207, 210, 217], dtype=torch.uint8 + ) + return_pt_slice = torch.tensor([0.6196, 0.6382, 0.7310, 0.5869, 0.5625, 0.7373, 0.6240, 0.6465, 0.7002]) + + class RemoteAutoencoderKLSlowTestMixin: channels: int = 4 endpoint: str = None