Source code for BioExp.spatial.ablation

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import os
import keras
import numpy as np
import tensorflow as tf
from glob import glob
import pandas as pd
from tqdm import tqdm
import keras.backend as K
from keras.utils import np_utils
from keras.models import load_model


[docs]class Ablate(): """ A class for conducting an ablation study on a trained keras model instance """ def __init__(self, model, weights_pth, metric, layer_name, test_image, gt, classes, nclasses=4, image_name=None): """ model : keras model architecture (keras.models.Model) weights_pth : saved weights path (str) metric : metric to compare prediction with gt, for example dice, CE layer_name : name of the layer which needs to be ablated test_img : test image used for ablation gt : ground truth for comparision classes : class informatiton which needs to be considered, class label as key and corresponding required values in a tuple: {'class1': (1,), 'whole': (1,2,3)} nclasses : number of unique classes in gt """ self.model = model self.weights = weights_pth self.metric = metric self.test_image = test_image self.layer = layer_name self.gt = gt self.classinfo = classes self.nclasses = nclasses self.layer_idx = 0 self.image_name = image_name for idx, layer in enumerate(self.model.layers): if layer.name == self.layer: self.layer_idx = idx self.noutputs = len(self.model.outputs) self.model.load_weights(self.weights, by_name = True) self.layer_weights = np.array(self.model.layers[self.layer_idx].get_weights()) self.filter_shape = self.layer_weights[0].shape
[docs] def ablate_filters(self, filters_to_ablate = None, concept = 'random', step = None, save_path=None, verbose=1): """ Drops individual weights from the model, makes the prediction for the test image, and calculates the difference in the evaluation metric as compared to the non- ablated case. For example, for a layer with a weight matrix of shape 3x3x64, individual 3x3 matrices are zeroed out at the interval given by the step argument. Arguments: step: The interval at which to drop weights Outputs: A dataframe containing the importance scores for each individual weight matrix in the layer """ if (step == None) and (filters_to_ablate == None): raise ValueError("step or filters_to_ablate is required") if filters_to_ablate == None: filters_to_ablate = np.arange(0, self.filter_shape[-1], step) #predicts each volume and save the results in np array prediction_unshaped, og_rec = self.model.predict(self.test_image, batch_size=1, verbose=verbose) dice_json = {} dice_json['concept'] = [] for class_ in self.classinfo.keys(): dice_json[class_] = [] dice_json[class_] = [] occluded_weights = self.layer_weights.copy() for j in (filters_to_ablate): occluded_weights[0][:,:,:,j] = 0 occluded_weights[1][j] = 0 self.model.layers[self.layer_idx].set_weights(occluded_weights) prediction_unshaped_occluded, ab_rec = self.model.predict(self.test_image,batch_size=1, verbose=0) dice_json['concept'].append('actual_' + str(concept)) dice_json['concept'].append('ablated_' + str(concept)) for class_ in self.classinfo.keys(): if self.noutputs == 1: dice_json[class_].append(self.metric(self.gt, prediction_unshaped.argmax(axis = -1), self.classinfo[class_])) dice_json[class_].append(self.metric(self.gt, prediction_unshaped_occluded.argmax(axis = -1), self.classinfo[class_])) else: for ii in range(self.noutputs): if prediction_unshaped[ii].shape[-1] == self.nclasses: idx = ii dice_json[class_].append(self.metric(self.gt, prediction_unshaped[idx].argmax(axis = -1), self.classinfo[class_])) dice_json[class_].append(self.metric(self.gt, prediction_unshaped_occluded[idx].argmax(axis = -1), self.classinfo[class_])) if not (save_path == None): if self.noutputs > 1: plt.subplot(1,2*self.noutputs + 2, 1) # plt.imshow(np.squeeze(self.test_image)) plt.imshow(np.squeeze(self.test_image), alpha=1) plt.subplot(1,2*self.noutputs + 2, 2) # plt.imshow(np.squeeze(self.test_image)) plt.imshow(np.squeeze(self.gt), alpha=1) for ii in range(self.noutputs): plt.subplot(1, 2*self.noutputs + 2, ii + 3) img = prediction_unshaped[ii] if img.shape[-1] == self.nclasses: img = img.argmax(axis = -1) plt.imshow(np.squeeze(img)) plt.title('unoccluded') for ii in range(self.noutputs): plt.subplot(1, 2*self.noutputs + 2, self.noutputs + ii + 3) img = prediction_unshaped_occluded[ii] if img.shape[-1] == self.nclasses: img = img.argmax(axis = -1) plt.imshow(np.squeeze(img)) plt.title('occluded') else: plt.subplot(1,4,1) # plt.imshow(np.squeeze(self.test_image)) plt.imshow(np.squeeze(self.test_image), alpha=1) plt.subplot(1,4,2) # plt.imshow(np.squeeze(self.test_image)) plt.imshow(np.squeeze(self.gt), alpha=1) plt.subplot(1,4,3) # plt.imshow(np.squeeze(self.test_image)) plt.imshow(np.squeeze(prediction_unshaped.argmax(axis = -1)), alpha=1) plt.subplot(1,4,4) # plt.imshow(np.squeeze(self.test_image)) plt.imshow(np.squeeze(prediction_unshaped_occluded.argmax(axis = -1)), alpha=1) plt.savefig(os.path.join(save_path, 'image_{}_{}_concept_{}.png'.format(self.image_name, self.layer, concept))) df = pd.DataFrame(dice_json) return df