Как понять этот пример LSTM? - PullRequest
0 голосов
/ 28 июня 2018

Я пытался согласиться с этим уроком . И я застрял на этой части кода, сопровождаемой этой картинкой:

class DataGeneratorSeq(object):

    def __init__(self,prices,batch_size,num_unroll):
        self._prices = prices
        self._prices_length = len(self._prices) - num_unroll
        self._batch_size = batch_size
        self._num_unroll = num_unroll
        self._segments = self._prices_length //self._batch_size
        self._cursor = [offset * self._segments for offset in range(self._batch_size)]

enter image description here

Я смущен тем, что там происходит. Почему существуют партии других партий? И по какой логике они генерируются?

Я даже попытался сделать небольшой пример, чтобы попытаться понять это, но я все еще не знаю, что происходит:

batch_size = 6
prices = [10,11,12,13,14,15,16,17,18,19]
num_unroll = 4
prices_length = len(prices) - num_unroll
print('prices_length =', prices_length)
segments = prices_length // batch_size
print('segments =', segments)
cursor = [offset * segments for offset in range(batch_size)]
print('cursor :', cursor, '\n')

dg2 = DataGeneratorSeq(prices, batch_size, num_unroll)
u_data2, u_labels2 = dg2.unroll_batches()

for ui,(dat,lbl) in enumerate(zip(u_data2,u_labels2)):   
    print('\n\nUnrolled index %d'%ui)
    dat_ind = dat
    lbl_ind = lbl
    print('\tInputs: ',dat )
    print('\n\tOutput:',lbl)

Выход:

prices_length = 6
segments = 1
cursor : [0, 1, 2, 3, 4, 5] 

b:  0 ,  [10.  0.  0.  0.  0.  0.]
b:  1 ,  [10. 11.  0.  0.  0.  0.]
b:  2 ,  [10. 11. 12.  0.  0.  0.]
b:  3 ,  [10. 11. 12. 13.  0.  0.]
b:  4 ,  [10. 11. 12. 13. 14.  0.]
b:  5 ,  [10. 11. 12. 13. 14. 14.]
ui:  0 ,  [array([10., 11., 12., 13., 14., 14.], dtype=float32)]
ui:  0 ,  [array([14., 13., 15., 16., 16., 17.], dtype=float32)]
b:  0 ,  [11.  0.  0.  0.  0.  0.]
b:  1 ,  [11. 12.  0.  0.  0.  0.]
b:  2 ,  [11. 12. 13.  0.  0.  0.]
b:  3 ,  [11. 12. 13. 14.  0.  0.]
b:  4 ,  [11. 12. 13. 14. 12.  0.]
b:  5 ,  [11. 12. 13. 14. 12. 15.]
ui:  1 ,  [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32)]
ui:  1 ,  [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32)]
b:  0 ,  [12.  0.  0.  0.  0.  0.]
b:  1 ,  [12. 13.  0.  0.  0.  0.]
b:  2 ,  [12. 13. 14.  0.  0.  0.]
b:  3 ,  [12. 13. 14. 13.  0.  0.]
b:  4 ,  [12. 13. 14. 13. 13.  0.]
b:  5 ,  [12. 13. 14. 13. 13. 10.]
ui:  2 ,  [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32)]
ui:  2 ,  [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32)]
b:  0 ,  [13.  0.  0.  0.  0.  0.]
b:  1 ,  [13. 14.  0.  0.  0.  0.]
b:  2 ,  [13. 14. 12.  0.  0.  0.]
b:  3 ,  [13. 14. 12. 14.  0.  0.]
b:  4 ,  [13. 14. 12. 14. 14.  0.]
b:  5 ,  [13. 14. 12. 14. 14. 11.]
ui:  3 ,  [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)]
ui:  3 ,  [array([14., 13., 15., 16., 16., 17.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32), array([15., 18., 15., 18., 15., 12.], dtype=float32)]
 ----------  [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)] --------------
[(array([10., 11., 12., 13., 14., 14.], dtype=float32), array([14., 13., 15., 16., 16., 17.], dtype=float32)), (array([11., 12., 13., 14., 12., 15.], dtype=float32), array([14., 15., 16., 16., 16., 18.], dtype=float32)), (array([12., 13., 14., 13., 13., 10.], dtype=float32), array([16., 15., 16., 17., 16., 11.], dtype=float32)), (array([13., 14., 12., 14., 14., 11.], dtype=float32), array([15., 18., 15., 18., 15., 12.], dtype=float32))]
<enumerate object at 0x7fb354a12a68>
 ----------  [array([10., 11., 12., 13., 14., 14.], dtype=float32), array([11., 12., 13., 14., 12., 15.], dtype=float32), array([12., 13., 14., 13., 13., 10.], dtype=float32), array([13., 14., 12., 14., 14., 11.], dtype=float32)] --------------

Почему prices_length рассчитывается как общая длина массива цен минус это таинственное число развертывания? Кажется, num_unroll - это количество партий, но картина меня смущает. И что это за segments переменная?

И я прочитал целый ряд других учебных пособий по LSTM, но ни один из них не является настолько углубленным. И я чувствую, что понимаю, как это должно работать в теории, только сам по себе LSTM, но я не могу понять этот код.

Вот весь код класса:

class DataGeneratorSeq(object):

    def __init__(self,prices,batch_size,num_unroll):
        self._prices = prices
        self._prices_length = len(self._prices) - num_unroll
        self._batch_size = batch_size
        self._num_unroll = num_unroll
        self._segments = self._prices_length //self._batch_size
        self._cursor = [offset * self._segments for offset in range(self._batch_size)]

    def next_batch(self):

        batch_data = np.zeros((self._batch_size),dtype=np.float32)
        batch_labels = np.zeros((self._batch_size),dtype=np.float32)

        for b in range(self._batch_size):
            if self._cursor[b]+1>=self._prices_length:
                #self._cursor[b] = b * self._segments
                self._cursor[b] = np.random.randint(0,(b+1)*self._segments)

            batch_data[b] = self._prices[self._cursor[b]]
            batch_labels[b]= self._prices[self._cursor[b]+np.random.randint(0,5)]

            self._cursor[b] = (self._cursor[b]+1)%self._prices_length

        return batch_data,batch_labels

    def unroll_batches(self):

        unroll_data,unroll_labels = [],[]
        init_data, init_label = None,None
        for ui in range(self._num_unroll):

            data, labels = self.next_batch()    

            unroll_data.append(data)
            unroll_labels.append(labels)

        return unroll_data, unroll_labels

    def reset_indices(self):
        for b in range(self._batch_size):
            self._cursor[b] = np.random.randint(0,min((b+1)*self._segments,self._prices_length-1))



dg = DataGeneratorSeq(train_data,5,5)
u_data, u_labels = dg.unroll_batches()

for ui,(dat,lbl) in enumerate(zip(u_data,u_labels)):   
    print('\n\nUnrolled index %d'%ui)
    dat_ind = dat
    lbl_ind = lbl
    print('\tInputs: ',dat )
    print('\n\tOutput:',lbl)

РЕДАКТИРОВАТЬ: хорошо, поэтому я прочитал больше о развертывании сети LSTM, и теперь имеет смысл использовать переменную num_unroll. Так что это просто итерация по всем точкам данных, которые мы хотим оглянуться назад и подача их в сеть при обновлении на каждом этапе.

Но все еще не уверен, что такое "сегменты" и почему вводится некоторая случайность. Разве это не побеждает точку предсказания, основанную на предыдущих n пунктах?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...