Как обновить данные в наборе данных тензорного потока tf.data.Dataset? - PullRequest
0 голосов
/ 29 апреля 2020

У меня есть некоторые данные (x, y, m), где x и y - тензоры с соответствующими размерами (nxm) и (nx 1) (x - данные, а y - метка). Данные x бывают двух типов, а двоичный тензор m (nx 1) указывает, к какому типу относится каждая точка данных.

При обучении моей модели я буду sh для случайного чередования между партиями типа 0 или Тип 1 данных. Для этого я разбил набор данных на две части:

#create initial dataset
init_data = tf.data.Dataset.from_tensor_slices((x,y,m))
#split dataset into two (based on m)
m0 = init_data.filter(lambda x,y,m : tf.math.equal(m,0) )
m1 = init_data.filter(lambda x,y,m : tf.math.equal(m,1) )

Иногда количество точек типа 1 или 0 не очень много. Для моего обучения не имеет значения, содержит ли партия один и тот же пункт данных несколько раз (только то, что партии выбираются случайным образом в каждую эпоху обучения). Чтобы решить эту проблему, я запускаю:

#batch size for training epochs
batch_size = 100 
#large buffer size for shuffling
n = 1000

#shuffle and batch the dataset (allow repeats of datapoints)
m0 = m0.repeat().shuffle(n).batch(batch_size)
m1 = m1.repeat().shuffle(n).batch(batch_size)

Теперь я могу произвольно выбирать, из какого из двух наборов данных я получаю свою партию в каждой эпохе обучения, используя следующее:

#dataset to sample from at each training iteration
traindat = tf.data.experimental.sample_from_datasets([m0,m1], [0.5, 0.5])

In psuedo -код моего обучения l oop выглядит следующим образом:

#create an iterator over the dataset
it = iter(traindat)

#train for t iterations
for t in range(T):
    #sample batch
    mysample,labels = next(it)
    #forward pass
    pred = my_NN(my_sample)
    #loss
    loss = my_loss(pred,labels)
    #gradient step
    my_NN.update_via_autograd

Моя проблема заключается в следующем:

Предположим, что каждый раз, когда я заканчиваю 10% своего обучения, l oop Я хочу заменить мои данные (x, y, m) прогнозом нейронных сетей на этой итерации, т.е. (pred, y, m)?

Для этого я в настоящее время просто перезаписываю набор данных init_data, добавляя строку:

if t%(t//10) == 0: 
    init_data =tf.data.Dataset.from_tensor_slices((pred,y,m))

, а затем снова запускаю все предыдущие строки кода, чтобы восстановить traindat и его. Однако это кажется крайне неэффективным. Есть лучший способ сделать это? Большое спасибо за любую помощь.

...