aboutsummaryrefslogtreecommitdiff
path: root/file-tagger.py
diff options
context:
space:
mode:
Diffstat (limited to 'file-tagger.py')
-rw-r--r--file-tagger.py60
1 files changed, 48 insertions, 12 deletions
diff --git a/file-tagger.py b/file-tagger.py
index 82aacf7..92eb023 100644
--- a/file-tagger.py
+++ b/file-tagger.py
@@ -8,6 +8,23 @@ import magic
from tmsu import *
from util import *
+MODEL_DIMENSIONS = 224
+
+def predict_image(model, img, top):
+ from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
+ logger = logging.getLogger(__name__)
+ array = np.expand_dims(img, axis=0)
+ array = preprocess_input(array)
+ predictions = model.predict(array)
+ classes = decode_predictions(predictions, top=top)
+ logger.debug("Predicted image classes: {}".format(classes[0]))
+ return set([(name, prob) for _, name, prob in classes[0]])
+
+def predict_partial(tags, model, img, x, y, top):
+ #cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)])
+ #cv2.waitKey(0)
+ tags.update(predict_image(model, img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], top))
+
'''
Walk over all files for the given base directory and all subdirectories recursively.
@@ -50,24 +67,37 @@ def walk(args):
if mime_type.split("/")[0] == "image":
logger.debug("File is image")
img = cv2.imread(file_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = cv2.resize(img, dsize=(800, 800), interpolation=cv2.INTER_CUBIC)
+ #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if args["predict_images"]:
logger.info("Predicting image tags ...")
- array_pre = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
+ tags_predict = set()
for _ in range(4):
- array = np.expand_dims(array_pre, axis=0)
- array = preprocess_input(array)
- predictions = model.predict(array)
- classes = decode_predictions(predictions, top=args["predict_images_top"])
- logger.debug("Predicted image classes: {}".format(classes[0]))
- tags.update([name for _, name, _ in classes[0]])
- array_pre = cv2.rotate(array_pre, cv2.ROTATE_90_CLOCKWISE)
- logger.info("Predicted tags: {}".format(tags))
+ logger.debug("Raw scan")
+ raw = cv2.resize(img.copy(), dsize=(MODEL_DIMENSIONS, MODEL_DIMENSIONS), interpolation=cv2.INTER_CUBIC)
+ tags_predict.update(predict_image(model, raw, args["predict_images_top"]))
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+ if not args["predict_images_skip_detail"]:
+ pool = ThreadPool(max(1, os.cpu_count() - 2), 1000)
+ for _ in range(4):
+ if img.shape[0] > img.shape[1]:
+ detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
+ else:
+ detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
+ for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)):
+ for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)):
+ pool.add_task(predict_partial, tags_predict, model, detail.copy(), x, y, args["predict_images_top"])
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+ pool.wait_completion()
+ tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)]
+ tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]])
+ logger.info("Predicted tags: {}".format(tags_predict))
+ tags.update(tags_predict)
if args["gui_tag"]:
while(True): # For GUI inputs (rotate, ...)
logger.debug("Showing image GUI ...")
- ret = GuiImage(i, file_path, img, tags).loop()
+ img_show = image_resize(img, width=args["gui_image_length"]) if img.shape[1] > img.shape[0] else image_resize(img, height=args["gui_image_length"])
+ img_show = cv2.cvtColor(img_show, cv2.COLOR_BGR2RGB)
+ ret = GuiImage(i, file_path, img_show, tags).loop()
tags = set(ret[1]).difference({''})
if ret[0] == GuiImage.RETURN_ROTATE_90_CLOCKWISE:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
@@ -102,7 +132,10 @@ if __name__ == "__main__":
parser.add_argument('-g', '--gui', nargs='?', const=1, default=False, type=bool, help='Show main GUI (default: %(default)s)')
parser.add_argument('--predict-images', nargs='?', const=1, default=False, type=bool, help='Use prediction for image tagging (default: %(default)s)')
parser.add_argument('--predict-images-top', nargs='?', const=1, default=10, type=int, help='Defines how many top prediction keywords should be used (default: %(default)s)')
+ parser.add_argument('--predict-images-detail-factor', nargs='?', const=1, default=2, type=int, help='Width factor for detail scan, multiplied by 224 for ResNet50 (default: %(default)s)')
+ parser.add_argument('--predict-images-skip-detail', nargs='?', const=1, default=False, type=bool, help='Skip detail scan in image prediction (default: %(default)s)')
parser.add_argument('--gui-tag', nargs='?', const=1, default=False, type=bool, help='Show GUI for tagging (default: %(default)s)')
+ parser.add_argument('--gui-image-length', nargs='?', const=1, default=800, type=int, help='Length of longest side for preview (default: %(default)s)')
parser.add_argument('--open-system', nargs='?', const=1, default=False, type=bool, help='Open all files with system default (default: %(default)s)')
parser.add_argument('-s', '--skip-prompt', nargs='?', const=1, default=False, type=bool, help='Skip prompt for file tags (default: %(default)s)')
parser.add_argument('-i', '--index', nargs='?', const=1, default=0, type=int, help='Start tagging at the given file index (default: %(default)s)')
@@ -125,7 +158,10 @@ if __name__ == "__main__":
"gui": args.gui,
"predict_images": args.predict_images,
"predict_images_top": args.predict_images_top,
+ "predict_images_detail_factor": args.predict_images_detail_factor,
+ "predict_images_skip_detail": args.predict_images_skip_detail,
"gui_tag": args.gui_tag,
+ "gui_image_length": args.gui_image_length,
"open_system": args.open_system,
"skip_prompt": args.skip_prompt,
"index": args.index,