Параллельные слои в Keras перезаписываются - PullRequest
0 голосов
/ 15 января 2020

Я пытаюсь применить БПФ по одной оси 2D-изображения. Моя модель делает это по одной строке за a для l oop, прежде чем объединять результаты для вывода БПФ изображения.

Моя проблема в том, что в нем хранится только результат последней строки изображения (результат каждой новой строки перезаписывается последним).

Я решил, что мог бы решить эту проблему, взяв .copy () для каждой строки, но метод не существует.

FFT of incrementing frequencies

Вот пример кода, который вы можете запустить:

import matplotlib.pyplot as plt
import numpy as np
from keras.layers import Input, Lambda, Concatenate
from keras.models import Model
from keras.layers import Reshape
import tensorflow as tf


def ParallelFFT(input_layer, input_shape):
    subIns = []
    subOuts = []

    for i in range(input_shape[1]):
        # select one line at index i
        subIn = Lambda(lambda w: w[:, :, i, 0], name="sub_input_%d" % (i + 1))(
            input_layer)  # w.shape: (?, 2048, 10, 1) --> (?, 2048)
        # Compute 1D FFT
        subOut = Lambda(lambda v: tf.cast(
            tf.slice(tf.fft(tf.cast(v, dtype=tf.complex64)), [0, input_shape[0] // 2], [1, -1]), tf.float32),
                        name="sub_FFT_%d" % (i + 1))(subIn)  # -> (?, 1024)

        # Add empty dimension for concatenation
        subOut = Reshape((input_shape[0] // 2, 1))(subOut)  # -> (?, 1024, 1)

        subIns.append(subIn)
        subOuts.append(subOut)

    out = Concatenate(axis=-1)(subOuts)  # -> (?, 1024, 10)

    return Model(inputs=input_layer, outputs=[out], name="Parallel_1D_FFT")


def fftModel2D(input_shape):
    x_input = Input(input_shape)
    p = ParallelFFT(x_input, input_shape)
    x = p(x_input)
    return Model(inputs=x_input, outputs=[x])


model = fftModel2D((2048, 10, 1))

testData = []
for i in range(10):
    testData.append(np.sin((i+1)**3.5*np.linspace(0, 1, 2048)))
testData = np.reshape(np.moveaxis(np.asarray(testData), 0, 1), (1, 2048, 10, 1))

pred = model.predict(testData, batch_size=1)[0]

fig, axes = plt.subplots(1, 3)
axes[0].imshow(np.squeeze(testData), aspect="auto")
axes[0].set_title("Input Signal")
axes[1].imshow(np.abs(np.fft.fft(np.squeeze(testData), axis=0)[1024:]), aspect="auto")
axes[1].set_title("Expected Output")
axes[2].imshow(np.squeeze(np.abs(pred)), aspect="auto")
axes[2].set_title("Output")
plt.show()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...