Gaussian Process Regression using GPyTorch

Gaussian Process, or GP for short, is an underappreciated yet powerful algorithm for machine learning tasks. It is a non-parametric, Bayesian approach to machine learning that can be applied to supervised learning problems like regression and classification. Compared to other supervised learning algorithms, GP has several practical advantages: it can work well on small datasets and has the ability to provide uncertainty measurements on the predictions.

In this tutorial, I am going to demonstrate how to perform GP regression using GPyTorch. GPyTorch is a Gaussian process library implemented using PyTorch that is designed for creating scalable and flexible GP models. You can learn more about GPyTorch on their official website.

Note: This tutorial is not necessarily intended to teach the mathematical background of GP, but rather how to build one using GPyTorch. I highly recommend reading the Chapter 2 of Gaussian Processes for Machine Learning for a very thorough introduction to GP regression.

Setup

Before we start, we first need to install the gpytorch library. You can do this either by pip or conda using the following command:

# Install using pip
pip install gpytorch
# Install using conda
conda install gpytorch -c gpytorch

You can also check the requirements and installation instructions on their website here.

Note: If you want to follow along with this tutorial, you can find the notebook of this tutorial here.

Generating the data Next, we need to generate a training data for our model. We will be modeling the following function:

$$y = \sin{(2 \pi x)} + \epsilon, \enspace \epsilon \sim \mathcal{N}(0, 0.04)$$

The above function is the true function for our GP model, which is a sine function with Gaussian noise. We will evaluate this function on 15 equally-spaced points from [0,1]. The generated training data is depicted in the following plot:

Building the model

Now that we have generated our training data, we can start building our GP model. GPyTorch offers a flexible way for us to build GP models, by constructing the components of the model by ourselves. This is analogous to building neural networks in the standard PyTorch library. For most GP regression models, you will need to construct the following components:

  1. A GP Model: For exact (i.e. non-variational) GP models we will use gpytorch.models.ExactGP.
  2. A likelihood function: The likelihood function for GP regression, we commonly use gpytorch.likelihoods.GaussianLikelihood.
  3. A mean function: The prior mean of the GP. If you don’t know which mean function to use, gpytorch.means.ConstantMean() is usually a good place to start.
  4. A kernel function: The prior covariance of the GP. We’ll use the Spectral Mixture (SM) kernel (gpytorch.kernels.SpectralMixtureKernels()) for this tutorial.
  5. A multivariate normal distribution: The multivariate normal distribution in GP (gpytorch.distributions.MultivariateNormal)

We can build our GP model by constructing the above components as follows:

class SpectralMixtureGP(gpytorch.models.ExactGP):
    def __init__(self, x_train, y_train, likelihood):
        super(SpectralMixtureGP, self).__init__(x_train, y_train, likelihood)
        self.mean = gpytorch.means.ConstantMean() # Construct the mean function
        self.cov = gpytorch.kernels.SpectralMixtureKernel(num_mixtures=4) # Construct the kernel function
        self.cov.initialize_from_data(x_train, y_train) # Initialize the hyperparameters from data
        
    def forward(self, x):
        # Evaluate the mean and kernel function at x
        mean_x = self.mean(x)
        cov_x = self.cov(x)
        # Return the multivariate normal distribution using the evaluated mean and kernel function
        return gpytorch.distributions.MultivariateNormal(mean_x, cov_x) 

# Initialize the likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = SpectralMixtureGP(x_train, y_train, likelihood)

Let me breakdown the above code line-by-line:

  • The above GP model has two main components: the __init__ and forward method.
  • The __init__ method takes the training data and a likelihood as the inputs and constructs whatever objects are necessary for the model’s forward method. This will most commonly include objects like a mean function and a kernel function.
  • The forward method takes in the data x and returns a multivariate normal distribution with the prior mean and covariance evaluated at x.
  • Finally, we initialize the likelihood function for the GP model. Here, we use the Gaussian likelihood, which is the simplest likelihood function that assumes a homoskedastic noise model (i.e. all inputs have the same noise).

Training the model

Now that we have built the model, we can train the model to find the optimal hyperparameters. Training a GP model in GPyTorch is also analogous to training a neural network in the standard PyTorch library. The training loop mainly consists of the following steps:

  1. Setting all the parameter gradients to zero
  2. Calling the model and computing the loss
  3. Calling backward on the loss to fill in gradients
  4. Taking a step on the optimizer

Note: By defining our custom training loop, we can have greater flexibility in training our model. For example, it is easy to save the parameters at each step of training or use different learning rates for different parameters.

The code for the training loop is given below:

# Put the model into training mode
model.train()
likelihood.train()

# Use the Adam optimizer, with learning rate set to 0.1
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# Use the negative marginal log-likelihood as the loss function
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

# Set the number of training iterations
n_iter = 50

for i in range(n_iter):
    # Set the gradients from previous iteration to zero
    optimizer.zero_grad()
    # Output from model
    output = model(x_train)
    # Compute loss and backprop gradients
    loss = -mll(output, y_train)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (i + 1, n_iter, loss.item()))
    optimizer.step()

In the above code, we first put our model into training mode by calling model.train() and likelihood.train(). Then, we define the loss function and optimizer that we want to use in the training process. Here, we use the negative marginal log-likelihood as the loss function and Adam as the optimizer. We also need to set the number of iterations for the training loop, say 50 iterations.

Making predictions with the model

Finally, we can make predictions using the trained model. The basic routine of evaluating the model and making predictions is given in the following code:

# The test data is 50 equally-spaced points from [0,5]
x_test = torch.linspace(0, 5, 50)

# Put the model into evaluation mode
model.eval()
likelihood.eval()

# The gpytorch.settings.fast_pred_var flag activates LOVE (for fast variances)
# See https://arxiv.org/abs/1803.06058
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    # Obtain the predictive mean and covariance matrix
    f_preds = model(x_test)
    f_mean = f_preds.mean
    f_cov = f_preds.covariance_matrix
    
    # Make predictions by feeding model through likelihood
    observed_pred = likelihood(model(x_test))
    
    # Initialize plot
    f, ax = plt.subplots(1, 1, figsize=(8, 6))
    # Get upper and lower confidence bounds
    lower, upper = observed_pred.confidence_region()
    # Plot training data as black stars
    ax.plot(x_train.numpy(), y_train.numpy(), 'k*')
    # Plot predictive means as blue line
    ax.plot(x_test.numpy(), observed_pred.mean.numpy(), 'b')
    # Shade between the lower and upper confidence bounds
    ax.fill_between(x_test.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)
    ax.set_ylim([-3, 3])
    ax.legend(['Observed Data', 'Mean', 'Confidence'])

There are a few things going on in the above code:

  • We first generate the test data using 50 equally-spaced points from [0, 5].
  • We put the model into evaluation mode by calling model.eval() and likelihood.eval().
  • We use gpytorch.settings.fast_pred_var() to get faster predictive distributions using LOVE.
  • When put into the eval mode, the trained GP model returns a MultivariateNormal containing the posterior mean and covariance. Thus, we can obtain the predictive mean and covariance matrix from the multivariate normal distribution.
  • Finally, we plot the mean and confidence region of the fitted GP model. The confidence_region() method is a helper method that returns 2 standard deviations above and below the mean.

The resulting plot is depicted below:

The black stars in the above plot represent the training (observed) data, while the blue line and the shaded area represent the mean and the confidence bounds respectively. Notice how the uncertainty is reduced close to the observed points. If more data points were added, we would see the mean function adjust itself to pass through these points and the uncertainty would reduce close to the observations.

Takeaways

In this tutorial, we have learned how to build a scalable and flexible GP model using GPyTorch. Like any other libraries, there are still a lot of cool things that you can do with GPyTorch which I didn’t cover in this tutorial. For example, you can utilize a GPU to accelerate your model training when implementing state-of-the-art algorithms like Deep GP or stochastic variational deep kernel learning.

Richard Cornelius Suwandi

PhD Student at CUHK-Shenzhen

Related

comments powered by Disqus