diff options
-rw-r--r-- | modules/modelloader.py | 3 | ||||
-rw-r--r-- | modules/sd_disable_initialization.py | 17 | ||||
-rw-r--r-- | modules/sd_models.py | 6 |
3 files changed, 18 insertions, 8 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index e9aa514e..fc3f6249 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -45,6 +45,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None full_path = file if os.path.isdir(full_path): continue + if os.path.islink(full_path) and not os.path.exists(full_path): + print(f"Skipping broken symlink: {full_path}") + continue if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): continue if len(ext_filter) != 0: diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index e90aa9fe..c4a09d15 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,8 +20,9 @@ class DisableInitialization: ```
"""
- def __init__(self):
+ def __init__(self, disable_clip=True):
self.replaced = []
+ self.disable_clip = disable_clip
def replace(self, obj, field, func):
original = getattr(obj, field, None)
@@ -75,12 +76,14 @@ class DisableInitialization: self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
- self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
- self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
- self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
- self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
- self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
- self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
+
+ if self.disable_clip:
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced:
diff --git a/modules/sd_models.py b/modules/sd_models.py index af1731e5..d847d358 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -354,6 +354,9 @@ def repair_config(sd_config): sd_config.model.params.unet_config.params.use_fp16 = True
+sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
+sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -374,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
timer.record("find config")
@@ -386,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ sd_model = None
try:
- with sd_disable_initialization.DisableInitialization():
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
pass
|