Ошибки с простой нейронной сетью в Keras / Tensorflow - PullRequest
0 голосов
/ 14 апреля 2019

Я строю простую нейронную сеть. Данные представляют собой вектор длиной 231, который кодируется в горячем виде. Каждому 231 длинному вектору присваивается 8 длинных меток с горячим кодированием.

Пока мой код:

ssdf = pd.read_csv("/some/path/to/1AMX_one_hot.csv", sep=',')

ss = ssdf.iloc[:,3:11] # slice the df for the ss
labels = ss.values # vector of all ss's
labels = labels.astype('int32')
# data
onehot = ssdf.iloc[:,11:260]
data = onehot.values
data = data.astype('int32')

model = tf.keras.Sequential()
# Adds a densely-connected layer with 64 units to the model:
model.add(layers.Dense(64, activation='relu'))

# Add another:
model.add(layers.Dense(64, activation='relu'))

# Add a softmax layer with 8 output units:
model.add(layers.Dense(8, activation='softmax'))


model.compile(Adam(lr=.0001), 
          loss='sparse_categorical_crossentropy', 
          metrics=['accuracy']
)

## fit the model
model.fit(data, labels, epochs=10, batch_size=32)

Проблема в том, что выходной слой состоит из 8 единиц, однако мои метки не являются единичными единицами, они представляют собой 8 длинных векторов, которые имеют одно горячее кодирование. Как мне представить это как вывод?

Сообщение об ошибке:

TypeError: Unable to build 'Dense' layer with non-floating point dtype <dtype: 'int32'>

Полная трассировка:

Traceback (most recent call last):
  File "/some/path/to/file/main.py", line 36, in <module>
    model.fit(data, labels, epochs=10, batch_size=32)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/keras/engine/training.py", line 806, in fit
    shuffle=shuffle)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/keras/engine/training.py", line 2503, in     _standardize_user_data
    self._set_inputs(cast_inputs)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/training/tracking/base.py", line 456, in     _method_wrapper
    result = method(self, *args, **kwargs)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/keras/engine/training.py", line 2773, in     _set_inputs
    outputs = self.call(inputs, training=training)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/keras/engine/sequential.py", line 256, in call
outputs = layer(inputs, **kwargs)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/keras/engine/base_layer.py", line 594, in     __call__
    self._maybe_build(inputs)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/keras/engine/base_layer.py", line 1713, in     _maybe_build
    self.build(input_shapes)
  File "/anaconda3/lib/python3.7/site-    packages/tensorflow/python/keras/layers/core.py", line 963, in build
    'dtype %s' % (dtype,))

1 Ответ

0 голосов
/ 14 апреля 2019

В вашем примере кода есть несколько проблем:

  1. Вам необходим входной слой или форма ввода для вашей сети.
  2. Подайте данные и меткикак: astype(np.float32)

Если ваши метки имеют форму (150, 8), то установите последний слой с 8 нейронами.

model.add(layers.Dense(8, activation='softmax'))
model.compile(Adam(lr=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

ОБНОВЛЕНИЕ:

ssdf = pd.read_csv("/some/path/to/1AMX_one_hot.csv", sep=',')

ss = ssdf.iloc[:,3:11] # slice the df for the ss
labels = ss.values # vector of all ss's
labels = labels.astype('float32')                     # changed this
# data
onehot = ssdf.iloc[:,11:260]
data = onehot.values
data = data.astype('float32')                         # changed this

model = tf.keras.Sequential()
# Adds a densely-connected layer with 64 units to the model:
model.add(layers.Dense(64, activation='relu'))

# Add another:
model.add(layers.Dense(64, activation='relu'))

# Add a softmax layer with 8 output units:
model.add(layers.Dense(8, activation='softmax'))


model.compile(Adam(lr=.0001), 
          loss='categorical_crossentropy',            # changed this
          metrics=['accuracy']
)

## fit the model
model.fit(data, labels, epochs=10, batch_size=32)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...