Я пытался согласиться с этим уроком . И я застрял на этой части кода, сопровождаемой этой картинкой:
class DataGeneratorSeq(object):
def __init__(self,prices,batch_size,num_unroll):
self._prices = prices
self._prices_length = len(self._prices) - num_unroll
self._batch_size = batch_size
self._num_unroll = num_unroll
self._segments = self._prices_length //self._batch_size
self._cursor = [offset * self._segments for offset in range(self._batch_size)]
Я смущен тем, что там происходит. Почему существуют партии других партий? И по какой логике они генерируются?
Я даже попытался сделать небольшой пример, чтобы попытаться понять это, но я все еще не знаю, что происходит:
batch_size = 6
prices = [10,11,12,13,14,15,16,17,18,19]
num_unroll = 4
prices_length = len(prices) - num_unroll
print('prices_length =', prices_length)
segments = prices_length // batch_size
print('segments =', segments)
cursor = [offset * segments for offset in range(batch_size)]
print('cursor :', cursor, '\n')
dg2 = DataGeneratorSeq(prices, batch_size, num_unroll)
u_data2, u_labels2 = dg2.unroll_batches()
for ui,(dat,lbl) in enumerate(zip(u_data2,u_labels2)):
print('\n\nUnrolled index %d'%ui)
dat_ind = dat
lbl_ind = lbl
print('\tInputs: ',dat )
print('\n\tOutput:',lbl)
Выход:
prices_length = 6
segments = 1
cursor : [0, 1, 2, 3, 4, 5]
b: 0 , [10. 0. 0. 0. 0. 0.]
b: 1 , [10. 11. 0. 0. 0. 0.]
b: 2 , [10. 11. 12. 0. 0. 0.]
b: 3 , [10. 11. 12. 13. 0. 0.]
b: 4 , [10. 11. 12. 13. 14. 0.]
b: 5 , [10. 11. 12. 13. 14. 14.]
ui: 0 , [array([10., 11., 12., 13., 14., 14.], dtype=float32)]
ui: 0 , [array([14., 13., 15., 16., 16., 17.], dtype=float32)]
b: 0 , [11. 0. 0. 0. 0. 0.]
b: 1 , [11. 12. 0. 0. 0. 0.]
b: 2 , [11. 12. 13. 0. 0. 0.]
b: 3 , [11. 12. 13. 14. 0. 0.]
b: 4 , [11. 12. 13. 14. 12. 0.]
b: 5 , [11. 12. 13. 14. 12. 15.]
ui: 1 , [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32)]
ui: 1 , [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32)]
b: 0 , [12. 0. 0. 0. 0. 0.]
b: 1 , [12. 13. 0. 0. 0. 0.]
b: 2 , [12. 13. 14. 0. 0. 0.]
b: 3 , [12. 13. 14. 13. 0. 0.]
b: 4 , [12. 13. 14. 13. 13. 0.]
b: 5 , [12. 13. 14. 13. 13. 10.]
ui: 2 , [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32)]
ui: 2 , [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32)]
b: 0 , [13. 0. 0. 0. 0. 0.]
b: 1 , [13. 14. 0. 0. 0. 0.]
b: 2 , [13. 14. 12. 0. 0. 0.]
b: 3 , [13. 14. 12. 14. 0. 0.]
b: 4 , [13. 14. 12. 14. 14. 0.]
b: 5 , [13. 14. 12. 14. 14. 11.]
ui: 3 , [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)]
ui: 3 , [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32), array([15., 18., 15., 18., 15., 12.], dtype=float32)]
---------- [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)] --------------
[(array([10., 11., 12., 13., 14., 14.], dtype=float32), array([14., 13., 15., 16., 16., 17.], dtype=float32)), (array([11., 12., 13., 14., 12., 15.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32)), (array([12., 13., 14., 13., 13., 10.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32)), (array([13., 14., 12., 14., 14., 11.], dtype=float32), array([15., 18., 15., 18., 15., 12.], dtype=float32))]
<enumerate object at 0x7fb354a12a68>
---------- [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)] --------------
Почему prices_length
рассчитывается как общая длина массива цен минус это таинственное число развертывания? Кажется, num_unroll
- это количество партий, но картина меня смущает. И что это за segments
переменная?
И я прочитал целый ряд других учебных пособий по LSTM, но ни один из них не является настолько углубленным. И я чувствую, что понимаю, как это должно работать в теории, только сам по себе LSTM, но я не могу понять этот код.
Вот весь код класса:
class DataGeneratorSeq(object):
def __init__(self,prices,batch_size,num_unroll):
self._prices = prices
self._prices_length = len(self._prices) - num_unroll
self._batch_size = batch_size
self._num_unroll = num_unroll
self._segments = self._prices_length //self._batch_size
self._cursor = [offset * self._segments for offset in range(self._batch_size)]
def next_batch(self):
batch_data = np.zeros((self._batch_size),dtype=np.float32)
batch_labels = np.zeros((self._batch_size),dtype=np.float32)
for b in range(self._batch_size):
if self._cursor[b]+1>=self._prices_length:
#self._cursor[b] = b * self._segments
self._cursor[b] = np.random.randint(0,(b+1)*self._segments)
batch_data[b] = self._prices[self._cursor[b]]
batch_labels[b]= self._prices[self._cursor[b]+np.random.randint(0,5)]
self._cursor[b] = (self._cursor[b]+1)%self._prices_length
return batch_data,batch_labels
def unroll_batches(self):
unroll_data,unroll_labels = [],[]
init_data, init_label = None,None
for ui in range(self._num_unroll):
data, labels = self.next_batch()
unroll_data.append(data)
unroll_labels.append(labels)
return unroll_data, unroll_labels
def reset_indices(self):
for b in range(self._batch_size):
self._cursor[b] = np.random.randint(0,min((b+1)*self._segments,self._prices_length-1))
dg = DataGeneratorSeq(train_data,5,5)
u_data, u_labels = dg.unroll_batches()
for ui,(dat,lbl) in enumerate(zip(u_data,u_labels)):
print('\n\nUnrolled index %d'%ui)
dat_ind = dat
lbl_ind = lbl
print('\tInputs: ',dat )
print('\n\tOutput:',lbl)
РЕДАКТИРОВАТЬ: хорошо, поэтому я прочитал больше о развертывании сети LSTM, и теперь имеет смысл использовать переменную num_unroll. Так что это просто итерация по всем точкам данных, которые мы хотим оглянуться назад и подача их в сеть при обновлении на каждом этапе.
Но все еще не уверен, что такое "сегменты" и почему вводится некоторая случайность. Разве это не побеждает точку предсказания, основанную на предыдущих n пунктах?