У меня есть все эти классы, которые очень похожи, за исключением covar_module. Мне было интересно, могу ли я создать общий суперкласс, который имеет все атрибуты и функции, а подклассы имеют только covar_module. Я новичок в python, поэтому не знаю, каким будет синтаксис.
class RbfGP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y):
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
base_kernel= gpytorch.kernels.RBFKernel(ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
class Matern12GP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y):
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
base_kernel=gpytorch.kernels.MaternKernel(nu=0.5, ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
class Matern32GP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API
def __init__(self, train_X, train_Y):
# squeeze output dim before passing train_Y to ExactGP
super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
base_kernel=gpytorch.kernels.MaternKernel(nu=1.5, ard_num_dims=train_X.shape[-1]),
)
self.to(train_X) # make sure we're on the right device/dtype
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)