Я изучаю усиленное обучение и написание кода для получения сближения функций и политик с использованием функций значений в непрерывном поле.
Что должен делать этот код и почему он не работает?
import numpy as np
import matplotlib.pyplot as plt
import time
import copy
Настройка класса, который соответствует информации вашей среды
class Environment:
cliff = -3; road = -1; sink = -2; goal = 2
goal_position = [2,3]
reward_list = [[road, road, road, road], [road, road, sink, road], [road, road, road, goal]]
reward_list1 = [['road', 'road', 'road', 'road'], ['road', 'road', 'sink', 'road'], ['road', "road", "road", 'goal']]
def __init__(self):
self.reward = np.asarray(self.reward_list)
def move(self, agent, action):
done = False
new_pos = agent.pos + agent.action[action]
if self.reward_list1 [agent.pos[0]][agent.pos[1]] == 'goal':
reward = self.goal
observation = agent.set_pos(agent.pos)
done = True
elif new_pos[0] <0 or new_pos[0] >= self.reward.shape[0] or new_pos[1] < 0 or new_pos[1] >= self.reward.shape[1]:
reward = self.cliff
observation = agent.set_pos(agent.pos)
done = True
else:
observation = agent.set_pos(new_pos)
reward = self.reward[observation[0], observation[1]]
return observation, reward, done
Класс для выбора действия
class Agent:
action = np.array([[-1.0], [0,1],[1,0],[0,-1]])
select_action_pr = np.array([0.25, 0.25, 0.25, 0.25])
def __init__(self, initial_position):
self.pos = initial_position
def set_pos(self, position):
self.pos = position
return self.pos
def get_pos(self):
return self.pos
Функция, которая выводит установленную политику в соответствии с действием.
def show_policy(policy, env):
for i in range(env.reward.shape[0]):
print('+----------'*env.reward.shape[1], end=''); print("+"); print('|', end='')
for j in range(env.reward.shape[1]):
if env.reward_list[i][j] != 'goal':
if policy[i,j] == 0:
print(' ↑ |',end='')
elif policy[i, j] == 1:
print(' → |', end='')
elif policy[i, j] == 1:
print(' ↓ |', end='')
elif policy[i, j] == 1:
print(' ← |', end='')
else:
print(' * |', end='')
print()
print('+----------'*env.reward.shape[1], end=''); print("+")
def policy_extraction(env, agent, v_table, optimal_policy):
gamma = 0.9
for i in range(env.reward.shape[0]):
for j in range(env.reward.shape[1]):
temp = -1e+10
for action in range(len(agent.action)):
agent.set_pos([i,j])
observation, reward, done = env.move(agent, action)
if temp < reward + gamma * v_table[observation[0], observation[1]]:
optimal_policy[i,j] = action
temp = reward + gamma * v_table[observation[0], observation[1]]
return optimal_policy
Из приведенного ниже кода, код состоит в том, что используют функции значений в непрерывном поле для получения утверждения функций и политик.
np.random.seed(0)
env = Environment()
initial_position = np.array([0,0]); agent = Agent(initial_position)
gamma = 0.9
w = np.random.rand(3) # 함수의 파라미터
w -= 0.5
v_table = np.zeros((env.reward.shape[0], env.reward.shape[1]))
for i in range(env.reward.shape[0]):
for j in range(env.reward.shape[1]):
v_table[i,j] = w[0] + w[1]*i + w[2]*j # 가치함수를 근사하는 함수
max_episode = 10000; max_step = 100; alpha = 0.01; epsilon = 0.3
for epi in range(max_episode):
delta = 0; i = 0; j = 0
agent.set_pos([i,j]); temp = 0
for k in range(max_step):
pos = agent.get_pos()
action = np.random.randint(0,len(agent.action))
observation, reward, done = env.move(agent, action)
now_v = 0; next_v = 0
now_v = w[0] + np.dot(w[1:], pos); next_v = w[0] + np.dot(w[1:], observation)
w[0] += alpha*(reward+gamma*next_v - now_v) # 경사 하강법
w[1] += alpha*(reward+gamma*next_v - now_v)*pos[0]
w[2] += alpha*(reward+gamma*next_v - now_v)*pos[1]
if done == True:
break
for i in range(env.reward.shape[0]):
for j in range(env.reward.shape[1]):
v_table[i,j] = w[0] + w[1]*i + w[2]*j
print('Approximation of Functions: V(s)'); show_v_table(np.round(v_table, 2), env)
policy = np.zeros((env.reward.shape[0], env.reward.shape[1]))
policy = policy_extraction(env, agent, v_table, policy); show_policy(policy, env)
Когда я запускаю этот код, я получаю это ошибка:
ValueError: shapes (2,) and (4,) not aligned: 2 (dim 0) != 4 (dienter code herem 0)