From 16451ca573220e49f2eaaab97580b6b91287c8c4 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Fri, 28 Oct 2022 17:16:23 +0700 Subject: Learning rate sched syntax support for grad clipping --- modules/textual_inversion/textual_inversion.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7bad73a6..6b00c6a1 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -255,9 +255,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ititial_step = embedding.step or 0 if ititial_step > steps: return embedding, filename - + clip_grad_mode_value = clip_grad_mode == "value" clip_grad_mode_norm = clip_grad_mode == "norm" + clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm + if clip_grad_enabled: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) @@ -273,6 +276,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if shared.state.interrupted: break + if clip_grad_enabled: + clip_grad_sched.step(embedding.step) + with torch.autocast("cuda"): c = cond_model([entry.cond_text for entry in entries]) x = torch.stack([entry.latent for entry in entries]).to(devices.device) @@ -285,9 +291,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc loss.backward() if clip_grad_mode_value: - torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value) + torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate) elif clip_grad_mode_norm: - torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value) + torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate) optimizer.step() -- cgit v1.2.1