В TensorFlow.js я создал последовательную нейронную сеть с 3-мя плотными слоями, которая работает, когда я устанавливаю функцию активации на «relu», но когда я пытаюсь «tanh» или «sigmoid», она выдает ошибку «Ошибка:Тензор распорядился ".
Я сделал сводку модели, чтобы убедиться, что изменение функции активации не изменило структуру сети или номера параметров. Я также попытался закомментировать tf.tidy
, который я использовал.
Вот моя модель:
const myModel = tf.sequential();
myModel.add(tf.layers.dense({ units: 64, inputShape: [1], activation: 'tanh' }));
myModel.add(tf.layers.dense({ units: 64, inputShape: [1], activation: 'relu' }));
myModel.add(tf.layers.dense({ units: 1 }));
Переключение 'tanh'
на 'relu'
решает проблему, но я не знаю почему.
Вот мойкод тренировки:
optimizer.minimize(() => {
let inputs = tf.tensor2d(x_vals);
let predictions = myModel.predictOnBatch(inputs);
let totalLoss = tf.losses.meanSquaredError(tf.tensor2d(y_vals), predictions);
return totalLoss;
});
Фрагмент полного кода (для запуска требуется секунда):
x_vals = [
[1],
[2],
[3],
[4],
[5]
];
y_vals = [
[1],
[2],
[3],
[4],
[5]
];
const optimizer = tf.train.adam(.005);
const myModel = tf.sequential();
myModel.add(tf.layers.dense({ units: 64, inputShape: [1], activation: 'tanh' }));
myModel.add(tf.layers.dense({ units: 64, activation: 'relu' }));
myModel.add(tf.layers.dense({ units: 1 }));
myModel.summary();
optimizer.minimize(() => {
let inputs = tf.tensor2d(x_vals);
let predictions = myModel.predictOnBatch(inputs);
let totalLoss = tf.losses.meanSquaredError(tf.tensor2d(y_vals), predictions);
return totalLoss;
});
curveY = [];
for (let i = 0; i < x_vals.length; i++) {
curveY.push(myModel.predict(tf.tensor([
x_vals[i]
])).dataSync());
}
console.log(curveY);
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1">
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.js">
</script>
</head>
<body>
</body>
</html>