Keras Sub классифицируя стиль API - PullRequest
0 голосов
/ 15 мая 2018

Я застрял, чтобы сделать модель с методом подкласса. Вопрос в том, что в этом методе подкласса, где наш метод формы ввода и где наш шаг компиляции?

Пожалуйста, помогите мне выполнить мои задания.

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

model = MyModel()

Вот ссылка

1 Ответ

0 голосов
/ 16 октября 2018

Надеюсь, этот код, взятый из https://www.tensorflow.org/guide/keras, может помочь:

class MyModel(keras.Model):

  def __init__(self, num_classes=10):
    super(MyModel, self).__init__(name='my_model')
    self.num_classes = num_classes
    # Define your layers here.
    self.dense_1 = keras.layers.Dense(32, activation='relu')
    self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid')

  def call(self, inputs):
    # Define your forward pass here,
    # using layers you previously defined (in `__init__`).
    x = self.dense_1(inputs)
    return self.dense_2(x)

  def compute_output_shape(self, input_shape):
    # You need to override this function if you want to use the subclassed model
    # as part of a functional-style model.
    # Otherwise, this method is optional.
    shape = tf.TensorShape(input_shape).as_list()
    shape[-1] = self.num_classes
    return tf.TensorShape(shape)


# Instantiates the subclassed model.
model = MyModel(num_classes=10)

# The compile step specifies the training configuration.
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Trains for 5 epochs.
model.fit(data, labels, batch_size=32, epochs=5)

Вы можете увидеть вызов "model.compile" и в фазе подгонки вы передадите свои входные данные в модель. Как потоки данных внутри модели определяются внутри метода вызова, поэтому, если вы хотите выполнить некоторую проверку размера входного файла, вы можете также поместить его туда.

Seba

...