DeepLearning4J - обновить обученную модель после того, как был сделан прогноз - PullRequest
0 голосов
/ 13 мая 2019

Я построил две нейронные сети с DeepLearning4J, одну прямую и одну рекуррентную сеть.Для обучения, тестирования и оценки обоих типов сетей я использую следующий фрагмент кода:

protected Evaluation trainAndTestNetwork(MultiLayerNetwork model, DataSetIterator trainData, DataSetIterator testData) {
    String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
    Evaluation evaluation = new Evaluation();
    for (int i = 0; i < nEpochs; i++) {
        model.fit(trainData);
        evaluation = model.evaluate(testData); // Evaluate on test set
        log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()));
        testData.reset();
        trainData.reset();
    }
    return evaluation;
}

Для сохранения окончательной сети на диске я использовал следующий метод:

public static void saveNeuralNet(String networkPath, Model model, DataNormalization normalizer) {
    try {
        File file = new File(networkPath + MODEL_NAME + ".zip"); // MODEL_NAME id model
        if (file.exists()) file.delete();
        file.createNewFile();
        ModelSerializer.writeModel(model, file, true, normalizer);
    } catch (IOException e) { e.printStackTrace(); }
}

После того, как я восстановил модель, используя:

public static MultiLayerNetwork restoreMultiLayerNetwork(String networkDirectory) throws IOException {
    File file = new File(networkDirectory + "/" + MODEL_NAME + ".zip"); // MODEL_NAME id model
    return ModelSerializer.restoreMultiLayerNetwork(file, true);
}

Я хочу сделать прогноз на никогда не замеченном наборе данных.Поэтому я восстанавливаю модель, получаю DataSetIterator и передаю ее методу output модели.

MultiLayerNetwork model = Classifier.restoreMultiLayerNetwork(configFile.getNetworkPath() + configFile.getNetworkName());
dataSetIterator = getDataSetIterator(configFile, predictionData); // Method follows below
INDArray networkPrediction = model.output(dataSetIterator);

Чтобы передать dataSetIterator в метод output модели, мне нужно загрузитьданные в CSVRecordReader для прямой передачи:

public static DataSetIterator getDataSetIterator(ConfigFile configFile, Double[][] predictionData) {
    try {
        DataFileWriter.writeInputLayerToCSVFile(Predictor.tempFilepath, new Double[][] { Aggregator.mergeAggregation(predictionData) });
        CSVRecordReader dataSet = new CSVRecordReader();
        dataSet.initialize(new FileSplit(new File(DataFileWriter.getRootDirectory() + Predictor.tempFilepath + DataFileWriter.FILE_FORMAT)));
        return new RecordReaderDataSetIterator(dataSet, configFile.getBatchSize());
    } catch (IOException | InterruptedException e) { e.printStackTrace(); }
    return null;
}

или в SequenceRecordReader для текущей сети:

public static DataSetIterator getDataSetIterator(ConfigFile configFile, Double[][] predictionData) {
    try {

        DataFileWriter.writeInputLayerToCSVFile(Predictor.tempFilepath + "/0", predictionData);
        SequenceRecordReader dataSetFeatures = new CSVSequenceRecordReader();
        dataSetFeatures.initialize(new NumberedFileInputSplit(DataFileWriter.getRootDirectory() + Predictor.tempFilepath + "/%d.csv", 0, 0));
        return new SequenceRecordReaderDataSetIterator(dataSetFeatures, configFile.getBatchSize(), configFile.getNumLabelClasses(), 0);
    } catch (IOException | InterruptedException e) { e.printStackTrace(); }
    return null;
}

У меня есть четыре вопроса:

  1. Первый вопрос касается не самой проблемы, а общего вопроса, касающегося процесса тестирования и оценки обученной сети с использованием набора тестов.Сохраняются ли тестовые данные в модели автоматически или мне нужно обновить модель, используя набор тестовых данных?

  2. Когда я использую INDArray networkPrediction = model.output(dataSetIterator);, как мне впоследствии обновить модель?Когда я получаю DataSetIterator, я, очевидно, не устанавливаю выходные метки, но когда результат, наконец, появляется в реальном мире, я должен обновить свою текущую модель.Я просто делаю то же самое, что и в getDataSetIterator(ConfigFile configFile, Double[][] predictionData) раньше, но на этот раз с выводом, записанным в набор данных, и настройкой индекса метки и количества возможных результатов, как new RecordReaderDataSetIterator(trainFeatures, batchSize, 0, numLabelClasses);?Или мне нужно запустить всю сеть с нуля, переобучить и протестировать все заново с добавлением новых данных?

  3. Третий вопрос о том, как я использую SequenceRecordReaderDataSetIterator при прогнозировании,Хотя все работает для метода getDataSetIterator(ConfigFile configFile, Double[][] predictionData) прямой связи, я борюсь с тем же методом рекуррентной сети.Конструктор RecordReaderDataSetIterator позволяет передавать только входной набор данных и размер пакета, что имеет смысл, если вы хотите сделать прогноз.Однако для конструктора SequenceRecordReaderDataSetIterator требуется количество возможных выходных меток, а также индекс меток.Это не имеет смысла, потому что пока нет выходных данных, поэтому я всегда получаю исключение Exception in thread "main" java.lang.NumberFormatException: For input string: "0.0", которое ссылается на недопустимый вывод.Есть ли другой способ получить правильный DataSetIterator для прогнозирования с использованием периодических сетей или я пропускаю что-то еще?

  4. Поскольку я также пробовал другие подходы для получения прогноза из моей сети,Я наткнулся на метод модели predict(DataSet dataset).Итак, последний вопрос: что такое List<String> networkPrediction = model.predict(dataSet); по сравнению с INDArray networkPrediction = model.output(dataSetIterator);, в чем разница?

...