Я пытаюсь создать 1D сверточную сеть с KERAS со следующим кодом /
Цель - мультикласс (0,1,2)
Длина вектора элемента - 36
Запустите следующий код, но получите несоответствие в формах
import tensorflow
from tensorflow.keras.layers import Dropout, Conv1D
from tensorflow.keras import optimizers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import losses
print ('xtrain shape = ',X_train.shape) # output: (11744, 36)
X = np.expand_dims(X_train, axis=2)
print ('X shape = ',X.shape) # output: (11744, 36, 1)
print ('yshape = ',y_train.shape)#output: (11744, 1)
y = to_categorical(y_train, num_classes=3)
print(to_categorical(y).shape)# output: (11744, 3, 2)
maxlen = num_features = NUM_OF_FEATURES # 36
input_dim = 1
def baseline_model(opt):
# create model
model = Sequential()
model.add(Conv1D(2,2,activation='relu',input_shape=(maxlen, input_dim)))
model.add(Dense(14, input_dim=NUM_OF_FEATURES, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(8, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(3, activation='softmax'))
loss1 = 'softmax_cross_entropy_with_logits'
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
sgd = tensorflow.keras.optimizers.SGD(lr=0.0000001, decay=1e-6, momentum=0.9, nesterov=True)
model = baseline_model(sgd)
monitor = EarlyStopping(monitor = 'val_loss', min_delta = 1e-5, patience=50)
r=model.fit(X, y_train, epochs = 300, shuffle = True, verbose=1, batch_size = 1)
fp = "../output/testi4/epochs:{epoch:03d}.hdf5"
modelcheckpoint = ModelCheckpoint(filepath = fp, verbose = 1, save_best_only= True)
n_epochs = len(r.history['loss'])
score = model.evaluate(X_val, y, verbose=0)
model.save(fp)
Получите следующую ошибку
ValueError: Целевой массив с формой (11744, 3) был передан для вывода формы (Нет, 35, 3) при использовании в качестве потери categorical_crossentropy
. Эта потеря предполагает, что цели будут иметь ту же форму, что и выходные данные.
Буду признателен, если кто-нибудь сможет мне помочь с этим
Спасибо