diff options
author | arcticfaded <jbelt021@fiu.edu> | 2022-10-19 05:19:01 +0000 |
---|---|---|
committer | arcticfaded <jbelt021@fiu.edu> | 2022-10-19 05:19:01 +0000 |
commit | 0f0d6ab8e06898ce066251fc769fe14e77e98ced (patch) | |
tree | a8587d440fce92fc427128baee8aa645f63f687b /modules/api/api.py | |
parent | e7f4808505f7a6339927c32b9a0c01bc9134bdeb (diff) |
call sampler by name
Diffstat (limited to 'modules/api/api.py')
-rw-r--r-- | modules/api/api.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index ff9df0d1..5b0c934e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,6 +1,7 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images -from modules.sd_samplers import samplers_k_diffusion +from modules.sd_samplers import all_samplers +from modules.extras import run_pnginfo import modules.shared as shared import uvicorn from fastapi import Body, APIRouter, HTTPException @@ -10,7 +11,7 @@ import json import io import base64 -sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None) +sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") @@ -53,13 +54,13 @@ class Api: - def img2imgendoint(self): + def img2imgapi(self): raise NotImplementedError - def extrasendoint(self): + def extrasapi(self): raise NotImplementedError - def pnginfoendoint(self): + def pnginfoapi(self): raise NotImplementedError def launch(self, server_name, port): |