From 03ee297aa22296ea12b965fc1cb11aa46375d372 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 27 Nov 2023 17:26:16 +0900 Subject: fix Auto focal point crop for opencv >= 4.8.x autocrop.download_and_cache_models in opencv >= 4.8 the face detection model was updated download the base on opencv version returns the model path or raise exception --- modules/textual_inversion/autocrop.py | 29 ++++++++++++++++------------- modules/textual_inversion/preprocess.py | 4 ++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 1675e39a..051be118 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -3,6 +3,8 @@ import requests import os import numpy as np from PIL import ImageDraw +from modules import paths_internal +from pkg_resources import parse_version GREEN = "#0F0" BLUE = "#00F" @@ -294,22 +296,23 @@ def is_square(w, h): return w == h -def download_and_cache_models(dirname): - download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - model_file_name = 'face_detection_yunet.onnx' +model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv') +if parse_version(cv2.__version__) >= parse_version('4.8'): + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true' +else: + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - os.makedirs(dirname, exist_ok=True) - cache_file = os.path.join(dirname, model_file_name) - if not os.path.exists(cache_file): - print(f"downloading face detection model from '{download_url}' to '{cache_file}'") - response = requests.get(download_url) - with open(cache_file, "wb") as f: +def download_and_cache_models(): + if not os.path.exists(model_file_path): + os.makedirs(model_dir_opencv, exist_ok=True) + print(f"downloading face detection model from '{model_url}' to '{model_file_path}'") + response = requests.get(model_url) + with open(model_file_path, "wb") as f: f.write(response.content) - - if os.path.exists(cache_file): - return cache_file - return None + return model_file_path class PointOfInterest: diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index dbd856bd..789fa083 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -3,7 +3,7 @@ from PIL import Image, ImageOps import math import tqdm -from modules import paths, shared, images, deepbooru +from modules import shared, images, deepbooru from modules.textual_inversion import autocrop @@ -196,7 +196,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = None try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) + dnn_model_path = autocrop.download_and_cache_models() except Exception as e: print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) -- cgit v1.2.1