From 2904639ea5567e44e92f36fd59e093e6eb057eee Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 12 May 2025 19:41:39 +0530 Subject: [PATCH] update --- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/components_manager.py | 154 +++-- src/diffusers/pipelines/modular_pipeline.py | 456 +++++++------- .../pipelines/modular_pipeline_utils.py | 157 ++--- .../pipelines/pipeline_loading_utils.py | 2 + src/diffusers/pipelines/pipeline_utils.py | 2 +- .../pipelines/stable_diffusion_xl/__init__.py | 2 +- .../pipeline_stable_diffusion_xl_modular.py | 572 +++++++++--------- src/diffusers/utils/dynamic_modules_utils.py | 154 ++++- 9 files changed, 831 insertions(+), 670 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0567eb687c62..4ef7faf54a4f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -703,12 +703,12 @@ from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_sag import StableDiffusionSAGPipeline from .stable_diffusion_xl import ( + StableDiffusionXLAutoPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLModularLoader, StableDiffusionXLPipeline, - StableDiffusionXLAutoPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( diff --git a/src/diffusers/pipelines/components_manager.py b/src/diffusers/pipelines/components_manager.py index bdff133e22d9..bd126ee2c7f5 100644 --- a/src/diffusers/pipelines/components_manager.py +++ b/src/diffusers/pipelines/components_manager.py @@ -12,21 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import time from collections import OrderedDict from itertools import combinations -from typing import List, Optional, Union, Dict, Any -import copy +from typing import Any, Dict, List, Optional, Union import torch -import time -from dataclasses import dataclass from ..utils import ( is_accelerate_available, logging, ) -from ..models.modeling_utils import ModelMixin -from .modular_pipeline_utils import ComponentSpec if is_accelerate_available(): @@ -231,17 +228,18 @@ def search_best_candidate(module_sizes, min_memory_offload): -from .modular_pipeline_utils import ComponentSpec import uuid + + class ComponentsManager: def __init__(self): self.components = OrderedDict() - self.added_time = OrderedDict() # Store when components were added + self.added_time = OrderedDict() # Store when components were added self.collections = OrderedDict() # collection_name -> set of component_names self.model_hooks = None self._auto_offload_enabled = False - + def _get_by_collection(self, collection: str): """ Select components by collection name. @@ -252,8 +250,8 @@ def _get_by_collection(self, collection: str): for component_id in component_ids: selected_components[component_id] = self.components[component_id] return selected_components - - + + def _get_by_load_id(self, load_id: str): """ Select components by its load_id. @@ -263,8 +261,8 @@ def _get_by_load_id(self, load_id: str): if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: selected_components[name] = component return selected_components - - + + def add(self, name, component, collection: Optional[str] = None): for comp_id, comp in self.components.items(): @@ -282,7 +280,7 @@ def add(self, name, component, collection: Optional[str] = None): f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " f"To remove a duplicate, call `components_manager.remove('')`." ) - + # add component to components manager self.components[component_id] = component @@ -293,8 +291,8 @@ def add(self, name, component, collection: Optional[str] = None): self.collections[collection].add(component_id) if self._auto_offload_enabled: - self.enable_auto_cpu_offload(self._auto_offload_device) - + self.enable_auto_cpu_offload(self._auto_offload_device) + logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'") return component_id @@ -304,14 +302,14 @@ def remove(self, name: Union[str, List[str]]): if name not in self.components: logger.warning(f"Component '{name}' not found in ComponentsManager") return - + self.components.pop(name) self.added_time.pop(name) for collection in self.collections: if name in self.collections[collection]: self.collections[collection].remove(name) - + if self._auto_offload_enabled: self.enable_auto_cpu_offload(self._auto_offload_device) @@ -341,7 +339,7 @@ def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = N Dictionary mapping component IDs to components, or list of (base_name, component) tuples if as_name_component_tuples=True """ - + if collection: if collection not in self.collections: logger.warning(f"Collection '{collection}' not found in ComponentsManager") @@ -360,16 +358,16 @@ def get_base_name(component_id): if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: return '_'.join(parts[:-1]) return component_id - + if names is None: if as_name_component_tuples: return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()] else: return components - + # Create mapping from component_id to base_name for all components base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()} - + def matches_pattern(component_id, pattern, exact_match=False): """ Helper function to check if a component matches a pattern based on its base name. @@ -380,124 +378,124 @@ def matches_pattern(component_id, pattern, exact_match=False): exact_match: If True, only exact matches to base_name are considered """ base_name = base_names[component_id] - + # Exact match with base name if exact_match: return pattern == base_name - + # Prefix match (ends with *) elif pattern.endswith('*'): prefix = pattern[:-1] return base_name.startswith(prefix) - + # Contains match (starts with *) elif pattern.startswith('*'): search = pattern[1:-1] if pattern.endswith('*') else pattern[1:] return search in base_name - + # Exact match (no wildcards) else: return pattern == base_name - + if isinstance(names, str): # Check if this is a "not" pattern is_not_pattern = names.startswith('!') if is_not_pattern: names = names[1:] # Remove the ! prefix - + # Handle OR patterns (containing |) if '|' in names: terms = names.split('|') matches = {} - + for comp_id, comp in components.items(): # For OR patterns with exact names (no wildcards), we do exact matching on base names exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms) - + # Check if any of the terms match this component should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) - + # Flip the decision if this is a NOT pattern if is_not_pattern: should_include = not should_include - + if should_include: matches[comp_id] = comp - + log_msg = "NOT " if is_not_pattern else "" match_type = "exactly matching" if exact_match else "matching any of patterns" logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") - + # Try exact match with a base name elif any(names == base_name for base_name in base_names.values()): # Find all components with this base name matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (base_names[comp_id] == names) != is_not_pattern } - + if is_not_pattern: logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") else: logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") - + # Prefix match (ends with *) elif names.endswith('*'): prefix = names[:-1] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if base_names[comp_id].startswith(prefix) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") else: logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") - + # Contains match (starts with *) elif names.startswith('*'): search = names[1:-1] if names.endswith('*') else names[1:] matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (search in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{search}': {list(matches.keys())}") - + # Substring match (no wildcards, but not an exact component name) elif any(names in base_name for base_name in base_names.values()): matches = { - comp_id: comp for comp_id, comp in components.items() + comp_id: comp for comp_id, comp in components.items() if (names in base_names[comp_id]) != is_not_pattern } if is_not_pattern: logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") else: logger.info(f"Getting components containing '{names}': {list(matches.keys())}") - + else: raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") - + if not matches: raise ValueError(f"No components found matching pattern '{names}'") - + if as_name_component_tuples: return [(base_names[comp_id], comp) for comp_id, comp in matches.items()] else: return matches - + elif isinstance(names, list): results = {} for name in names: result = self.get(name, collection, load_id, as_name_component_tuples=False) results.update(result) - + if as_name_component_tuples: return [(base_names[comp_id], comp) for comp_id, comp in results.items()] else: return results - + else: raise ValueError(f"Invalid type for names: {type(names)}") @@ -558,14 +556,14 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No raise ValueError(f"Component '{name}' not found in ComponentsManager") component = self.components[name] - + # Build complete info dict first info = { "model_id": name, "added_time": self.added_time[name], "collection": next((coll for coll, comps in self.collections.items() if name in comps), None), } - + # Additional info for torch.nn.Module components if isinstance(component, torch.nn.Module): # Check for hook information @@ -573,7 +571,7 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No execution_device = None if has_hook and hasattr(component._hf_hook, "execution_device"): execution_device = component._hf_hook.execution_device - + info.update({ "class_name": component.__class__.__name__, "size_gb": get_memory_footprint(component) / (1024**3), @@ -594,8 +592,8 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No if any("IPAdapter" in ptype for ptype in processor_types): # Then get scales only from IP-Adapter processors scales = { - k: v.scale - for k, v in processors.items() + k: v.scale + for k, v in processors.items() if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ } if scales: @@ -609,7 +607,7 @@ def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = No else: # List of fields requested, return dict with just those fields return {k: v for k, v in info.items() if k in fields} - + return info def __repr__(self): @@ -622,13 +620,13 @@ def get_simple_name(name): if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]: return '_'.join(parts[:-1]) return name - + # Extract load_id if available def get_load_id(component): if hasattr(component, "_diffusers_load_id"): return component._diffusers_load_id return "N/A" - + # Format device info compactly def format_device(component, info): if not info["has_hook"]: @@ -637,24 +635,24 @@ def format_device(component, info): device = str(getattr(component, 'device', 'N/A')) exec_device = str(info['execution_device'] or 'N/A') return f"{device}({exec_device})" - + # Get all simple names to calculate width simple_names = [get_simple_name(id) for id in self.components.keys()] - + # Get max length of load_ids for models load_ids = [ - get_load_id(component) - for component in self.components.values() + get_load_id(component) + for component in self.components.values() if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") ] max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 - + # Collection names collection_names = [ next((coll for coll, comps in self.collections.items() if name in comps), "N/A") for name in self.components.keys() ] - + col_widths = { "name": max(15, max(len(name) for name in simple_names)), "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), @@ -692,7 +690,7 @@ def format_device(component, info): dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" load_id = get_load_id(component) collection = info["collection"] or "N/A" - + output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | " output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n" @@ -712,7 +710,7 @@ def format_device(component, info): info = self.get_model_info(name) simple_name = get_simple_name(name) collection = info["collection"] or "N/A" - + output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n" output += dash_line @@ -726,9 +724,9 @@ def format_device(component, info): if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): - output += f" IP-Adapter: Enabled\n" + output += " IP-Adapter: Enabled\n" output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n" - + return output def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs): @@ -759,13 +757,13 @@ def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = from ..pipelines.pipeline_utils import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) for name, component in pipe.components.items(): - + if component is None: continue - + # Add prefix if specified component_name = f"{prefix}_{name}" if prefix else name - + if component_name not in self.components: self.add(component_name, component) else: @@ -791,13 +789,13 @@ def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, ValueError: If no components match or multiple components match """ results = self.get(name, collection, load_id) - + if not results: raise ValueError(f"No components found matching '{name}'") - + if len(results) > 1: raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") - + return next(iter(results.values())) def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: @@ -823,17 +821,17 @@ def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]: if value_tuple not in value_to_keys: value_to_keys[value_tuple] = [] value_to_keys[value_tuple].append(key) - + def find_common_prefix(keys: List[str]) -> str: """Find the shortest common prefix among a list of dot-separated keys.""" if not keys: return "" if len(keys) == 1: return keys[0] - + # Split all keys into parts key_parts = [k.split('.') for k in keys] - + # Find how many initial parts are common common_length = 0 for parts in zip(*key_parts): @@ -841,10 +839,10 @@ def find_common_prefix(keys: List[str]) -> str: common_length += 1 else: break - + if common_length == 0: return "" - + # Return the common prefix return '.'.join(key_parts[0][:common_length]) @@ -858,5 +856,5 @@ def find_common_prefix(keys: List[str]) -> str: summary[prefix] = value else: summary[""] = value # Use empty string if no common prefix - + return summary diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 636b543395df..a90623d7217d 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -12,29 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os import traceback import warnings from collections import OrderedDict +from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional, Type - +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from tqdm.auto import tqdm -import re -import os -import importlib - from huggingface_hub.utils import validate_hf_hub_args +from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import ( + PushToHubMixin, is_accelerate_available, - is_accelerate_version, logging, - PushToHubMixin, ) -from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from .components_manager import ComponentsManager from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -42,18 +40,15 @@ OutputParam, format_components, format_configs, - format_input_params, format_inputs_short, format_intermediates_short, - format_output_params, - format_params, make_doc_string, ) -from .components_manager import ComponentsManager +from .pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj + -from copy import deepcopy if is_accelerate_available(): - import accelerate + pass logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -108,18 +103,16 @@ def format_value(v): intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" - f")" + f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" f" intermediates={{\n{intermediates}\n }}\n" f")" ) -@dataclass +@dataclass class BlockState: """ Container for block state data with attribute access and formatted representation. """ + def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @@ -129,28 +122,28 @@ def format_value(v): # Handle tensors directly if hasattr(v, "shape") and hasattr(v, "dtype"): return f"Tensor(dtype={v.dtype}, shape={v.shape})" - + # Handle lists of tensors elif isinstance(v, list): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"List[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle tuples of tensors elif isinstance(v, tuple): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle dicts with tensor values elif isinstance(v, dict): if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} return f"Dict of Tensors with shapes {shapes}" return repr(v) - + # Default case return repr(v) @@ -158,31 +151,78 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" - -class ModularPipelineMixin: +class ModularPipelineMixin(ConfigMixin): """ Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + config_name = "config.json" + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, False, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError("") + + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + block_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + module_file=module_file, + class_name=class_name, + is_modular=True, + **hub_kwargs, + **kwargs, + ) + return block_cls() + + def setup_loader( + self, + modular_repo: Optional[Union[str, os.PathLike]] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + ): """ - create a mouldar loader, optionally accept modular_repo to load from hub. + create a ModularLoader, optionally accept modular_repo to load from hub. """ # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) + diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) - + # Create deep copies to avoid modifying the original specs component_specs = deepcopy(self.expected_components) config_specs = deepcopy(self.expected_configs) # Create the loader with the updated specs specs = component_specs + config_specs - - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + self.loader = loader_class( + specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection + ) @property def default_call_parameters(self) -> Dict[str, Any]: @@ -238,7 +278,6 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, if output is None: return state - elif isinstance(output, str): return state.get_intermediate(output) @@ -268,9 +307,8 @@ def set_progress_bar_config(self, **kwargs): class PipelineBlock(ModularPipelineMixin): - model_name = None - + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -279,12 +317,11 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: return [] - + @property def expected_configs(self) -> List[ConfigSpec]: return [] - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property def inputs(self) -> List[InputParam]: @@ -322,7 +359,6 @@ def required_intermediates_inputs(self) -> List[str]: input_names.append(input_param.name) return input_names - def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -331,14 +367,14 @@ def __repr__(self): base_class = self.__class__.__bases__[0].__name__ # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) @@ -355,7 +391,9 @@ def __repr__(self): inputs = "Inputs:\n " + inputs_str # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = format_intermediates_short( + self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs + ) intermediates = f"Intermediates:\n{intermediates_str}" return ( @@ -369,24 +407,22 @@ def __repr__(self): f")" ) - @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) - def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - + # Check inputs for input_param in self.inputs: value = state.get_input(input_param.name) @@ -402,7 +438,7 @@ def get_block_state(self, state: PipelineState) -> dict: data[input_param.name] = value return BlockState(**data) - + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): @@ -412,26 +448,28 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values exist for the same input. Args: named_input_lists: List of tuples containing (block_name, input_param_list) pairs - + Returns: List[InputParam]: Combined list of unique InputParam objects """ combined_dict = {} # name -> InputParam value_sources = {} # name -> block_name - + for block_name, inputs in named_input_lists: for input_param in inputs: if input_param.name in combined_dict: current_param = combined_dict[input_param.name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): + if ( + current_param.default is not None + and input_param.default is not None + and current_param.default != input_param.default + ): warnings.warn( f"Multiple different default values found for input '{input_param.name}': " f"{current_param.default} (from block '{value_sources[input_param.name]}') and " @@ -443,9 +481,10 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li else: combined_dict[input_param.name] = input_param value_sources[input_param.name] = block_name - + return list(combined_dict.values()) + def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: """ Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, @@ -453,17 +492,17 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> Args: named_output_lists: List of tuples containing (block_name, output_param_list) pairs - + Returns: List[OutputParam]: Combined list of unique OutputParam objects """ combined_dict = {} # name -> OutputParam - + for block_name, outputs in named_output_lists: for output_param in outputs: if output_param.name not in combined_dict: combined_dict[output_param.name] = output_param - + return list(combined_dict.values()) @@ -487,15 +526,15 @@ def __init__(self): blocks[block_name] = block_cls() self.blocks = blocks if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + raise ValueError( + f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + ) default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last + # can only have 1 or 0 default block, and has to put in the last # the order of blocksmatters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): + if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): raise ValueError( f"In {self.__class__.__name__}, exactly one None must be specified as the last element " "in block_trigger_inputs." @@ -509,7 +548,7 @@ def __init__(self): @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -532,7 +571,6 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - @property def required_inputs(self) -> List[str]: first_block = next(iter(self.blocks.values())) @@ -557,7 +595,6 @@ def required_intermediates_inputs(self) -> List[str]: return list(required_by_all) - # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: @@ -571,7 +608,6 @@ def inputs(self) -> List[Tuple[str, Any]]: input_param.required = False return combined_inputs - @property def intermediates_inputs(self) -> List[str]: named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] @@ -589,7 +625,7 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] @@ -630,26 +666,27 @@ def _get_trigger_inputs(self): Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + trigger_inputs = set(self.block_trigger_inputs) trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - + return trigger_inputs @property @@ -660,12 +697,9 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -677,19 +711,19 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -699,7 +733,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -713,47 +747,41 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - + return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")" @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) + class SequentialPipelineBlocks(ModularPipelineMixin): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ + block_classes = [] block_names = [] @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -779,10 +807,10 @@ def expected_configs(self): @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - + Args: blocks_dict: Dictionary mapping block names to block instances - + Returns: A new SequentialPipelineBlocks instance """ @@ -791,14 +819,13 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo instance.block_names = list(blocks_dict.keys()) instance.blocks = blocks_dict return instance - + def __init__(self): blocks = OrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks - @property def required_inputs(self) -> List[str]: # Get the first block from the dictionary @@ -809,9 +836,9 @@ def required_inputs(self) -> List[str]: for block in list(self.blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) - + return list(required_by_any) - + @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -847,7 +874,7 @@ def intermediates_inputs(self) -> List[str]: should_add_outputs = True if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: should_add_outputs = False - + if should_add_outputs: # Add this block's outputs block_intermediates_outputs = [out.name for out in block.intermediates_outputs] @@ -859,11 +886,11 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: return next(reversed(self.blocks.values())).intermediates_outputs - + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): @@ -878,29 +905,30 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: logger.error(error_msg) raise return pipeline, state - + def _get_trigger_inputs(self): """ Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + return fn_recursive_get_trigger(self.blocks) @property @@ -913,10 +941,10 @@ def _traverse_trigger_blocks(self, trigger_inputs): def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - + # sequential or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): + if not hasattr(block, "block_trigger_inputs"): + if hasattr(block, "blocks"): # sequential for block_name, block in block.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) @@ -925,10 +953,10 @@ def fn_recursive_traverse(block, block_name, active_triggers): # PipelineBlock result_blocks[block_name] = block # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): + if hasattr(block, "outputs"): active_triggers.update(out.name for out in block.outputs) return result_blocks - + # auto else: # Find first block_trigger_input that matches any value in our active_triggers @@ -939,36 +967,35 @@ def fn_recursive_traverse(block, block_name, active_triggers): this_block = block.trigger_to_block_map[trigger_input] matching_trigger = trigger_input break - + # If no matches found, try to get the default (None) block if this_block is None and None in block.block_trigger_inputs: this_block = block.trigger_to_block_map[None] matching_trigger = None - + if this_block is not None: # sequential/auto - if hasattr(this_block, 'blocks'): + if hasattr(this_block, "blocks"): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block # Add this block's output names to active triggers if defined - if hasattr(this_block, 'outputs'): + if hasattr(this_block, "outputs"): active_triggers.update(out.name for out in this_block.outputs) return result_blocks - + all_blocks = OrderedDict() for block_name, block in self.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) all_blocks.update(blocks_to_update) return all_blocks - + def get_execution_blocks(self, *trigger_inputs): trigger_inputs_all = self.trigger_inputs if trigger_inputs is not None: - if not isinstance(trigger_inputs, (list, tuple, set)): trigger_inputs = [trigger_inputs] invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] @@ -977,7 +1004,7 @@ def get_execution_blocks(self, *trigger_inputs): f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" ) trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - + if trigger_inputs is None: if None in trigger_inputs_all: trigger_inputs = [None] @@ -985,17 +1012,14 @@ def get_execution_blocks(self, *trigger_inputs): trigger_inputs = [trigger_inputs_all[0]] blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - + def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -1007,19 +1031,19 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -1029,7 +1053,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -1043,39 +1067,30 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - + return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")" @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) - -# YiYi TODO: +# YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader # 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() @@ -1084,30 +1099,29 @@ class ModularLoader(ConfigMixin, PushToHubMixin): Base class for all Modular pipelines loaders. """ - config_name = "modular_model_index.json" + config_name = "modular_model_index.json" def register_components(self, **kwargs): """ - Register components with their corresponding specs. + Register components with their corresponding specs. This method is called when component changed or __init__ is called. Args: **kwargs: Keyword arguments where keys are component names and values are component objects. - + """ for name, module in kwargs.items(): - # current component spec component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue - + is_registered = hasattr(self, name) if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") # actual library and class name of the module @@ -1115,10 +1129,10 @@ def register_components(self, **kwargs): library, class_name = _fetch_class_library_tuple(module) new_component_spec = ComponentSpec.from_component(name, module) component_spec_dict = self._component_spec_to_dict(new_component_spec) - + else: library, class_name = None, None - # if module is None, we do not update the spec, + # if module is None, we do not update the spec, # but we still need to update the config to make sure it's synced with the component spec # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) new_component_spec = component_spec @@ -1139,16 +1153,24 @@ def register_components(self, **kwargs): if module is not None and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue - + current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: - logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + logger.info( + f"ModularLoader.register_components: {name} is already registered with same object, skipping" + ) continue - + # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + if ( + module is not None + and component_spec.type_hint is not None + and not isinstance(module, component_spec.type_hint) + ): + logger.warning( + f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}" + ) # warn if unregister if current_module is not None and module is None: @@ -1157,10 +1179,12 @@ def register_components(self, **kwargs): f"(was {current_module.__class__.__name__})" ) # same type, new instance → debug - elif current_module is not None \ - and module is not None \ - and isinstance(module, current_module.__class__) \ - and current_module != module: + elif ( + current_module is not None + and module is not None + and isinstance(module, current_module.__class__) + and current_module != module + ): logger.debug( f"ModularLoader.register_components: replacing existing '{name}' " f"(same type {type(current_module).__name__}, new instance)" @@ -1175,46 +1199,51 @@ def register_components(self, **kwargs): if module is not None and self._component_manager is not None: self._component_manager.add(name, module, self._collection) - - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + def __init__( + self, + specs: List[Union[ComponentSpec, ConfigSpec]], + modular_repo: Optional[str] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): """ Initialize the loader with a list of component specs and config specs. """ self._component_manager = component_manager self._collection = collection - self._component_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) - } - self._config_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) - } + self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)} + self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)} # update component_specs and config_specs from modular_repo if modular_repo is not None: config_dict = self.load_config(modular_repo, **kwargs) for name, value in config_dict.items(): - if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + if ( + name in self._component_specs + and self._component_specs[name].default_creation_method == "from_pretrained" + and isinstance(value, (tuple, list)) + and len(value) == 3 + ): library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) self._component_specs[name] = component_spec elif name in self._config_specs: self._config_specs[name].default = value - + register_components_dict = {} for name, component_spec in self._component_specs.items(): register_components_dict[name] = None self.register_components(**register_components_dict) - + default_configs = {} for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default self.register_to_config(**default_configs) - @property def device(self) -> torch.device: r""" @@ -1251,7 +1280,7 @@ def _execution_device(self): ): return torch.device(module._hf_hook.execution_device) return self.device - + @property def device(self) -> torch.device: r""" @@ -1280,23 +1309,18 @@ def dtype(self) -> torch.dtype: return torch.float32 - @property def components(self) -> Dict[str, Any]: # return only components we've actually set as attributes on self - return { - name: getattr(self, name) - for name in self._component_specs.keys() - if hasattr(self, name) - } + return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} def update(self, **kwargs): """ Update components and configs after instance creation. - + Args: - """ + """ """ Update components and configuration values after the loader has been instantiated. @@ -1332,7 +1356,7 @@ def update(self, **kwargs): requires_safety_checker=False ) ``` - """ + """ # extract component_specs_updates & config_specs_updates from `specs` passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} @@ -1340,29 +1364,25 @@ def update(self, **kwargs): for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") + if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - self.register_components(**passed_components) - config_to_register = {} for name, new_value in passed_config_values.items(): - # e.g. requires_aesthetics_score = False self._config_specs[name].default = new_value config_to_register[name] = new_value self.register_to_config(**config_to_register) - # YiYi TODO: support map for additional from_pretrained kwargs def load(self, component_names: Optional[List[str]] = None, **kwargs): """ Load selectedcomponents from specs. - + Args: component_names: List of component names to load **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: @@ -1379,7 +1399,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): unknown_component_names = set([name for name in component_names if name not in self._component_specs]) if len(unknown_component_names) > 0: logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - + components_to_register = {} for name in components_to_load: spec = self._component_specs[name] @@ -1399,7 +1419,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): components_to_register[name] = spec.create(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") - + # Register all components at once self.register_components(**components_to_register) @@ -1407,11 +1427,12 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): def to(self, *args, **kwargs): pass - # YiYi TODO: + # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - + def save_pretrained( + self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs + ): component_names = list(self._component_specs.keys()) config_names = list(self._config_specs.keys()) self.register_to_config(_components_names=component_names, _configs_names=config_names) @@ -1421,11 +1442,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: config.pop("_configs_names", None) self._internal_dict = FrozenDict(config) - @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): - + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs + ): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) expected_config = set(config_dict.pop("_configs_names")) @@ -1440,7 +1461,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - + for name in expected_component: for spec in component_specs: if spec.name == name: @@ -1450,7 +1471,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) return cls(component_specs + config_specs) - @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ @@ -1533,4 +1553,4 @@ def _dict_to_component_spec( name=name, type_hint=type_hint, **spec_dict, - ) \ No newline at end of file + ) diff --git a/src/diffusers/pipelines/modular_pipeline_utils.py b/src/diffusers/pipelines/modular_pipeline_utils.py index c8064a5215aa..482c209726d3 100644 --- a/src/diffusers/pipelines/modular_pipeline_utils.py +++ b/src/diffusers/pipelines/modular_pipeline_utils.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re import inspect -from dataclasses import dataclass, asdict, field, fields -from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal +import re +from dataclasses import dataclass, field, fields +from typing import Any, Dict, List, Literal, Optional, Type, Union +from ..configuration_utils import ConfigMixin, FrozenDict from ..utils.import_utils import is_torch_available -from ..configuration_utils import FrozenDict, ConfigMixin + if is_torch_available(): import torch @@ -56,50 +57,50 @@ class ComponentSpec: variant: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True}) default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" - - + + def __hash__(self): """Make ComponentSpec hashable, using load_id as the hash value.""" return hash((self.name, self.load_id, self.default_creation_method)) - + def __eq__(self, other): """Compare ComponentSpec objects based on name and load_id.""" if not isinstance(other, ComponentSpec): return False - return (self.name == other.name and - self.load_id == other.load_id and + return (self.name == other.name and + self.load_id == other.load_id and self.default_creation_method == other.default_creation_method) - + @classmethod def from_component(cls, name: str, component: torch.nn.Module) -> Any: """Create a ComponentSpec from a Component created by `create` method.""" - + if not hasattr(component, "_diffusers_load_id"): raise ValueError("Component is not created by `create` method") - + type_hint = component.__class__ - + if component._diffusers_load_id == "null" and isinstance(component, ConfigMixin): config = component.config else: config = None - + load_spec = cls.decode_load_id(component._diffusers_load_id) - + return cls(name=name, type_hint=type_hint, config=config, **load_spec) - + @classmethod def from_load_id(cls, load_id: str, name: Optional[str] = None) -> Any: """Create a ComponentSpec from a load_id string.""" if load_id == "null": raise ValueError("Cannot create ComponentSpec from null load_id") - + # Decode the load_id into a dictionary of loading fields load_fields = cls.decode_load_id(load_id) - + # Create a new ComponentSpec instance with the decoded fields return cls(name=name, **load_fields) - + @classmethod def loading_fields(cls) -> List[str]: """ @@ -107,8 +108,8 @@ def loading_fields(cls) -> List[str]: (i.e. those whose field.metadata["loading"] is True). """ return [f.name for f in fields(cls) if f.metadata.get("loading", False)] - - + + @property def load_id(self) -> str: """ @@ -118,7 +119,7 @@ def load_id(self) -> str: parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] return "|".join(p for p in parts if p) - + @classmethod def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: """ @@ -139,29 +140,29 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating component not loaded from pretrained). """ - + # Get all loading fields in order loading_fields = cls.loading_fields() result = {f: None for f in loading_fields} if load_id == "null": return result - + # Split the load_id parts = load_id.split("|") - + # Map parts to loading fields by position for i, part in enumerate(parts): if i < len(loading_fields): # Convert "null" string back to None result[loading_fields[i]] = None if part == "null" else part - + return result - + # YiYi TODO: add validator def create(self, **kwargs) -> Any: """Create the component using the preferred creation method.""" - + # from_pretrained creation if self.default_creation_method == "from_pretrained": return self.create_from_pretrained(**kwargs) @@ -170,17 +171,17 @@ def create(self, **kwargs) -> Any: return self.create_from_config(**kwargs) else: raise ValueError(f"Invalid creation method: {self.default_creation_method}") - + def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" if self.type_hint is None or not isinstance(self.type_hint, type): raise ValueError( - f"`type_hint` is required when using from_config creation method." + "`type_hint` is required when using from_config creation method." ) - + config = config or self.config or {} - + if issubclass(self.type_hint, ConfigMixin): component = self.type_hint.from_config(config, **kwargs) else: @@ -193,24 +194,24 @@ def create_from_config(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] if k in signature_params: init_kwargs[k] = v component = self.type_hint(**init_kwargs) - + component._diffusers_load_id = "null" if hasattr(component, "config"): self.config = component.config - + return component - + # YiYi TODO: add guard for type of model, if it is supported by from_pretrained def create_from_pretrained(self, **kwargs) -> Any: """Create component using from_pretrained.""" - + passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path repo = load_kwargs.pop("repo", None) if repo is None: - raise ValueError(f"`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") - + raise ValueError("`repo` info is required when using from_pretrained creation method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)") + if self.type_hint is None: try: from diffusers import AutoModel @@ -223,19 +224,19 @@ def create_from_pretrained(self, **kwargs) -> Any: component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) except Exception as e: raise ValueError(f"Error creating {self.name}[{self.type_hint.__name__}] from pretrained: {e}") - + if repo != self.repo: self.repo = repo for k, v in passed_loading_kwargs.items(): if v is not None: setattr(self, k, v) component._diffusers_load_id = self.load_id - + return component - -@dataclass + +@dataclass class ConfigSpec: """Specification for a pipeline configuration parameter.""" name: str @@ -254,7 +255,7 @@ def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" -@dataclass +@dataclass class OutputParam: """Specification for an output parameter.""" name: str @@ -287,14 +288,14 @@ def format_inputs_short(inputs): """ required_inputs = [param for param in inputs if param.required] optional_inputs = [param for param in inputs if not param.required] - + required_str = ", ".join(param.name for param in required_inputs) optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs) - + inputs_str = required_str if optional_str: inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str - + return inputs_str @@ -321,18 +322,18 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu input_parts.append(f"Required({inp.name})") else: input_parts.append(inp.name) - + # Handle modified variables (appear in both inputs and outputs) inputs_set = {inp.name for inp in intermediates_inputs} modified_parts = [] new_output_parts = [] - + for out in intermediates_outputs: if out.name in inputs_set: modified_parts.append(out.name) else: new_output_parts.append(out.name) - + result = [] if input_parts: result.append(f" - inputs: {', '.join(input_parts)}") @@ -340,7 +341,7 @@ def format_intermediates_short(intermediates_inputs, required_intermediates_inpu result.append(f" - modified: {', '.join(modified_parts)}") if new_output_parts: result.append(f" - outputs: {', '.join(new_output_parts)}") - + return "\n".join(result) if result else " (none)" @@ -358,18 +359,18 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115): """ if not params: return "" - + base_indent = " " * indent_level param_indent = " " * (indent_level + 4) desc_indent = " " * (indent_level + 8) formatted_params = [] - + def get_type_str(type_hint): if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union: types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__] return f"Union[{', '.join(types)}]" return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint) - + def wrap_text(text, indent, max_length): """Wrap text while preserving markdown links and maintaining indentation.""" words = text.split() @@ -379,7 +380,7 @@ def wrap_text(text, indent, max_length): for word in words: word_length = len(word) + (1 if current_line else 0) - + if current_line and current_length + word_length > max_length: lines.append(" ".join(current_line)) current_line = [word] @@ -387,20 +388,20 @@ def wrap_text(text, indent, max_length): else: current_line.append(word) current_length += word_length - + if current_line: lines.append(" ".join(current_line)) - + return f"\n{indent}".join(lines) - + # Add the header formatted_params.append(f"{base_indent}{header}:") - + for param in params: # Format parameter name and type type_str = get_type_str(param.type_hint) if param.type_hint != Any else "" param_str = f"{param_indent}{param.name} (`{type_str}`" - + # Add optional tag and default value if parameter is an InputParam and optional if hasattr(param, "required"): if not param.required: @@ -408,7 +409,7 @@ def wrap_text(text, indent, max_length): if param.default is not None: param_str += f", defaults to {param.default}" param_str += "):" - + # Add description on a new line with additional indentation and wrapping if param.description: desc = re.sub( @@ -418,9 +419,9 @@ def wrap_text(text, indent, max_length): ) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" - + formatted_params.append(param_str) - + return "\n\n".join(formatted_params) @@ -466,42 +467,42 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty """ if not components: return "" - + base_indent = " " * indent_level component_indent = " " * (indent_level + 4) formatted_components = [] - + # Add the header formatted_components.append(f"{base_indent}Components:") if add_empty_lines: formatted_components.append("") - + # Add each component with optional empty lines between them for i, component in enumerate(components): # Get type name, handling special cases type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint) - + component_desc = f"{component_indent}{component.name} (`{type_name}`)" if component.description: component_desc += f": {component.description}" - + # Get the loading fields dynamically loading_field_values = [] for field_name in component.loading_fields(): field_value = getattr(component, field_name) if field_value is not None: loading_field_values.append(f"{field_name}={field_value}") - + # Add loading field information if available if loading_field_values: component_desc += f" [{', '.join(loading_field_values)}]" - + formatted_components.append(component_desc) - + # Add an empty line after each component except the last one if add_empty_lines and i < len(components) - 1: formatted_components.append("") - + return "\n".join(formatted_components) @@ -519,27 +520,27 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines """ if not configs: return "" - + base_indent = " " * indent_level config_indent = " " * (indent_level + 4) formatted_configs = [] - + # Add the header formatted_configs.append(f"{base_indent}Configs:") if add_empty_lines: formatted_configs.append("") - + # Add each config with optional empty lines between them for i, config in enumerate(configs): config_desc = f"{config_indent}{config.name} (default: {config.default})" if config.description: config_desc += f": {config.description}" formatted_configs.append(config_desc) - + # Add an empty line after each config except the last one if add_empty_lines and i < len(configs) - 1: formatted_configs.append("") - + return "\n".join(formatted_configs) @@ -584,9 +585,9 @@ def make_doc_string(inputs, intermediates_inputs, outputs, description="", class # Add inputs section output += format_input_params(inputs + intermediates_inputs, indent_level=2) - + # Add outputs section output += "\n\n" output += format_output_params(outputs, indent_level=2) - return output \ No newline at end of file + return output diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8b422798713f..b5600e466725 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -334,6 +334,7 @@ def maybe_raise_or_warn( # a simpler version of get_class_obj_and_candidates, it won't work with custom code def simple_get_class_obj(library_name, class_name): from diffusers import pipelines + is_pipeline_module = hasattr(pipelines, library_name) if is_pipeline_module: @@ -345,6 +346,7 @@ def simple_get_class_obj(library_name, class_name): return class_obj + def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 49575e99763a..e8950cfbcec5 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1120,7 +1120,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will automatically detect the available accelerator and use. """ - + self._maybe_raise_error_if_group_offload_active(raise_error=True) is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1 diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index 006836fe30d4..49563dd7ccd7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -61,6 +61,7 @@ from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline from .pipeline_stable_diffusion_xl_modular import ( + StableDiffusionXLAutoPipeline, StableDiffusionXLControlNetDenoiseStep, StableDiffusionXLDecodeLatentsStep, StableDiffusionXLDenoiseStep, @@ -70,7 +71,6 @@ StableDiffusionXLPrepareLatentsStep, StableDiffusionXLSetTimestepsStep, StableDiffusionXLTextEncoderStep, - StableDiffusionXLAutoPipeline, ) try: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py index 5ae9e63851db..6a8f7f98dd85 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py @@ -13,17 +13,28 @@ # limitations under the License. import inspect -from typing import Any, List, Optional, Tuple, Union, Dict +from collections import OrderedDict +from typing import Any, List, Optional, Tuple, Union +import numpy as np import PIL import torch -from collections import OrderedDict +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin -from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import EulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, logging, @@ -34,33 +45,20 @@ from ..controlnet.multicontrolnet import MultiControlNetModel from ..modular_pipeline import ( AutoPipelineBlocks, + ComponentSpec, + ConfigSpec, + InputParam, ModularLoader, + OutputParam, PipelineBlock, PipelineState, - InputParam, - OutputParam, SequentialPipelineBlocks, - ComponentSpec, - ConfigSpec, ) from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import ( StableDiffusionXLPipelineOutput, ) -from transformers import ( - CLIPTextModel, - CLIPImageProcessor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...schedulers import EulerDiscreteScheduler -from ...guiders import ClassifierFreeGuidance -from ...configuration_utils import FrozenDict - -import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -152,7 +150,7 @@ def description(self) -> str: " See [StableDiffusionXLLoraLoaderMixin](https://huggingface.co/docs/diffusers/api/loaders/lora#diffusers.loaders.StableDiffusionXLLoraLoaderMixin)" " for more details" ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -160,8 +158,8 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), ComponentSpec("unet", UNet2DConditionModel), ] - - + + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise EnvironmentError("StableDiffusionXLLoraStep is desgined to be used to load lora weights, __call__ is not implemented") @@ -170,7 +168,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLIPAdapterStep(PipelineBlock, ModularIPAdapterMixin): model_name = "stable-diffusion-xl" - + @property def description(self) -> str: return ( @@ -178,7 +176,7 @@ def description(self) -> str: " See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" " for more details" ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -186,8 +184,8 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"), ComponentSpec("unet", UNet2DConditionModel), ComponentSpec( - "guider", - ClassifierFreeGuidance, + "guider", + ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ] @@ -196,8 +194,8 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam( - "ip_adapter_image", - PipelineImageInput, + "ip_adapter_image", + PipelineImageInput, required=True, description="The image(s) to be used as ip adapter" ) @@ -210,7 +208,7 @@ def intermediates_outputs(self) -> List[OutputParam]: OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings") ] - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components def encode_image(self, components, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(components.image_encoder.parameters()).dtype @@ -235,7 +233,7 @@ def encode_image(self, components, image, device, num_images_per_prompt, output_ uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds - + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds @@ -317,7 +315,7 @@ def description(self) -> str: return( "Text Encoder step that generate text_embeddings to guide the image generation" ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -326,9 +324,9 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("tokenizer_2", CLIPTokenizer), ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ] @@ -643,7 +641,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def description(self) -> str: return ( @@ -655,9 +653,9 @@ def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ] @@ -673,7 +671,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")] @property @@ -683,13 +681,13 @@ def intermediates_outputs(self) -> List[OutputParam]: # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -715,8 +713,8 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents - + return image_latents + @torch.no_grad() @@ -725,7 +723,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.preprocess_kwargs = data.preprocess_kwargs or {} data.device = pipeline._execution_device data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype - + data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, **data.preprocess_kwargs) data.image = data.image.to(device=data.device, dtype=data.dtype) @@ -748,23 +746,23 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ComponentSpec( - "mask_processor", - VaeImageProcessor, + "mask_processor", + VaeImageProcessor, config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}), default_creation_method="from_config"), ] - + @property def description(self) -> str: @@ -789,21 +787,21 @@ def intermediates_inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), + return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"), OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")] # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components # YiYi TODO: update the _encode_vae_image so that we can use #Coped from def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): - + latents_mean = latents_std = None if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) - + dtype = image.dtype if components.vae.config.force_upcast: image = image.float() @@ -829,7 +827,7 @@ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Ge else: image_latents = components.vae.config.scaling_factor * image_latents - return image_latents + return image_latents # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents # do not accept do_classifier_free_guidance @@ -879,8 +877,8 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - - + + @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: @@ -896,7 +894,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin else: data.crops_coords = None data.resize_mode = "default" - + data.image = pipeline.image_processor.preprocess(data.image, height=data.height, width=data.width, crops_coords=data.crops_coords, resize_mode=data.resize_mode) data.image = data.image.to(dtype=torch.float32) @@ -956,7 +954,7 @@ def intermediates_inputs(self) -> List[str]: InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step."), InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step."), ] - + @property def intermediates_outputs(self) -> List[str]: return [ @@ -969,7 +967,7 @@ def intermediates_outputs(self) -> List[str]: OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="image embeddings for IP-Adapter"), OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="negative image embeddings for IP-Adapter"), ] - + def check_inputs(self, pipeline, data): if data.prompt_embeds is not None and data.negative_prompt_embeds is not None: @@ -989,13 +987,13 @@ def check_inputs(self, pipeline, data): raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) - + if data.ip_adapter_embeds is not None and not isinstance(data.ip_adapter_embeds, list): raise ValueError("`ip_adapter_embeds` must be a list") - + if data.negative_ip_adapter_embeds is not None and not isinstance(data.negative_ip_adapter_embeds, list): raise ValueError("`negative_ip_adapter_embeds` must be a list") - + if data.ip_adapter_embeds is not None and data.negative_ip_adapter_embeds is not None: for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): if ip_adapter_embed.shape != data.negative_ip_adapter_embeds[i].shape: @@ -1017,19 +1015,19 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # duplicate text embeddings for each generation per prompt, using mps friendly method data.prompt_embeds = data.prompt_embeds.repeat(1, data.num_images_per_prompt, 1) data.prompt_embeds = data.prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - + if data.negative_prompt_embeds is not None: _, seq_len, _ = data.negative_prompt_embeds.shape data.negative_prompt_embeds = data.negative_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) data.negative_prompt_embeds = data.negative_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, seq_len, -1) - + data.pooled_prompt_embeds = data.pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) data.pooled_prompt_embeds = data.pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - + if data.negative_pooled_prompt_embeds is not None: data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.repeat(1, data.num_images_per_prompt, 1) data.negative_pooled_prompt_embeds = data.negative_pooled_prompt_embeds.view(data.batch_size * data.num_images_per_prompt, -1) - + if data.ip_adapter_embeds is not None: for i, ip_adapter_embed in enumerate(data.ip_adapter_embeds): data.ip_adapter_embeds[i] = torch.cat([ip_adapter_embed] * data.num_images_per_prompt, dim=0) @@ -1037,7 +1035,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if data.negative_ip_adapter_embeds is not None: for i, negative_ip_adapter_embed in enumerate(data.negative_ip_adapter_embeds): data.negative_ip_adapter_embeds[i] = torch.cat([negative_ip_adapter_embed] * data.num_images_per_prompt, dim=0) - + self.add_block_state(state, data) return pipeline, state @@ -1075,14 +1073,14 @@ def inputs(self) -> List[InputParam]: @property def intermediates_inputs(self) -> List[str]: return [ - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt"), ] @property def intermediates_outputs(self) -> List[str]: return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time"), OutputParam("latent_timestep", type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image generation") ] @@ -1174,7 +1172,7 @@ def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("scheduler", EulerDiscreteScheduler), ] - + @property def description(self) -> str: return ( @@ -1192,7 +1190,7 @@ def inputs(self) -> List[InputParam]: @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")] @@ -1244,7 +1242,7 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), InputParam( - "strength", + "strength", default=0.9999, description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " @@ -1259,46 +1257,46 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), + ), InputParam( - "latent_timestep", - required=True, - type_hint=torch.Tensor, + "latent_timestep", + required=True, + type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." - ), + ), InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, + "image_latents", + required=True, + type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." - ), + ), InputParam( - "mask", - required=True, - type_hint=torch.Tensor, + "mask", + required=True, + type_hint=torch.Tensor, description="The mask for the inpainting generation. Can be generated in vae_encode step." - ), + ), InputParam( - "masked_image_latents", - type_hint=torch.Tensor, + "masked_image_latents", + type_hint=torch.Tensor, description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step." ), InputParam( - "dtype", - type_hint=torch.dtype, + "dtype", + type_hint=torch.dtype, description="The dtype of the model inputs" ) ] @property def intermediates_outputs(self) -> List[str]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), - OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), - OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), + return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for inpainting generation"), + OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting generation (only for inpainting-specific unet)"), OutputParam("noise", type_hint=torch.Tensor, description="The noise added to the image latents, used for inpainting generation")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents with self -> components @@ -1417,15 +1415,15 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - - + + @torch.no_grad() def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype data.device = pipeline._execution_device - + data.is_strength_max = data.strength == 1.0 # for non-inpainting specific unet, we do not need masked_image_latents @@ -1502,9 +1500,9 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), - InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), - InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), + InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step."), + InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), + InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")] @property @@ -1652,14 +1650,14 @@ def inputs(self) -> List[InputParam]: def intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." - ), + ), InputParam( - "dtype", - type_hint=torch.dtype, + "dtype", + type_hint=torch.dtype, description="The dtype of the model inputs" ) ] @@ -1668,8 +1666,8 @@ def intermediates_inputs(self) -> List[InputParam]: def intermediates_outputs(self) -> List[OutputParam]: return [ OutputParam( - "latents", - type_hint=torch.Tensor, + "latents", + type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ) ] @@ -1745,7 +1743,7 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_configs(self) -> List[ConfigSpec]: return [ConfigSpec("requires_aesthetics_score", False),] @@ -1773,15 +1771,15 @@ def inputs(self) -> List[Tuple[str, Any]]: @property def intermediates_inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."), InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step."), InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self -> components @@ -1947,29 +1945,29 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[InputParam]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." - ), + ), InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), ] @property def intermediates_outputs(self) -> List[OutputParam]: - return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), - OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), + return [OutputParam("add_time_ids", type_hint=torch.Tensor, description="The time ids to condition the denoising process"), + OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="The negative time ids to condition the denoising process"), OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM")] # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self -> components @@ -2084,9 +2082,9 @@ class StableDiffusionXLDenoiseStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), @@ -2111,95 +2109,95 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." ), InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], + "negative_pooled_prompt_embeds", + type_hint=Optional[torch.Tensor], description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " ), InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, + "add_time_ids", + required=True, + type_hint=torch.Tensor, description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." ), InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], + "negative_add_time_ids", + type_hint=Optional[torch.Tensor], description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step." ), InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, + "prompt_embeds", + required=True, + type_hint=torch.Tensor, description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." ), InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], + "negative_prompt_embeds", + type_hint=Optional[torch.Tensor], description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step. " ), InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], + "timestep_cond", + type_hint=Optional[torch.Tensor], description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step." ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], + "masked_image_latents", + type_hint=Optional[torch.Tensor], description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "noise", - type_hint=Optional[torch.Tensor], + "noise", + type_hint=Optional[torch.Tensor], description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." ), InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], + "image_latents", + type_hint=Optional[torch.Tensor], description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), ] @@ -2245,7 +2243,7 @@ def prepare_extra_step_kwargs(self, components, generator, eta): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs - + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: @@ -2276,14 +2274,14 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: guider_data = pipeline.guider.prepare_inputs(data) data.scaled_latents = pipeline.scheduler.scale_model_input(data.latents, t) - + # Prepare for inpainting if data.num_channels_unet == 9: data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) - + for batch in guider_data: pipeline.guider.prepare_models(pipeline.unet) - + # Prepare additional conditionings batch.added_cond_kwargs = { "text_embeds": batch.pooled_prompt_embeds, @@ -2291,7 +2289,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } if batch.ip_adapter_embeds is not None: batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - + # Predict the noise residual batch.noise_pred = pipeline.unet( data.scaled_latents, @@ -2306,7 +2304,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Perform guidance data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) - + # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] @@ -2315,7 +2313,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 data.latents = data.latents.to(data.latents_dtype) - + if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: data.init_latents_proper = data.image_latents if i < len(data.timesteps) - 1: @@ -2342,9 +2340,9 @@ class StableDiffusionXLControlNetDenoiseStep(PipelineBlock): def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec("unet", UNet2DConditionModel), @@ -2374,100 +2372,100 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", - required=True, - type_hint=torch.Tensor, + "latents", + required=True, + type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( - "batch_size", - required=True, - type_hint=int, + "batch_size", + required=True, + type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, + "timesteps", + required=True, + type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "num_inference_steps", - required=True, - type_hint=int, + "num_inference_steps", + required=True, + type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "prompt_embeds", - required=True, - type_hint=torch.Tensor, + "prompt_embeds", + required=True, + type_hint=torch.Tensor, description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." ), InputParam( - "negative_prompt_embeds", - type_hint=Optional[torch.Tensor], + "negative_prompt_embeds", + type_hint=Optional[torch.Tensor], description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." ), InputParam( - "add_time_ids", - required=True, - type_hint=torch.Tensor, + "add_time_ids", + required=True, + type_hint=torch.Tensor, description="The time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." ), InputParam( - "negative_add_time_ids", - type_hint=Optional[torch.Tensor], + "negative_add_time_ids", + type_hint=Optional[torch.Tensor], description="The negative time ids used to condition the denoising process. Can be generated in parepare_additional_conditioning step." ), InputParam( - "pooled_prompt_embeds", - required=True, - type_hint=torch.Tensor, + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." ), InputParam( - "negative_pooled_prompt_embeds", - type_hint=Optional[torch.Tensor], + "negative_pooled_prompt_embeds", + type_hint=Optional[torch.Tensor], description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step." ), InputParam( - "timestep_cond", - type_hint=Optional[torch.Tensor], + "timestep_cond", + type_hint=Optional[torch.Tensor], description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step" ), InputParam( - "mask", - type_hint=Optional[torch.Tensor], + "mask", + type_hint=Optional[torch.Tensor], description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "masked_image_latents", - type_hint=Optional[torch.Tensor], + "masked_image_latents", + type_hint=Optional[torch.Tensor], description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "noise", - type_hint=Optional[torch.Tensor], + "noise", + type_hint=Optional[torch.Tensor], description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step." ), InputParam( - "image_latents", - type_hint=Optional[torch.Tensor], + "image_latents", + type_hint=Optional[torch.Tensor], description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step." ), InputParam( - "crops_coords", - type_hint=Optional[Tuple[int]], + "crops_coords", + type_hint=Optional[Tuple[int]], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step." ), InputParam( - "ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), InputParam( - "negative_ip_adapter_embeds", - type_hint=Optional[torch.Tensor], + "negative_ip_adapter_embeds", + type_hint=Optional[torch.Tensor], description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step." ), ] @@ -2509,12 +2507,12 @@ def prepare_control_image( device, dtype, crops_coords=None, - ): + ): if crops_coords is not None: image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: image = components.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - + image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size @@ -2547,12 +2545,12 @@ def prepare_extra_step_kwargs(self, components, generator, eta): @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: - + data = self.get_block_state(state) self.check_inputs(pipeline, data) - + data.num_channels_unet = pipeline.unet.config.in_channels - + # (1) prepare controlnet inputs data.device = pipeline._execution_device data.height, data.width = data.latents.shape[-2:] @@ -2580,14 +2578,14 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.controlnet_conditioning_scale = [data.controlnet_conditioning_scale] * len(controlnet.nets) # (1.3) - # global_pool_conditions + # global_pool_conditions data.global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) # (1.4) - # guess_mode + # guess_mode data.guess_mode = data.guess_mode or data.global_pool_conditions # (1.5) @@ -2669,10 +2667,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if isinstance(data.controlnet_cond_scale, list): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - + for batch in guider_data: pipeline.guider.prepare_models(pipeline.unet) - + # Prepare additional conditionings batch.added_cond_kwargs = { "text_embeds": batch.pooled_prompt_embeds, @@ -2680,7 +2678,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } if batch.ip_adapter_embeds is not None: batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - + # Prepare controlnet additional conditionings batch.controlnet_added_cond_kwargs = { "text_embeds": batch.pooled_prompt_embeds, @@ -2699,14 +2697,14 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - + batch.down_block_res_samples = data.down_block_res_samples batch.mid_block_res_sample = data.mid_block_res_sample - + if pipeline.guider.is_unconditional and data.guess_mode: batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) - + # Prepare for inpainting if data.num_channels_unet == 9: data.scaled_latents = torch.cat([data.scaled_latents, data.mask, data.masked_image_latents], dim=1) @@ -2723,19 +2721,19 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, )[0] pipeline.guider.cleanup_models(pipeline.unet) - + # Perform guidance data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 data.latents = data.latents.to(data.latents_dtype) - + if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None: data.init_latents_proper = data.image_latents if i < len(data.timesteps) - 1: @@ -2748,7 +2746,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - + self.add_block_state(state, data) return pipeline, state @@ -2756,7 +2754,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLControlNetUnionDenoiseStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -2764,14 +2762,14 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("controlnet", ControlNetUnionModel), ComponentSpec("scheduler", EulerDiscreteScheduler), ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config"), ComponentSpec( - "control_image_processor", - VaeImageProcessor, - config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), default_creation_method="from_config"), ] @@ -2797,31 +2795,31 @@ def inputs(self) -> List[Tuple[str, Any]]: def intermediates_inputs(self) -> List[str]: return [ InputParam( - "latents", + "latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step." ), InputParam( - "batch_size", + "batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ), InputParam( - "timesteps", + "timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "num_inference_steps", + "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ), InputParam( - "prompt_embeds", + "prompt_embeds", required=True, type_hint=torch.Tensor, description="The prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." @@ -2832,7 +2830,7 @@ def intermediates_inputs(self) -> List[str]: description="The negative prompt embeddings used to condition the denoising process. Can be generated in text_encoder step. See: https://github.com/huggingface/diffusers/issues/4208" ), InputParam( - "add_time_ids", + "add_time_ids", required=True, type_hint=torch.Tensor, description="The time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step." @@ -2843,7 +2841,7 @@ def intermediates_inputs(self) -> List[str]: description="The negative time ids used to condition the denoising process. Can be generated in prepare_additional_conditioning step. " ), InputParam( - "pooled_prompt_embeds", + "pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="The pooled prompt embeddings used to condition the denoising process. Can be generated in text_encoder step." @@ -2918,7 +2916,7 @@ def check_inputs(self, pipeline, data): " `pipeline.unet` or your `mask_image` or `image` input." ) - + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image # 1. return image without apply any guidance # 2. add crops_coords and resize_mode to preprocess() @@ -2933,7 +2931,7 @@ def prepare_control_image( device, dtype, crops_coords=None, - ): + ): if crops_coords is not None: image = components.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32) else: @@ -2968,8 +2966,8 @@ def prepare_extra_step_kwargs(self, components, generator, eta): accepts_generator = "generator" in set(inspect.signature(components.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator - return extra_step_kwargs - + return extra_step_kwargs + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -2978,7 +2976,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.num_channels_unet = pipeline.unet.config.in_channels # (1) prepare controlnet inputs - data.device = pipeline._execution_device + data.device = pipeline._execution_device data.height, data.width = data.latents.shape[-2:] data.height = data.height * pipeline.vae_scale_factor data.width = data.width * pipeline.vae_scale_factor @@ -2998,7 +2996,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: data.guess_mode = data.guess_mode or data.global_pool_conditions # (1.3) - # control_type + # control_type data.num_control_type = controlnet.config.num_control_type # (1.4) @@ -3033,7 +3031,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: crops_coords=data.crops_coords, ) data.height, data.width = data.control_image[idx].shape[-2:] - + # (1.6) # controlnet_keep data.controlnet_keep = [] @@ -3080,10 +3078,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if isinstance(data.controlnet_cond_scale, list): data.controlnet_cond_scale = data.controlnet_cond_scale[0] data.cond_scale = data.controlnet_cond_scale * data.controlnet_keep[i] - + for batch in guider_data: pipeline.guider.prepare_models(pipeline.unet) - + # Prepare additional conditionings batch.added_cond_kwargs = { "text_embeds": batch.pooled_prompt_embeds, @@ -3091,7 +3089,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: } if batch.ip_adapter_embeds is not None: batch.added_cond_kwargs["image_embeds"] = batch.ip_adapter_embeds - + # Prepare controlnet additional conditionings batch.controlnet_added_cond_kwargs = { "text_embeds": batch.pooled_prompt_embeds, @@ -3112,10 +3110,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: added_cond_kwargs=batch.controlnet_added_cond_kwargs, return_dict=False, ) - + batch.down_block_res_samples = data.down_block_res_samples batch.mid_block_res_sample = data.mid_block_res_sample - + if pipeline.guider.is_unconditional and data.guess_mode: batch.down_block_res_samples = [torch.zeros_like(d) for d in data.down_block_res_samples] batch.mid_block_res_sample = torch.zeros_like(data.mid_block_res_sample) @@ -3135,14 +3133,14 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: return_dict=False, )[0] pipeline.guider.cleanup_models(pipeline.unet) - + # Perform guidance data.noise_pred, scheduler_step_kwargs = pipeline.guider(guider_data) # Perform scheduler step using the predicted output data.latents_dtype = data.latents.dtype data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, **scheduler_step_kwargs, return_dict=False)[0] - + if data.latents.dtype != data.latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 @@ -3159,7 +3157,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): progress_bar.update() - + self.add_block_state(state, data) return pipeline, state @@ -3168,15 +3166,15 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: class StableDiffusionXLDecodeLatentsStep(PipelineBlock): model_name = "stable-diffusion-xl" - + @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKL), ComponentSpec( - "image_processor", - VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 8}), + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), ] @@ -3282,10 +3280,10 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("image", required=True), - InputParam("mask_image", required=True), + InputParam("mask_image", required=True), InputParam("padding_mask_crop"), ] - + @property def intermediates_inputs(self) -> List[str]: return [ @@ -3318,17 +3316,17 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("return_dict", default=True)] + return [InputParam("return_dict", default=True)] @property def intermediates_inputs(self) -> List[str]: return [InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step.")] - + @property def intermediates_outputs(self) -> List[str]: return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")] - - + + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: data = self.get_block_state(state) @@ -3342,7 +3340,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: # Encode -class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] block_names = ["inpaint", "img2img"] block_trigger_inputs = ["mask_image", "image"] @@ -3494,7 +3492,7 @@ def description(self): # always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the # configuration of guider is. -# block mapping +# block mapping TEXT2IMAGE_BLOCKS = OrderedDict([ ("text_encoder", StableDiffusionXLTextEncoderStep), ("ip_adapter", StableDiffusionXLAutoIPAdapterStep), diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 5d0752af8983..ad30bfa2f32c 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """Utilities to dynamically load objects from the Hub.""" +import hashlib import importlib import inspect import json @@ -21,8 +22,9 @@ import re import shutil import sys +import threading from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, ModuleType, Optional, Union from urllib import request from huggingface_hub import hf_hub_download, model_info @@ -37,6 +39,7 @@ # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" +_HF_REMOTE_CODE_LOCK = threading.Lock() def get_diffusers_versions(): @@ -154,15 +157,132 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code): + if trust_remote_code is None: + if has_local_code: + trust_remote_code = False + elif has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not has_local_code and not trust_remote_code: + raise ValueError( + f"Loading {model_name} requires you to execute the configuration file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + + return trust_remote_code + + +def get_class_in_modular_module( + class_name: str, + module_path: Union[str, os.PathLike], + *, + force_reload: bool = False, +) -> type: + """ + Import a module on the cache directory for modules and extract a class from it. + + Args: + class_name (`str`): The name of the class to import. + module_path (`str` or `os.PathLike`): The path to the module to import. + force_reload (`bool`, *optional*, defaults to `False`): + Whether to reload the dynamic module from file if it already exists in `sys.modules`. + Otherwise, the module is only reloaded if the file has changed. + + Returns: + `typing.Type`: The class looked for. + """ + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + # Hash the module file and all its relative imports to check if we need to reload it + module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file))) + module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest() + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + # reload in both cases, unless the module is already imported and the hash hits + if getattr(module, "__transformers_module_hash__", "") != module_hash: + module_spec.loader.exec_module(module) + module.__transformers_module_hash__ = module_hash + + return getattr(module, class_name) + + +def get_class_in_module(class_name, module_path, force_reload=False): """ Import a module on the cache directory for modules and extract a class from it. """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + + module_spec.loader.exec_module(module) if class_name is None: return find_pipeline_class(module) + return getattr(module, class_name) @@ -203,6 +323,7 @@ def get_cached_module_file( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, + is_modular: bool = False, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached @@ -257,7 +378,7 @@ def get_cached_module_file( if os.path.isfile(module_file_or_url): resolved_module_file = module_file_or_url submodule = "local" - elif pretrained_model_name_or_path.count("/") == 0: + elif pretrained_model_name_or_path.count("/") == 0 and not is_modular: available_versions = get_diffusers_versions() # cut ".dev0" latest_version = "v" + ".".join(__version__.split(".")[:3]) @@ -297,6 +418,24 @@ def get_cached_module_file( except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise + + elif is_modular: + try: + # Load from URL or cache if already cached + resolved_module_file = hf_hub_download( + pretrained_model_name_or_path, + module_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + ) + submodule = pretrained_model_name_or_path.replace("/", os.path.sep) + except EnvironmentError: + logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") + raise + else: try: # Load from URL or cache if already cached @@ -381,6 +520,7 @@ def get_class_from_dynamic_module( token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, + is_modular: bool = False, **kwargs, ): """ @@ -453,5 +593,7 @@ def get_class_from_dynamic_module( token=token, revision=revision, local_files_only=local_files_only, + is_modular=is_modular, ) - return get_class_in_module(class_name, final_module.replace(".py", "")) + __import__("ipdb").set_trace() + return get_class_in_module(class_name, final_module)