Как я могу l oop над словарем внутри keras.utils.sequence - PullRequest
0 голосов
/ 21 апреля 2020

У меня есть класс keras.utils.sequence, и как часть процесса генерации данных я хочу оцифровать данные, чтобы все объекты имели диапазон 0-1. В настоящее время у меня есть:

class gen_sequence(keras.utils.Sequence):

    def __init__(self, var_list, steps, csv_file_list, shuffle=True):
        self.dim = len(var_list)
        self.shuffle = shuffle
        self.steps = steps
        fp = open(csv_file_list)
        self.all_files = fp.readlines()
        self.var_list = var_list
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return self.steps

    def on_epoch_end(self):
        self.index = list(range(0,self.steps))
        if self.shuffle == True:
            random.shuffle(self.index)

    def __data_generation(self, filepath):
        df = pd.read_csv(filepath)
        rescale_dict = { #varname: (min, max)
            'part_pt': (0,200),'part_phi': (-3.5,3.5),'part_eta': (-2.5,2.5),
            'el_e255': (0,300000),
            'el_weta2': (-0.01,0.04),
            'el_pos7': (-6,6),'el_barys1': (-3,3),
            'el_ecore': (0,300000),'el_DeltaE': (0,2000),
            'calo_energyBE0': (0,40000),'calo_energyBE1': (0,125000),'calo_energyBE2': (0,200000),'calo_energyBE3': (0,25000),
            'calo_rawE': (0,400000),'calo_E': (0,400000),
            'el_deltaEta0': (-0.1,0.1),'el_deltaEta1': (-0.1,0.1),'el_deltaEta2': (-0.1,0.1),'el_deltaEta3': (-0.1,0.1),
            'el_deltaPhi0': (-0.75,0.25),'el_deltaPhi1': (-0.75,0.25),'el_deltaPhi2': (-0.75,0.25),'el_deltaPhi3': (-0.75,0.25),
            'el_deltaPhiRescaled0': (-0.25,0.25),'el_deltaPhiRescaled1': (-0.25,0.25),
            'el_deltaPhiRescaled2': (-0.25,0.25),'el_deltaPhiRescaled3': (-0.25,0.25),
            'el_deltaPhiFromLastMeasurement': (-0.75,0.25),
            'iso_ptcone20': (0,300000),'iso_ptcone30': (0,300000),'iso_ptcone40': (0,300000),
            'iso_ptvarcone20': (0,300000),'iso_ptvarcone30': (0,300000),'iso_ptvarcone40': (0,300000),
            'iso_etcone20': (0,300000),'iso_etcone30': (0,300000),'iso_etcone40': (0,300000),
            'iso_ettopocone20': (0,300000),'iso_ettopocone30': (0,300000),'iso_ettopocone40': (0,300000),
            'LJ_pT': (0,200000)
            }
    for var, (min,max) in rescale_dict.items():
            # Want to normalize each variable
            data = df[var]
            normalized = (data - min)/(max-min)
            df[var]=normalized
        x = df[self.var_list]
        y = df['targets_LJ']
        return (np.array(x), np.array(y))

    def __getitem__(self, index):
        filepath = self.all_files[index]
        filepathnew = filepath[0:-1]
        x, y = self.__data_generation(filepathnew)
        return x, y

И это работает.

Однако я хотел бы определить словарь rescale_dict в основном коде (вне класса последовательности), но когда я пытаюсь это сделать:

class gen_sequence(keras.utils.Sequence):

    def __init__(self, var_list, steps, csv_file_list, rescale_dict, shuffle=True):
        self.dim = len(var_list)
        self.shuffle = shuffle
        self.steps = steps
        fp = open(csv_file_list)
        self.all_files = fp.readlines()
        self.var_list = var_list
        slef.rescale_dict = rescale_dict
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return self.steps

    def on_epoch_end(self):
        self.index = list(range(0,self.steps))
        if self.shuffle == True:
            random.shuffle(self.index)

    def __data_generation(self, filepath):
        df = pd.read_csv(filepath)
        for var, (min,max) in self.rescale_dict.items():
            # Want to normalize each variable
            data = df[var]
            normalized = (data - min)/(max-min)
            df[var]=normalized
        x = df[self.var_list]
        y = df['targets_LJ']
        return (np.array(x), np.array(y))

    def __getitem__(self, index):
        filepath = self.all_files[index]
        filepathnew = filepath[0:-1]
        x, y = self.__data_generation(filepathnew)
        return x, y

Я получаю ошибку:

Traceback (most recent call last):
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 2657, in get_loc
    return self._engine.get_loc(key)
  File "pandas/_libs/index.pyx", line 108, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/index.pyx", line 132, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 1601, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 1608, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: ('part_pt', (0, 200))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./train_nn_scaleddata.py", line 225, in <module>
    run()
  File "./train_nn_scaleddata.py", line 182, in run
    steps_per_epoch=steps, epochs=args.epochs, validation_steps=steps_val, callbacks=[early_stop], class_weight=class_weight)
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/engine/training.py", line 1418, in fit_generator
    initial_epoch=initial_epoch)
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/engine/training_generator.py", line 181, in fit_generator
    generator_output = next(output_generator)
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/utils/data_utils.py", line 601, in get
    six.reraise(*sys.exc_info())
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/six.py", line 693, in reraise
    raise value
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/utils/data_utils.py", line 595, in get
    inputs = self.queue.get(block=True).get()
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/multiprocessing/pool.py", line 657, in get
    raise self._value
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/multiprocessing/pool.py", line 121, in worker
    result = (True, func(*args, **kwds))
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/utils/data_utils.py", line 401, in get_index
    return _SHARED_SEQUENCES[uid][i]
  File "./train_nn_scaleddata.py", line 67, in __getitem__
    x, y = self.__data_generation(filepathnew)
  File "./train_nn_scaleddata.py", line 57, in __data_generation
    data = df[var] 
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/pandas/core/frame.py", line 2927, in __getitem__
    indexer = self.columns.get_loc(key)
  File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 2659, in get_loc
    return self._engine.get_loc(self._maybe_cast_indexer(key))
  File "pandas/_libs/index.pyx", line 108, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/index.pyx", line 132, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 1601, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 1608, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: ('part_pt', (0, 200))

Есть ли способ определить словарь в основной части кода, ввести его в класс keras.utils.sequence и l oop поверх него, чтобы нормализовать функции?

Спасибо, Сара

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...