Source code for BioExp.graphs.significance

import keras
import numpy as np
import tensorflow as tf
import os
import pdb
import cv2 
import pickle


import pandas as pd
from ..helpers.utils import *

from keras.models import Model
from keras.utils import np_utils
from tqdm import tqdm


[docs]class SignificanceTester(): """ A class for testing significance of each concepts generated in a trained keras model instance """ def __init__(self, model, weights_pth, metric, classinfo=None): """ model : keras model architecture (keras.models.Model) weights_pth : saved weights path (str) metric : metric to compare prediction with gt, for example dice, CE layer_name : name of the layer which needs to be ablated test_img : test image used for ablation max_clusters: maximum number of clusters per layer """ self.model = model self.modelcopy = keras.models.clone_model(self.model) self.weights = weights_pth self.metric = metric self.classinfo = classinfo self.noutputs = len(self.model.outputs)
[docs] def get_layer_idx(self, layer_name): for idx, layer in enumerate(self.model.layers): if layer.name == layer_name: return idx
[docs] def node_significance(self, concept_info, dataset_path, loader, nmontecarlo = 10, max_samples = 1): """ test significance of each concepts concept: {'layer_name', 'filter_idxs'} """ self.model.load_weights(self.weights, by_name = True) node_idx = self.get_layer_idx(concept_info['layer_name']) node_idxs = concept_info['filter_idxs'] total_filters = np.arange(np.array(self.model.layers[node_idx].get_weights())[0].shape[-1]) nfilters = len(node_idxs) test_filters = np.delete(total_filters, node_idxs) if len(test_filters) < nfilters: print("Huge cluster size, may not be significant, cluster size: {}, total data size: {}".format(nfilters, len(test_filters))) return False input_paths = os.listdir(dataset_path) dice_json = {} for class_ in self.classinfo.keys(): dice_json[class_] = [] dice_json['IG'] = [] # information gain for _ in range(nmontecarlo): np.random.shuffle(test_filters) self.modelcopy.load_weights(self.weights, by_name = True) layer_weights = np.array(self.modelcopy.layers[node_idx].get_weights()) occluded_weights = layer_weights.copy() for j in test_filters[:nfilters]: occluded_weights[0][:,:,:,j] = 0 occluded_weights[1][j] = 0 self.modelcopy.layers[node_idx].set_weights(occluded_weights) for i in range(len(input_paths) if len(input_paths) < max_samples else max_samples): input_, label_ = loader(os.path.join(dataset_path, input_paths[i]), os.path.join(dataset_path, input_paths[i]).replace('mask', 'label').replace('labels', 'masks')) prediction_occluded = np.squeeze(self.modelcopy.predict(input_[None, ...])) prediction = np.squeeze(self.model.predict(input_[None, ...])) idx = 0 if self.noutputs > 1: for ii in range(self.noutputs): if prediction[ii] == self.nclasses: idx = ii break; for class_ in self.classinfo.keys(): if self.noutputs > 1: dice_json[class_].append(self.metric(label_, prediction[idx].argmax(axis = -1), self.classinfo[class_]) - self.metric(label_, prediction_occluded[idx].argmax(axis = -1), self.classinfo[class_])) else: dice_json[class_].append(self.metric(label_, prediction.argmax(axis = -1), self.classinfo[class_]) - self.metric(label_, prediction_occluded.argmax(axis = -1), self.classinfo[class_])) if self.noutputs > 1: dice_json['IG'].append(np.mean(-prediction_occluded[idx]*np.log2(prediction_occluded[idx]) + prediction[idx]*np.log(prediction[idx]))) else: dice_json['IG'].append(np.mean(-prediction_occluded*np.log2(prediction_occluded) + prediction*np.log(prediction))) for class_ in self.classinfo.keys(): dice_json[class_] = np.mean(dice_json[class_]) dice_json['IG'] = np.mean(dice_json['IG']) return dice_json
[docs] def graph_significance(self, graph_info, dataset_path = None, loader = None, save_path=None, max_samples = 1, nmontecarlo = 10): """ generates graph adj matrix for computation graph_info: {'concept_name', 'layer_name', 'feature_map_idxs'} save_path : graph_path or path to save graph """ if os.path.exists(os.path.join(save_path, 'significance_info.pickle')): with open(os.path.join(save_path, 'significance_info.pickle'), 'rb') as f: significance = pickle.load(f) else: nodes = graph_info['concept_name'] significance = {} for i, node in enumerate(nodes): node_info = {'layer_name': graph_info['layer_name'][i], 'filter_idxs': graph_info['feature_map_idxs'][i]} significance_dice = self.node_significance(node_info, dataset_path, loader, nmontecarlo = nmontecarlo, max_samples = max_samples ) significance[node] = significance_dice with open(os.path.join(save_path, 'significance_info.pickle'), 'wb') as f: pickle.dump(significance, f) return significance