Переменная batch_size в функции вызова - PullRequest
0 голосов
/ 20 марта 2020

Я пытаюсь реализовать сеть внимания с помощью TensorFlow 2. Таким образом, для каждого изображения я хочу взять лишь несколько проблесков, то есть небольшую часть изображения. Для этого я реализовал подкласс из tenorflow.keras.models.Model, вот фрагмент из него.

class RecurrentAttentionModel(models.Model):
# ...

def call(self, inputs):

    l = tf.random.uniform((40,2,), minval=0, maxval=1)

    for _ in range(0, self.glimpses):
        glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True)

        # some other code...
        # update l to take a glimpse somewhere else


    return result           

Теперь приведенный выше код работает и работает отлично, но моя проблема в том, что я в нем есть жестко заданный код 40, размер batch_size, который я определил в своем наборе данных. Я не могу прочитать / получить batch_size в методе вызова, так как переменная "input" имеет форму Tensor("input_1_77:0", shape=(None, 250, 500, 1), dtype=float32), где None для batch_size, похоже, является ожидаемым поведением. Когда я просто инициализирую l следующим кодом (без batch_size)

l = tf.random.uniform((2,), minval=0, maxval=1)

, он выдает эту ошибку

ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]

, что я полностью понимаю, но я понятия не имею, как я мог бы реализовать начальные значения согласно batch_size.

1 Ответ

1 голос
/ 20 марта 2020

Вы можете динамически извлечь размер размера партии, используя tf.shape.

l = tf.random.normal(tf.stack([tf.shape(inputs)[0], 2]), minval=0, maxval=1))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...