Prerequisites: install tensorflow 1.0 and scikit-image.
clone this fork of tf-slim somewhere download the pretrained model and put it in tf-models/slim/pretrained/
import tensorflow as tf
slim = tf.contrib.slim
import sys
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from nets import inception
from preprocessing import inception_preprocessing
session = tf.Session()
image_size = inception.inception_v3.default_image_size
def transform_img_fn(path_list):
out = []
for f in path_list:
image_raw = tf.image.decode_jpeg(open(f).read(), channels=3)
image = inception_preprocessing.preprocess_image(image_raw, image_size, image_size, is_training=False)
from datasets import imagenet
names = imagenet.create_readable_names_for_imagenet_labels()
processed_images = tf.placeholder(tf.float32, shape=(None, 299, 299, 3))
import os
with slim.arg_scope(inception.inception_v3_arg_scope()):
logits, _ = inception.inception_v3(processed_images, num_classes=1001, is_training=False)
probabilities = tf.nn.softmax(logits)
checkpoints_dir = '/Users/marcotcr/phd/tf-models/slim/pretrained'
init_fn = slim.assign_from_checkpoint_fn(
os.path.join(checkpoints_dir, 'inception_v3.ckpt'),
def predict_fn(images):
return, feed_dict={processed_images: images})
images = transform_img_fn(['dogs.jpg'])
# I'm dividing by 2 and adding 0.5 because of how this Inception represents images
plt.imshow(images[0] / 2 + 0.5)
preds = predict_fn(images)
for x in preds.argsort()[0][-5:]:
print x, names[x], preds[0,x]
image = images[0]
## Now let's get an explanation
from lime import lime_image
import time
explainer = lime_image.LimeImageExplainer()
hide_color is the color for a superpixel turned OFF. Alternatively, if it is NONE, the superpixel will be replaced by the average of its pixels. Here, we set it to 0 (in the representation used by inception model, 0 means gray)
tmp = time.time()
# Hide color is the color for a superpixel turned OFF. Alternatively, if it is NONE, the superpixel will be replaced by the average of its pixels
explanation = explainer.explain_instance(image, predict_fn, top_labels=5, hide_color=0, num_samples=1000)
print time.time() - tmp
Image classifiers are a bit slow. Notice that an explanation in my macbookpro took 7.65 minutes
We can see the top 5 superpixels that are most positive towards the class with the rest of the image hidden
from skimage.segmentation import mark_boundaries
temp, mask = explanation.get_image_and_mask(240, positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
Or with the rest of the image present:
temp, mask = explanation.get_image_and_mask(240, positive_only=True, num_features=5, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
We can also see the 'pros and cons' (pros in green, cons in red)
temp, mask = explanation.get_image_and_mask(240, positive_only=False, num_features=10, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
Or the pros and cons that have weight at least 0.1
temp, mask = explanation.get_image_and_mask(240, positive_only=False, num_features=1000, hide_rest=False, min_weight=0.1)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
Most positive towards egyptian cat:
temp, mask = explanation.get_image_and_mask(286, positive_only=True, num_features=5, hide_rest=True)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
Pros and cons:
temp, mask = explanation.get_image_and_mask(286, positive_only=False, num_features=10, hide_rest=False)
plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))