Исключение, сгенерированное при обучении модели CNN Conv1D Keras, импортированной в deeplearning4j - PullRequest
0 голосов
/ 23 октября 2019

У меня довольно простой CNN, определенный в Jupyter как

model = Sequential()

model.add(Conv1D(32, 12, activation='relu', padding='same', input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(Conv1D(64, 12, activation='relu', padding='same'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
​
model.summary()

model.compile(loss='mse', optimizer='sgd')

_________________________________________________________________ Layer (type) Параметр выходной формы #
================================================================= conv1d_7 (Conv1D) (Нет, 675, 32) 416
_________________________________________________________________ conv1d_8 (Conv1D) (Нет, 675, 64) 24640
_________________________________________________________________ dropout_4 (Выпадение) (Нет, 675, 64) 0
_________________________________________________________________dens_4 (плотный) (None, 675, 1) 65
================================================================ Всего параметров: 25 121 Обучаемые параметры: 25 121 Необучаемые параметры: 0


Я сохранил модель (и веса) в виде файла .h5 и могу импортировать его в мое Java-приложение. Это прекрасно работает, и я могу генерировать прогнозы, используя эту модель. Однако я также хотел бы переобучить модель на Java.

, используя этот фрагмент кода

    MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(modelConf.getAbsolutePath());
    int nSamp = normFbe.length; // 3545
    int nChan = normFbe[0].length; // 675;
    INDArray X = Nd4j.create(normFbe);

    // Create the ground truth
    INDArray y_truth = generateTruthData(nSamp, nChan);

    for (int i = 0; i < nSamp; i++) {
        INDArray X_train = X.getRow(i);
        INDArray y_train = y_truth.getRow(i);
        X_train = X_train.reshape(1, 1, 675);
        y_train = y_train.reshape(1, 1, 675);
        model.fit(X_train, y_train);
    }

, но выдает исключение и генерирует сообщение об ошибке:

ОШИБКА 23/10/19 11: 55: 25,412 [DxS Worker-0] org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner - Не удалось выполнить операцию умножения. Попытка выполнить с 2 входами, 1 выходом, 0 targs, 0 bargs и 0 iargs. Входы: [(FLOAT, [1,32,675,1], c), (FLOAT, [1,32,675], c)]. Выходы: [(FLOAT, [1,32,675,1], c)]. Таргс:. iArgs: -. Баргс: -. Оп собственное имя: "d4e50f04-f0b2-4eb4-9276-439b871087ad" - см. Сообщение выше (распечатано из c ++) для возможной причины ошибки. ОШИБКА 23/10/19 11: 55: 27,568 [DxS Worker-0] DxS - инструмент выполнения Машинное обучение java.lang.RuntimeException: Операция [multiply] не выполнена

Есть идеи? Мне кажется, что форма данных не совпадает на входе или между слоями, но у меня закончились идеи, как это исправить. Любая помощь высоко ценится.

Дэвид Робб

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...