Keras: ввод с размером x * x генерирует нежелательный вывод y * x - PullRequest
0 голосов
/ 21 ноября 2018

У меня есть следующая нейронная сеть в Керасе:

inp = layers.Input((3,))
#Middle layers omitted
out_prop = layers.Dense(units=3, activation='softmax')(inp)
out_value = layers.Dense(units=1, activation = 'linear')(inp)

Затем я подготовил псевдо-ввод для проверки моей сети:

inpu = np.array([[1,2,3],[4,5,6],[7,8,9]])

Когда я пытаюсь предсказать, это происходит:

In [45]:nn.network.predict(inpu)
Out[45]: 
[array([[0.257513  , 0.41672954, 0.32575747],
    [0.20175152, 0.4763418 , 0.32190666],
    [0.15986516, 0.53449154, 0.30564335]], dtype=float32),
array([[-0.24281949],
    [-0.10461146],
    [ 0.11201331]], dtype=float32)]

Итак, как вы можете видеть выше, я хотел получить два вывода: один должен был быть массивом с размером 3, другой должен был иметь нормальное значение.Вместо этого я получаю матрицу 3х3 и массив с 3 элементами.Что я делаю не так?

1 Ответ

0 голосов
/ 21 ноября 2018

Вы передаете три входных сэмпла в сеть:

>>> inpu.shape
(3,3)  # three samples of size 3

И у вас есть два выходных слоя: один из них выводит вектор размером 3 для каждого сэмпла , а другойвыводит вектор размера один (т.е. скалярный), опять же для каждого образца .В результате выходные формы будут иметь значения (3, 3) и (3, 1).

Обновление: Если вы хотите, чтобы ваша сеть принимала входной образец формы (3,3) и выводила векторы размера3 и 1, и вы хотите использовать только плотные слои в своей сети, тогда вы должны использовать слой Flatten где-нибудь в модели.Один из возможных вариантов - использовать его сразу после входного слоя:

inp = layers.Input((3,3))  # don't forget to set the correct input shape
x = Flatten()(inp)
# pass x to other Dense layers

В качестве альтернативы, вы можете сгладить ваши данные, чтобы иметь форму (num_samples, 9), а затем передать их в вашу сеть, не используя Flattenlayer.

Update 2: Как правильно заметил @Mete в комментариях, убедитесь, что входной массив имеет форму (num_samples, 3, 3), если каждая входная выборка имеет форму (3,3).

...