From 8d026f6a9eba0bc3905b945543f0e88a19c5f5cc Mon Sep 17 00:00:00 2001 From: Leonard Kugis Date: Fri, 31 Mar 2023 02:50:27 +0200 Subject: Predictor: Moved out of main script --- file-tagger.py | 54 +++++------------------------------------------------- 1 file changed, 5 insertions(+), 49 deletions(-) (limited to 'file-tagger.py') diff --git a/file-tagger.py b/file-tagger.py index ca4b8f5..70909c8 100644 --- a/file-tagger.py +++ b/file-tagger.py @@ -7,30 +7,8 @@ import logging import magic from tmsu import * from util import * - -MODEL_DIMENSIONS = 224 - -def predict_image(model, img, top): - from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions - logger = logging.getLogger(__name__) - #cv2.imshow("test", img) - #cv2.waitKey(0) - #cv2.destroyAllWindows() - array = np.expand_dims(img, axis=0) - array = preprocess_input(array) - predictions = model.predict(array) - classes = decode_predictions(predictions, top=top) - logger.debug("Predicted image classes: {}".format(classes[0])) - return set([(name, prob) for _, name, prob in classes[0]]) - -def predict_partial(tags, model, img, x, y, rot, top): - #cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)]) - #cv2.waitKey(0) - if rot is None: - tmp = img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)] - else: - tmp = cv2.rotate(img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], rot) - tags.update(predict_image(model, tmp, top)) +from predictor import * +from PIL import Image ''' Walk over all files for the given base directory and all subdirectories recursively. @@ -52,10 +30,8 @@ def walk(args): return if args["predict_images"]: - from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions - from tensorflow.keras.preprocessing import image - from tensorflow.keras.models import Model - model = ResNet50(weights="imagenet") + #predictor = Predictor(Predictor.BackendTorch(top=args["predict_images_top"])) + predictor = Predictor(Predictor.BackendTensorflow(top=args["predict_images_top"], detail=(not args["predict_images_skip_detail"]), detail_factor=args["predict_images_detail_factor"])) for i in range(args["index"], len(files)): file_path = files[i] @@ -77,27 +53,7 @@ def walk(args): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if args["predict_images"]: logger.info("Predicting image tags ...") - tags_predict = set() - for _ in range(4): - logger.debug("Raw scan") - raw = cv2.resize(img.copy(), dsize=(MODEL_DIMENSIONS, MODEL_DIMENSIONS), interpolation=cv2.INTER_CUBIC) - tags_predict.update(predict_image(model, raw, args["predict_images_top"])) - img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) - if not args["predict_images_skip_detail"]: - pool = ThreadPool(max(1, os.cpu_count() - 2), 10000) - if img.shape[0] > img.shape[1]: - detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) - else: - detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) - for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)): - for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)): - pool.add_task(predict_partial, tags_predict, model, detail, x, y, None, args["predict_images_top"]) - pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_CLOCKWISE, args["predict_images_top"]) - pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_180, args["predict_images_top"]) - pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_COUNTERCLOCKWISE, args["predict_images_top"]) - pool.wait_completion() - tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)] - tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]]) + tags_predict = predictor.predict(img) logger.info("Predicted tags: {}".format(tags_predict)) tags.update(tags_predict) if args["gui_tag"]: -- cgit v1.2.1