class CNN(tf.keras.Model):
def __init__(self,num_state,num_action):
super().__init__()
self.input_layer = tf.keras.layers.Input(shape=(num_state,))
self.conv1 = tf.keras.layers.Conv2D(16,8,4,'VALID',activation='relu')
self.conv2 = tf.keras.layers.Conv2D(32,4,2,'VALID',activation='relu')
self.conv3 = tf.keras.layers.Conv2D(32,3,1,'VALID',activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.fc1 = tf.keras.layers.Dense(128,activation='relu')
self.fc2 = tf.keras.layers.Dense(num_action,activation='relu')
@tf.function
def call(self,num_state):
x = tf.keras.layers.Input(shape=(None, state[0], state[1],state[2]), dtype=tf.float32)
x = tf.convert_to_tensor(x)
x = tf.reshape(x,(num_state[1]*num_state[1]*num_state[2]))
x = self.input_layer(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x