diff --git a/deepbrain/extractor.py b/deepbrain/extractor.py index fa380bc..ae77d41 100644 --- a/deepbrain/extractor.py +++ b/deepbrain/extractor.py @@ -3,9 +3,11 @@ from skimage.transform import resize import os -PB_FILE = os.path.join(os.path.dirname(__file__), "models", "extractor", "graph_v2.pb") -CHECKPOINT_DIR = os.path.join(os.path.dirname(__file__), "models", "extractor", "v2") +# Enable TF1 compatibility mode +tf.compat.v1.disable_eager_execution() +PB_FILE = os.path.join(os.path.dirname(__file__), "models", "graph.pb") +CHECKPOINT_DIR = os.path.join(os.path.dirname(__file__), "models") class Extractor: @@ -15,26 +17,26 @@ def __init__(self): def load_pb(self): graph = tf.Graph() - self.sess = tf.Session(graph=graph) - with tf.gfile.FastGFile(PB_FILE, 'rb') as f: - graph_def = tf.GraphDef() + self.sess = tf.compat.v1.Session(graph=graph) + with tf.io.gfile.GFile(PB_FILE, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) with self.sess.graph.as_default(): - tf.import_graph_def(graph_def) + tf.import_graph_def(graph_def, name='') - self.img = graph.get_tensor_by_name("import/img:0") - self.training = graph.get_tensor_by_name("import/training:0") - self.dim = graph.get_tensor_by_name("import/dim:0") - self.prob = graph.get_tensor_by_name("import/prob:0") - self.pred = graph.get_tensor_by_name("import/pred:0") + self.img = graph.get_tensor_by_name("img:0") + self.training = graph.get_tensor_by_name("training:0") + self.dim = graph.get_tensor_by_name("dim:0") + self.prob = graph.get_tensor_by_name("prob:0") + self.pred = graph.get_tensor_by_name("pred:0") def load_ckpt(self): - self.sess = tf.Session() + self.sess = tf.compat.v1.Session() ckpt_path = tf.train.latest_checkpoint(CHECKPOINT_DIR) - saver = tf.train.import_meta_graph('{}.meta'.format(ckpt_path)) + saver = tf.compat.v1.train.import_meta_graph('{}.meta'.format(ckpt_path)) saver.restore(self.sess, ckpt_path) - g = tf.get_default_graph() + g = tf.compat.v1.get_default_graph() self.img = g.get_tensor_by_name("img:0") self.training = g.get_tensor_by_name("training:0") @@ -51,5 +53,3 @@ def run(self, image): prob = self.sess.run(self.prob, feed_dict={self.training: False, self.img: img}).squeeze() prob = resize(prob, (shape), mode='constant', anti_aliasing=True) return prob - -