Я пытаюсь реализовать сеть внимания с помощью TensorFlow 2. Таким образом, для каждого изображения я хочу взять лишь несколько проблесков, то есть небольшую часть изображения. Для этого я реализовал подкласс из tenorflow.keras.models.Model, вот фрагмент из него.
class RecurrentAttentionModel(models.Model):
# ...
def call(self, inputs):
l = tf.random.uniform((40,2,), minval=0, maxval=1)
for _ in range(0, self.glimpses):
glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True)
# some other code...
# update l to take a glimpse somewhere else
return result
Теперь приведенный выше код работает и работает отлично, но моя проблема в том, что я в нем есть жестко заданный код 40, размер batch_size, который я определил в своем наборе данных. Я не могу прочитать / получить batch_size в методе вызова, так как переменная "input" имеет форму Tensor("input_1_77:0", shape=(None, 250, 500, 1), dtype=float32)
, где None
для batch_size, похоже, является ожидаемым поведением. Когда я просто инициализирую l следующим кодом (без batch_size)
l = tf.random.uniform((2,), minval=0, maxval=1)
, он выдает эту ошибку
ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]
, что я полностью понимаю, но я понятия не имею, как я мог бы реализовать начальные значения согласно batch_size.