Как я могу передать функцию Keras (тензорный потоковый график) моих опорных данных в качестве цели (y) для model.fit? - PullRequest
0 голосов
/ 31 мая 2019

(полный скрипт Python добавлен ниже описания.)

У меня есть два набора сигналов: шумный и чистый.Я хочу разработать шумоподавитель (назовите его f ()), который делает шумовые сигналы похожими на чистые.(Они в паре.) Это относительно хорошо понимаемая проблема.(См., Например: https://ieeexplore.ieee.org/document/8502864)

Моя проблема заключается в том, что я хочу обучить эту сеть функции, скажем, g (), примененной к зашумленным данным, и сопоставить ее с той же функцией g (), примененной кчистые данные. Проблема в том, что я не знаю, как передать функцию "target" в функцию Keras model.fit.

Единственный способ, который я разработал, - это предварительная обработкаочистите данные, сохраните их, а затем передайте предварительно обработанные данные как y: target. Это не будет работать, скажем, если я попытаюсь использовать функции расширения класса ImageGenerator.

Что я не могу сделать, так этопросто передайте model.fit () функцию Keras, примененную к моим чистым (целевым) данным. Первые две строки ниже работают. Третья строка не работает.

input = Input(shape=(1024,1024,1,))
combo_model = Model(inputs=input, outputs=f(g(input)))
combo_model.fit(noisedata, g(targdata))

Возвращаемая ошибка:

ValueError: Слой Average_pooling2d_9 был вызван со входом, который не является символическим тензором. Полученный тип:. Полный ввод: [массив ([[[[6.53835247e-01], ...

Полный скрипт Python:

# -*- coding: utf-8 -*-
"""Toy_f_of_g.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1MAZ8yVI6lo5thnV6ZhbgO23xhYQOjKbA
"""

from keras.models import Model
from keras.layers import Input, AveragePooling2D, UpSampling2D
from keras.layers import Conv2D
from keras.optimizers import Adam
import numpy as np

"""###Abstract input tensor, shape 1024x1024x1"""

input = Input(shape=(1024,1024,1,))

"""##Construct cross-encoder
This is a simple model that does one level of encoding followed by one level of decoding

f(_) is a trainable flow graph
"""

def f(x):
  ycode = Conv2D(32, (3,3), activation='relu', padding='same')(x)
  ydec = Conv2D(1, (3,3), activation='sigmoid', padding='same')(ycode)
  return ydec

model = Model(inputs=input, outputs=f(input))
model.compile(loss='mean_squared_error', optimizer = Adam())
model.summary()

"""###Create (fake) noise and target images: They're all 1024x1024x1 float from 0 to 1.0"""

noisedata = np.random.random((30,1024,1024,1))
targdata = np.random.random((30,1024,1024,1))

"""###Fit Model to original data"""

model_train = model.fit(noisedata, targdata, batch_size=2, epochs=2)

"""##But, I want to train on a transformed version of the noisy and target data

###Construct transformation g(_) as a Keras flow graph. 
This is a transformation.
It has no trainable parameters
"""

# This just averages 2x2 regions, effectively low-pass filtering
def g(x):
    xl = AveragePooling2D()(x)
    xlpf = UpSampling2D()(xl)
    return xlpf

"""###Use the transformation g(_) to explicitly construct the transformed noise and training data
Note: no trainable parameters: It's just a specific function.
"""

# xfrm = g(input)
xfrmmodel = Model(inputs=input,outputs=g(input))
xfrmmodel.summary()

"""Use model.predict to execute the transformation"""

noisexfrm = xfrmmodel.predict(noisedata, batch_size=1)
targxfrm = xfrmmodel.predict(targdata, batch_size=1)

"""###Fit Model to transformed data"""

## Fit Model ##
model_train = model.fit(noisexfrm, targxfrm, batch_size=2, epochs=2)

"""##I would prefer not to have to transform the data by hand.
I would rather embed the transformation into the flow graph. I can do that on the input side, but not on the "target" side

What I would like is:
"""

combo_model = Model(inputs=input, outputs=f(g(input)))
combo_model.compile(loss='mean_squared_error', optimizer = Adam())
combo_model.summary()

"""But, the following is illegal"""

combo_model_train = combo_model.fit(noisedata, g(targdata), batch_size=2, epochs=2)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...