tf.placeholder внутри класса в TF 2.0 - PullRequest
1 голос
/ 04 мая 2020

Я пытаюсь изменить код, написанный в TF 1.0, на TF 2.0, и у меня возникают трудности с заменой tf.placeholder внутри функции класса. Мой код следующий

class User:
    x = tf.placeholder(tf.float32,shape=[None,784])
    y_true = tf.placeholder(tf.float32, [None, 10])

    W1 = tf.Variable(tf.random.truncated_normal([7840,1], stddev=0.1))
    lambda_W = tf.Variable(tf.zeros([7840,1]))    
    W = tf.reshape(W1,[784, 10])

    ylogits = W*x
    y = tf.nn.softmax(ylogits)
    def __init__(self):
        pass

Есть ли способ заменить tf.placeholder внутри класса, чтобы сделать код работающим в TF 2.0?

1 Ответ

0 голосов
/ 04 мая 2020

Во-первых, я думаю, что вы намеревались создать каждый из этих объектов для каждого экземпляра класса, а не один для всего класса, как сейчас. Я также думаю, что ваш продукт между W и x должен был быть матричным продуктом, а не поэлементным продуктом, который не будет работать с данными формами:

class User:
    def __init__(self):
        self.x = tf.placeholder(tf.float32,shape=[None,784])
        self.y_true = tf.placeholder(tf.float32, [None, 10])
        self.W1 = tf.Variable(tf.random.truncated_normal([7840,1], stddev=0.1))
        self.lambda_W = tf.Variable(tf.zeros([7840,1]))    
        self.W = tf.reshape(W1,[784, 10])
        self.ylogits = self.x @ self.W
        self.y = tf.nn.softmax(ylogits)

Чтобы использовать его в TensorFlow 2.x, вы бы удалили заполнители и просто выполняли операции каждый раз с каждым новым вводом, например, с новой функцией call:

class User:
    def __init__(self):
        self.W1 = tf.Variable(tf.random.truncated_normal([7840,1], stddev=0.1))
        self.lambda_W = tf.Variable(tf.zeros([7840,1]))
        self.W = tf.reshape(W1,[784, 10])

    def call(self, x):
        ylogits = self.x @ self.W
        return tf.nn.softmax(ylogits)

Вы можете использовать это как:

user1 = User()
x = ...  # Get some data
y = user1.call(x)

Или, если вы хотите быть более «идиоматическим» c, вы можете использовать __call__:

class User:
    def __init__(self):
        self.W1 = tf.Variable(tf.random.truncated_normal([7840,1], stddev=0.1))
        self.lambda_W = tf.Variable(tf.zeros([7840,1]))
        self.W = tf.reshape(W1,[784, 10])

    def __call__(self, x):
        ylogits = x @ W
        return tf.nn.softmax(ylogits)

И тогда вы сделаете:

user1 = User()
x = ...  # Get some data
y = user1(x)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...