-
Notifications
You must be signed in to change notification settings - Fork 6k
[hybrid inference 🍯🐝] Wan 2.1 decode #11031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
081e68f
140e0c2
e70bdb2
e5448f2
15914a9
0a2231a
0f5705b
998c3c6
b2756ad
c6ac397
abb3e3b
73adcd8
79620bf
a1eacb3
3f44fa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
DECODE_ENDPOINT_HUNYUAN_VIDEO, | ||
DECODE_ENDPOINT_SD_V1, | ||
DECODE_ENDPOINT_SD_XL, | ||
DECODE_ENDPOINT_WAN_2_1, | ||
) | ||
from diffusers.utils.remote_utils import ( | ||
remote_decode, | ||
|
@@ -176,18 +177,6 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self): | |
f"{output_slice}", | ||
) | ||
|
||
def test_do_scaling_deprecation(self): | ||
inputs = self.get_dummy_inputs() | ||
inputs.pop("scaling_factor", None) | ||
inputs.pop("shift_factor", None) | ||
with self.assertWarns(FutureWarning) as warning: | ||
_ = remote_decode(output_type="pt", partial_postprocess=True, **inputs) | ||
self.assertEqual( | ||
str(warning.warnings[0].message), | ||
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", | ||
str(warning.warnings[0].message), | ||
) | ||
|
||
def test_input_tensor_type_base64_deprecation(self): | ||
inputs = self.get_dummy_inputs() | ||
with self.assertWarns(FutureWarning) as warning: | ||
|
@@ -209,7 +198,7 @@ def test_output_tensor_type_base64_deprecation(self): | |
) | ||
|
||
|
||
class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin): | ||
class RemoteAutoencoderKLVideoMixin(RemoteAutoencoderKLMixin): | ||
def test_no_scaling(self): | ||
inputs = self.get_dummy_inputs() | ||
if inputs["scaling_factor"] is not None: | ||
|
@@ -221,7 +210,6 @@ def test_no_scaling(self): | |
processor = self.processor_cls() | ||
output = remote_decode( | ||
output_type="pt", | ||
# required for now, will be removed in next update | ||
do_scaling=False, | ||
processor=processor, | ||
**inputs, | ||
|
@@ -337,6 +325,8 @@ def test_output_type_mp4(self): | |
inputs = self.get_dummy_inputs() | ||
output = remote_decode(output_type="mp4", return_type="mp4", **inputs) | ||
self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}") | ||
with open("test.mp4", "wb") as f: | ||
f.write(output) | ||
|
||
|
||
class RemoteAutoencoderKLSDv1Tests( | ||
|
@@ -442,7 +432,7 @@ class RemoteAutoencoderKLFluxPackedTests( | |
|
||
|
||
class RemoteAutoencoderKLHunyuanVideoTests( | ||
RemoteAutoencoderKLHunyuanVideoMixin, | ||
RemoteAutoencoderKLVideoMixin, | ||
unittest.TestCase, | ||
): | ||
shape = ( | ||
|
@@ -467,6 +457,31 @@ class RemoteAutoencoderKLHunyuanVideoTests( | |
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708]) | ||
|
||
|
||
class RemoteAutoencoderKLWanTests( | ||
RemoteAutoencoderKLVideoMixin, | ||
unittest.TestCase, | ||
): | ||
shape = ( | ||
1, | ||
16, | ||
3, | ||
40, | ||
64, | ||
) | ||
out_hw = ( | ||
320, | ||
512, | ||
) | ||
endpoint = DECODE_ENDPOINT_WAN_2_1 | ||
dtype = torch.float16 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @yiyixuxu Currently the endpoint is running in float16 and output seems ok (on random latent), the examples use float32 but we noticed in the original code that everything is under bfloat16 autocast context. Can we check with the authors regarding the use of float32 for VAE? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i haven't heard back from them, |
||
processor_cls = VideoProcessor | ||
output_pt_slice = torch.tensor([203, 174, 178, 204, 171, 177, 209, 183, 182], dtype=torch.uint8) | ||
partial_postprocess_return_pt_slice = torch.tensor( | ||
[206, 209, 221, 202, 199, 222, 207, 210, 217], dtype=torch.uint8 | ||
) | ||
return_pt_slice = torch.tensor([0.6196, 0.6382, 0.7310, 0.5869, 0.5625, 0.7373, 0.6240, 0.6465, 0.7002]) | ||
|
||
|
||
class RemoteAutoencoderKLSlowTestMixin: | ||
channels: int = 4 | ||
endpoint: str = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wan VAE doesn't have a
scaling_factor
that we could pass, so this deprecation/removingdo_scaling
doesn't work, we will keep it.