Page MenuHomePhabricator
Paste P7468

Run image quality TF model on spark

Authored by EBernhardson on Aug 20 2018, 10:24 PM.
Image quality classification using Inception trained on ImageNet 2012 Challenge data
set and finetuned on quality data.
This program runs inference on input JPEG images in a folder.
Change the --image_dir argument to any jpg image folder to compute a
classification images in that folder.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pyspark
except ImportError:
import findspark
import pyspark
import argparse
from collections import defaultdict
from concurrent import futures
import csv
import glob
import json
import os.path
import random
import re
import sys
import tarfile
import time
from os import listdir
from os.path import isfile, join
import numpy as np
import pyspark.sql
from pyspark.sql import types as T
import requests
from six.moves import urllib
import tensorflow as tf
FLAGS = None
def proxies():
cluster = 'eqiad' if random.random() > 0.5 else 'codfw'
return {
'http': 'http://webproxy.{}.wmnet:8080/'.format(cluster),
'https': 'https://webproxy.{}.wmnet:8080/'.format(cluster),
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
local_model_path = pyspark.SparkFiles.get(FLAGS.model_path)
with tf.gfile.FastGFile(local_model_path, 'rb') as f:
graph_def = tf.GraphDef()
_ = tf.import_graph_def(graph_def, name='')
def batch(x, n):
batch = []
for item in x:
if len(batch) == n:
yield batch
batch = []
if batch:
yield batch
def fetch_url_batch(session, titles, retries=0):
response ='', timeout=5, data={
'format': 'json',
'formatversion': 2,
'action': 'query',
'prop': 'imageinfo',
'iiprop': 'url',
'iiurlwidth': 600,
'titles': 'File:' + '|File:'.join(titles)
res = response.json()
except json.decoder.JSONDecodeError:
raise Exception(response.content)
if 'error' in res or 'query' not in res:
if 'error' in res and res['error']['code'] == 'urlparamnormal':
# We have a bad title. Try and extract it from the message and filter
prefix = 'Could not normalize image parameters for '
bad_title = res['error']['info'][len(prefix):-1]
del titles[titles.index(bad_title)]
except ValueError:
return fetch_url_batch(session, titles, retries + 1)
# Some sort of api error, we should retry
if retries < 3:
return fetch_url_batch(session, titles, retries + 1)
raise Exception(response.content)
return res
def fetch_urls(titles, batch_size=50):
# Spread out the requests when spark first starts up
if batch_size > 50:
raise Exception('Mediawiki api will only resize 50 images at a time')
with requests.Session() as session:
for batch_titles in batch(titles, batch_size):
res = fetch_url_batch(session, batch_titles)
normalized = {}
if 'normalized' in res['query']:
for norm in res['query']['normalized']:
normalized[norm['to']] = norm['from']
for page in res['query']['pages']:
title = page['title']
except KeyError:
title = '***MISSING TITLE***'
if title in normalized:
title = normalized[title]
if 'invalid' in page and page['invalid']:
yield -1, title, None, page['invalidreason']
elif 'missing' in page and page['missing']:
yield -1, title, None, 'missing'
for info in page['imageinfo']:
url = info['thumburl'] if 'thumburl' in info else info['url']
yield page['pageid'], title, url, None
except KeyError:
raise Exception(res)
def buffer_images(image_infos):
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RetryError
from requests.packages.urllib3.util.retry import Retry
from requests_futures.sessions import FuturesSession
with FuturesSession(max_workers=10) as session:
retries = defaultdict(int)
def on_complete(future, page_id, title, url):
res = future.result()
if res.status_code == 200:
yield page_id, title, res.content, error
elif res.status_code == 429 and retries[future] < 3:
# We can't really pause the in-progress requests be we can
# at least stop adding new ones for a bit.
# Sleep for 10, 20, 40 seconds
time.sleep(10 * (2 ** (retries[future])))
next_future = session.get(url, timeout=120, proxies=proxies())
retries[next_future] = retries[future] + 1
fs[next_future] = (page_id, title, url)
yield page_id, title, None, 'Received http status code {}'.format(res.status_code)
except (ConnectionError, RetryError) as e:
yield page_id, title, None, str(e.message)
if future in retries:
del retries[future]
fs = {}
for page_id, title, url, error in image_infos:
if error is not None:
yield page_id, title, None, error
future = session.get(url, timeout=120, proxies=proxies())
fs[future] = (page_id, title, url)
while len(fs) >= 10:
done_and_not_done = futures.wait(fs.keys(), return_when=futures.FIRST_COMPLETED)
for future in done_and_not_done.done:
image_info = fs[future]
del fs[future]
yield from on_complete(future, *image_info)
for future in futures.as_completed(fs):
yield from on_complete(future, *fs[future])
def run_inference_on_images(image_infos):
session_conf = tf.ConfigProto(intra_op_parallelism_threads=10,inter_op_parallelism_threads=10)
with tf.Session(config=session_conf) as sess:
# 'softmax:0': A tensor containing the normalized prediction across
# 1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
# float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
for page_id, title, image_data, error in buffer_images(image_infos):
if error is not None:
yield page_id, title, float('nan'), error
if image_data is None:
yield page_id, title, float('nan'), 'Failed to load image data from {}'.format(url)
count += 1
predictions =,
{'DecodeJpeg/contents:0': image_data})
except Exception as e:
yield page_id, title, float('nan'), '{}: {}'.format(type(e).__name__, e.message)
#if by chance anything went well
yield page_id, title, float(predictions[0][1]), None
def main(_):
conf = pyspark.SparkConf()
sc = pyspark.SparkContext(appName="classify_image_quality")
for path in glob.glob(FLAGS.image_titles):
with open(path, 'r') as f:
titles = [t.strip() for t in f]
rdd = sc.parallelize(titles, 200) \
.mapPartitions(fetch_urls) \
with open(FLAGS.outfile, 'w') as f:
writer = csv.writer(f)
# toLocalIterator would make sense, but it does silly things
# (running 200 tasks as 200 jobs end to end). Our
# datasets are only ~100k items simply load it in memory.
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# classify_image_graph_def.pb:
# Binary representation of the GraphDef protocol buffer.
# imagenet_synset_to_human_label_map.txt:
# Map from synset ID to a human readable string.
# imagenet_2012_challenge_label_map_proto.pbtxt:
# Text representation of a protocol buffer mapping a label to synset ID.
Local path to output_graph_new.pb,
help='local path to list of images'
help='HDFS path to write output to'
FLAGS, unparsed = parser.parse_known_args(), argv=[sys.argv[0]] + unparsed)