Выход из обученной нейронной сети регрессии одинаков для всех записей - PullRequest
0 голосов
/ 26 апреля 2019

Я проверил и обучил нейронную сеть регрессии.Когда я запускаю сеть с новыми данными, я просто получаю один и тот же номер для каждой записи.Я адаптировал это из системы классификации, и это сработало.

код:

private void writeINDArray(INDArray output, PrintWriter writer, Iterator<String> identifierIterator) {
    int rows = output.rows();
    int coluumns = output.columns();

    for (int i = 0; i < rows; i++) {
        INDArray row = output.getRow(i);
        StringJoiner stringJoiner = new StringJoiner("\t");

        for (int j = 0; j < coluumns; j++) {
            stringJoiner.add(Float.toString(row.getFloat(j)));
        }

        if (identifierIterator.hasNext()) {
            stringJoiner.add(identifierIterator.next());
        }
        else {
            throw new RuntimeException("identifier list is empty!");
        }
        writer.println(stringJoiner.toString());
        log.info(stringJoiner);
    }
}

@Override
public void run(File neuralNetworkZipFile, File fingerPrintFile, List<String> identifiers) {
    log.info(String.format("running %s on %s", neuralNetworkZipFile.getAbsolutePath(), fingerPrintFile.getAbsolutePath()));

    Iterator<String> identifierIterator = identifiers.iterator();

    runResultFile = new File("run_results_" + Utility.timeDate() + ".txt");

    try (RecordReader recordReader = new CSVRecordReader(0, ','); PrintWriter writer = new PrintWriter(runResultFile)) {
        recordReader.initialize(new FileSplit(fingerPrintFile));

        DataSetIterator iterator = neuralNetworkSupporter.getDataSetIterator(recordReader);
        MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(neuralNetworkZipFile);

        while (iterator.hasNext()) {
            DataSet fingerPrint = iterator.next();
            INDArray output = model.output(fingerPrint.getFeatures(), false);

            writeINDArray(output, writer, identifierIterator);
        }
    }
    catch (IOException | InterruptedException e) {
        e.printStackTrace();
    }
}

Есть предложения, что я делаю неправильно?Я прочитал JavaDoc для MultiLayerNetwork и INDArray, но, кажется, ничто не вызывает эту проблему.У меня действительно была проблема с загрузкой данных без данных, и мне пришлось делать отвратительный взлом.Чтобы это работало.

private void outputBitSet(MolecularProperties molecularProperties, PrintWriter writer) {
    StringBuilder builder = new StringBuilder();
    BitSet fingerprintBitSet = molecularProperties.bitSet;

    if (useStructuralFingerprint) {
        for (int i = 0; i < fingerprintBitSet.size(); i++) {
            double bit = fingerprintBitSet.get(i) ? VALUE2 : VALUE1;

            appendComma(builder);
            builder.append(bit);
        }
    }
    if (useMolecularProperties) {
        addProperties(builder, molecularProperties);
    }

    if (! action.equals(Action.RUN)) {
        if (isRegression) {
            log.debug(String.format("%8.6f", molecularProperties.regressionValue) + " " + molecularProperties.id);

            appendComma(builder);
            builder.append(String.format("%8.6f", molecularProperties.regressionValue));
        }
        else {
            appendComma(builder);
            builder.append(molecularProperties.classification);
        }
    }
    else { // TODO This is need to fix an issue with CSVRecordReader expecting there to be one or more regression values. 
        if (isRegression) {
            appendComma(builder);
            builder.append(String.format("%8.6f", VALUE1));
        }
        else {
            appendComma(builder);
            builder.append(CLASS1);
        }
    }
    writer.println(builder.toString());
}

private void outputBitSet(List<MolecularProperties> molecularPropertiesList, PrintWriter writer) {
    if (!action.equals(Action.RUN)) {
        Collections.shuffle(molecularPropertiesList);
    }
    molecularPropertiesList.forEach(m -> outputBitSet(m, writer));
}

Буду очень признателен за любые предложения?

Барт

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...