Моя цель - создать скрипт node.js для обучения модели в папке, содержащей изображения в формате jpeg / png, и использовать его для классификации изображений. Я сгенерировал тестовую модель для набора изображений в пределах https://teachablemachine.withgoogle.com/train/image и успешно загрузил модель в nodejs - прогнозирование работает хорошо.
Проблема в том, что я сейчас пытаюсь воспроизвести логику в браузере c, используемую на обучаемой машине, в nodejs для локального обучения модели. Я не могу понять, как собрать входной тензор из объектов изображения tenor3d (readImage(path)
вывод в коде). Я использую tf.stack
, и он, кажется, дает правильные размеры, однако пригонка никогда не начинает сходиться даже после 50 эпох. это на все выходные.
Вот код, который я использовал:
const tf = require('@tensorflow/tfjs');
const tfnode = require('@tensorflow/tfjs-node');
const fs = require('fs');
const util = require('util');
const readdir = util.promisify(fs.readdir);
const IMAGE_HEIGHT = 224;
const IMAGE_WIDTH = 224;
const LABEL_FLAT_SIZE = 2;
function readImage(path) {
const imageBuffer = fs.readFileSync(path);
const tfimage = tfnode.node.decodeImage(imageBuffer, 3);
const smallimg = tf.image.resizeBilinear(tfimage, [IMAGE_HEIGHT, IMAGE_WIDTH])
const floatimg = tf.cast(smallimg, 'float32');
return floatimg;
}
async function loadImagesDir({images}, dir) {
let file_count = 0;
try {
files = await readdir(dir);
files.forEach(function (file) {
if (/^(?!\._).*\.((png)|(jpg)|(jpeg)|(gif))$/gi.test(file)) {
console.log(file);
let this_image = readImage(dir + '/' + file);
images.push(this_image);
file_count++;
}
});
} catch (err) {
console.error('Unable to scan directory: ' + err);
}
return file_count;
}
function getLabelsSubArray(image_count, classId) {
return Array(image_count).fill(classId);
}
async function loadImages() {
const images = [];
let labels = [];
let imgcount;
image_count = await loadImagesDir({images}, 'train_data/squares');
labels = labels.concat(getLabelsSubArray(image_count, 1));
image_count = await loadImagesDir({images}, 'train_data/circles');
labels = labels.concat(getLabelsSubArray(image_count, 2));
return {images: images, labels:labels};
}
async function loadTestImages() {
const images = [];
let labels = [];
let imgcount;
image_count = await loadImagesDir({images}, 'test_data/squares');
labels = labels.concat(getLabelsSubArray(image_count, 1));
image_count = await loadImagesDir({images}, 'test_data/circles');
labels = labels.concat(getLabelsSubArray(image_count, 2));
return {images: images, labels:labels};
}
/** Helper class to handle loading training and test data. */
class MnistDataset {
constructor() {
this.dataset = null;
this.trainSize = 0;
this.testSize = 0;
this.trainBatchIndex = 0;
this.testBatchIndex = 0;
}
/** Loads training and test data. */
async loadData() {
this.dataset = await Promise.all([
loadImages(), loadTestImages()
]);
this.trainSize = this.dataset[0].images.length;
this.testSize = this.dataset[1].images.length;
console.log("LEN:", this.trainSize, this.testSize);
}
getTrainData() {
return this.getData_(true);
}
getTestData() {
return this.getData_(false);
}
getData_(isTrainingData) {
let imagesIndex;
let labelsIndex;
if (isTrainingData) {
imagesIndex = 0;
} else {
imagesIndex = 1;
}
return {
images: tf.stack(this.dataset[imagesIndex].images),
labels: tf.oneHot(tf.tensor1d(labels, 'int32'), LABEL_FLAT_SIZE)
};
}
}
module.exports = new MnistDataset();
Моя модель:
const tf = require('@tensorflow/tfjs');
const IMAGE_HEIGHT = 224;
const IMAGE_WIDTH = 224;
const NUM_CLASSES = 2;
const LEARN_RATE = 0.0001;
const DENSE_UNITS = 64;
let model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [IMAGE_HEIGHT, IMAGE_WIDTH , 3],
filters: 8,
kernelSize: 5,
strides: 1,
activation: 'relu',
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
kernelInitializer: 'VarianceScaling',
useBias: false,
activation: 'softmax',
units: NUM_CLASSES
}));
const optimizer = tf.train.adam(LEARN_RATE);
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
module.exports = model;
А сценарий отслеживания вызывается
const {images: trainImages, labels: trainLabels} = data.getTrainData();
model.summary();
let epochBeginTime;
let millisPerStep;
const validationSplit = 0.15;
const numTrainExamplesPerEpoch =
trainImages.shape[0] * (1 - validationSplit);
const numTrainBatchesPerEpoch =
Math.ceil(numTrainExamplesPerEpoch / batchSize);
await model.fit(trainImages, trainLabels, {
epochs,
batchSize,
validationSplit
});
const {images: testImages, labels: testLabels} = data.getTestData();
const evalOutput = model.evaluate(testImages, testLabels);