У меня есть класс keras.utils.sequence, и как часть процесса генерации данных я хочу оцифровать данные, чтобы все объекты имели диапазон 0-1. В настоящее время у меня есть:
class gen_sequence(keras.utils.Sequence):
def __init__(self, var_list, steps, csv_file_list, shuffle=True):
self.dim = len(var_list)
self.shuffle = shuffle
self.steps = steps
fp = open(csv_file_list)
self.all_files = fp.readlines()
self.var_list = var_list
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return self.steps
def on_epoch_end(self):
self.index = list(range(0,self.steps))
if self.shuffle == True:
random.shuffle(self.index)
def __data_generation(self, filepath):
df = pd.read_csv(filepath)
rescale_dict = { #varname: (min, max)
'part_pt': (0,200),'part_phi': (-3.5,3.5),'part_eta': (-2.5,2.5),
'el_e255': (0,300000),
'el_weta2': (-0.01,0.04),
'el_pos7': (-6,6),'el_barys1': (-3,3),
'el_ecore': (0,300000),'el_DeltaE': (0,2000),
'calo_energyBE0': (0,40000),'calo_energyBE1': (0,125000),'calo_energyBE2': (0,200000),'calo_energyBE3': (0,25000),
'calo_rawE': (0,400000),'calo_E': (0,400000),
'el_deltaEta0': (-0.1,0.1),'el_deltaEta1': (-0.1,0.1),'el_deltaEta2': (-0.1,0.1),'el_deltaEta3': (-0.1,0.1),
'el_deltaPhi0': (-0.75,0.25),'el_deltaPhi1': (-0.75,0.25),'el_deltaPhi2': (-0.75,0.25),'el_deltaPhi3': (-0.75,0.25),
'el_deltaPhiRescaled0': (-0.25,0.25),'el_deltaPhiRescaled1': (-0.25,0.25),
'el_deltaPhiRescaled2': (-0.25,0.25),'el_deltaPhiRescaled3': (-0.25,0.25),
'el_deltaPhiFromLastMeasurement': (-0.75,0.25),
'iso_ptcone20': (0,300000),'iso_ptcone30': (0,300000),'iso_ptcone40': (0,300000),
'iso_ptvarcone20': (0,300000),'iso_ptvarcone30': (0,300000),'iso_ptvarcone40': (0,300000),
'iso_etcone20': (0,300000),'iso_etcone30': (0,300000),'iso_etcone40': (0,300000),
'iso_ettopocone20': (0,300000),'iso_ettopocone30': (0,300000),'iso_ettopocone40': (0,300000),
'LJ_pT': (0,200000)
}
for var, (min,max) in rescale_dict.items():
# Want to normalize each variable
data = df[var]
normalized = (data - min)/(max-min)
df[var]=normalized
x = df[self.var_list]
y = df['targets_LJ']
return (np.array(x), np.array(y))
def __getitem__(self, index):
filepath = self.all_files[index]
filepathnew = filepath[0:-1]
x, y = self.__data_generation(filepathnew)
return x, y
И это работает.
Однако я хотел бы определить словарь rescale_dict
в основном коде (вне класса последовательности), но когда я пытаюсь это сделать:
class gen_sequence(keras.utils.Sequence):
def __init__(self, var_list, steps, csv_file_list, rescale_dict, shuffle=True):
self.dim = len(var_list)
self.shuffle = shuffle
self.steps = steps
fp = open(csv_file_list)
self.all_files = fp.readlines()
self.var_list = var_list
slef.rescale_dict = rescale_dict
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return self.steps
def on_epoch_end(self):
self.index = list(range(0,self.steps))
if self.shuffle == True:
random.shuffle(self.index)
def __data_generation(self, filepath):
df = pd.read_csv(filepath)
for var, (min,max) in self.rescale_dict.items():
# Want to normalize each variable
data = df[var]
normalized = (data - min)/(max-min)
df[var]=normalized
x = df[self.var_list]
y = df['targets_LJ']
return (np.array(x), np.array(y))
def __getitem__(self, index):
filepath = self.all_files[index]
filepathnew = filepath[0:-1]
x, y = self.__data_generation(filepathnew)
return x, y
Я получаю ошибку:
Traceback (most recent call last):
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 2657, in get_loc
return self._engine.get_loc(key)
File "pandas/_libs/index.pyx", line 108, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 132, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1601, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1608, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: ('part_pt', (0, 200))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "./train_nn_scaleddata.py", line 225, in <module>
run()
File "./train_nn_scaleddata.py", line 182, in run
steps_per_epoch=steps, epochs=args.epochs, validation_steps=steps_val, callbacks=[early_stop], class_weight=class_weight)
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/engine/training.py", line 1418, in fit_generator
initial_epoch=initial_epoch)
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/engine/training_generator.py", line 181, in fit_generator
generator_output = next(output_generator)
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/utils/data_utils.py", line 601, in get
six.reraise(*sys.exc_info())
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/six.py", line 693, in reraise
raise value
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/utils/data_utils.py", line 595, in get
inputs = self.queue.get(block=True).get()
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/multiprocessing/pool.py", line 657, in get
raise self._value
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/multiprocessing/pool.py", line 121, in worker
result = (True, func(*args, **kwds))
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/keras/utils/data_utils.py", line 401, in get_index
return _SHARED_SEQUENCES[uid][i]
File "./train_nn_scaleddata.py", line 67, in __getitem__
x, y = self.__data_generation(filepathnew)
File "./train_nn_scaleddata.py", line 57, in __data_generation
data = df[var]
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/pandas/core/frame.py", line 2927, in __getitem__
indexer = self.columns.get_loc(key)
File "/opt/ohpc/pub/packages/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 2659, in get_loc
return self._engine.get_loc(self._maybe_cast_indexer(key))
File "pandas/_libs/index.pyx", line 108, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 132, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi", line 1601, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi", line 1608, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: ('part_pt', (0, 200))
Есть ли способ определить словарь в основной части кода, ввести его в класс keras.utils.sequence и l oop поверх него, чтобы нормализовать функции?
Спасибо, Сара