@@ -546,21 +546,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
546 | 546 | def load_unet_state_dict(sd): #load unet in diffusers or regular format |
547 | 547 | |
548 | 548 | #Allow loading unets from checkpoint files |
549 | | -checkpoint = False |
550 | 549 | diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) |
551 | 550 | temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) |
552 | 551 | if len(temp_sd) > 0: |
553 | 552 | sd = temp_sd |
554 | | -checkpoint = True |
555 | 553 | |
556 | 554 | parameters = comfy.utils.calculate_parameters(sd) |
557 | 555 | unet_dtype = model_management.unet_dtype(model_params=parameters) |
558 | 556 | load_device = model_management.get_torch_device() |
| 557 | +model_config = model_detection.model_config_from_unet(sd, "") |
559 | 558 | |
560 | | -if checkpoint or "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade |
561 | | -model_config = model_detection.model_config_from_unet(sd, "") |
562 | | -if model_config is None: |
563 | | -return None |
| 559 | +if model_config is not None: |
564 | 560 | new_sd = sd |
565 | 561 | elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 |
566 | 562 | new_sd = model_detection.convert_diffusers_mmdit(sd, "") |
|