aboutsummaryrefslogtreecommitdiff
path: root/ldm/data/base.py
diff options
context:
space:
mode:
authorzhaohu xing <920232796@qq.com>2022-11-29 10:28:41 +0800
committerzhaohu xing <920232796@qq.com>2022-11-29 10:28:41 +0800
commit75c4511e6b81ae8fb0dbd932043e8eb35cd09f72 (patch)
tree6f4662507be1d532a4e992f54f82d905fc450f3a /ldm/data/base.py
parent828438b4a190759807f9054932cae3a8b880ddf1 (diff)
add AltDiffusion to webui
Signed-off-by: zhaohu xing <920232796@qq.com>
Diffstat (limited to 'ldm/data/base.py')
-rw-r--r--ldm/data/base.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/ldm/data/base.py b/ldm/data/base.py
new file mode 100644
index 00000000..b196c2f7
--- /dev/null
+++ b/ldm/data/base.py
@@ -0,0 +1,23 @@
+from abc import abstractmethod
+from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
+
+
+class Txt2ImgIterableBaseDataset(IterableDataset):
+ '''
+ Define an interface to make the IterableDatasets for text2img data chainable
+ '''
+ def __init__(self, num_records=0, valid_ids=None, size=256):
+ super().__init__()
+ self.num_records = num_records
+ self.valid_ids = valid_ids
+ self.sample_ids = valid_ids
+ self.size = size
+
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
+
+ def __len__(self):
+ return self.num_records
+
+ @abstractmethod
+ def __iter__(self):
+ pass \ No newline at end of file