Skip to content
This repository was archived by the owner on Jan 5, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions model_tools/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from model_tools.activations.keras import KerasWrapper, preprocess as preprocess_keras
from model_tools.activations.pytorch import PytorchWrapper, preprocess_images as preprocess_pytorch
from model_tools.activations.tensorflow import TensorflowWrapper, TensorflowSlimWrapper
from model_tools.activations.tfkeras import TFKerasWrapper, resnet_preprocessing
119 changes: 119 additions & 0 deletions model_tools/activations/tfkeras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from collections import OrderedDict
import numpy as np
from model_tools.activations.core import ActivationsExtractorHelper
import tensorflow as tf


class TFKerasWrapper:
"""
A wrapper for the Model class created from tensorflow(>=2.0.0).keras
"""
def __init__(self,
model,
preprocessing,
identifier=None,
*args, **kwargs):
"""
:param model: a tf.keras model with a function `preprocess_input`
that will later be called on the loaded numpy image
"""
self._model = model
identifier = identifier or model.name
self._extractor = ActivationsExtractorHelper(
identifier=identifier, get_activations=self.get_activations, preprocessing=preprocessing,
*args, **kwargs)
self._extractor.insert_attrs(self)

def __call__(self, *args, **kwargs): # cannot assign __call__ as attribute due to Python convention
return self._extractor(*args, **kwargs)

@property
def identifier(self):
return self._extractor.identifier

@identifier.setter
def identifier(self, value):
self._extractor.identifier = value

def get_activations(self, images, layer_names):

"""
param images: a list of image paths
param layer_names: a list of layer names
"""
from tensorflow.keras import backend as K

input_tensor = self._model.input
layers = [layer for layer in self._model.layers if layer.name in layer_names]
layers = sorted(layers, key=lambda layer: layer_names.index(layer.name))

if 'logits' in layer_names:
layers.insert(layer_names.index('logits'), self._model.layers[-1])

assert len(layers) == len(layer_names)
layer_out_tensors = [layer.output for layer in layers]
functor = K.function([input_tensor], layer_out_tensors) # evaluate all tensors at once
K.set_learning_phase(0) # 0 to signal testing phase
layer_outputs = functor([images])
return OrderedDict([(layer_name, layer_output) for layer_name, layer_output in zip(layer_names, layer_outputs)])

def __repr__(self):
return repr(self._model)

def graph(self):
import networkx as nx
g = nx.DiGraph()
for layer in self._model.layers:
g.add_node(layer.name, object=layer, type=type(layer))
for outbound_node in layer._outbound_nodes:
g.add_edge(layer.name, outbound_node.outbound_layer.name)
return g



def tfkeras_load_images(image_paths):
"""
:param image_paths: list of strings of len B
return tf.Tensor of [B, H, W, 3] of dtype tf.uint8
"""

def load_image(path):
"""
param path: tf.Tensor (1,)
"""
blob = tf.io.read_file(path[0])
im = tf.image.decode_png(blob, channels=3)
return im

def load_images(paths):
images = tf.map_fn(load_image, elems=paths, dtype=tf.uint8)
return images

images = list(map(lambda s: [s], image_paths))
images = tf.constant(images)
images = load_images(images) # [b, h, w, 3], dtype tf.uint8

return images


def resnet_preprocessing(image_paths,
image_size=224):
_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]

images = tfkeras_load_images(image_paths)

# Execute resizing
images = tf.image.resize(
images,
(image_size, image_size),
method=tf.image.ResizeMethod.BILINEAR,
preserve_aspect_ratio=False,
antialias=False,
)

images = images - CHANNEL_MEANS

return images