Пользовательская функция потери Keras с Keras дискриминатором с нейронной сетью Гана - PullRequest
1 голос
/ 10 февраля 2020

GAN Discriminator

Я использую этот код ниже, чтобы получить дискриминатор нейронной сети GAN:

import tensorflow as tf
import numpy as np
from IPython.display import display, Audio

tf.reset_default_graph()
saver = tf.train.import_meta_graph('./infer/infer.meta')
graph = tf.get_default_graph()
sess = tf.InteractiveSession()
saver.restore(sess, tf.train.latest_checkpoint('model/'))

# here is z with underline, it doesn't showing ceractly in stack.
# I use random data to test this function.
_z = np.random.uniform(-1., 1., size=[5, 257])
x = graph.get_tensor_by_name('x:0')
D_z = graph.get_tensor_by_name('D_z:0')
D_z = sess.run(D_z, {x: _z})
print(D_z)

Пользовательская функция потери Keras

Я хочу создать функцию для Пользовательские функции потери keras:

# Load the graph
tf.reset_default_graph()
saver = tf.train.import_meta_graph('./infer/infer.meta')
graph = tf.get_default_graph()
sess = tf.InteractiveSession()
saver.restore(sess, tf.train.latest_checkpoint('model/'))

def gan_loss(y_true, y_pred):
    _z = y_pred
    x = graph.get_tensor_by_name('x:0')
    D_z = graph.get_tensor_by_name('D_z:0')
    D_z = sess.run(D_z, {x: _z})

    return D_z

Проблема, с которой я столкнулся

У меня возникла проблема, которая показывает мне: Не удается накормить тестер, вы должны накормить его numpy или другим типом данных.

TypeError: Значение фида не может быть объектом tf.Tensor. Допустимые значения подачи включают Python скаляры, строки, списки или numpy ndarrays.

Я обнаружил связанную проблему в Stak: Преобразование Tensor в np.array с помощью K.eval () в Keras возвращает InvalidArgumentError

Tensorflow: как передать переменную-заполнитель с помощью тензора?

нейронную сеть GAN, как получить дискриминатор

X = tf.placeholder(tf.float32, [None, 257], name='x')
D_z, h3 = discriminator(X)
D_z = tf.identity(D_z, name='D_z')

D_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='GAN/Discriminator')
# global_step = tf.train.get_or_create_global_step()
saver = tf.train.Saver(D_vars)
infer_dir = './infer/'
tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')
infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
tf.train.export_meta_graph(
    filename=infer_metagraph_fp,
    clear_devices=True,
    saver_def=saver.as_saver_def())
tf.reset_default_graph()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...