Я пытаюсь реализовать STFT
с Pytorch
.Но выход из реализации Pytorch
немного отличается по сравнению с реализацией из Librosa
.
версии Librosa
import numpy as np
from librosa.core import stft
import matplotlib.pyplot as plt
np.random.seed(3)
y = np.sin(2*np.pi*50*np.linspace(0,10,2048))+np.sin(2*np.pi*20*np.linspace(0,10,2048)) + np.random.normal(scale=1,size=2048)
S_stft = np.abs(stft(y, hop_length=512, n_fft=2048,center=False))
plt.plot(S_stft)
![enter image description here](https://i.stack.imgur.com/deAH6.png)
Версия Pytorch
import torch
from torch.autograd import Variable
from torch.nn.functional import conv1d
from scipy.signal.windows import hann
stride = 512
def create_filters(d,k,low=50,high=6000):
x = np.arange(0, d, 1)
wsin = np.empty((k,1,d), dtype=np.float32)
wcos = np.empty((k,1,d), dtype=np.float32)
start_freq = low
end_freq = high
# num_cycles = start_freq*d/44000.
# scaling_ind = np.log(end_freq/start_freq)/k
window_mask = hann(2048, sym=False) # same as 0.5-0.5*np.cos(2*np.pi*x/(k))
for ind in range(k):
wsin[ind,0,:] = window_mask*np.sin(2*np.pi*ind/k*x)
wcos[ind,0,:] = window_mask*np.cos(2*np.pi*ind/k*x)
return wsin,wcos
wsin, wcos = create_filters(2048,2048)
wsin_var = Variable(torch.from_numpy(wsin), requires_grad=False)
wcos_var = Variable(torch.from_numpy(wcos),requires_grad=False)
network_input = torch.from_numpy(y).float()
network_input = network_input.reshape(1,-1)
zx = np.sqrt(conv1d(network_input[:,None,:], wsin_var, stride=stride).pow(2)+conv1d(network_input[:,None,:], wcos_var, stride=stride).pow(2))
pytorch_Xs = zx.cpu().numpy()
plt.plot(pytorch_Xs[0,:1025,0])
![enter image description here](https://i.stack.imgur.com/m5Cit.png)
Мой вопрос
Два графика могут выглядеть одинаково, но если япроверьте два выхода с помощью np.allclose
, мы видим, что они немного отличаются.
np.allclose(S_stft, pytorch_Xs[0,:1025,0].reshape(1025,1))
output >>> False
Только когда я настраиваю допуск на 1e-5
, он дает мне True
результат
np.allclose(S_stft, pytorch_Xs[0,:1025,0].reshape(1025,1),atol=1e-5)
output >>> True
Что вызывает разницу в значениях?Это из-за преобразования данных с использованием torch.from_numpy(y).float()
?
Я хотел бы иметь разницу в значениях меньше, чем 1e-7
, 1e-8
даже лучше.