From d6fcc6b87bc00fcdecea276fe5b7c7945f7a8b14 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 22:03:05 +0300 Subject: apply lr schedule to hypernets --- modules/textual_inversion/textual_inversion.py | 44 +++----------------------- 1 file changed, 4 insertions(+), 40 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 47a27faf..7717837d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -10,6 +10,7 @@ import datetime from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +from modules.textual_inversion.learn_schedule import LearnSchedule class Embedding: @@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) - epoch_len = (tr_img_len * num_repeats) + tr_img_len - - scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(scheduleIter) + schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) + (learn_rate, end_step) = next(schedules) print(f'Training at rate of {learn_rate} until step {end_step}') optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) @@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > end_step: try: - (learn_rate, end_step) = next(scheduleIter) + (learn_rate, end_step) = next(schedules) except: break tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') @@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}
embedding.save(filename) return embedding, filename - -class LearnSchedule: - def __init__(self, learn_rate, max_steps, cur_step=0): - pairs = learn_rate.split(',') - self.rates = [] - self.it = 0 - self.maxit = 0 - for i, pair in enumerate(pairs): - tmp = pair.split(':') - if len(tmp) == 2: - step = int(tmp[1]) - if step > cur_step: - self.rates.append((float(tmp[0]), min(step, max_steps))) - self.maxit += 1 - if step > max_steps: - return - elif step == -1: - self.rates.append((float(tmp[0]), max_steps)) - self.maxit += 1 - return - else: - self.rates.append((float(tmp[0]), max_steps)) - self.maxit += 1 - return - - def __iter__(self): - return self - - def __next__(self): - if self.it < self.maxit: - self.it += 1 - return self.rates[self.it - 1] - else: - raise StopIteration -- cgit v1.2.1