Welcome to GPEX’s documentation!

Introduction

GPEX is a method and tool for performing knowledge distillation between Gaussian processes(GPs) and artificial neural networks (NNs). It takes in an arbitrary pytorch module, and replaces one neural-network submodule to be replaced by GPs.

_images/tgpframeworkv.png
The pytorch module can be quite general (as depicted above), with a few requirements:
  • It has to have on ANN submodule.

  • The ANN submodule has to take in one input tensor and one output tensor. The input has to be of dimension [Nx*], where * means any number of dimensions, but the output has to be of shape [NxD].

Usage (Distill from ANN to GP)

A typical code is as follows:

>>> .
>>> .
>>> .
>>> model = SomeModel() #your general pytorch module, it can be for example a classifier
>>> gpexmodule = gpex.GPEXModule(model, ...) #the main module that takes in your general pipeline and asks for different information (like the ANN submodule to be replace by GP, etc.).
>>> .
>>> .
>>> gpexmodule.init_UV() #this method has to be called to initialize internal parameters, i.e., the matrices U and V in the paper.
>>> for itr in range(num_iters):
>>>    .
>>>    loss = gpexmodule.getcost_GPmatchNN() #This is the loss that matches the GP kernel to the ANN submodule.
>>>    loss.backward()
>>>    .
>>>    .
>>>    for _ in range(count_updateU):
>>>        gpexmodule.update_U()
>>>        #The U matrices in the paper have to be updated frequently.
>>>        #Ideally `count_updateU` should be very large, so U is updated by going through the dataset onece.
>>>        #But in practice event when `count_updateU` is set to 1, convergence happens.
>>>
>>>    if(itr%interval_renewXTX == 0):
>>>        gpexmodule.renew_precomputed_XTX()
>>>        #The function `renew_precomputed_XTX` should be called
>>>        #   in every, e.g. 1000 iterations or so to avoid the propagation of a numerical error in internal computations.

Base Modules

class gpex.GPEXModule(module_rawmodule, size_recurringdataset, device, func_mainmodule_to_moduletobecomeGP, func_feed_noise_minibatch, func_feed_inducing_minibatch, func_feed_nonrecurring_minibatch, func_feed_test_minibatch, func_get_indices_lastrecurringinstances, func_get_modulef1, flag_efficient=True, flag_detachcovpvn=True, flag_setcovtoOne=False, flag_controlvariate=True, int_mode_controlvariate=2, flag_train_memefficient=False, memefficeint_heads_in_compgraph=None)[source]

The main module to be created in order to use GPEX.

Inputs.
  • module_rawmodule: the raw module in which a module it to be replaced by GP.

  • size_recurringdataset: the size of the inducing dataset (i.e. the variable M in paper).

  • device: torch device to be used.

  • func_mainmodule_to_moduletobecomeGP: a function that takes in your pytorch module, and returns the ANN submodule to be replaced by GP.

  • func_feed_inducing_minibatch: in this function you should implement how a mini-batch from the inducing dataset is fed to your pytorch module. This function has to have 0 input arguments.

  • func_feed_noise_minibatch: in this function you should implement how a mini-batch of instances over which the GP is matched to ANN, is fed to your pytorch module. As explained in the paper and implemented in the sample notebook, a proper way is to feed a minibatch of samples like lambda*x + (1.0-lambda)*(1-x). This function has to have 0 input arguments.

  • func_feed_nonrecurring_minibatch: in this function you should implement how a mini-batch from the training dataset if fed to your pytorch module. This function has to have 0 input arguments.

  • func_feed_test_minibatch: in this function you should implement how a mini-batch from the testing dataset is fed to your pytorch module. This function has to have 0 input arguments.

  • func_get_indices_lastrecurringinstances: A function that returns the indices of the inducing instances which are fed to the module. In other words, when implementing func_feed_xxx_minibatch you should put the indices of the inducing instances which are last fed, in a list or os, so you can return it later on in this function. Importantly, you have to update the list “before” calling any forward functions (as done in the sample notebook). Otherwised, it may lead to unwanted behaviour.

  • func_get_modulef1: This function has to have 0 input arguments, and returns the kenel module. In the notation of the paper, let’s say ANN has L output heads so there will be L kernel functions. If each kernel-space is considered D-dimensional, the output of the kernel module has to be D*L dimensional, where each group of D dimensions should be L2-normalized. In the paper experiments kernel module ends with a leaky relu activation.

init_UV()[source]

Initializes U (the kernel-space representaitons of the inducing points) and V (the GP posterior values at the inducing points). This function must be called before calling the function getcost_explainANN.

getparams_explainANN()[source]

Returns the parameters to be optimized for explain ANN. When optimizing the cost returned by getcost_explainANN, the optimizer has to operate on the parameters returned by this function.

split_vectors_inkernelspace(input_x)[source]

Splits the Du dimensions in Dv groups, the same way they are splitted to build different kernels.

get_W_of_kernelspace()[source]

Returns the weights of the linear transform on the kernel space.

forward_GPwithoutuncertainty()[source]

Forwards a non-recurring mini-batch as if the GP is trained without uncertainty.

update_U()[source]

Updates some elements of GPX based on the current value of the function (i.e. the module) f1(.).

getcost_GPmatchNN()[source]

Computes cost w.r.t. model params.

get_costQvn()[source]

Computes the KL-divergence part of elbo w.r.t. vn.

get_costQvhatm()[source]

Computes the cost for Qvhatm parameters.

check_GPmatchANN_on_aDataloader(func_feed_dlinstances, func_get_lastidx_fedinstances, list_allidx)[source]

Given a daloader, checks whether the GP and the ANN match on the dataloader. Inputs.

  • func_feed_dlinstances: a function. This function feeds some instances in the dataloader to the raw module.

  • func_get_lastidx_fedinstances: a function. This fucntion returns the indices of the last fed instances.

  • list_allidx: a list. list of all indices of the dataloader’s instances that has to be fed to the raw module.

initV_from_theannitself()[source]

This function initializes the GP_Y based on g(.) values at recurring points.

initU_from_kernelmappings()[source]

Initializes GPX from module f1.