Рассмотрим гипотетическую, но очень практичную ситуацию, описанную ниже:
I have to create a network which uses two pre-trained networks `A`
and `B`. I would like to concatenate the output of `Layer-L` of
`A` with the output of `Layer-M` of `B` and do further computations
on the result.
Теперь в TensorFlow 1.x я мог бы использовать библиотеку tf-slim
(конечно, я предполагаю, что предварительно обученные модели и код A
и B
доступны в tf-slim
). Как мы знаем, tf-slim
предоставляет end_points
словарь. Я мог бы использовать end_points
для создания произвольных соединений при создании собственной сети.
При переходе на TensorFlow 2.x tf-slim
отсутствует. У меня вопрос о том, является ли следующая практика хорошей реализацией, пытаясь портировать tf-slim
специальные коды на TensorFlow 2.x
Например, если я попытаюсь перенести код tf-slim
для сети VGG16
в TensorFlow 2.x, я реализую его как подкласс класса tensorflow.keras.Model
. Пример приведен ниже:
import tensorflow as tf
"""
The following class RepeatLayer is a basic implementation of
slim.repeat class in TensorFlow-slim.
"""
class RepeatLayer(tf.keras.layers.Layer):
def __init__(self, layerobj, count, layernames, **kwargs):
"""
class instantiator
:param layerobj: A tf.keras.layers object (e.g:- tf.keras.layers.Conv2D)
:param count: (int) Number of times layerobj should be repeated.
:param layernames: (List of length count) Name of each repeatition of layerobj
:param kwargs: layerobj specific named arguments.
"""
super(RepeatLayer, self).__init__()
if not isinstance(count, int):
raise TypeError('The argument "count" must be a positive integer.')
if count <= 0:
raise ValueError('The argument "count" is provided as {}. It must be a'
'positive integer.'.format(count))
if not isinstance(layernames, list):
raise TypeError('The argument "layernames" must be a list of strings.')
if not len(layernames) == count:
raise ValueError('The length of "layernames" must be the value of "count".')
for name, value in kwargs.items():
if not isinstance(value, list):
value = [value] * count
kwargs[name] = value
self._layernames = layernames
self._count = count
self._end_points = dict()
self._outputs = []
self._kwargs = kwargs
for layernum in range(self._count):
args = dict(
map(
lambda x: (x[0], x[1][layernum]),
self._kwargs.items()
)
)
output = layerobj(**args)
self._outputs.append(output)
def call(self, input_tensor):
out = input_tensor
for node, layername in zip(self._outputs, self._layernames):
out = node(out)
self._end_points[layername] = out
return out
@property
def end_points(self):
return self._end_points
Затем я пытаюсь проверить приведенный выше код, используя следующий фрагмент:
from utils import RepeatLayer
import tensorflow as tf
layer = RepeatLayer(tf.keras.layers.Conv2D, 3, ['c1', 'c2', 'c3'],
filters=64, kernel_size=7, strides=1, padding='same',
activation='relu', data_format='channels_last',
use_bias=True)
input_tensor = tf.keras.backend.random_uniform(shape=(3,512,512,3))
print(layer.end_points) # should print an empty dictionary
output_tensor = layer(input_tensor)
print(layer.end_points) # Prints a dictionary in which values are output tensors
Таким образом, к словарю end_points
можно получить доступ только после того, как хотя бы один __call__
был сделан для модели. Я считаю, что из-за нетерпеливого исполнения по умолчанию это ожидаемое поведение.
Вопросы, с которыми я сейчас сталкиваюсь, таковы:
- Это правильный подход с точки зрения реализации? То есть, пытаясь портировать коды структуры модели с
tf-slim
на TensorFlow 2.x?
- Если я использую этот подход и хотел бы обернуть все в
@tf.function
для статического выполнения для эффективности, будет ли это работать? Я не пробовал этого, но без __call__
разве TensorFlow не нашел бы словарь end_points
пустым и выдал ошибку?
- Как я могу восстановить предварительно обученные модели в
tf-slim
для этих новых структур кода?
Я уверен, что у них должна быть очень веская причина не портить tf-slim
, а скорее отказаться от него полностью. Таким образом, я озадачен тем, является ли вышеуказанный подход правильным или нет.