diff --git a/deepbrain/extractor.py b/deepbrain/extractor.py index fa380bc..217f596 100644 --- a/deepbrain/extractor.py +++ b/deepbrain/extractor.py @@ -3,8 +3,8 @@ from skimage.transform import resize import os -PB_FILE = os.path.join(os.path.dirname(__file__), "models", "extractor", "graph_v2.pb") -CHECKPOINT_DIR = os.path.join(os.path.dirname(__file__), "models", "extractor", "v2") +PB_FILE = os.path.join(os.path.dirname(__file__), "models", "graph.pb") +CHECKPOINT_DIR = os.path.join(os.path.dirname(__file__), "models") class Extractor: @@ -15,9 +15,9 @@ def __init__(self): def load_pb(self): graph = tf.Graph() - self.sess = tf.Session(graph=graph) - with tf.gfile.FastGFile(PB_FILE, 'rb') as f: - graph_def = tf.GraphDef() + self.sess = tf.compat.v1.Session(graph=graph) + with tf.io.gfile.GFile(PB_FILE, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) with self.sess.graph.as_default(): tf.import_graph_def(graph_def)