Я работаю над проблемой RL и создал класс для инициализации модели и других параметров. Код выглядит следующим образом:
class Agent:
def __init__(self, state_size, is_eval=False, model_name=""):
self.state_size = state_size
self.action_size = 20 # measurement, CNOT, bit-flip
self.memory = deque(maxlen=1000)
self.inventory = []
self.model_name = model_name
self.is_eval = is_eval
self.done = False
self.gamma = 0.95
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
def model(self):
model = Sequential()
model.add(Dense(units=16, input_dim=self.state_size, activation="relu"))
model.add(Dense(units=32, activation="relu"))
model.add(Dense(units=8, activation="relu"))
model.add(Dense(self.action_size, activation="softmax"))
model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=0.003))
return model
def act(self, state):
options = self.model.predict(state)
return np.argmax(options[0]), options
Я хочу запустить его только для одной итерации, поэтому я создаю объект и передаю вектор длины 16
следующим образом:
agent = Agent(density.flatten().shape)
state = density.flatten()
action, probs = agent.act(state)
Однако я получаю следующую ошибку:
AttributeError Traceback (most recent call last) <ipython-input-14-4f0ff0c40f49> in <module>
----> 1 action, probs = agent.act(state)
<ipython-input-10-562aaf040521> in act(self, state)
39 # return random.randrange(self.action_size)
40 # model = self.model()
---> 41 options = self.model.predict(state)
42 return np.argmax(options[0]), options
43
AttributeError: 'function' object has no attribute 'predict'
В чем проблема? Я также проверил коды других людей, например this , и я думаю, что мой тоже очень похож.
Дайте мне знать.
РЕДАКТИРОВАТЬ:
Я изменил аргумент в Dense
с input_dim
на input_shape
и self.model.predict(state)
на self.model().predict(state)
.
Теперь, когда я запускаю NN для одних входных данных формы (16,1)
, я получаю следующую ошибку:
ValueError: Ошибка при проверке входных данных: ожидается, что плотность_данных_3 имеет 3 измерения, но получил массив с формой (16, 1)
И когда я запускаю его с формой (1,16)
, я получаю следующую ошибку:
ValueError: Ошибка при проверке ввода: ожидается, что dens_1_input имеет 3 измерения, но получил массив с формой (1, 16)
Что мне делать в этом случае?