Согласно официальному документу по tf.keras.layers.Conv2D
,
При использовании этого слоя в качестве первого слоя в модели укажите ключевое слово аргумент input_shape (кортеж целых чисел, не включая ось образца), например, input_shape = (128, 128, 3) для изображений RGB 128x128 в data_format = "channel_last".
но на самом деле без input_shape он работает как в графическом режиме, так и в среде активного исполнения.
В графическом исполнении,
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Flatten, Dense
class CNN(tf.keras.Model):
def __init__(self):
super(CNN, self).__init__()
self.conv = Conv2D(1, 3, padding='same', data_format='channels_first')
self.flatten = Flatten()
self.dense = Dense(1)
def call(self, inputs):
x = self.conv(inputs)
x = self.flatten(x)
return self.dense(x)
cnn = CNN()
inputs = tf.random_uniform([2, 3, 16, 16])
outputs = cnn(inputs)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outputs = sess.run(outputs)
print(outputs)
работает без ошибок и в стремительном исполнении,
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Flatten, Dense
tf.enable_eager_execution()
class CNN(tf.keras.Model):
def __init__(self):
super(CNN, self).__init__()
self.conv = Conv2D(1, 3, padding='same', data_format='channels_first')
self.flatten = Flatten()
self.dense = Dense(1)
def call(self, inputs):
x = self.conv(inputs)
x = self.flatten(x)
return self.dense(x)
cnn = CNN()
inputs = tf.random_uniform([2, 3, 16, 16])
outputs = cnn(inputs)
print(outputs)
также делает.
Q1: действительно ли tf.keras.layers.Conv2D
как первый слой в модели необходимо указать input_shape
?
Q2: Если нет, то когда это необходимо и почему это упоминается в официальном документе?
Update1:
Учебник по tf.keras говорит
Количество входных измерений часто не нужно, так как оно может быть выведено
первый раз, когда слой используется, но он может быть предоставлен, если вы хотите
укажите его вручную, что полезно в некоторых сложных моделях.
UPDATE2:
git blame
строки документации в источнике TensorFlow показали, что этот документ скопирован из Keras API (который не является TensorFlow keras API).