Гауссовская регрессия процесса: среднее значение - PullRequest
1 голос
/ 02 февраля 2020

, как показано в приведенном ниже коде, я хотел бы обучить GP на моих данных x и z. Однако, как показывает график, среднее значение GP равно 0, а дисперсия ядра равна 0. Я надеюсь, что смогу найти некоторые подсказки об этом ...

import gpflow
import numpy as np
import matplotlib
from gpflow.utilities import print_summary
import math
import matplotlib.pyplot as plt
import select

from mpl_toolkits.mplot3d import Axes3D
def objective_closure():
    return - m.log_marginal_likelihood()

from pathlib import Path
# Receving points


matplotlib.rcParams['figure.figsize'] = (12, 6)
plt = matplotlib.pyplot


x= [637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 642.0, 646.8, 651.3, 655.6, 659.3, 662.5, 665.1, 667.0, 668.1, 668.5, 668.1, 667.0, 665.1, 662.5, 659.3, 655.6, 651.3, 646.8, 642.0, 637.1, 637.1, 646.4, 655.6, 664.2, 672.2, 679.4, 685.5, 690.4, 694.0, 696.2, 696.9, 696.2, 694.0, 690.4, 685.5, 679.4, 672.2, 664.2, 655.6, 646.4, 637.1, 637.1, 649.9, 662.5, 674.5, 685.5, 695.3, 703.7, 710.4, 715.4, 718.4, 719.4, 718.4, 715.4, 710.4, 703.7, 695.3, 685.5, 674.5, 662.5, 649.9, 637.1, 637.1, 652.2, 667.0, 681.0, 694.0, 705.5, 715.4, 723.3, 729.1, 732.7, 733.9, 732.7, 729.1, 723.3, 715.4, 705.5, 694.0, 681.0, 667.0, 652.2, 637.1, 637.1, 653.0, 668.5, 683.3, 696.9, 709.0, 719.4, 727.8, 733.9, 737.6, 738.9, 737.6, 733.9, 727.8, 719.4, 709.0, 696.9, 683.3, 668.5, 653.0, 637.1, 637.1, 652.2, 667.0, 681.0, 694.0, 705.5, 715.4, 723.3, 729.1, 732.7, 733.9, 732.7, 729.1, 723.3, 715.4, 705.5, 694.0, 681.0, 667.0, 652.2, 637.1, 637.1, 649.9, 662.5, 674.5, 685.5, 695.3, 703.7, 710.4, 715.4, 718.4, 719.4, 718.4, 715.4, 710.4, 703.7, 695.3, 685.5, 674.5, 662.5, 649.9, 637.1, 637.1, 646.4, 655.6, 664.2, 672.2, 679.4, 685.5, 690.4, 694.0, 696.2, 696.9, 696.2, 694.0, 690.4, 685.5, 679.4, 672.2, 664.2, 655.6, 646.4, 637.1, 637.1, 642.0, 646.8, 651.3, 655.6, 659.3, 662.5, 665.1, 667.0, 668.1, 668.5, 668.1, 667.0, 665.1, 662.5, 659.3, 655.6, 651.3, 646.8, 642.0, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 632.1, 627.3, 622.8, 618.6, 614.8, 611.6, 609.0, 607.2, 606.0, 605.6, 606.0, 607.2, 609.0, 611.6, 614.8, 618.6, 622.8, 627.3, 632.1, 637.1, 637.1, 627.7, 618.6, 609.9, 601.9, 594.8, 588.7, 583.8, 580.2, 578.0, 577.2, 578.0, 580.2, 583.8, 588.7, 594.8, 601.9, 609.9, 618.6, 627.7, 637.1, 637.1, 624.2, 611.6, 599.7, 588.7, 578.8, 570.4, 563.7, 558.7, 555.7, 554.7, 555.7, 558.7, 563.7, 570.4, 578.8, 588.7, 599.7, 611.6, 624.2, 637.1, 637.1, 621.9, 607.2, 593.1, 580.2, 568.6, 558.7, 550.8, 545.0, 541.5, 540.3, 541.5, 545.0, 550.8, 558.7, 568.6, 580.2, 593.1, 607.2, 621.9, 637.1, 637.1, 621.1, 605.6, 590.9, 577.2, 565.1, 554.7, 546.4, 540.3, 536.5, 535.3, 536.5, 540.3, 546.4, 554.7, 565.1, 577.2, 590.9, 605.6, 621.1, 637.1, 637.1, 621.9, 607.2, 593.1, 580.2, 568.6, 558.7, 550.8, 545.0, 541.5, 540.3, 541.5, 545.0, 550.8, 558.7, 568.6, 580.2, 593.1, 607.2, 621.9, 637.1, 637.1, 624.2, 611.6, 599.7, 588.7, 578.8, 570.4, 563.7, 558.7, 555.7, 554.7, 555.7, 558.7, 563.7, 570.4, 578.8, 588.7, 599.7, 611.6, 624.2, 637.1, 637.1, 627.7, 618.6, 609.9, 601.9, 594.8, 588.7, 583.8, 580.2, 578.0, 577.2, 578.0, 580.2, 583.8, 588.7, 594.8, 601.9, 609.9, 618.6, 627.7, 637.1, 637.1, 632.1, 627.3, 622.8, 618.6, 614.8, 611.6, 609.0, 607.2, 606.0, 605.6, 606.0, 607.2, 609.0, 611.6, 614.8, 618.6, 622.8, 627.3, 632.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1, 637.1]
z= [-13.2663, -28.7976, -43.5543, -57.1732, -69.3188, -79.6922, -88.0378, -94.1503, -97.879, -99.1322, -97.879, -94.1503, -88.0378, -79.6922, -69.3188, -57.1732, -43.5543, -28.7976, -13.2663, 2.657, 2.657, -12.487, -27.2581, -41.2926, -54.2449, -65.7961, -75.6617, -83.5989, -89.4122, -92.9584, -94.1503, -92.9584, -89.4122, -83.5989, -75.6617, -65.7961, -54.2449, -41.2926, -27.2581, -12.487, 2.657, 2.657, -10.2252, -22.7903, -34.7287, -45.7466, -55.5727, -63.9649, -70.7167, -75.6617, -78.6783, -79.6922, -78.6783, -75.6617, -70.7167, -63.9649, -55.5727, -45.7466, -34.7287, -22.7903, -10.2252, 2.657, 2.657, -6.7025, -15.8315, -24.5053, -32.5103, -39.6493, -45.7466, -50.6521, -54.2449, -56.4366, -57.1732, -56.4366, -54.2449, -50.6521, -45.7466, -39.6493, -32.5103, -24.5053, -15.8315, -6.7025, 2.657, 2.657, -2.2636, -7.063, -11.6231, -15.8315, -19.5847, -22.7903, -25.3692, -27.2581, -28.4103, -28.7976, -28.4103, -27.2581, -25.3692, -22.7903, -19.5847, -15.8315, -11.6231, -7.063, -2.2636, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 7.5776, 12.377, 16.9371, 21.1456, 24.8988, 28.1043, 30.6833, 32.5721, 33.7244, 34.1116, 33.7244, 32.5721, 30.6833, 28.1043, 24.8988, 21.1456, 16.9371, 12.377, 7.5776, 2.657, 2.657, 12.0165, 21.1456, 29.8194, 37.8243, 44.9634, 51.0607, 55.9661, 59.5589, 61.7506, 62.4872, 61.7506, 59.5589, 55.9661, 51.0607, 44.9634, 37.8243, 29.8194, 21.1456, 12.0165, 2.657, 2.657, 15.5393, 28.1043, 40.0428, 51.0607, 60.8867, 69.2789, 76.0307, 80.9758, 83.9924, 85.0062, 83.9924, 80.9758, 76.0307, 69.2789, 60.8867, 51.0607, 40.0428, 28.1043, 15.5393, 2.657, 2.657, 17.801, 32.5721, 46.6066, 59.5589, 71.1101, 80.9758, 88.913, 94.7262, 98.2725, 99.4643, 98.2725, 94.7262, 88.913, 80.9758, 71.1101, 59.5589, 46.6066, 32.5721, 17.801, 2.657, 2.657, 18.5804, 34.1116, 48.8684, 62.4872, 74.6329, 85.0062, 93.3519, 99.4643, 103.193, 104.4462, 103.193, 99.4643, 93.3519, 85.0062, 74.6329, 62.4872, 48.8684, 34.1116, 18.5804, 2.657, 2.657, 17.801, 32.5721, 46.6066, 59.5589, 71.1101, 80.9758, 88.913, 94.7262, 98.2725, 99.4643, 98.2725, 94.7262, 88.913, 80.9758, 71.1101, 59.5589, 46.6066, 32.5721, 17.801, 2.657, 2.657, 15.5393, 28.1043, 40.0428, 51.0607, 60.8867, 69.2789, 76.0307, 80.9758, 83.9924, 85.0062, 83.9924, 80.9758, 76.0307, 69.2789, 60.8867, 51.0607, 40.0428, 28.1043, 15.5393, 2.657, 2.657, 12.0165, 21.1456, 29.8194, 37.8243, 44.9634, 51.0607, 55.9661, 59.5589, 61.7506, 62.4872, 61.7506, 59.5589, 55.9661, 51.0607, 44.9634, 37.8243, 29.8194, 21.1456, 12.0165, 2.657, 2.657, 7.5776, 12.377, 16.9371, 21.1456, 24.8988, 28.1043, 30.6833, 32.5721, 33.7244, 34.1116, 33.7244, 32.5721, 30.6833, 28.1043, 24.8988, 21.1456, 16.9371, 12.377, 7.5776, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, 2.657, -2.2636, -7.063, -11.6231, -15.8315, -19.5847, -22.7903, -25.3692, -27.2581, -28.4103, -28.7976, -28.4103, -27.2581, -25.3692, -22.7903, -19.5847, -15.8315, -11.6231, -7.063, -2.2636, 2.657, 2.657, -6.7025, -15.8315, -24.5053, -32.5103, -39.6493, -45.7466, -50.6521, -54.2449, -56.4366, -57.1732, -56.4366, -54.2449, -50.6521, -45.7466, -39.6493, -32.5103, -24.5053, -15.8315, -6.7025, 2.657, 2.657, -10.2252, -22.7903, -34.7287, -45.7466, -55.5727, -63.9649, -70.7167, -75.6617, -78.6783, -79.6922, -78.6783, -75.6617, -70.7167, -63.9649, -55.5727, -45.7466, -34.7287, -22.7903, -10.2252, 2.657, 2.657, -12.487, -27.2581, -41.2926, -54.2449, -65.7961, -75.6617, -83.5989, -89.4122, -92.9584, -94.1503, -92.9584, -89.4122, -83.5989, -75.6617, -65.7961, -54.2449, -41.2926, -27.2581, -12.487, 2.657, 2.657, -13.2663, -28.7976, -43.5543, -57.1732, -69.3188, -79.6922, -88.0378, -94.1503, -97.879, -99.1322, -97.879, -94.1503, -88.0378, -79.6922, -69.3188, -57.1732, -43.5543, -28.7976, -13.2663, 2.657, 2.657]

x_np= np.reshape(x,(-1,1))
z_np= np.reshape(z,(-1,1))


k = gpflow.kernels.SquaredExponential()
m = gpflow.models.GPR(data=(x_np, z_np), kernel=k, mean_function=None)

m.likelihood.variance.assign(0.01)
m.kernel.lengthscale.assign(0.2)

print_summary(m)
opt = gpflow.optimizers.Scipy()
opt_logs = opt.minimize(objective_closure,
                        m.trainable_variables,
                        options=dict(maxiter=100))
print_summary(m)
################################################################## Done with generating GPR
numb_of_points= 441

## generate sample points
xx = np.linspace(x_np.min()-20, x_np.max()+20, numb_of_points).reshape(numb_of_points, 1)  

## predict mean and variance of latent GP at test points
mean, var = m.predict_y(xx)

## generate 10 samples from posterior
samples = m.predict_f_samples(xx, 10)  # shape (10, 100, 1 )


##plot

fig = plt.figure(figsize=(10,8))


plt.plot(x_np, z_np, 'kx', mew=2)
plt.plot(xx, mean, 'C0', lw=2)
plt.fill_between(xx[:,0],
                 mean[:,0] - 1.96 * np.sqrt(var[:,0]),
                 mean[:,0] + 1.96 * np.sqrt(var[:,0]),
                 color='C0', alpha=0.2)

plt.plot(xx, samples[:, :,0].numpy().T, 'C0', linewidth=.5)



plt.show()

enter image description here


Редактировать: enter image description here

...