aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--file-tagger.py30
1 files changed, 19 insertions, 11 deletions
diff --git a/file-tagger.py b/file-tagger.py
index a8bf3ef..ca4b8f5 100644
--- a/file-tagger.py
+++ b/file-tagger.py
@@ -13,6 +13,9 @@ MODEL_DIMENSIONS = 224
def predict_image(model, img, top):
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
logger = logging.getLogger(__name__)
+ #cv2.imshow("test", img)
+ #cv2.waitKey(0)
+ #cv2.destroyAllWindows()
array = np.expand_dims(img, axis=0)
array = preprocess_input(array)
predictions = model.predict(array)
@@ -20,10 +23,14 @@ def predict_image(model, img, top):
logger.debug("Predicted image classes: {}".format(classes[0]))
return set([(name, prob) for _, name, prob in classes[0]])
-def predict_partial(tags, model, img, x, y, top):
+def predict_partial(tags, model, img, x, y, rot, top):
#cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)])
#cv2.waitKey(0)
- tags.update(predict_image(model, img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], top))
+ if rot is None:
+ tmp = img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)]
+ else:
+ tmp = cv2.rotate(img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], rot)
+ tags.update(predict_image(model, tmp, top))
'''
Walk over all files for the given base directory and all subdirectories recursively.
@@ -78,15 +85,16 @@ def walk(args):
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
if not args["predict_images_skip_detail"]:
pool = ThreadPool(max(1, os.cpu_count() - 2), 10000)
- for _ in range(4):
- if img.shape[0] > img.shape[1]:
- detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
- else:
- detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
- for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)):
- for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)):
- pool.add_task(predict_partial, tags_predict, model, detail, x, y, args["predict_images_top"])
- img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+ if img.shape[0] > img.shape[1]:
+ detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
+ else:
+ detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
+ for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)):
+ for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)):
+ pool.add_task(predict_partial, tags_predict, model, detail, x, y, None, args["predict_images_top"])
+ pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_CLOCKWISE, args["predict_images_top"])
+ pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_180, args["predict_images_top"])
+ pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_COUNTERCLOCKWISE, args["predict_images_top"])
pool.wait_completion()
tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)]
tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]])