От скуки я решил создать простую и быструю реализацию EM. Код ниже.
Однако по какой-то причине мои оценочные средние / стандартные значения всегда сходятся к одним и тем же значениям для всех классов. Я не получаю спецификацию кластера c означает / stds, как можно было бы ожидать от алгоритма EM. Я часами пытаюсь отладить код. Я искал в Google, но не смог найти никаких статей о разрушении EM, что заставило меня поверить, что это просто моя глупость. даже тогда они соскакивают, по-видимому, сходясь к распределению, которое имеет среднее / стандартное значение всех точек, а не конкретных c кластеров. Меня это очень смущает. Я сделал ошибку обновления, которую просто не вижу?
import numpy as np
import matplotlib.pyplot as plt
def normal(x, mu, sig):
return np.exp(-0.5*np.power((x-mu)/sig,2.))/(np.sqrt(2.*np.pi*(sig**2)))
def em(X, num_cl, num_it):
# Initial mean/std values
mu = np.full(num_cl, np.mean(X)) + np.random.rand(num_cl)*5
sig = np.full(num_cl, np.std(X)) + np.abs(np.random.rand(num_cl))*5
for i in range(num_it):
# Temp variables to hold new means/stds
new_mu = np.zeros_like(mu)
new_sig = np.zeros_like(sig)
y = np.sum([normal(X, mu[k], sig[k]) for k in range(num_cl)], axis=0)
for j in range(num_cl):
w = normal(X, mu[j],sig[j])/y
# Calculate updated mean/std
new_mu[j] = np.sum(w*X)/np.sum(w)
new_sig[j] = np.sum(w*(X-new_mu[j])**2)/np.sum(w)
# Set updated means/stds
mu = new_mu
sig = new_sig
# Returns list of mu, sig
return mu, sig
num_cl = 2 # Number of classes
num_pts = 10 # Number of points per class
num_it = 10 # Number of iterations
X_mu, X_sig = [1,6], [1,1.5] # mean/std per class
# Data points generated using above true means/stds
X = np.hstack([np.random.normal(X_mu[i], X_sig[i], num_pts) for i in range(len(X_mu))])
# EM execution
X_mu_p, X_sig_p = em(X, num_cl, num_it)
print('real', X_mu, X_sig)
print('pred', X_mu_p, X_sig_p)
# Plot
ls = np.linspace(-10,10,1000)
for j in range(num_cl):
# Real points
x_real = X[j*num_pts:j*num_pts+num_pts]
plt.scatter(x_real,np.zeros_like(x_real), alpha=0.2)
# real distributions
#x_real = normal(ls, X_mu[j], X_sig[j])
#plt.plot(ls, x_real)
# Estimated distribution
x_pred = normal(ls, X_mu_p[j], X_sig_p[j])
plt.plot(ls, x_pred, alpha=0.6)
plt.show()