Как использовать существующую обученную модель DL4J для классификации нового ввода - PullRequest
0 голосов
/ 14 февраля 2019

У меня есть модель DL4J LSTM, которая генерирует двоичную классификацию последовательного ввода.Я обучил и проверил модель и доволен точностью / отзывом.Теперь я хочу использовать эту модель для прогнозирования бинарной классификации новых входных данных.Как мне это сделать?то есть как я могу дать обученной нейронной сети один вход (файл, содержащий последовательность строк признаков) и получить двоичную классификацию этого входного файла.

Вот мой исходный итератор набора обучающих данных:

        SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(0, ",");  //skip no header lines
    try {
        trainFeatures.initialize( new NumberedFileInputSplit(featureBaseDir + "/s_%d.csv", 0,this._modelDefinition.getNB_TRAIN_EXAMPLES()-1));
    } catch (IOException e) {
        trainFeatures.close();
        throw new IOException(String.format("IO error %s. during trainFeatures", e.getMessage()));
    } catch (InterruptedException e) {
        trainFeatures.close();
        throw new IOException(String.format("Interrupted exception error %s. during trainFeatures", e.getMessage()));
    }

    SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
    try {
        trainLabels.initialize(new NumberedFileInputSplit(labelBaseDir + "/s_%d.csv", 0,this._modelDefinition.getNB_TRAIN_EXAMPLES()-1));
    } catch (InterruptedException e) {
        trainLabels.close();
        trainFeatures.close();
        throw new IOException(String.format("Interrupted exception error %s. during trainLabels initialise", e.getMessage()));
    }



    DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
            this._modelDefinition.getBATCH_SIZE(),this._modelDefinition.getNUM_LABEL_CLASSES(), false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);

Вот моя модель:

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(this._modelDefinition.getRANDOM_SEED())    //Random number generator seed for improved repeatability. Optional.
            .weightInit(WeightInit.XAVIER)
            .updater(new Nesterovs(this._modelDefinition.getLEARNING_RATE()))
            .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)  //Not always required, but helps with this data set
            .gradientNormalizationThreshold(0.5)
            .list()
            .layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(this._modelDefinition.getNB_INPUTS()).nOut(this._modelDefinition.getLSTM_LAYER_SIZE()).build())
            .layer(1, new LSTM.Builder().activation(Activation.TANH).nIn(this._modelDefinition.getLSTM_LAYER_SIZE()).nOut(this._modelDefinition.getLSTM_LAYER_SIZE()).build())
            .layer(2,new DenseLayer.Builder().nIn(this._modelDefinition.getLSTM_LAYER_SIZE()).nOut(this._modelDefinition.getLSTM_LAYER_SIZE())
                    .weightInit(WeightInit.XAVIER)
                    .build())
            .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                    .activation(Activation.SOFTMAX).nIn(this._modelDefinition.getLSTM_LAYER_SIZE()).nOut(this._modelDefinition.getNUM_LABEL_CLASSES()).build())
            .pretrain(false).backprop(true).build();

Я тренирую модель в течение N эпох, чтобы получить оптимальные результаты.Я сохраняю модель, теперь я хочу открыть модель и получить классификации для новых файлов последовательных объектов.

Если есть пример этого - пожалуйста, дайте мне знать, где.

спасибо

anton

1 Ответ

0 голосов
/ 21 февраля 2019

Ответ заключается в том, чтобы подать модель точно так же, как мы тренировались, за исключением того, что установите метки на -1.Выходными данными будет INDarray, содержащий вероятность 0 в одном массиве и вероятность 1 в другом массиве, отображаемую в последней строке последовательности.

Вот код:

public void getOutputsForTheseInputsUsingThisNet(String netFilePath,String inputFileDir) throws Exception {

    //open the network file
    File locationToSave = new File(netFilePath);
    MultiLayerNetwork nNet = null;
    logger.info("Trying to open the model");
    try {
        nNet = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
        logger.info("Success: Model opened");
    } catch (IOException e) {
        throw new Exception(String.format("Unable to open model from %s because of error %s", locationToSave.getAbsolutePath(),e.getMessage()));
    }

    logger.info("Loading test data");
    SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0, ",");  //skip no lines at the top - i.e. no header
    try {
        testFeatures.initialize(new NumberedFileInputSplit(inputFileDir + "/features/s_4180%d.csv", 0, 4));
    } catch (InterruptedException e) {
        testFeatures.close();
        throw new Exception(String.format("IO error %s. during testFeatures", e.getMessage()));
    }
    logger.info("Loading label data");
    SequenceRecordReader testLabels = new CSVSequenceRecordReader();
    try {
        testLabels.initialize(new NumberedFileInputSplit(inputFileDir + "/labels/s_4180%d.csv", 0,4));
    } catch (InterruptedException e) {
        testLabels.close();
        testFeatures.close();
        throw new IOException(String.format("Interrupted exception error %s. during testLabels initialise", e.getMessage()));
    }


    //DataSetIterator inputData = new Seque
    logger.info("creating iterator");

    DataSetIterator testData =  new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
            this._modelDefinition.getBATCH_SIZE(),this._modelDefinition.getNUM_LABEL_CLASSES(), false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);


    //now use it to classify some data
    logger.info("classifying examples");

    INDArray output = nNet.output(testData);
    logger.info("outputing the classifications");
    if(output==null||output.isEmpty())
        throw new Exception("There is no output");
    System.out.println(output);

    //sample output

// [[[0, 0, 0, 0, 0.9882, 0, 0, 0, 0], // [0, 0, 0, 0, 0.0118, 0, 0, 0, 0]], // // [[0, 0.1443, 0, 0, 0, 0, 0, 0, 0], // [0, 0.8557, 0, 0, 0, 0, 0, 0, 0]], // // [[0, 0, 0, 0, 0, 0, 0, 0, 0.9975], // [0, 0, 0, 0, 0, 0, 0, 0, 0.0025]], // // [[0, 0, 0, 0, 0, 0, 0.8482, 0, 0], // [0, 0, 0, 0, 0, 0, 0.1518, 0, 0]], // // [[0, 0, 0, 0.8760, 0, 0, 0, 0, 0], // [0, 0, 0, 0.1240, 0, 0, 0, 0, 0]]]

}
...