Source code for BioExp.helpers.pb_file_generation

import tensorflow as tf
import keras.backend as K
from keras.models import load_model
from .losses import *

[docs]def generate_pb(model_path, layer_name, pb_path, wts_path): """ freezes model weights and convert entire graph into .pb file model_path: saved model path (model architecture) (str) layer_name: name of output layer (str) pb_path : path to save pb file wts_path : saved model weights """ with tf.Session(graph=K.get_session().graph) as session: session.run(tf.global_variables_initializer()) model = load_model(model_path, custom_objects={'gen_dice_loss':gen_dice_loss, 'dice_whole_metric':dice_whole_metric, 'dice_core_metric':dice_core_metric, 'dice_en_metric':dice_en_metric}) model.load_weights(wts_path) print (model.summary()) try: output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( session, K.get_session().graph.as_graph_def(), [layer_name + '/convolution'] ) except: output_graph_def = tf.graph_util.convert_variables_to_constants( session, K.get_session().graph.as_graph_def(), [layer_name + '/convolution'] ) with open(pb_path, "wb") as f: f.write(output_graph_def.SerializeToString())