Я получаю сообщение об ошибке:
Uncaught (в обещании) TypeError: model.predict не является функцией
Однако код в функции useModel, если я переместу его туда, где я вызываю функцию useModel, работает. Я не могу понять почему. Но это не помогает мне, так как мне нужно будет иметь возможность прогнозировать за пределами функции настройки в моих собственных функциях.
Я предполагаю, что это как-то связано с обещаниями, и я попытался поставить async перед функцией useModel. Но не знаю, почему это поможет.
Может быть, использовать .then каким-нибудь умным способом?
let data;
let xs;
let ys;
function preload(){
data = loadJSON('gridson.json');
}
function setup() {
createCanvas(40, 40);
// prepare data for tensor
let board = [];
for (let i =0; i < data.in.length; i++){
let norm = [];
for (let j =0; j < 200; j++){
norm.push(data['in'][i]['arr'][j] / 2);
}
board.push(norm);
}
xs = tf.tensor2d(board);
let labelList = ['left', 'right', 'rotate', 'fall'];
let label = [];
for (let record of data.in){
label.push(labelList.indexOf(record.move));
}
let labelTensor = tf.tensor1d(label, 'int32');
ys = tf.oneHot(labelTensor, 4).cast('float32');
labelTensor.dispose();
// create the model
let model = tf.sequential();
let hidden = tf.layers.dense({
units: 16,
inputShape: [200],
activation: 'sigmoid'
});
let output = tf.layers.dense({
units: 4,
activation: 'softmax'
});
model.add(hidden);
model.add(output);
// create an optimizer
const lr = 0.1;
const optimizer = tf.train.sgd(lr);
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
// train model
model.fit(xs, ys, {
shuffle: true,
validationSplit: 0.1,
epochs: 1,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(epoch);
},
onBatchEnd: async (batch, logs) => {
await tf.nextFrame();
},
onTrainEnd: () => {
console.log('finished');
// use the model
useModel();
},
},
});
}
function useModel(){
tf.tidy(() => {
let grid = [];
for (let h =0; h < 200; h++){
grid.push(0); // create junk test data
}
const input = tf.tensor2d([grid]);
let results = model.predict(input);
let argMax = results.argMax(1);
let index = argMax.dataSync()[0];
let label = labelList[index];
console.log(label);
});
}
function draw() {
background(150);
}