Я проверил и обучил нейронную сеть регрессии.Когда я запускаю сеть с новыми данными, я просто получаю один и тот же номер для каждой записи.Я адаптировал это из системы классификации, и это сработало.
код:
private void writeINDArray(INDArray output, PrintWriter writer, Iterator<String> identifierIterator) {
int rows = output.rows();
int coluumns = output.columns();
for (int i = 0; i < rows; i++) {
INDArray row = output.getRow(i);
StringJoiner stringJoiner = new StringJoiner("\t");
for (int j = 0; j < coluumns; j++) {
stringJoiner.add(Float.toString(row.getFloat(j)));
}
if (identifierIterator.hasNext()) {
stringJoiner.add(identifierIterator.next());
}
else {
throw new RuntimeException("identifier list is empty!");
}
writer.println(stringJoiner.toString());
log.info(stringJoiner);
}
}
@Override
public void run(File neuralNetworkZipFile, File fingerPrintFile, List<String> identifiers) {
log.info(String.format("running %s on %s", neuralNetworkZipFile.getAbsolutePath(), fingerPrintFile.getAbsolutePath()));
Iterator<String> identifierIterator = identifiers.iterator();
runResultFile = new File("run_results_" + Utility.timeDate() + ".txt");
try (RecordReader recordReader = new CSVRecordReader(0, ','); PrintWriter writer = new PrintWriter(runResultFile)) {
recordReader.initialize(new FileSplit(fingerPrintFile));
DataSetIterator iterator = neuralNetworkSupporter.getDataSetIterator(recordReader);
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(neuralNetworkZipFile);
while (iterator.hasNext()) {
DataSet fingerPrint = iterator.next();
INDArray output = model.output(fingerPrint.getFeatures(), false);
writeINDArray(output, writer, identifierIterator);
}
}
catch (IOException | InterruptedException e) {
e.printStackTrace();
}
}
Есть предложения, что я делаю неправильно?Я прочитал JavaDoc для MultiLayerNetwork и INDArray, но, кажется, ничто не вызывает эту проблему.У меня действительно была проблема с загрузкой данных без данных, и мне пришлось делать отвратительный взлом.Чтобы это работало.
private void outputBitSet(MolecularProperties molecularProperties, PrintWriter writer) {
StringBuilder builder = new StringBuilder();
BitSet fingerprintBitSet = molecularProperties.bitSet;
if (useStructuralFingerprint) {
for (int i = 0; i < fingerprintBitSet.size(); i++) {
double bit = fingerprintBitSet.get(i) ? VALUE2 : VALUE1;
appendComma(builder);
builder.append(bit);
}
}
if (useMolecularProperties) {
addProperties(builder, molecularProperties);
}
if (! action.equals(Action.RUN)) {
if (isRegression) {
log.debug(String.format("%8.6f", molecularProperties.regressionValue) + " " + molecularProperties.id);
appendComma(builder);
builder.append(String.format("%8.6f", molecularProperties.regressionValue));
}
else {
appendComma(builder);
builder.append(molecularProperties.classification);
}
}
else { // TODO This is need to fix an issue with CSVRecordReader expecting there to be one or more regression values.
if (isRegression) {
appendComma(builder);
builder.append(String.format("%8.6f", VALUE1));
}
else {
appendComma(builder);
builder.append(CLASS1);
}
}
writer.println(builder.toString());
}
private void outputBitSet(List<MolecularProperties> molecularPropertiesList, PrintWriter writer) {
if (!action.equals(Action.RUN)) {
Collections.shuffle(molecularPropertiesList);
}
molecularPropertiesList.forEach(m -> outputBitSet(m, writer));
}
Буду очень признателен за любые предложения?
Барт