diff options
Diffstat (limited to 'file-tagger.py')
-rw-r--r-- | file-tagger.py | 60 |
1 files changed, 48 insertions, 12 deletions
diff --git a/file-tagger.py b/file-tagger.py index 82aacf7..92eb023 100644 --- a/file-tagger.py +++ b/file-tagger.py @@ -8,6 +8,23 @@ import magic from tmsu import * from util import * +MODEL_DIMENSIONS = 224 + +def predict_image(model, img, top): + from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions + logger = logging.getLogger(__name__) + array = np.expand_dims(img, axis=0) + array = preprocess_input(array) + predictions = model.predict(array) + classes = decode_predictions(predictions, top=top) + logger.debug("Predicted image classes: {}".format(classes[0])) + return set([(name, prob) for _, name, prob in classes[0]]) + +def predict_partial(tags, model, img, x, y, top): + #cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)]) + #cv2.waitKey(0) + tags.update(predict_image(model, img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], top)) + ''' Walk over all files for the given base directory and all subdirectories recursively. @@ -50,24 +67,37 @@ def walk(args): if mime_type.split("/")[0] == "image": logger.debug("File is image") img = cv2.imread(file_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = cv2.resize(img, dsize=(800, 800), interpolation=cv2.INTER_CUBIC) + #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if args["predict_images"]: logger.info("Predicting image tags ...") - array_pre = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_CUBIC) + tags_predict = set() for _ in range(4): - array = np.expand_dims(array_pre, axis=0) - array = preprocess_input(array) - predictions = model.predict(array) - classes = decode_predictions(predictions, top=args["predict_images_top"]) - logger.debug("Predicted image classes: {}".format(classes[0])) - tags.update([name for _, name, _ in classes[0]]) - array_pre = cv2.rotate(array_pre, cv2.ROTATE_90_CLOCKWISE) - logger.info("Predicted tags: {}".format(tags)) + logger.debug("Raw scan") + raw = cv2.resize(img.copy(), dsize=(MODEL_DIMENSIONS, MODEL_DIMENSIONS), interpolation=cv2.INTER_CUBIC) + tags_predict.update(predict_image(model, raw, args["predict_images_top"])) + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + if not args["predict_images_skip_detail"]: + pool = ThreadPool(max(1, os.cpu_count() - 2), 1000) + for _ in range(4): + if img.shape[0] > img.shape[1]: + detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) + else: + detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) + for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)): + for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)): + pool.add_task(predict_partial, tags_predict, model, detail.copy(), x, y, args["predict_images_top"]) + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + pool.wait_completion() + tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)] + tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]]) + logger.info("Predicted tags: {}".format(tags_predict)) + tags.update(tags_predict) if args["gui_tag"]: while(True): # For GUI inputs (rotate, ...) logger.debug("Showing image GUI ...") - ret = GuiImage(i, file_path, img, tags).loop() + img_show = image_resize(img, width=args["gui_image_length"]) if img.shape[1] > img.shape[0] else image_resize(img, height=args["gui_image_length"]) + img_show = cv2.cvtColor(img_show, cv2.COLOR_BGR2RGB) + ret = GuiImage(i, file_path, img_show, tags).loop() tags = set(ret[1]).difference({''}) if ret[0] == GuiImage.RETURN_ROTATE_90_CLOCKWISE: img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) @@ -102,7 +132,10 @@ if __name__ == "__main__": parser.add_argument('-g', '--gui', nargs='?', const=1, default=False, type=bool, help='Show main GUI (default: %(default)s)') parser.add_argument('--predict-images', nargs='?', const=1, default=False, type=bool, help='Use prediction for image tagging (default: %(default)s)') parser.add_argument('--predict-images-top', nargs='?', const=1, default=10, type=int, help='Defines how many top prediction keywords should be used (default: %(default)s)') + parser.add_argument('--predict-images-detail-factor', nargs='?', const=1, default=2, type=int, help='Width factor for detail scan, multiplied by 224 for ResNet50 (default: %(default)s)') + parser.add_argument('--predict-images-skip-detail', nargs='?', const=1, default=False, type=bool, help='Skip detail scan in image prediction (default: %(default)s)') parser.add_argument('--gui-tag', nargs='?', const=1, default=False, type=bool, help='Show GUI for tagging (default: %(default)s)') + parser.add_argument('--gui-image-length', nargs='?', const=1, default=800, type=int, help='Length of longest side for preview (default: %(default)s)') parser.add_argument('--open-system', nargs='?', const=1, default=False, type=bool, help='Open all files with system default (default: %(default)s)') parser.add_argument('-s', '--skip-prompt', nargs='?', const=1, default=False, type=bool, help='Skip prompt for file tags (default: %(default)s)') parser.add_argument('-i', '--index', nargs='?', const=1, default=0, type=int, help='Start tagging at the given file index (default: %(default)s)') @@ -125,7 +158,10 @@ if __name__ == "__main__": "gui": args.gui, "predict_images": args.predict_images, "predict_images_top": args.predict_images_top, + "predict_images_detail_factor": args.predict_images_detail_factor, + "predict_images_skip_detail": args.predict_images_skip_detail, "gui_tag": args.gui_tag, + "gui_image_length": args.gui_image_length, "open_system": args.open_system, "skip_prompt": args.skip_prompt, "index": args.index, |