График линейного участка и логарифмический участок рядом друг с другом в Python. Аналогично mfrow = c (2,1) в R - PullRequest
1 голос
/ 31 марта 2020

Я пытаюсь построить два графика рядом друг с другом в python, один - линейный результат эксперимента, а другой - логарифмическое c преобразование. Цель состояла бы в том, чтобы расположить графики рядом друг с другом, как par(mfrow=c(1,2)) в R.

if __name__ == '__main__':
  c_1 = run_experiment(1.0, 2.0, 3.0, 0.1, 100000) # run_experiment(m1, m2, m3, eps, N):
  c_05 = run_experiment(1.0, 2.0, 3.0, 0.05, 100000)
  c_01 = run_experiment(1.0, 2.0, 3.0, 0.01, 100000)

  # log scale plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.xscale('log')
  plt.title(label="Log Multi-Arm Bandit")
  plt.show()


  # linear plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.show()

Я пробовал много методов, но, похоже, продолжает получать ошибку. Может ли кто-то реализовать это с помощью кода. Я относительно новичок в Python и в основном имею опыт работы с R, но любая помощь будет означать для меня мир. Ниже я приведу некоторый код для изменения, которое я предпринял.

  # log scale plot
  fig, axes = plt.subplots(122)
  ax1, ax2 = axes[0], axes[1]
  ax1.plot(c_1, label='eps = 0.10')
  ax1.plot(c_05,label='eps = 0.05')
  ax1.plot(c_01,label='eps = 0.01')
  ax1.legend()
  ax1.xscale('log')
  #plt.title(label="Log Multi-Arm Bandit")
  #plt.show()

  # linear plot
  ax2.plot(c_1, label='eps = 0.10')
  ax2.plot(c_05,label='eps = 0.05')
  ax2.plot(c_01, label='eps = 0.01')
  ax2.legend()
  plt.show()

Но я получил ошибку.

Полный код, который я использую для запуска эксперимента, приведен ниже от начала до конец. Это проблема многорукого бандита из обучения подкреплению.

# Premable
from __future__ import print_function, division
from builtins import range
import numpy as np
import matplotlib.pyplot as plt


class Bandit:
  def __init__(self, m):  # m is the true mean
    self.m = m
    self.mean = 0
    self.N = 0

  def pull(self): # simulated pulling bandits arm
    return np.random.randn() + self.m

  def update(self, x):
    self.N += 1
    # look at the derivation above of the mean
    self.mean = (1 - 1.0/self.N)*self.mean + 1.0/self.N*x  


def run_experiment(m1, m2, m3, eps, N):
  bandits = [Bandit(m1), Bandit(m2), Bandit(m3)]

  data = np.empty(N)

  for i in range(N): # Implement epsilon greedy shown above
    # epsilon greedy
    p = np.random.random()
    if p < eps:
      j = np.random.choice(3) # Explore
    else:
      j = np.argmax([b.mean for b in bandits]) # Exploit
    x = bandits[j].pull()  # Pull and update
    bandits[j].update(x)

    # Results for the plot
    data[i] = x  # Store the results in an array called data of size N
    # Calculate cumulative average
  cumulative_average = np.cumsum(data) / (np.arange(N) + 1)


  # plot moving average ctr
  plt.plot(cumulative_average) # plot cumulative average
  # Plot bars with each of the means so we can see where are 
  # cumulative averages relative to means
  plt.plot(np.ones(N)*m1) 
  plt.title('Slot Machine ')
  plt.plot(np.ones(N)*m2)
  plt.title('Slot Machine ')
  plt.plot(np.ones(N)*m3)
  plt.title('Slot Machine ')
  # We do this on a log scale so that you can see the 
  # fluctuations in earlier rounds more clearly
  plt.xscale('log') 
  plt.show()

  for b in bandits:
    print(b.mean)

  return cumulative_average

if __name__ == '__main__':
  c_1 = run_experiment(1.0, 2.0, 3.0, 0.1, 100000) # run_experiment(m1, m2, m3, eps, N):
  c_05 = run_experiment(1.0, 2.0, 3.0, 0.05, 100000)
  c_01 = run_experiment(1.0, 2.0, 3.0, 0.01, 100000)

  # log scale plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.xscale('log')
  plt.title(label="Log Multi-Arm Bandit")
  plt.show()


  # linear plot
  plt.plot(c_1, label='eps = 0.10')
  plt.plot(c_05, label='eps = 0.05')
  plt.plot(c_01, label='eps = 0.01')
  plt.legend()
  plt.show()

1 Ответ

0 голосов
/ 01 апреля 2020

Как отметили @ Йохан C в комментариях, вы путаете синтаксис plt.subplots() и plt.subplot(). Строка

fig, axes = plt.subplots(122)

Создает 122 подзаговора в одном столбце. Это должно быть

fig, axes = plt.subplots(1, 2)
...