Ответ заключается в том, чтобы подать модель точно так же, как мы тренировались, за исключением того, что установите метки на -1.Выходными данными будет INDarray, содержащий вероятность 0 в одном массиве и вероятность 1 в другом массиве, отображаемую в последней строке последовательности.
Вот код:
public void getOutputsForTheseInputsUsingThisNet(String netFilePath,String inputFileDir) throws Exception {
//open the network file
File locationToSave = new File(netFilePath);
MultiLayerNetwork nNet = null;
logger.info("Trying to open the model");
try {
nNet = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
logger.info("Success: Model opened");
} catch (IOException e) {
throw new Exception(String.format("Unable to open model from %s because of error %s", locationToSave.getAbsolutePath(),e.getMessage()));
}
logger.info("Loading test data");
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0, ","); //skip no lines at the top - i.e. no header
try {
testFeatures.initialize(new NumberedFileInputSplit(inputFileDir + "/features/s_4180%d.csv", 0, 4));
} catch (InterruptedException e) {
testFeatures.close();
throw new Exception(String.format("IO error %s. during testFeatures", e.getMessage()));
}
logger.info("Loading label data");
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
try {
testLabels.initialize(new NumberedFileInputSplit(inputFileDir + "/labels/s_4180%d.csv", 0,4));
} catch (InterruptedException e) {
testLabels.close();
testFeatures.close();
throw new IOException(String.format("Interrupted exception error %s. during testLabels initialise", e.getMessage()));
}
//DataSetIterator inputData = new Seque
logger.info("creating iterator");
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
this._modelDefinition.getBATCH_SIZE(),this._modelDefinition.getNUM_LABEL_CLASSES(), false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//now use it to classify some data
logger.info("classifying examples");
INDArray output = nNet.output(testData);
logger.info("outputing the classifications");
if(output==null||output.isEmpty())
throw new Exception("There is no output");
System.out.println(output);
//sample output
// [[[0, 0, 0, 0, 0.9882, 0, 0, 0, 0], // [0, 0, 0, 0, 0.0118, 0, 0, 0, 0]], // // [[0, 0.1443, 0, 0, 0, 0, 0, 0, 0], // [0, 0.8557, 0, 0, 0, 0, 0, 0, 0]], // // [[0, 0, 0, 0, 0, 0, 0, 0, 0.9975], // [0, 0, 0, 0, 0, 0, 0, 0, 0.0025]], // // [[0, 0, 0, 0, 0, 0, 0.8482, 0, 0], // [0, 0, 0, 0, 0, 0, 0.1518, 0, 0]], // // [[0, 0, 0, 0.8760, 0, 0, 0, 0, 0], // [0, 0, 0, 0.1240, 0, 0, 0, 0, 0]]]
}