Нейронная сеть MNIST распознавания номеров - PullRequest
0 голосов
/ 08 апреля 2019

Недавно я попытался создать свой первый проект с помощью Neural Networks, и вот что я придумал.Я хотел, чтобы он распознавал рукописные числа MNIST.Проблема в том, что когда я запускаю этот код и заставляю его тренироваться примерно 400 раз, я получаю ~ 28% точности с тестовыми данными.Это должно быть так?400k слишком мало, чтобы получить лучшие результаты, или это потому, что моя нейронная сеть может иметь только один скрытый слой?

Подводя итог короткому вопросу, вещи должны выглядеть так, или я сделал что-то не так?Ниже много избыточного кода и тому подобное, я просто хотел, чтобы он работал.

Все, если предположить, что моя нейронная сеть работает, очевидно.

public static void main(String[] args) {

  List<Data> trainData = new ArrayList<>();
  List<Data> testData = new ArrayList<>();

  byte[] trainLabels;
  byte[] trainImages;
  byte[] testLabels;
  byte[] testImages;

  try {

     Path tempPath1 = Paths.get("res/train-labels-idx1-ubyte");
     trainLabels = Files.readAllBytes(tempPath1);
     ByteBuffer bufferLabels = ByteBuffer.wrap(trainLabels);
     int magicLabels = bufferLabels.getInt();
     int numberOfItems = bufferLabels.getInt();

     Path tempPath = Paths.get("res/train-images-idx3-ubyte");
     trainImages = Files.readAllBytes(tempPath);
     ByteBuffer bufferImages = ByteBuffer.wrap(trainImages);
     int magicImages = bufferImages.getInt();
     int numberOfImageItems = bufferImages.getInt();
     int rows = bufferImages.getInt();
     int cols = bufferImages.getInt();

     for(int i = 0; i < numberOfItems; i++) {
        int t = bufferLabels.get();
        double[] target = createTargets(t);
        double[] inputs = new double[rows*cols];
        for(int j = 0; j < inputs.length; j++) {
           inputs[j] = bufferImages.get();
           }
         Data tobj = new Data(inputs, target);
         trainData.add(tobj);
       }

      tempPath = Paths.get("res/t10k-labels-idx1-ubyte");
      testLabels = Files.readAllBytes(tempPath);
      ByteBuffer testLabelBuffer = ByteBuffer.wrap(testLabels);
      int testMagicLabels = testLabelBuffer.getInt();
      int numberOfTestLabels = testLabelBuffer.getInt();

      tempPath = Paths.get("res/t10k-images-idx3-ubyte");
      testImages = Files.readAllBytes(tempPath);
      ByteBuffer testImageBuffer = ByteBuffer.wrap(testImages);
      int testMagicImages = testImageBuffer.getInt();
      int numberOfTestImages = testImageBuffer.getInt();
      int testRows = testImageBuffer.getInt();
      int testCols = testImageBuffer.getInt();

      for(int i = 0; i < numberOfTestImages; i++) {
          double[] target = new double[]{testLabelBuffer.get()};
          double[] inputs = new double[testRows*testCols];
          for(int j = 0; j < inputs.length; j++) {
              inputs[j] = testImageBuffer.get();
             }
          Data tobj = new Data(inputs, target);
          testData.add(tobj);
         }

       NeuralNetwork neuralNetwork = new NeuralNetwork(784,64,10);

       int len = trainData.size();
       Random randomGenerator = new Random();
       for(int i = 0; i < 400000; i++) {
           int randomInt = randomGenerator.nextInt(len);
           neuralNetwork.train(trainData.get(randomInt).getInputs(), trainData.get(randomInt).getTargets());
          }

        float rightAnswers = 0;

        for(Data testObj : testData) {
           double[] output = neuralNetwork.feedforward(testObj.getInputs());
           double[] answer = testObj.getTargets(); 
         }
            System.out.println(percentage);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public static double[] createTargets(int number) {
            double[] result = new double[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
            result[number] = 1;
            return  result;

        }

1 Ответ

0 голосов
/ 08 апреля 2019

Если кому-то интересно, с моей стороны был баг. При регистрации всего я заметил, что значения входного пикселя варьировались от -255 до 255, а из документации MNIST они должны быть 0-255. Кроме того, мои входные данные не были нормализованы, поэтому некоторые из них были 0, когда другие 255. Это то, что я добавил. Надеюсь, я ничего не пропустил. Теперь я получаю ~ 90% точности.

for(int i = 0; i < numberOfTestImages; i++) {

   double[] target = new double[]{testLabelBuffer.get()& 0xFF};
   double[] inputs = new double[testRows*testCols];
   or(int j = 0; j < inputs.length; j++) {
   // Normalize input from 0-255 to 0-1
   double temp = (testImageBuffer.get() & 0xFF) / 255f;
   inputs[j] = temp;
 }
 Data tobj = new Data(inputs, target);
 testData.add(tobj);
}
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...