Я конвертировал эту реализацию Keet Unet из Github в 1D версию с надеждой на обработку монофонических аудио файлов.Я также добавил tf.signal.rfft
через лямбду и tf.signal.irfft
в конце.Это было основано на некотором гугле и в основном в этом посте .
Я получаю сообщение об ошибке из-за лямбды irfft из-за того, что форма вывода неожиданно равна (NUM_SAMPLES, 0)
.Хотя входной слой и ожидаемая форма имеют вид `(NUM_SAMPLES, 1).
Ошибка:
ValueError: Error when checking target: expected lambda_5 to have shape (1024, 0) but got array with shape (1024, 1)
Когда я проверяю слои модели, я вижу выходные данныеЯ ожидаю, что форма окончательных слоев слоев будет (NUM_SAMPLES, 1)
, но форма лямбда-вывода равна (NUM_SAMPLES, 0)
, что имеет смысл, основываясь на ошибке.Я попытался внести изменение формы, а также попытался сгладить и плотно без блокировки.
Код:
def unet(input_shape=(1024, 1), fft=True):
wave_input = Input(shape=input_shape, name='main_input')
x = Lambda(lambda v: tf.to_float(tf.spectral.rfft(v)))(wave_input)
conv1 = Conv1D(
64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(x)
conv1 = Conv1D(
64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv1)
pool1 = MaxPooling1D(2)(conv1)
conv2 = Conv1D(
128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool1)
conv2 = Conv1D(
128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv2)
pool2 = MaxPooling1D(2)(conv2)
conv3 = Conv1D(
256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool2)
conv3 = Conv1D(
256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv3)
pool3 = MaxPooling1D(2)(conv3)
conv4 = Conv1D(
512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool3)
conv4 = Conv1D(
512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling1D(2)(drop4)
conv5 = Conv1D(
1024,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(pool4)
conv5 = Conv1D(
1024,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)
up6 = Conv1D(
512,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling1D(2)(drop5))
merge6 = concatenate([drop4, up6], axis=2)
conv6 = Conv1D(
512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge6)
conv6 = Conv1D(
512,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv6)
up7 = Conv1D(
256,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling1D(2)(conv6))
merge7 = concatenate([conv3, up7], axis=2)
conv7 = Conv1D(
256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge7)
conv7 = Conv1D(
256,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv7)
up8 = Conv1D(
128,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling1D(2)(conv7))
merge8 = concatenate([conv2, up8], axis=2)
conv8 = Conv1D(
128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge8)
conv8 = Conv1D(
128,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv8)
up9 = Conv1D(
64,
2,
activation='relu',
padding='same',
kernel_initializer='he_normal')(UpSampling1D(2)(conv8))
merge9 = concatenate([conv1, up9], axis=2)
conv9 = Conv1D(
64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(merge9)
conv9 = Conv1D(
64,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv9)
conv9 = Conv1D(
2,
3,
activation='relu',
padding='same',
kernel_initializer='he_normal')(conv9)
conv10 = Conv1D(1, 1, activation='sigmoid')(conv9)
x = tf.keras.layers.Lambda(lambda v: tf.to_float(tf.spectral.irfft(tf.cast(v, dtype=tf.complex64))))(conv10)
# x = tf.keras.layers.Reshape(input_shape)(x)
# output = tf.keras.layers.Dense(input_shape[0])(x)
model = tf.keras.models.Model(inputs=[wave_input], outputs=[output])
model.compile(
optimizer=Adam(lr=1e-4),
loss='binary_crossentropy',
metrics=['accuracy'])
return model
Спасибо за любую помощь