Модель поезда на пакете изображений с использованием tenorflow. js с использованием node.js - PullRequest
0 голосов
/ 06 апреля 2020

Моя цель - создать скрипт node.js для обучения модели в папке, содержащей изображения в формате jpeg / png, и использовать его для классификации изображений. Я сгенерировал тестовую модель для набора изображений в пределах https://teachablemachine.withgoogle.com/train/image и успешно загрузил модель в nodejs - прогнозирование работает хорошо.

Проблема в том, что я сейчас пытаюсь воспроизвести логику в браузере c, используемую на обучаемой машине, в nodejs для локального обучения модели. Я не могу понять, как собрать входной тензор из объектов изображения tenor3d (readImage(path) вывод в коде). Я использую tf.stack, и он, кажется, дает правильные размеры, однако пригонка никогда не начинает сходиться даже после 50 эпох. это на все выходные.

Вот код, который я использовал:

const tf = require('@tensorflow/tfjs');
const tfnode = require('@tensorflow/tfjs-node');
const fs = require('fs');
const util = require('util');
const readdir = util.promisify(fs.readdir);

const IMAGE_HEIGHT = 224;
const IMAGE_WIDTH = 224;
const LABEL_FLAT_SIZE = 2;


function readImage(path) {
    const imageBuffer = fs.readFileSync(path);
    const tfimage = tfnode.node.decodeImage(imageBuffer, 3);
    const smallimg = tf.image.resizeBilinear(tfimage, [IMAGE_HEIGHT, IMAGE_WIDTH])
    const floatimg = tf.cast(smallimg, 'float32');

    return floatimg;
}

async function loadImagesDir({images}, dir) {
    let file_count = 0;
    try {
        files = await readdir(dir);
        files.forEach(function (file) {
            if (/^(?!\._).*\.((png)|(jpg)|(jpeg)|(gif))$/gi.test(file)) {
                console.log(file);
                let this_image = readImage(dir + '/' + file);
                images.push(this_image);
                file_count++;
            }
        });
    } catch (err) {
        console.error('Unable to scan directory: ' + err);
    }
    return file_count;
}

function getLabelsSubArray(image_count, classId) {
    return Array(image_count).fill(classId);
}

async function loadImages() {
    const images = [];
    let labels = [];
    let imgcount;

    image_count = await loadImagesDir({images}, 'train_data/squares');
    labels = labels.concat(getLabelsSubArray(image_count, 1));

    image_count = await loadImagesDir({images}, 'train_data/circles');
    labels = labels.concat(getLabelsSubArray(image_count, 2));
    
    return {images: images, labels:labels};
}

async function loadTestImages() {
    const images = [];
    let labels = [];
    let imgcount;

    image_count = await loadImagesDir({images}, 'test_data/squares');
    labels = labels.concat(getLabelsSubArray(image_count, 1));

    image_count = await loadImagesDir({images}, 'test_data/circles');
    labels = labels.concat(getLabelsSubArray(image_count, 2));

    return {images: images, labels:labels};
}

/** Helper class to handle loading training and test data. */
class MnistDataset {
  constructor() {
    this.dataset = null;
    this.trainSize = 0;
    this.testSize = 0;
    this.trainBatchIndex = 0;
    this.testBatchIndex = 0;
  }

  /** Loads training and test data. */
  async loadData() {
    this.dataset = await Promise.all([
      loadImages(), loadTestImages()
    ]);
    this.trainSize = this.dataset[0].images.length;
    this.testSize = this.dataset[1].images.length;

    console.log("LEN:", this.trainSize, this.testSize);
  }

  getTrainData() {
    return this.getData_(true);
  }

  getTestData() {
    return this.getData_(false);
  }

  getData_(isTrainingData) {
    let imagesIndex;
    let labelsIndex;
    if (isTrainingData) {
      imagesIndex = 0;
    } else {
      imagesIndex = 1;
    }

    return {
      images: tf.stack(this.dataset[imagesIndex].images),
      labels: tf.oneHot(tf.tensor1d(labels, 'int32'), LABEL_FLAT_SIZE)
    };
  }
}

module.exports = new MnistDataset();

Моя модель:

const tf = require('@tensorflow/tfjs');

const IMAGE_HEIGHT = 224;
const IMAGE_WIDTH = 224;
const NUM_CLASSES = 2;
const LEARN_RATE = 0.0001;
const DENSE_UNITS = 64;

let model = tf.sequential();

model.add(tf.layers.conv2d({
  inputShape: [IMAGE_HEIGHT, IMAGE_WIDTH , 3],
  filters: 8,
  kernelSize: 5,
  strides: 1,
  activation: 'relu',
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
model.add(tf.layers.flatten());


model.add(tf.layers.dense({
    kernelInitializer: 'VarianceScaling',
    useBias: false,
    activation: 'softmax',
    units: NUM_CLASSES
}));

const optimizer = tf.train.adam(LEARN_RATE);
model.compile({
  optimizer: optimizer,
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy'],
});

module.exports = model;

А сценарий отслеживания вызывается

  const {images: trainImages, labels: trainLabels} = data.getTrainData();
  model.summary();

  let epochBeginTime;
  let millisPerStep;
  const validationSplit = 0.15;
  const numTrainExamplesPerEpoch =
      trainImages.shape[0] * (1 - validationSplit);
  const numTrainBatchesPerEpoch =
      Math.ceil(numTrainExamplesPerEpoch / batchSize);

  await model.fit(trainImages, trainLabels, {
    epochs,
    batchSize,
    validationSplit
  });

  const {images: testImages, labels: testLabels} = data.getTestData();
  const evalOutput = model.evaluate(testImages, testLabels);
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...