summaryrefslogtreecommitdiff
path: root/nsfw_detect.py
diff options
context:
space:
mode:
Diffstat (limited to 'nsfw_detect.py')
-rwxr-xr-xnsfw_detect.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/nsfw_detect.py b/nsfw_detect.py
new file mode 100755
index 0000000..fc9a7dd
--- /dev/null
+++ b/nsfw_detect.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+
+import numpy as np
+import os
+import sys
+from io import BytesIO
+from subprocess import run, PIPE, DEVNULL
+
+os.environ["GLOG_minloglevel"] = "2" # seriously :|
+import caffe
+
+class NSFWDetector:
+ def __init__(self):
+
+ npath = os.path.join(os.path.dirname(__file__), "nsfw_model")
+ self.nsfw_net = caffe.Net(os.path.join(npath, "deploy.prototxt"),
+ os.path.join(npath, "resnet_50_1by2_nsfw.caffemodel"),
+ caffe.TEST)
+ self.caffe_transformer = caffe.io.Transformer({'data': self.nsfw_net.blobs['data'].data.shape})
+ self.caffe_transformer.set_transpose('data', (2, 0, 1)) # move image channels to outermost
+ self.caffe_transformer.set_mean('data', np.array([104, 117, 123])) # subtract the dataset-mean value in each channel
+ self.caffe_transformer.set_raw_scale('data', 255) # rescale from [0, 1] to [0, 255]
+ self.caffe_transformer.set_channel_swap('data', (2, 1, 0)) # swap channels from RGB to BGR
+
+ def _compute(self, img):
+ image = caffe.io.load_image(BytesIO(img))
+
+ H, W, _ = image.shape
+ _, _, h, w = self.nsfw_net.blobs["data"].data.shape
+ h_off = int(max((H - h) / 2, 0))
+ w_off = int(max((W - w) / 2, 0))
+ crop = image[h_off:h_off + h, w_off:w_off + w, :]
+
+ transformed_image = self.caffe_transformer.preprocess('data', crop)
+ transformed_image.shape = (1,) + transformed_image.shape
+
+ input_name = self.nsfw_net.inputs[0]
+ output_layers = ["prob"]
+ all_outputs = self.nsfw_net.forward_all(blobs=output_layers,
+ **{input_name: transformed_image})
+
+ outputs = all_outputs[output_layers[0]][0].astype(float)
+
+ return outputs
+
+ def detect(self, fpath):
+ try:
+ ff = run(["ffmpegthumbnailer", "-m", "-o-", "-s256", "-t50%", "-a", "-cpng", "-i", fpath], stdout=PIPE, stderr=DEVNULL, check=True)
+ image_data = ff.stdout
+ except:
+ return -1.0
+
+ scores = self._compute(image_data)
+
+ return scores[1]
+
+if __name__ == "__main__":
+ n = NSFWDetector()
+
+ for inf in sys.argv[1:]:
+ score = n.detect(inf)
+ print(inf, score)