Как перенести модель, загруженную tf.loadFrozenModel в основном потоке, на веб-сайта - PullRequest
0 голосов
/ 29 января 2019

Теперь я использую функцию tf.loadFrozenModel () для загрузки модели в основной поток, затем я хочу клонировать или перенести загруженную модель веб-работнику.Как я могу это сделать?
код в моем github: https://github.com/yiifanLu/webWorker

1 Ответ

0 голосов
/ 29 января 2019

Лучше скачать замороженную модель прямо в рабочем.Причина в том, что в версиях 10 и 11 нет tf.models.modelFromJSON для загрузки зашифрованной модели, которую можно передать работнику с помощью model.toJson.

. Ниже определяется модель в главном потоке.,Эта модель сохраняется в файле, который обслуживается локальным сервером.Работник может загрузить и использовать его для прогнозов

<head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.2/dist/tf.min.js"></script>
    <script>
        const worker_function = () => {

            onmessage =  async (event) => {
                console.log('from web worker')
                    this.window = this
                    importScripts('https://cdn.jsdelivr.net/npm/setimmediate@1.0.5/setImmediate.min.js')
                    importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2')
                    tf.setBackend('cpu')
                    const model = await tf.loadModel('http://localhost:8080/model.json')
                    model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
                    
                    // Generate some synthetic data for training.
                    const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
                    const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

                    // Train the model inside the worker
                     await model.fit(xs, ys, {epochs: 10})
                     const res = model.predict(tf.tensor2d([5], [1, 1]));
                    // send response to main thread
                    
                    postMessage({res: res.dataSync(), shape: res.shape})
            };
        }
        if (window != self)
            worker_function();
    </script>
    <script>
    
        const model = tf.sequential();
        model.add(tf.layers.dense({units: 1, inputShape: [1]}));
        
        const worker = new Worker(URL.createObjectURL(new Blob(["(" + worker_function.toString() + ")()"], { type: 'text/javascript' })));
        (async() => {
            model.save('downloads://model')
        })()
       
        worker.postMessage({model : 'model'});
        worker.onmessage = (message) => {
            console.log('from main thread')
            const {data} = message
            tf.tensor(data.res, data.shape).print()
        }
    </script>
</head>
...