From 96d6ca4199e7c5eee8d451618de5161cea317c40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:25:25 +0300 Subject: manual fixes for ruff --- modules/codeformer/codeformer_arch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/codeformer/codeformer_arch.py') diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 11dcc3ee..f1a7cf09 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -7,7 +7,7 @@ from torch import nn, Tensor import torch.nn.functional as F from typing import Optional, List -from modules.codeformer.vqgan_arch import * +from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY -- cgit v1.2.1 From f741a98baccae100fcfb40c017b5c35c5cba1b0c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:43:42 +0300 Subject: imports cleanup for ruff --- modules/codeformer/codeformer_arch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'modules/codeformer/codeformer_arch.py') diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index f1a7cf09..00c407de 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -1,14 +1,12 @@ # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py import math -import numpy as np import torch from torch import nn, Tensor import torch.nn.functional as F -from typing import Optional, List +from typing import Optional from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock -from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY def calc_mean_std(feat, eps=1e-5): -- cgit v1.2.1 From 550256db1ce18778a9d56ff343d844c61b9f9b83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:19:16 +0300 Subject: ruff manual fixes --- modules/codeformer/codeformer_arch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/codeformer/codeformer_arch.py') diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 00c407de..ff1c0b4b 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -161,10 +161,13 @@ class Fuse_sft_block(nn.Module): class CodeFormer(VQAutoEncoder): def __init__(self, dim_embd=512, n_head=8, n_layers=9, codebook_size=1024, latent_size=256, - connect_list=['32', '64', '128', '256'], - fix_modules=['quantize','generator']): + connect_list=None, + fix_modules=None): super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + connect_list = connect_list or ['32', '64', '128', '256'] + fix_modules = fix_modules or ['quantize', 'generator'] + if fix_modules is not None: for module in fix_modules: for param in getattr(self, module).parameters(): -- cgit v1.2.1 From 3ec7b705c78b7aca9569c92a419837352c7a4ec6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 21:21:32 +0300 Subject: suggestions and fixes from the PR --- modules/codeformer/codeformer_arch.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'modules/codeformer/codeformer_arch.py') diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index ff1c0b4b..45c70f84 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -161,13 +161,10 @@ class Fuse_sft_block(nn.Module): class CodeFormer(VQAutoEncoder): def __init__(self, dim_embd=512, n_head=8, n_layers=9, codebook_size=1024, latent_size=256, - connect_list=None, - fix_modules=None): + connect_list=('32', '64', '128', '256'), + fix_modules=('quantize', 'generator')): super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) - connect_list = connect_list or ['32', '64', '128', '256'] - fix_modules = fix_modules or ['quantize', 'generator'] - if fix_modules is not None: for module in fix_modules: for param in getattr(self, module).parameters(): -- cgit v1.2.1