Реализация интерфейса, подобного tf-slim, в TensorFlow 2.x - PullRequest
0 голосов
/ 02 июня 2019

Рассмотрим гипотетическую, но очень практичную ситуацию, описанную ниже:

I have to create a network which uses two pre-trained networks `A`
 and `B`. I would like to concatenate the output of  `Layer-L` of 
`A` with the output of `Layer-M` of `B` and do further computations
 on the result.

Теперь в TensorFlow 1.x я мог бы использовать библиотеку tf-slim (конечно, я предполагаю, что предварительно обученные модели и код A и B доступны в tf-slim). Как мы знаем, tf-slim предоставляет end_points словарь. Я мог бы использовать end_points для создания произвольных соединений при создании собственной сети.

При переходе на TensorFlow 2.x tf-slim отсутствует. У меня вопрос о том, является ли следующая практика хорошей реализацией, пытаясь портировать tf-slim специальные коды на TensorFlow 2.x

Например, если я попытаюсь перенести код tf-slim для сети VGG16 в TensorFlow 2.x, я реализую его как подкласс класса tensorflow.keras.Model. Пример приведен ниже:

import tensorflow as tf
"""
The following class RepeatLayer is a basic implementation of 
slim.repeat class in TensorFlow-slim.
"""

class RepeatLayer(tf.keras.layers.Layer):
    def __init__(self, layerobj, count, layernames, **kwargs):
        """
        class instantiator
        :param layerobj: A tf.keras.layers object (e.g:- tf.keras.layers.Conv2D)
        :param count: (int) Number of times layerobj should be repeated.
        :param layernames: (List of length count) Name of each repeatition of layerobj
        :param kwargs: layerobj specific named arguments.
        """
        super(RepeatLayer, self).__init__()
        if not isinstance(count, int):
            raise TypeError('The argument "count" must be a positive integer.')

        if count <= 0:
            raise ValueError('The argument "count" is provided as {}. It must be a'
                             'positive integer.'.format(count))

        if not isinstance(layernames, list):
            raise TypeError('The argument "layernames" must be a list of strings.')

        if not len(layernames) == count:
            raise ValueError('The length of "layernames" must be the value of "count".')

        for name, value in kwargs.items():
            if not isinstance(value, list):
                value = [value] * count
                kwargs[name] = value

        self._layernames = layernames
        self._count = count
        self._end_points = dict()
        self._outputs = []
        self._kwargs = kwargs
        for layernum in range(self._count):
            args = dict(
                map(
                    lambda x: (x[0], x[1][layernum]),
                    self._kwargs.items()
                )
            )
            output = layerobj(**args)
            self._outputs.append(output)

    def call(self, input_tensor):
        out = input_tensor
        for node, layername in zip(self._outputs, self._layernames):
            out = node(out)
            self._end_points[layername] = out
        return out

    @property
    def end_points(self):
        return self._end_points

Затем я пытаюсь проверить приведенный выше код, используя следующий фрагмент:

from utils import RepeatLayer
import tensorflow as tf

layer = RepeatLayer(tf.keras.layers.Conv2D, 3, ['c1', 'c2', 'c3'],
                    filters=64, kernel_size=7, strides=1, padding='same',
                    activation='relu', data_format='channels_last',
                    use_bias=True)

input_tensor = tf.keras.backend.random_uniform(shape=(3,512,512,3))

print(layer.end_points) # should print an empty dictionary

output_tensor = layer(input_tensor)

print(layer.end_points) # Prints a dictionary in which values are output tensors

Таким образом, к словарю end_points можно получить доступ только после того, как хотя бы один __call__ был сделан для модели. Я считаю, что из-за нетерпеливого исполнения по умолчанию это ожидаемое поведение.

Вопросы, с которыми я сейчас сталкиваюсь, таковы:

  1. Это правильный подход с точки зрения реализации? То есть, пытаясь портировать коды структуры модели с tf-slim на TensorFlow 2.x?
  2. Если я использую этот подход и хотел бы обернуть все в @tf.function для статического выполнения для эффективности, будет ли это работать? Я не пробовал этого, но без __call__ разве TensorFlow не нашел бы словарь end_points пустым и выдал ошибку?
  3. Как я могу восстановить предварительно обученные модели в tf-slim для этих новых структур кода?

Я уверен, что у них должна быть очень веская причина не портить tf-slim, а скорее отказаться от него полностью. Таким образом, я озадачен тем, является ли вышеуказанный подход правильным или нет.

...