Я хочу сделать регрессию и классификацию в одной модели.Причина в том, что я хочу перенести это на TFLite.В классификации мне нужно использовать SparseCategoricalcrossEntropy, чтобы можно было прогнозировать непосредственно индекс класса.Чтобы построить его в той же модели, я использую Functional API от Keras.
Поэтому я хотел бы получить прогноз class_index и регрессию в модели sam.
Упрощенная версиякод такой:
import numpy as np
from keras.layers import Input, Dense
from keras.models import Model
X=np.random.random(size=(100,5))
y=np.random.randint(0,100,size=(100,2)).astype(float) #Regression
class_index=np.random.randint(0,2,size=(100,1)) #Classification
input1 = Input(shape=(5,))
t = Dense(10, activation='relu')(input1)
out1 = Dense(2, activation='softmax')(t)
t2 = Dense(50, activation='relu')(t)
t3 = Dense(50, activation='relu')(t2)
out2 = Dense(2)(t3)
model = Model(inputs=input1, outputs=[out1, out2])
model.compile(
optimizer='adam',
loss=['sparse_categorical_crossentropy', 'mean_squared_error'],
metrics=['sparse_categorical_accuracy','accuracy']
)
history = model.fit(X, [ class_index, y ], epochs=10, batch_size=64)
Я получаю ошибку:
File "C:/Users/.../Sin título5.py", line 31, in <module>
history = model.fit(X, [ class_index, y ], epochs=10, batch_size=64)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\engine\training.py", line 1178, in fit
validation_freq=validation_freq)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\engine\training_arrays.py", line 204, in fit_loop
outs = fit_function(ins_batch)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\backend\tensorflow_backend.py", line 2979, in __call__
return self._call(inputs)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\keras\backend\tensorflow_backend.py", line 2937, in _call
fetched = self._callable_fn(*array_vals)
File "C:\Users\AIRFI\Anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\client\session.py", line 1472, in __call__
run_metadata_ptr)
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Can not squeeze dim[1], expected a dimension of 1, got 2
[[{{node metrics_35/sparse_categorical_accuracy_1/Squeeze}}]]
[[metrics_35/acc_1/Mean/_567]]
(1) Invalid argument: Can not squeeze dim[1], expected a dimension of 1, got 2
[[{{node metrics_35/sparse_categorical_accuracy_1/Squeeze}}]]
0 successful operations.
0 derived errors ignored.