diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-11 22:03:05 +0300 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-11 22:03:05 +0300 |
commit | d6fcc6b87bc00fcdecea276fe5b7c7945f7a8b14 (patch) | |
tree | e9968d4fffc6d755e7f437f5042c09ae5f7b3895 /modules/hypernetworks/hypernetwork.py | |
parent | 12f4f4761bc6114e6cce6972409488e26c4c8ab4 (diff) |
apply lr schedule to hypernets
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5608e799..470659df 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,6 +14,7 @@ import torch from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class HypernetworkModule(torch.nn.Module):
@@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, for weight in weights:
weight.requires_grad = True
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
@@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps:
return hypernetwork, filename
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > steps:
- break
+ if hypernetwork.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except Exception:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
|