Как заставить итератор Tensorflow пропустить определенные индексы - PullRequest
0 голосов
/ 16 июня 2019

Сводка

У меня есть набор данных, содержащий n объектов и k строк.Это коллекция данных спутникового моделирования, охватывающая n симуляций.Я обучаю сеть LSTM, желая предоставить ей образцы (n_steps, n_features) смоделированных данных.Однако, поскольку это комбинация нескольких симуляций, некоторые образцы являются недействительными.Например, когда они охватывают конец симуляции и начало новой.Я проиллюстрировал это на изображении ниже.Зеленые поля представляют действительные образцы, а красные - недопустимые образцы.

The dataset with valid samples in green and invalid samples in red.

Ранее я использовал пользовательский генератор данных (tf.keras.utils.Sequence).класс (включен код в конце).Однако, начиная с TensorFlow 1.14.0, я больше не могу это использовать.Таким образом, я пытаюсь создать его с помощью API tf.data, но застрял в нем.

Проблема

Самый простой способ - создать каждый образец и затем создать набор данных tf.Тем не менее, есть более миллиона сэмплов, которые переполнили бы мою память.Он также не выглядит очень элегантным, так как он будет содержать в основном дубликаты данных.Если бы я создал итератор из смоделированных данных, он будет повторяться по каждому начальному индексу и, таким образом, также включать все недопустимые выборки.Поэтому я думаю, что мне нужно создать итератор, который будет использовать индексы для извлечения действительных выборок из моих смоделированных данных.

Итог

Я ищу создание итератора тензорного потока, который будет извлекать выборки из моих смоделированных данных в соответствии с массивом индекса, содержащим все допустимые начальные индексы, идлина последовательностиТаким образом, первый образец будет

sample0 = data[indexes[0]:indexes[0]+seq_len]
label0 = labels[indexes[0]]

, а общее количество образцов и партий будет

n_samples = len(indexes)
n_batches = np.floor(len(indexes) / batch_size)

** Старый функционал DataGenerator **

class DataGenerator(tf.keras.utils.Sequence):
  def __init__(self, data, indexes, labels, seq_len, batch_size, shuffle=True):
    self.data = data
    self.indexes = indexes
    self.labels = labels
    self.seq_len = seq_len
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.on_epoch_end()

  def on_epoch_end(self):
    if self.shuffle:
      np.random.shuffle(self.indexes)

  def __len__(self):
    return int(np.floor(len(self.indexes) / self.batch_size))

  def __generate_sample(self, temp_idxs):
    X = np.empty((self.batch_size, self.seq_len, self.n_features))
    y = np.empty(self.batch_size, dtype=np.int)

    for i, idx in enumerate(temp_idxs):
        X[i, ] = self.data[idx, idxs+self.seq_len,:]
        y[i, ] = self.labels[idx]
    return X, tf.keras.utils.to_categorical(y)

  def __getitem__(self, index)
    temp_idxs = self.indexes[index*self.batch_size: (index+1) * self.batch_size]
    return self.__generate_sample(temp_idxs)

** РЕДАКТИРОВАТЬ **

Возможно, я нашел решение с помощью функции tf.numpy_function ()

class PyFun:
  def __init__(self, data, labels, seq_len, n_classes):
    self.data = data
    self.labels = labels
    self.seq_len = seq_len
    self.n_classes = n_classes

  def get_item(self,index):
    X = self.data[index:index+self.seq_len,3:]
    y = self.labels[index]
    return X, tf.keras.utils.to_categorical(y, num_classes=self.n_classes)

def in_func(fun_obj, indexes, labels, n_features, seq_len, n_classes, batch_size, epochs):
  source = tf.data.Dataset.from_tensor_slices(indexes)
  dataset = source.map(lambda index: tuple(tf.numpy_function(fun_obj.get_item, [index], [tf.double, tf.float32])))
  dataset = dataset.map(lambda x, y: (tf.cast(x, tf.float32), y))
  dataset = dataset.batch(batch_size)
  dataset = dataset.map(lambda x,y: (tf.reshape(x, [batch_size, seq_len, n_features]), 
                                     tf.reshape(y, [batch_size,n_features])))
  dataset = dataset.shuffle(10000)

  return dataset.repeat(epochs)
fun_obj = PyFun(np_train, labels, params['seq_len'], params['n_classes'])

Выборка образцов с использованием следующих операций выглядит как талисман

ds = in_func(...)
iterator = ds.make_one_shot_iterator()
sess = tf.Session()
print(sess.run(iterator.get_next()))

Но выдает ошибку при попытке:

tf.keras.backend.clear_session()
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)

with strategy.scope():
  training_model = lstm_model(**params)
  adams = tf.train.AdamOptimizer(learning_rate=0.0006)
  training_model.compile(optimizer=adams,
               loss='categorical_crossentropy',
               metrics=['categorical_accuracy'])
fun_obj = PyFun(np_train, labels, params['seq_len'], params['n_classes'])
results = training_model.fit(in_func(fun_obj, training_gen, **params), epochs=params['epochs'], verbose=1)
W0616 15:51:30.669703 140504606967680 tpu_strategy_util.py:56] TPU system %s has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.

---------------------------------------------------------------------------

AbortedError                              Traceback (most recent call last)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1355     try:
-> 1356       return fn(*args)
   1357     except errors.OpError as e:

11 frames

AbortedError: Session ccd188ff0343641e is not found.


During handling of the above exception, another exception occurred:

AbortedError                              Traceback (most recent call last)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1368           pass
   1369       message = error_interpolation.interpolate(message, self._graph)
-> 1370       raise type(e)(node_def, op, message)
   1371 
   1372   def _extend_graph(self):

AbortedError: Session ccd188ff0343641e is not found.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...