Cleaner support for loading different diffusion model types. · NeoWorldTeam/ComfyUI@5e1fced

2 files changed

lines changed

Original file line numberDiff line numberDiff line change

@@ -105,6 +105,9 @@ def detect_unet_config(state_dict, key_prefix):

105105

unet_config["audio_model"] = "dit1.0"

106106

return unet_config

107107
108+

if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:

109+

return None

110+
108111

unet_config = {

109112

"use_checkpoint": False,

110113

"image_size": 32,

@@ -239,6 +242,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):

239242
240243

def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):

241244

unet_config = detect_unet_config(state_dict, unet_key_prefix)

245+

if unet_config is None:

246+

return None

242247

model_config = model_config_from_unet_config(unet_config, state_dict)

243248

if model_config is None and use_base_if_no_match:

244249

return comfy.supported_models_base.BASE(unet_config)

Original file line numberDiff line numberDiff line change

@@ -546,21 +546,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o

546546

def load_unet_state_dict(sd): #load unet in diffusers or regular format

547547
548548

#Allow loading unets from checkpoint files

549-

checkpoint = False

550549

diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)

551550

temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)

552551

if len(temp_sd) > 0:

553552

sd = temp_sd

554-

checkpoint = True

555553
556554

parameters = comfy.utils.calculate_parameters(sd)

557555

unet_dtype = model_management.unet_dtype(model_params=parameters)

558556

load_device = model_management.get_torch_device()

557+

model_config = model_detection.model_config_from_unet(sd, "")

559558
560-

if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade

561-

model_config = model_detection.model_config_from_unet(sd, "")

562-

if model_config is None:

563-

return None

559+

if model_config is not None:

564560

new_sd = sd

565561

elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3

566562

new_sd = model_detection.convert_diffusers_mmdit(sd, "")