Алгоритм EM всегда рушится: все средние / стандартные значения становятся одинаковыми - PullRequest
0 голосов
/ 29 мая 2020

От скуки я решил создать простую и быструю реализацию EM. Код ниже.

Однако по какой-то причине мои оценочные средние / стандартные значения всегда сходятся к одним и тем же значениям для всех классов. Я не получаю спецификацию кластера c означает / stds, как можно было бы ожидать от алгоритма EM. Я часами пытаюсь отладить код. Я искал в Google, но не смог найти никаких статей о разрушении EM, что заставило меня поверить, что это просто моя глупость. даже тогда они соскакивают, по-видимому, сходясь к распределению, которое имеет среднее / стандартное значение всех точек, а не конкретных c кластеров. Меня это очень смущает. Я сделал ошибку обновления, которую просто не вижу?

import numpy as np
import matplotlib.pyplot as plt

def normal(x, mu, sig):
    return np.exp(-0.5*np.power((x-mu)/sig,2.))/(np.sqrt(2.*np.pi*(sig**2)))

def em(X, num_cl, num_it):

    # Initial mean/std values
    mu = np.full(num_cl, np.mean(X)) + np.random.rand(num_cl)*5
    sig = np.full(num_cl, np.std(X)) + np.abs(np.random.rand(num_cl))*5

    for i in range(num_it):
        # Temp variables to hold new means/stds
        new_mu = np.zeros_like(mu)
        new_sig = np.zeros_like(sig)

        y = np.sum([normal(X, mu[k], sig[k]) for k in range(num_cl)], axis=0)

        for j in range(num_cl):
            w = normal(X, mu[j],sig[j])/y

            # Calculate updated mean/std
            new_mu[j] = np.sum(w*X)/np.sum(w)
            new_sig[j] = np.sum(w*(X-new_mu[j])**2)/np.sum(w)

        # Set updated means/stds
        mu = new_mu
        sig = new_sig

    # Returns list of mu, sig
    return mu, sig

num_cl = 2 # Number of classes
num_pts = 10 # Number of points per class
num_it = 10 # Number of iterations

X_mu, X_sig = [1,6], [1,1.5] # mean/std per class

# Data points generated using above true means/stds
X = np.hstack([np.random.normal(X_mu[i], X_sig[i], num_pts) for i in range(len(X_mu))])

# EM execution
X_mu_p, X_sig_p = em(X, num_cl, num_it)

print('real', X_mu, X_sig)
print('pred', X_mu_p, X_sig_p)

# Plot
ls = np.linspace(-10,10,1000)
for j in range(num_cl):

    # Real points
    x_real = X[j*num_pts:j*num_pts+num_pts]
    plt.scatter(x_real,np.zeros_like(x_real), alpha=0.2)

    # real distributions
    #x_real = normal(ls, X_mu[j], X_sig[j])
    #plt.plot(ls, x_real)

    # Estimated distribution
    x_pred = normal(ls, X_mu_p[j], X_sig_p[j])
    plt.plot(ls, x_pred, alpha=0.6)

plt.show()

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...