Skip to content

[hybrid inference 🍯🐝] Wan 2.1 decode #11031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
1 change: 1 addition & 0 deletions docs/source/en/hybrid_inference/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir

## Changelog

- March 11 2025: Added Wan 2.1 VAE decode
- March 10 2025: Added VAE encode
- March 2 2025: Initial release with VAE decoding

Expand Down
1 change: 1 addition & 0 deletions docs/source/en/hybrid_inference/vae_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ For the majority of these GPUs the memory usage % dictates other models (text en
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
| **Wan2.1** | [https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud](https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud) | [`Wan-AI/Wan2.1-T2V-1.3B-Diffusers`](https://hf.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) |


> [!TIP]
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"

DECODE_ENDPOINT_WAN_2_1 = "https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud/"

ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
Expand Down
7 changes: 0 additions & 7 deletions src/diffusers/utils/remote_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,6 @@ def check_inputs_decode(
and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
):
raise ValueError("`processor` is required.")
if do_scaling and scaling_factor is None:
deprecate(
"do_scaling",
"1.0.0",
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
standard_warn=False,
)


def postprocess_decode(
Expand Down
45 changes: 30 additions & 15 deletions tests/remote/test_remote_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DECODE_ENDPOINT_HUNYUAN_VIDEO,
DECODE_ENDPOINT_SD_V1,
DECODE_ENDPOINT_SD_XL,
DECODE_ENDPOINT_WAN_2_1,
)
from diffusers.utils.remote_utils import (
remote_decode,
Expand Down Expand Up @@ -176,18 +177,6 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self):
f"{output_slice}",
)

def test_do_scaling_deprecation(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wan VAE doesn't have a scaling_factor that we could pass, so this deprecation/removing do_scaling doesn't work, we will keep it.

inputs = self.get_dummy_inputs()
inputs.pop("scaling_factor", None)
inputs.pop("shift_factor", None)
with self.assertWarns(FutureWarning) as warning:
_ = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
self.assertEqual(
str(warning.warnings[0].message),
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
str(warning.warnings[0].message),
)

def test_input_tensor_type_base64_deprecation(self):
inputs = self.get_dummy_inputs()
with self.assertWarns(FutureWarning) as warning:
Expand All @@ -209,7 +198,7 @@ def test_output_tensor_type_base64_deprecation(self):
)


class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin):
class RemoteAutoencoderKLVideoMixin(RemoteAutoencoderKLMixin):
def test_no_scaling(self):
inputs = self.get_dummy_inputs()
if inputs["scaling_factor"] is not None:
Expand All @@ -221,7 +210,6 @@ def test_no_scaling(self):
processor = self.processor_cls()
output = remote_decode(
output_type="pt",
# required for now, will be removed in next update
do_scaling=False,
processor=processor,
**inputs,
Expand Down Expand Up @@ -337,6 +325,8 @@ def test_output_type_mp4(self):
inputs = self.get_dummy_inputs()
output = remote_decode(output_type="mp4", return_type="mp4", **inputs)
self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}")
with open("test.mp4", "wb") as f:
f.write(output)


class RemoteAutoencoderKLSDv1Tests(
Expand Down Expand Up @@ -442,7 +432,7 @@ class RemoteAutoencoderKLFluxPackedTests(


class RemoteAutoencoderKLHunyuanVideoTests(
RemoteAutoencoderKLHunyuanVideoMixin,
RemoteAutoencoderKLVideoMixin,
unittest.TestCase,
):
shape = (
Expand All @@ -467,6 +457,31 @@ class RemoteAutoencoderKLHunyuanVideoTests(
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])


class RemoteAutoencoderKLWanTests(
RemoteAutoencoderKLVideoMixin,
unittest.TestCase,
):
shape = (
1,
16,
3,
40,
64,
)
out_hw = (
320,
512,
)
endpoint = DECODE_ENDPOINT_WAN_2_1
dtype = torch.float16
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yiyixuxu Currently the endpoint is running in float16 and output seems ok (on random latent), the examples use float32 but we noticed in the original code that everything is under bfloat16 autocast context. Can we check with the authors regarding the use of float32 for VAE?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i haven't heard back from them,
maybe let's generate some sample comparison and ask @asomoza to help decide if there is a quality difference

processor_cls = VideoProcessor
output_pt_slice = torch.tensor([203, 174, 178, 204, 171, 177, 209, 183, 182], dtype=torch.uint8)
partial_postprocess_return_pt_slice = torch.tensor(
[206, 209, 221, 202, 199, 222, 207, 210, 217], dtype=torch.uint8
)
return_pt_slice = torch.tensor([0.6196, 0.6382, 0.7310, 0.5869, 0.5625, 0.7373, 0.6240, 0.6465, 0.7002])


class RemoteAutoencoderKLSlowTestMixin:
channels: int = 4
endpoint: str = None
Expand Down
Loading