тензор потока js как грузить гири из другой модели - PullRequest
0 голосов
/ 26 марта 2020

У меня есть две модели m1 и m2,

. Я бы хотел обновить вес модели m1, чтобы он соответствовал m2,

В python с PyTorch это можно сделать с помощью этой строки кода:

m1.load_state_dict(m2.state_dict())

, но я не смог найти никакой информации о ней rnet.

Единственное, что я нашел в соответствии с этой документацией: https://www.tensorflow.org/js/guide/save_load

- это, например, сохранить m2 через локальное хранилище, а затем полностью загрузить его в m1, но это мне не имеет смысла скачивать и сохранять его снова, чтобы я мог обновить вес.

Ответы [ 2 ]

0 голосов
/ 26 марта 2020

Так что после прочтения документации лучше,

Я нашел это:

m1.setWeights(m2.getWeights());

Я также попытался fit один из них, чтобы увидеть, что он не изучит другой и у него не было проблем.

Обратите внимание , что они оба должны иметь одинаковую структуру, полный пример:

const model = tf.sequential();
model.add(tf.layers.dense({ units: 4, inputShape: [8] }));
model.add(tf.layers.dense({ units: 4 }));
model.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });

const model2 = tf.sequential();
model2.add(tf.layers.dense({ units: 4, inputShape: [8] }));
model2.add(tf.layers.dense({ units: 4 }));
model2.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });

model2.setWeights(model.getWeights());

console.log(model.getWeights()[0].dataSync());
console.log(model2.getWeights()[0].dataSync());
0 голосов
/ 26 марта 2020

Загрузка весов другой модели.

Как указано в вопросе, это можно сделать, сохранив первую модель, а затем загрузив ее как другую модель.

мне не имеет смысла загружать и сохранять его снова, чтобы я мог обновить вес.

Нет смысла полностью обновлять модель 2 по весу модели 1, если оба не идентичны, то есть имеют одинаковую топологию. Нет способа напрямую клонировать модель и присвоить ее другой переменной. Для этого модель должна быть загружена как другая модель или скопированы ее веса и назначены другой модели с той же топологией.

model.getWeight и model.setWeights можно использовать

model2.setWeights(model1.getWeights());

Если модель 2 подлежит частичному обновлению, ie обновляет веса некоторых слоев, это обсуждалось в этих ответах здесь и там

...