diff options
-rw-r--r-- | file-tagger.py | 60 | ||||
-rw-r--r-- | gui.py | 2 | ||||
-rw-r--r-- | util.py | 61 |
3 files changed, 109 insertions, 14 deletions
diff --git a/file-tagger.py b/file-tagger.py index 82aacf7..92eb023 100644 --- a/file-tagger.py +++ b/file-tagger.py @@ -8,6 +8,23 @@ import magic from tmsu import * from util import * +MODEL_DIMENSIONS = 224 + +def predict_image(model, img, top): + from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions + logger = logging.getLogger(__name__) + array = np.expand_dims(img, axis=0) + array = preprocess_input(array) + predictions = model.predict(array) + classes = decode_predictions(predictions, top=top) + logger.debug("Predicted image classes: {}".format(classes[0])) + return set([(name, prob) for _, name, prob in classes[0]]) + +def predict_partial(tags, model, img, x, y, top): + #cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)]) + #cv2.waitKey(0) + tags.update(predict_image(model, img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], top)) + ''' Walk over all files for the given base directory and all subdirectories recursively. @@ -50,24 +67,37 @@ def walk(args): if mime_type.split("/")[0] == "image": logger.debug("File is image") img = cv2.imread(file_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = cv2.resize(img, dsize=(800, 800), interpolation=cv2.INTER_CUBIC) + #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if args["predict_images"]: logger.info("Predicting image tags ...") - array_pre = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_CUBIC) + tags_predict = set() for _ in range(4): - array = np.expand_dims(array_pre, axis=0) - array = preprocess_input(array) - predictions = model.predict(array) - classes = decode_predictions(predictions, top=args["predict_images_top"]) - logger.debug("Predicted image classes: {}".format(classes[0])) - tags.update([name for _, name, _ in classes[0]]) - array_pre = cv2.rotate(array_pre, cv2.ROTATE_90_CLOCKWISE) - logger.info("Predicted tags: {}".format(tags)) + logger.debug("Raw scan") + raw = cv2.resize(img.copy(), dsize=(MODEL_DIMENSIONS, MODEL_DIMENSIONS), interpolation=cv2.INTER_CUBIC) + tags_predict.update(predict_image(model, raw, args["predict_images_top"])) + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + if not args["predict_images_skip_detail"]: + pool = ThreadPool(max(1, os.cpu_count() - 2), 1000) + for _ in range(4): + if img.shape[0] > img.shape[1]: + detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) + else: + detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) + for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)): + for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)): + pool.add_task(predict_partial, tags_predict, model, detail.copy(), x, y, args["predict_images_top"]) + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + pool.wait_completion() + tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)] + tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]]) + logger.info("Predicted tags: {}".format(tags_predict)) + tags.update(tags_predict) if args["gui_tag"]: while(True): # For GUI inputs (rotate, ...) logger.debug("Showing image GUI ...") - ret = GuiImage(i, file_path, img, tags).loop() + img_show = image_resize(img, width=args["gui_image_length"]) if img.shape[1] > img.shape[0] else image_resize(img, height=args["gui_image_length"]) + img_show = cv2.cvtColor(img_show, cv2.COLOR_BGR2RGB) + ret = GuiImage(i, file_path, img_show, tags).loop() tags = set(ret[1]).difference({''}) if ret[0] == GuiImage.RETURN_ROTATE_90_CLOCKWISE: img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) @@ -102,7 +132,10 @@ if __name__ == "__main__": parser.add_argument('-g', '--gui', nargs='?', const=1, default=False, type=bool, help='Show main GUI (default: %(default)s)') parser.add_argument('--predict-images', nargs='?', const=1, default=False, type=bool, help='Use prediction for image tagging (default: %(default)s)') parser.add_argument('--predict-images-top', nargs='?', const=1, default=10, type=int, help='Defines how many top prediction keywords should be used (default: %(default)s)') + parser.add_argument('--predict-images-detail-factor', nargs='?', const=1, default=2, type=int, help='Width factor for detail scan, multiplied by 224 for ResNet50 (default: %(default)s)') + parser.add_argument('--predict-images-skip-detail', nargs='?', const=1, default=False, type=bool, help='Skip detail scan in image prediction (default: %(default)s)') parser.add_argument('--gui-tag', nargs='?', const=1, default=False, type=bool, help='Show GUI for tagging (default: %(default)s)') + parser.add_argument('--gui-image-length', nargs='?', const=1, default=800, type=int, help='Length of longest side for preview (default: %(default)s)') parser.add_argument('--open-system', nargs='?', const=1, default=False, type=bool, help='Open all files with system default (default: %(default)s)') parser.add_argument('-s', '--skip-prompt', nargs='?', const=1, default=False, type=bool, help='Skip prompt for file tags (default: %(default)s)') parser.add_argument('-i', '--index', nargs='?', const=1, default=0, type=int, help='Start tagging at the given file index (default: %(default)s)') @@ -125,7 +158,10 @@ if __name__ == "__main__": "gui": args.gui, "predict_images": args.predict_images, "predict_images_top": args.predict_images_top, + "predict_images_detail_factor": args.predict_images_detail_factor, + "predict_images_skip_detail": args.predict_images_skip_detail, "gui_tag": args.gui_tag, + "gui_image_length": args.gui_image_length, "open_system": args.open_system, "skip_prompt": args.skip_prompt, "index": args.index, @@ -102,7 +102,7 @@ class GuiImage(object): self.__image = ImageTk.PhotoImage(image=self.__image_pil) Label(self.__master, text="Index: {}".format(index)).grid(row=0, column=0, columnspan=4) Label(self.__master, text="File: {}".format(file)).grid(row=1, column=0, columnspan=4) - self.__label = Label(self.__master, width=800, height=800, image=self.__image) + self.__label = Label(self.__master, width=img.shape[1], height=img.shape[0], image=self.__image) self.__label.grid(row=2, column=0, columnspan=4) Entry(self.__master, textvariable=self.__tags).grid(row=3, column=0, columnspan=4, sticky="we") Button(self.__master, text="↺", command=self.__handle_rotate_90_counterclockwise).grid(row=4, column=0) @@ -4,6 +4,9 @@ import cv2 import platform import readline import os +import numpy as np +from queue import Queue +from threading import Thread, Lock def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA): # initialize the dimensions of the image to be resized and @@ -36,6 +39,11 @@ def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA): # return the resized image return resized +def image_embed(img, dimensions): + ret = np.zeros((dimensions[0], dimensions[1], 3), np.uint8) + ret[0:img.shape[0], 0:img.shape[1]] = img + return ret + ''' Fetch input prompt with prefilled text. @@ -76,4 +84,55 @@ def open_system(file): elif platform.system() == 'Windows': # Windows os.startfile(file) else: # linux variants - subprocess.call(('xdg-open', file))
\ No newline at end of file + subprocess.call(('xdg-open', file)) + +class Worker(Thread): + def __init__(self, tasks): + Thread.__init__(self) + self.tasks = tasks + self.daemon = True + self.lock = Lock() + self.start() + + def run(self): + while True: + func, args, kargs = self.tasks.get() + try: + if func.lower() == "terminate": + break + except: + try: + with self.lock: + func(*args, **kargs) + except Exception as exception: + print(exception) + self.tasks.task_done() + +class ThreadPool: + def __init__(self, num_threads, num_queue=None): + if num_queue is None or num_queue < num_threads: + num_queue = num_threads + self.tasks = Queue(num_queue) + self.threads = num_threads + for _ in range(num_threads): Worker(self.tasks) + + # This function can be called to terminate all the worker threads of the queue + def terminate(self): + self.wait_completion() + for _ in range(self.threads): self.add_task("terminate") + return None + + # This function can be called to add new work to the queue + def add_task(self, func, *args, **kargs): + self.tasks.put((func, args, kargs)) + + # This function can be called to wait till all the workers are done processing the pending works. If this function is called, the main will not process any new lines unless all the workers are done with the pending works. + def wait_completion(self): + self.tasks.join() + + # This function can be called to check if there are any pending/running works in the queue. If there are any works pending, the call will return Boolean True or else it will return Boolean False + def is_alive(self): + if self.tasks.unfinished_tasks == 0: + return False + else: + return True
\ No newline at end of file |