Я построил две нейронные сети с DeepLearning4J, одну прямую и одну рекуррентную сеть.Для обучения, тестирования и оценки обоих типов сетей я использую следующий фрагмент кода:
protected Evaluation trainAndTestNetwork(MultiLayerNetwork model, DataSetIterator trainData, DataSetIterator testData) {
String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
Evaluation evaluation = new Evaluation();
for (int i = 0; i < nEpochs; i++) {
model.fit(trainData);
evaluation = model.evaluate(testData); // Evaluate on test set
log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()));
testData.reset();
trainData.reset();
}
return evaluation;
}
Для сохранения окончательной сети на диске я использовал следующий метод:
public static void saveNeuralNet(String networkPath, Model model, DataNormalization normalizer) {
try {
File file = new File(networkPath + MODEL_NAME + ".zip"); // MODEL_NAME id model
if (file.exists()) file.delete();
file.createNewFile();
ModelSerializer.writeModel(model, file, true, normalizer);
} catch (IOException e) { e.printStackTrace(); }
}
После того, как я восстановил модель, используя:
public static MultiLayerNetwork restoreMultiLayerNetwork(String networkDirectory) throws IOException {
File file = new File(networkDirectory + "/" + MODEL_NAME + ".zip"); // MODEL_NAME id model
return ModelSerializer.restoreMultiLayerNetwork(file, true);
}
Я хочу сделать прогноз на никогда не замеченном наборе данных.Поэтому я восстанавливаю модель, получаю DataSetIterator
и передаю ее методу output
модели.
MultiLayerNetwork model = Classifier.restoreMultiLayerNetwork(configFile.getNetworkPath() + configFile.getNetworkName());
dataSetIterator = getDataSetIterator(configFile, predictionData); // Method follows below
INDArray networkPrediction = model.output(dataSetIterator);
Чтобы передать dataSetIterator
в метод output
модели, мне нужно загрузитьданные в CSVRecordReader
для прямой передачи:
public static DataSetIterator getDataSetIterator(ConfigFile configFile, Double[][] predictionData) {
try {
DataFileWriter.writeInputLayerToCSVFile(Predictor.tempFilepath, new Double[][] { Aggregator.mergeAggregation(predictionData) });
CSVRecordReader dataSet = new CSVRecordReader();
dataSet.initialize(new FileSplit(new File(DataFileWriter.getRootDirectory() + Predictor.tempFilepath + DataFileWriter.FILE_FORMAT)));
return new RecordReaderDataSetIterator(dataSet, configFile.getBatchSize());
} catch (IOException | InterruptedException e) { e.printStackTrace(); }
return null;
}
или в SequenceRecordReader
для текущей сети:
public static DataSetIterator getDataSetIterator(ConfigFile configFile, Double[][] predictionData) {
try {
DataFileWriter.writeInputLayerToCSVFile(Predictor.tempFilepath + "/0", predictionData);
SequenceRecordReader dataSetFeatures = new CSVSequenceRecordReader();
dataSetFeatures.initialize(new NumberedFileInputSplit(DataFileWriter.getRootDirectory() + Predictor.tempFilepath + "/%d.csv", 0, 0));
return new SequenceRecordReaderDataSetIterator(dataSetFeatures, configFile.getBatchSize(), configFile.getNumLabelClasses(), 0);
} catch (IOException | InterruptedException e) { e.printStackTrace(); }
return null;
}
У меня есть четыре вопроса:
Первый вопрос касается не самой проблемы, а общего вопроса, касающегося процесса тестирования и оценки обученной сети с использованием набора тестов.Сохраняются ли тестовые данные в модели автоматически или мне нужно обновить модель, используя набор тестовых данных?
Когда я использую INDArray networkPrediction = model.output(dataSetIterator);
, как мне впоследствии обновить модель?Когда я получаю DataSetIterator
, я, очевидно, не устанавливаю выходные метки, но когда результат, наконец, появляется в реальном мире, я должен обновить свою текущую модель.Я просто делаю то же самое, что и в getDataSetIterator(ConfigFile configFile, Double[][] predictionData)
раньше, но на этот раз с выводом, записанным в набор данных, и настройкой индекса метки и количества возможных результатов, как new RecordReaderDataSetIterator(trainFeatures, batchSize, 0, numLabelClasses);
?Или мне нужно запустить всю сеть с нуля, переобучить и протестировать все заново с добавлением новых данных?
Третий вопрос о том, как я использую SequenceRecordReaderDataSetIterator
при прогнозировании,Хотя все работает для метода getDataSetIterator(ConfigFile configFile, Double[][] predictionData)
прямой связи, я борюсь с тем же методом рекуррентной сети.Конструктор RecordReaderDataSetIterator
позволяет передавать только входной набор данных и размер пакета, что имеет смысл, если вы хотите сделать прогноз.Однако для конструктора SequenceRecordReaderDataSetIterator
требуется количество возможных выходных меток, а также индекс меток.Это не имеет смысла, потому что пока нет выходных данных, поэтому я всегда получаю исключение Exception in thread "main" java.lang.NumberFormatException: For input string: "0.0"
, которое ссылается на недопустимый вывод.Есть ли другой способ получить правильный DataSetIterator для прогнозирования с использованием периодических сетей или я пропускаю что-то еще?
Поскольку я также пробовал другие подходы для получения прогноза из моей сети,Я наткнулся на метод модели predict(DataSet dataset)
.Итак, последний вопрос: что такое List<String> networkPrediction = model.predict(dataSet);
по сравнению с INDArray networkPrediction = model.output(dataSetIterator);
, в чем разница?