Подача набора данных тензорного потока в модель - PullRequest
1 голос
/ 24 октября 2019

У меня есть набор входных данных со 102 функциями и соответствующий двоичный выход. Выходной сигнал равен 0 или 1, в зависимости от функций 102.

Вход:

tf.Tensor(
[-1.72999993e-01 -8.20000023e-02  3.38000000e-01  1.35000005e-01
  ...
  0.00000000e+00  2.00000009e-03], shape=(102,), dtype=float64)

Выход:

tf.Tensor([1], shape=(1,), dtype=int32)

Я пытаюсь следовать этому пользовательское учебное пособие и создайте эту модель следующим образом:

train_dataset = tf.data.Dataset.from_tensor_slices((train_x,tf.dtypes.cast(label_x, tf.int32)))
features, labels = next(iter(train_dataset))

model = tf.keras.Sequential([
  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(102,)),  # input shape required
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(1)
])

predictions = model(features)

Однако, когда возникает ошибка при попытке его запустить:

---------------------------------------------------------------------------

InvalidArgumentError                      Traceback (most recent call last)

<ipython-input-12-d7be7f733930> in <module>()
      6 ])
      7 
----> 8 predictions = model(features)

7 frames

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: In[0] is not a matrix. Instead it has shape [102] [Op:MatMul]

1 Ответ

0 голосов
/ 24 октября 2019

Вы должны настроить batch для создания набора данных или input_shape в вашей модели, чтобы соответствовать размерам.

train_x = np.arange(100, dtype=np.int32)
label_x = np.arange(100, dtype=np.int32)

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, label_x)).batch(10)

model = tf.keras.Sequential([
  tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(1,)),  # input shape required
  tf.keras.layers.Dense(10, activation=tf.nn.relu),
  tf.keras.layers.Dense(1)
])

for features, labels in train_dataset:
    pred = model(features[..., tf.newaxis])
print(pred)

#tf.Tensor(
#[[-21.829016]
# [-22.071556]
# [-22.314102]
# [-22.556648]
# [-22.799194]
# [-23.041737]
# [-23.284283]
# [-23.52683 ]
# [-23.76937 ]
# [-24.011917]], shape=(10, 1), dtype=float32)
...