InvalidArgumentError при попытке сделать регрессию и классификацию с помощью функционального API Keras и SparseCategoricalCrossentropy Loss - PullRequest
0 голосов
/ 20 сентября 2019

Я хочу сделать регрессию и классификацию в одной модели.Причина в том, что я хочу перенести это на TFLite.В классификации мне нужно использовать SparseCategoricalcrossEntropy, чтобы можно было прогнозировать непосредственно индекс класса.Чтобы построить его в той же модели, я использую Functional API от Keras.

Поэтому я хотел бы получить прогноз class_index и регрессию в модели sam.

Упрощенная версиякод такой:

import numpy as np
from keras.layers import Input, Dense
from keras.models import Model 

X=np.random.random(size=(100,5)) 
y=np.random.randint(0,100,size=(100,2)).astype(float)   #Regression
class_index=np.random.randint(0,2,size=(100,1))         #Classification

input1 = Input(shape=(5,))
t = Dense(10, activation='relu')(input1)
out1 = Dense(2, activation='softmax')(t)

t2 = Dense(50, activation='relu')(t) 
t3 = Dense(50, activation='relu')(t2) 
out2 = Dense(2)(t3)

model = Model(inputs=input1, outputs=[out1, out2])
model.compile(
    optimizer='adam',
    loss=['sparse_categorical_crossentropy', 'mean_squared_error'],
    metrics=['sparse_categorical_accuracy','accuracy']
    )

history = model.fit(X, [ class_index, y ], epochs=10, batch_size=64)

Я получаю ошибку:

  File "C:/Users/.../Sin título5.py", line 31, in <module>
    history = model.fit(X, [ class_index, y ], epochs=10, batch_size=64)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\engine\training.py", line 1178, in fit
    validation_freq=validation_freq)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\engine\training_arrays.py", line 204, in fit_loop
    outs = fit_function(ins_batch)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\backend\tensorflow_backend.py", line 2979, in __call__
    return self._call(inputs)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\backend\tensorflow_backend.py", line 2937, in _call
    fetched = self._callable_fn(*array_vals)

  File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\client\session.py", line 1472, in __call__
    run_metadata_ptr)

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: Can not squeeze dim[1], expected a dimension of 1, got 2
     [[{{node metrics_35/sparse_categorical_accuracy_1/Squeeze}}]]
     [[metrics_35/acc_1/Mean/_567]]
  (1) Invalid argument: Can not squeeze dim[1], expected a dimension of 1, got 2
     [[{{node metrics_35/sparse_categorical_accuracy_1/Squeeze}}]]
0 successful operations.
0 derived errors ignored.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...