Добавление структурированных данных к атрибуту класса в Python - PullRequest
0 голосов
/ 05 апреля 2020

У меня есть пара объектов, которые я использую для численного моделирования. Ниже показан минимальный пример, где есть два объекта: 1) объект Environment, который имеет два состояния (x и y), которые он симулирует стохастически во времени; и 2) объект Simulation, который управляет симуляцией и сохраняет состояние окружающей среды в течение симуляции.

В объекте Simulation я хочу сохранить состояние окружающей среды как 1) во времени, так и 2) для нескольких симуляций. Со временем я могу использовать defaultdict, чтобы сохранить переменные состояния в пределах одной симуляции, но в разных симуляциях мне не совсем понятен лучший способ сохранить созданные defaultdicts. Если я добавляю в список (без использования копирования), то список возвращает все идентичные defaultdicts из-за изменчивости списков. В приведенном ниже примере я использую copy.copy, в качестве ответа здесь предлагает.

Есть ли подходы, которые более "Pythoni c"? Было бы лучше использовать неизменяемый тип для хранения defaultdicts для каждого моделирования?

import copy
from collections import defaultdict
import numpy as np, pandas as pd
from matplotlib import pyplot as plt


class Environment(object):
    """
    Class representing a random walk of two variables x and y

    Methods
    -------
    start_simulation:   draw values from state variables from priors
    step:               add random noise to state variables
    current_state:      return current state of x and y in a dict

    """
    def __init__(self, mu1, sigma1, mu2, sigma2):
        self.mu1 = mu1
        self.mu2 = mu2
        self.sigma1 = sigma1
        self.sigma2 = sigma2

    def start_simulation(self):
        self.x = self.mu1 + self.sigma1 * np.random.randn()
        self.y = self.mu2 + self.sigma2 * np.random.randn()

    def step(self):
        self.x += self.sigma1 * np.random.randn()
        self.y += self.sigma2 * np.random.randn()

    def current_state(self):
        return({"x": self.x, "y": self.y})


class Simulation(object):
    """
    Class representing a simulation object for handling the Environment object
     and storing data

    Methods
    -------

    start_simulation:   start the simulation; initialise state of the environment
    simulate:           generate n_simulations simulations of n_timesteps time steps each
    save_state:          
    """
    def __init__(self, env, n_timesteps):
        self.env = env
        self.n_timesteps = n_timesteps

        self.data_all = []
        self.data_states = defaultdict(list)

    def start_simulation(self):
        self.timestep = 0
        self.env.start_simulation()

        # Append current data (if non empty)
        if self.data_states:
            self.data_all.append(copy.copy(self.data_states)) # <---------- this step
            # without copy.copy this will return all elements of the list data_all to be the 
            # same default dict at the end of all simulations - lists are mutable

        # Reset data_current
        self.data_states = defaultdict(list)

    def simulate(self, n_simulations):
        """
        Run simulation for n_simulations and n_timesteps timesteps
        """
        self.start_simulation()

        for self.simulation in range(n_simulations):

            self.timestep = 0

            while(self.timestep < self.n_timesteps):
                self.env.step()
                self.save_state(self.env.current_state())
                self.timestep += 1

            self.start_simulation()


    def save_state(self, state):
        """
        Save results to a default dict
        """
        for key, value in state.items():
            self.data_states[key].append(value)


if __name__ == "__main__":

    # Run 7 simulations, each for for 20 time steps
    N_TIME = 20
    N_SIM = 7

    e = Environment(
        mu1 = 1.4, sigma1 = 0.1, 
        mu2 = 2.6, sigma2 = 0.05)

    s = Simulation(env = e, n_timesteps = N_TIME)
    s.simulate(N_SIM)

    # Plot output
    fig, ax = plt.subplots()
    for var, c in zip(["x", "y"], ["#D55E00", "#009E73"]):
        [ax.plot(pd.DataFrame(d)[var], label = var, color = c) for d in s.data_all]
    ax.set_xlabel("Time")
    ax.set_ylabel("Value")
    plt.show()

random walkspython

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...