Построение в режиме реального времени модели прогнозов во время обучения - PullRequest
0 голосов
/ 22 февраля 2019

Я пытаюсь увидеть, как моя модель тренируется сверхурочно, видя ее прогнозы относительно реальных значений y на графике, который обновляется в каждом пакете, я гуглил, как это было сделано, и это очень сбивает с толку, ближе всего я получил его к работе, этоздесь добавлен код:

def cb(x, y_true):
    def _(batch, logs):
        s,e=batch*batch_size,(batch+1)*batch_size
        y_pred = model.predict(
            x[s:e],
            batch_size=batch_size
        )
        plt.clf()
        plt.plot(y_true[s:e], label='true')
        plt.plot(y_pred, label='pred')
        plt.legend()
        plt.show()
    return _
cb_plot=keras.callbacks.LambdaCallback(on_batch_end=cb(train_X,train_y))

Проблема в том, что мне нужно каждый раз закрывать фигуру вручную, чтобы продолжить обучение, потому что show блокирует.я попытался использовать block=False, и я попытался включить интерактив, используя ion, но это привело меня к пустому белому окну без ответа.У кого-нибудь есть идея, что нужно изменить здесь, чтобы она работала?

Кстати, я не видел, чтобы TensorBoard мог показать вам прогнозы, которые модель сделала в ходе обучения, по сравнению с реальными значениями y, возможно ливидите что в тензорной доске вместо того, чтобы вручную ее реализовывать?

Спасибо!

Ответы [ 2 ]

0 голосов
/ 22 февраля 2019

это сработало для меня:

plt.show(block=False)

, а затем

def cb(x, y_true):
    def _(batch, logs):
        s,e=batch*batch_size,(batch+1)*batch_size
        y_pred = model.predict(
            x[s:e],
            batch_size=batch_size
        )
        plt.clf()
        plt.plot(y_true[s:e], label='true')
        plt.plot(y_pred, label='pred')
        plt.axis([0, batch_size, -1, 1])
        plt.legend()
        plt.draw()
        plt.pause(0.0001)
    return _
cb_plot=keras.callbacks.LambdaCallback(on_batch_end=cb(train_X,train_y))
0 голосов
/ 22 февраля 2019

Попробуйте использовать matplotlib qt backend с:

%matplotlib qt

или, если вы запускаете .py файл

from IPython import get_ipython
get_ipython().run_line_magic('matplotlib', 'qt')

, затем создайте глобальный axe объект с ax = plt.axes() и, наконец, используйтечтобы нарисовать график:

def plot_stuff():
   ax.clear()
   x = np.linspace(-10, 10, 50)
   ax.plot(x, np.sin(x))

И если вы хотите построить прогнозируемые значения, вы можете создать несколько пользовательских метрических функций, которые просто возвращали бы значения y_true или y_pred.И используйте TensorBoard обратный вызов, чтобы построить его.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...