diff options
-rw-r--r-- | modules/devices.py | 44 | ||||
-rw-r--r-- | modules/rng_philox.py | 10 | ||||
-rw-r--r-- | modules/sd_samplers_common.py | 12 |
3 files changed, 50 insertions, 16 deletions
diff --git a/modules/devices.py b/modules/devices.py index b58776d8..00a00b18 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,14 +71,17 @@ def enable_tf32(): torch.backends.cudnn.allow_tf32 = True - errors.run(enable_tf32, "Enabling TF32") -cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None -dtype = torch.float16 -dtype_vae = torch.float16 -dtype_unet = torch.float16 +cpu: torch.device = torch.device("cpu") +device: torch.device = None +device_interrogate: torch.device = None +device_gfpgan: torch.device = None +device_esrgan: torch.device = None +device_codeformer: torch.device = None +dtype: torch.dtype = torch.float16 +dtype_vae: torch.dtype = torch.float16 +dtype_unet: torch.dtype = torch.float16 unet_needs_upcast = False @@ -94,6 +97,10 @@ nv_rng = None def randn(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" + from modules.shared import opts manual_seed(seed) @@ -107,7 +114,27 @@ def randn(seed, shape): return torch.randn(shape, device=device) +def randn_local(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + + from modules.shared import opts + + if opts.randn_source == "NV": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=device) + + local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(device) + + def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + from modules.shared import opts if opts.randn_source == "NV": @@ -120,6 +147,10 @@ def randn_like(x): def randn_without_seed(shape): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + from modules.shared import opts if opts.randn_source == "NV": @@ -132,6 +163,7 @@ def randn_without_seed(shape): def manual_seed(seed): + """Set up a global random number generator using the specified seed.""" from modules.shared import opts if opts.randn_source == "NV": diff --git a/modules/rng_philox.py b/modules/rng_philox.py index b5c02483..5532cf9d 100644 --- a/modules/rng_philox.py +++ b/modules/rng_philox.py @@ -26,7 +26,7 @@ two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32) def uint32(x):
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
- return np.moveaxis(x.view(np.uint32).reshape(-1, 2), 0, 1)
+ return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
def philox4_round(counter, key):
@@ -65,8 +65,8 @@ def philox4_32(counter, key, rounds=10): def box_muller(x, y):
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
- u = x.astype(np.float32) * two_pow32_inv + two_pow32_inv / 2
- v = y.astype(np.float32) * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
+ u = x * two_pow32_inv + two_pow32_inv / 2
+ v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
s = np.sqrt(-2.0 * np.log(u))
@@ -93,7 +93,9 @@ class Generator: counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
self.offset += 1
- key = uint32(np.array([[self.seed] * n], dtype=np.uint64))
+ key = np.empty(n, dtype=np.uint64)
+ key.fill(self.seed)
+ key = uint32(key)
g = philox4_32(counter, key)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 763829f1..5deda761 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,10 +2,8 @@ from collections import namedtuple import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
-
+from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
-import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -85,11 +83,13 @@ class InterruptedException(BaseException): pass
-if opts.randn_source == "CPU":
+def replace_torchsde_browinan():
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
- generator = torch.Generator(devices.cpu).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ return devices.randn_local(seed, size).to(device=device, dtype=dtype)
torchsde._brownian.brownian_interval._randn = torchsde_randn
+
+
+replace_torchsde_browinan()
|