Из документации , похоже, stft
принимает только (..., length)
входы, но не принимает (..., length, channels)
.
Таким образом, первое предложение состоит в том, чтобы сначала переместить каналы в другое измерение, чтобы сохранить длину в последнем индексе и заставить функцию работать.
Теперь, конечно, вам понадобятся совпадающие длины, вы не можете сопоставить 76000 с 75901. Таким образом, второе предложение заключается в использовании padding='same'
в 1-мерных свертках для сохранения равных длин.
И, наконец, поскольку в результате stft
у вас уже будет 10 каналов, вам не нужно расширять затемнения в последней лямбде.
Суммируя:
1D деталь
inputs = Input((76000,)) #(batch, 76000)
c1Out = Lambda(lambda x: K.expand_dims(x, axis=-1))(inputs) #(batch, 76000, 1)
c1Out = Conv1D(10, 100, activation = 'relu', padding='same')(c1Out) #(batch, 76000, 10)
#permute for putting length last, apply stft, put the channels back to their position
c1Stft = Permute((2,1))(c1Out) #(batch, 10, 76000)
c1Stft = x = Lambda(lambda v: tf.abs(tf.signal.stft(v,
frame_length=frame_length,
frame_step=frame_step)
)
)(c1Stft) #(batch, 10, probably 751, probably 513)
c1Stft = Permute((2,3,1))(c1Stft) #(batch, 751, 513, 10)
2D деталь , ваш код выглядит нормально:
c2Out = Lambda(lambda v: tf.expand_dims(tf.abs(tf.signal.stft(v,
frame_length=frame_length,
frame_step=frame_step)
),
-1))(inputs) #(batch, 751, 513, 1)
Теперь, когда все совместиморазмеры
#maybe
#c2Out = Conv2D(10, ..., padding='same')(c2Out)
joined = Concatenate()([c1Stft, c2Out]) #(batch, 751, 513, 11) #maybe (batch, 751, 513, 20)
further = BatchNormalization()(joined)
further = Conv2D(...)(further)
Предупреждение: я не знаю, сделали ли они stft
дифференцируемыми или нет, часть Conv1D
будет работать только если определены градиенты.