Skip to content

[WIP] Modular Diffusers support custom code/pipeline blocks #11539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: modular-refactor
Choose a base branch
from

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented May 12, 2025

What does this PR do?

Add support for loading custom pipeline blocks with Modular Diffusers. PR is still in very rough shape, but is functional.

Snippet to test

from diffusers.pipelines.modular_pipeline import PipelineBlock

block = PipelineBlock.from_pretrained(
        "diffusers-internal-dev/modular-depth-block", trust_remote_code=True
)

Note I think the formatting changes might have been because of a difference in ruff versions.

TODOs:

  • Possibly move logic to fetch custom modules into AutoPipelineBlock rather than have it in ModularPipelineMixin
  • Clean up custom code fetching in get_class_from_dynamic_module. Probably don't need the is_modular argument.
  • Remove get_class_in_modular_module
  • trust_remote_code=None is currently broken. Need to fix.
  • Add logic to save custom pipeline blocks
  • Add support for local custom code.
  • Add tests

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Dhruv!

I couldn't run the code as there seems to be some problems with undefined variables. I can give the saving logic a look after that is fixed.

@@ -154,15 +157,132 @@ def check_imports(filename):
return get_relative_imports(filename)


def get_class_in_module(class_name, module_path):
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): Would prefer this to be a private method.

Comment on lines +161 to +168
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variables like signal seem to be undefined. After-effects of merge-conflict resolves?

Comment on lines +171 to +180
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's just about passing trust_remote_code=Yes, would it be too much to just enforce that and avoid signal altogether?

@@ -37,6 +39,7 @@

# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
_HF_REMOTE_CODE_LOCK = threading.Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to acquire a lock?

Comment on lines +237 to +239
# Hash the module file and all its relative imports to check if we need to reload it
module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if the imports are too many could it cause any side-effects? But I guess since it's custom blocks, the users have some form of awareness already. But just wanted to flag.

Comment on lines +248 to +252
# reload in both cases, unless the module is already imported and the hash hits
if getattr(module, "__transformers_module_hash__", "") != module_hash:
module_spec.loader.exec_module(module)
module.__transformers_module_hash__ = module_hash

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it only transformers?

"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
name = os.path.normpath(module_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any common bits shared by get_class_in_module and get_class_in_modular_module that we could wrap in a method?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants