Ходовая модель Frozen Tensorflow на NodeJS - PullRequest
1 голос
/ 05 марта 2020

Я новичок в tenorflow js (и js в целом), однако мне нужно запустить на нем обученную модель. В настоящее время я преобразовал модель в формат json, но изо всех сил пытаюсь передать в нее данные:

const tf = require('@tensorflow/tfjs')
const tfn = require('@tensorflow/tfjs-node-gpu')

async function start() { 
    const handler = tfn.io.fileSystem("./model/model.json"); 
    const model = await tf.loadGraphModel(handler); 
    let latents = tf.randomNormal([1,512], 'float32'); 
    let labels = tf.zeros([1, 0]); 
    model.predict([latents, labels]);
}
start();

Но я получаю сообщение об ошибке, говорящее The Conv2D op currently supports NHWC tensor format on the CPU. The op was given the format: NCHW

Так что, как я понял, это проблема tf js, поэтому я попытался создать массив float32 и передать его модели следующим образом:

var f32array = new Float32Array(512);
model.predict([f32array, labels]);

Но затем я вижу сообщение об ошибке the dtype of dict['Gs/latents_in'] provided in model.execute(dict) must be float32, but was undefined

с помощью python, я делаю вывод, используя этот код:

graph = load_graph("dash/frozen_model.pb")

    x = graph.get_tensor_by_name('prefix/Gs/latents_in:0')
    x2 = graph.get_tensor_by_name('prefix/Gs/labels_in:0')
    y = graph.get_tensor_by_name('prefix/Gs/images_out:0')


    with tf.Session(graph=graph, config = config) as sess:
        while True:
            start_time = time.time()

            latents = np.random.randn(1, 512).astype(np.float32)
            labels = np.zeros([latents.shape[0], 0], np.float32)
            y_out = sess.run(y, feed_dict = { x: latents, x2: labels})

Буду признателен за любую помощь

1 Ответ

1 голос
/ 06 марта 2020

Передача данных как Float32Array не будет работать, поскольку model.predict ожидает либо тензор, либо массив тензоров.

Как указано в ошибке:

Conv2D op в настоящее время поддерживает тензорный формат NHW C на процессоре. Операция получила формат: NCHW

conv2D с версии 1.6 в js поддерживает только формат NHW C. Единственное, что вы можете сделать, это изменить модель в python, чтобы использовать только формат NHW C.

...