Предсказание модели gpflow говорит, что ожидалось, что вход будет двойным тензором, но это тензор с плавающей точкой - PullRequest
1 голос
/ 24 февраля 2020

Я пытаюсь запустить код из учебника gpflow: https://gpflow.readthedocs.io/en/stable/notebooks/regression.html Однако он не работает.

Следующий код:

N = 12
X = np.random.rand(N,1)
Y = np.sin(12*X) + 0.66*np.cos(25*X) + np.random.randn(N,1)*0.1 + 3
plt.plot(X, Y, 'kx', mew=2)

k = gpflow.kernels.Matern52(variance=1.0, lengthscale=1.0)
m = gpflow.models.GPR((X, Y), k, mean_function=None, noise_variance=1.0)
m.likelihood.variance = 0.01

def plot(m):
    xx = np.linspace(-0.2, 1.2, 141)[:,None]
    xx=tf.convert_to_tensor(xx,dtype=tf.float64)
    mean, var = m.predict_y(xx)
    plt.figure(figsize=(12, 6))
    plt.plot(X, Y, 'kx', mew=2)
    plt.plot(xx, mean, 'b', lw=2)
    plt.fill_between(xx[:,0], mean[:,0] - 2*np.sqrt(var[:,0]), mean[:,0] + 2*np.sqrt(var[:,0]), color='blue', alpha=0.2)
    plt.xlim(-0.1, 1.1)
plot(m)

возвращает следующую ошибку:

InvalidArgumentError: невозможно вычислить AddV2, так как ожидалось, что вход # 1 (на основе нуля) будет двойным тензором, но является тензором с плавающей точкой [Op: AddV2] name: add /

У меня есть windows 10, python 3.6, тензор потока 2.0, вероятность тензор потока 0,9, и gpflow был установлен с помощью команды pip install -e. команда от 21 февраля 2020 года.

Не могли бы вы помочь мне с этим? Я преобразовываю ввод в double, так что я думаю, что gpflow обновил код, но не учебник.

Ответы [ 2 ]

4 голосов
/ 24 февраля 2020

Проблема, с которой вы сталкиваетесь, связана с новым способом обновления значений параметров в GPflow. Вместо model.parameter = value вы должны использовать assign:

m.likelihood.variance.assign(0.01) 

. Это гарантирует, что тип параметров не изменится.

Мне удалось получить следующий график после того, как установка lengthscale ядра на 0.25.

enter image description here

0 голосов
/ 03 марта 2020

Я задавал (и отвечал сам) здесь тот же вопрос: GPFlow-2.0 - проблема с default_float и дисперсией вероятности

В общем, я бы предложил НЕ использовать примеры на readthedocs если вы пытаетесь использовать gpflow 2.0. Вместо этого, клонируйте репозиторий github и используйте примеры ноутбуков оттуда - они в основном были обновлены для корректной работы с gpflow 2.

Например, обновленная версия вашего примера регрессии доступна здесь: https://github.com/GPflow/GPflow/blob/develop/doc/source/notebooks/basics/regression.pct.py

Эта версия работает с gpflow 2.

...