Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Add ‘config_name' as a supplement to the 'model_setting' #2027

Merged
merged 4 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions demo/mmagic_inference_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def parse_args():
type=int,
default=None,
help='Pretrained mmagic algorithm setting')
parser.add_argument(
'--config-name',
type=str,
default=None,
help='Pretrained mmagic algorithm config name')
parser.add_argument(
'--model-config',
type=str,
Expand Down
6 changes: 5 additions & 1 deletion demo/mmagic_inference_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,13 @@
"\n",
"There are some different configs and checkpoints for one model.\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0.\n",
"\n",
"Take conditional GAN model 'biggan' as an example. We have pretrained model for Cifar and Imagenet, and all pretrained models of 'biggan' are listed in its [metafile.yaml](../configs/biggan/metafile.yml)\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0."
"There are six settings in this metafile. If you choose setting 1, then the config 'configs/biggan/biggan_ajbrock-sn_8xb32-1500kiters_imagenet1k-128x128.py' will be used. If 'model_setting' is not passed to 'MMagicInferencer', the config ‘configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py’ will be used by default.\n",
"\n",
"And you could also use 'config_name' to replace 'model_setting'. For example, you can init a MMagicInferencer with 'MMagicInferencer('biggan', config_name='biggan_2xb25-500kiters_cifar10-32x32')', which is the same with 'MMagicInferencer('biggan', model_setting=0)'."
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer):
postprocess=[])

def preprocess(self,
text: InputsType,
text: InputsType = None,
negative_prompt: InputsType = None,
num_inference_steps: int = 20,
height=None,
Expand All @@ -37,7 +37,8 @@ def preprocess(self,
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['prompt'] = text
if text:
result['prompt'] = text
if negative_prompt:
result['negative_prompt'] = negative_prompt
if num_inference_steps:
Expand Down
9 changes: 8 additions & 1 deletion mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class MMagicInferencer:
def __init__(self,
model_name: str = None,
model_setting: int = None,
config_name: int = None,
model_config: str = None,
model_ckpt: str = None,
device: torch.device = None,
Expand All @@ -140,14 +141,15 @@ def __init__(self,
MMagicInferencer.init_inference_supported_models_cfg()
inferencer_kwargs = {}
inferencer_kwargs.update(
self._get_inferencer_kwargs(model_name, model_setting,
self._get_inferencer_kwargs(model_name, model_setting, config_name,
model_config, model_ckpt,
extra_parameters))
self.inferencer = Inferencers(
device=device, seed=seed, **inferencer_kwargs)

def _get_inferencer_kwargs(self, model_name: Optional[str],
model_setting: Optional[int],
config_name: Optional[int],
model_config: Optional[str],
model_ckpt: Optional[str],
extra_parameters: Optional[Dict]) -> Dict:
Expand All @@ -161,6 +163,11 @@ def _get_inferencer_kwargs(self, model_name: Optional[str],
if model_setting:
setting_to_use = model_setting
config_dir = cfgs['settings'][setting_to_use]['Config']
if config_name:
for setting in cfgs['settings']:
if setting['Name'] == config_name:
config_dir = setting['Config']
break
config_dir = config_dir[config_dir.find('configs'):]
if osp.exists(
osp.join(osp.dirname(__file__), '..', '..', config_dir)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import platform

import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

Expand All @@ -21,24 +20,11 @@
def test_diffusers_pipeline_inferencer():
cfg = dict(
model=dict(
type='DiffusionPipeline',
from_pretrained='runwayml/stable-diffusion-v1-5'))
type='DiffusionPipeline', from_pretrained='google/ddpm-cat-256'))

inferencer_instance = DiffusersPipelineInferencer(cfg, None)

def mock_infer(*args, **kwargs):
return dict(samples=torch.randn(1, 3, 64, 64))

inferencer_instance.model.infer = mock_infer

text_prompts = 'Japanese anime style, girl'
negative_prompt = 'bad face, bad hands'
result = inferencer_instance(
text=text_prompts,
negative_prompt=negative_prompt,
height=64,
width=64)
assert result[1][0].size == (64, 64)
result = inferencer_instance()
assert result[1][0].size == (256, 256)


def teardown_module():
Expand Down