Python проблема с подклассами и кератами - PullRequest
1 голос
/ 01 апреля 2020

Я работаю с лабораторной записной книжкой для некоторых открытых заметок курса. Одним из упражнений является создание нового класса IdentityModel, который наследуется от tenorflow.keras.Model и имеет собственный метод "call (input, isidentity = False)". Это должно быть легким упражнением. Вот перефразированный код: импорт скопирован из их ячеек.

# Import relevant packages
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
import numpy as np
import matplotlib.pyplot as plt

class IdentityModel(tf.keras.Model):

  # As before, in __init__ we define the Model's layers
  # Since our desired behavior involves the forward pass, this part is unchanged
  def __init__(self, n_output_nodes):
    super(IdentityModel, self).__init__()
    self.dense_layer = tf.keras.layers.Dense(n_output_nodes, activation='sigmoid')


  def call(self, inputs, isidentity=False):
    x = self.dense_layer(inputs)
    if isidentity:
      return inputs
    else:
      return x

n_output_nodes = 3
model = IdentityModel(n_output_nodes)
x_input = tf.constant([[1,2.]], shape=(1,2))

Я должен вызывать метод вызова IndentityModel. Вот что идет не так.

IdentityModel.call(x_input, False)

вызывает tf.keras.Model.call вместо

IdentityModel.call(x_input, isidentity=False) 

с ошибкой TypeError: call (), пропускающей 1 обязательный позиционный аргумент: 'input'

IdentityModel.call(input=x_input, isidentity=False) 

имеет ошибку TypeError: call () отсутствует 1 обязательный позиционный аргумент: 'self'

Что здесь происходит? Ранее я использовал подобный код без этих проблем.

1 Ответ

0 голосов
/ 01 апреля 2020

Это методы экземпляра, поэтому вам нужно вызывать его из сгенерированного вами экземпляра, а не из класса

n_output_nodes = 3
x_input = tf.constant([[1,2.]], shape=(1,2))

instance = IdentityModel(n_output_nodes)

# Call any of the following
instance.call(x_input)
instance.call(x_input, False)
instance.call(x_input, isidentity=False) 
instance.call(input=x_input, isidentity=False) 

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...