Доступ к обучающей операции в tf.keras.Model - PullRequest
1 голос
/ 15 апреля 2019

Как получить доступ к обучающей операции с tf.keras.models.Model? Учтите следующее:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Flatten
from tensorflow.keras.models import Model
import numpy as np
from sys import exit as xit
# Make some dummy data
dummy_data_shape=(5,5)
def batch_generator(size):
    """ Makes some random data """
    def _gen():
        y_batch=np.random.randint(0,2, size=size)
        y_batch=np.expand_dims(y_batch,-1)
        y_expanded=np.expand_dims(y_batch,-1)
        x_batch=np.ones((size,*dummy_data_shape))*y_expanded
        yield x_batch,y_batch
    return _gen()

# Make some simple model
Y=tf.placeholder(tf.float32,[None,1])
X = Input(shape=dummy_data_shape)
layer_mod = Flatten()(X)
layer_mod = Dense(1)(layer_mod)

# Tie it all together and compile
out_model = Model(inputs=[X], outputs=[layer_mod])
out_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.metrics.binary_crossentropy
)

### How can I access a train_op from the out_model?
with tf.Session() as sess:
    data_iter=batch_generator(10)
    sess.run(tf.global_variables_initializer())
    x,y=next(data_iter)
    ## Here: How to access the operation that trains the model?
    train_op=out_model.train_op #<-- ?
    sess.run(train_op, feed_dict={X:x,Y:y})

Какой должна быть вторая до последней строки в приведенном выше коде, чтобы модель тренировалась?

...