Ошибка типа: Ошибка преобразования фигуры в TensorShape, PongNoFrameskip-v0 env - PullRequest
0 голосов
/ 11 февраля 2020
class CNN(tf.keras.Model):
    def __init__(self,num_state,num_action):
        super().__init__()
        self.input_layer = tf.keras.layers.Input(shape=(num_state,))
        self.conv1 = tf.keras.layers.Conv2D(16,8,4,'VALID',activation='relu')
        self.conv2 = tf.keras.layers.Conv2D(32,4,2,'VALID',activation='relu')
        self.conv3 = tf.keras.layers.Conv2D(32,3,1,'VALID',activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(128,activation='relu')
        self.fc2 = tf.keras.layers.Dense(num_action,activation='relu')

    @tf.function
    def call(self,num_state):
        x = tf.keras.layers.Input(shape=(None, state[0], state[1],state[2]), dtype=tf.float32)
        x = tf.convert_to_tensor(x)
        x = tf.reshape(x,(num_state[1]*num_state[1]*num_state[2]))
        x = self.input_layer(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
...