% This demonstration gives a slightly different view of the two stage sampling
% process demonstrated in gp_minimal.m. See that file for more details. Here we
% reinforce that sampling from the posterior really is just continuing the prior
% sampling process given the values that we've seen. I like this version of the
% demo because there's less linear algebra than in gp_minimal.m, but this
% presentation is less standard.
%
% Iain Murray, November 2016


%% The kernel function (as in gp_minimal.m)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

rbf_fn = @(X1, X2) exp(bsxfun(@minus, bsxfun(@minus, X1*(2*X2'), sum(X1.*X1,2)), sum(X2.*X2,2)'));
gauss_kernel_fn = @(X1, X2, ell, sigma_f) sigma_f^2 * rbf_fn(bsxfun(@rdivide, X1, sqrt(2)*ell(:)'), bsxfun(@rdivide, X2, sqrt(2)*ell(:)'));
k_fn = @(X1, X2) gauss_kernel_fn(X1, X2, 3.0, 10.0);

%% Sampling from the prior
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Pick the input locations that we want to see the function at.
X_train = [2,4,6,8]' + 0.01;
X_test = (0:0.02:10)';
X_all = [X_train; X_test];
N_train = size(X_train, 1);
N_all = size(X_all, 1);

% The joint distribution over function values has zero mean and covariance K_all = L_all*L_all'
K_all = k_fn(X_all, X_all) + 1e-9*eye(N_all);
L_all = chol(K_all, 'lower');

% Function values can be sampled with: L_all*nu, where nu = randn(N_all,1).
% Because L_all is lower-triangular, the first N_train function values depend
% only on the first N_train values of nu. We pick those first:
nu1 = randn(N_train, 1);
figure(1);
clf; hold all;
for ii = 1:3
    % Then we consider different samples from the prior that complete those
    % first N_train values in different ways:
    nu2 = randn(N_all - N_train, 1);
    nu = [nu1; nu2];
    f_all = L_all * nu;
    % These x's will fall on top of each other for each loop, as nu1 is shared:
    plot(X_train, f_all(1:N_train), 'x', 'MarkerSize', 20, 'LineWidth', 2);
    % But we'll get different completions for different nu2. These are
    % samples from the posterior given the 'x' observations.
    plot(X_test, f_all(N_train+1:end), '-', 'LineWidth', 2);
end

legend({'train points', 'completions / posterior samples'});
xlabel('x');
ylabel('f');


% Want to see samples from the posterior given noisy observations? You could
% insert the following two lines beneath the definition of K_all:
%noise_var = 1.0;
%K_all(1:N_train, 1:N_train) = K_all(1:N_train, 1:N_train) + noise_var*eye(N_train);

% You could extend the demo to plot mean and error bars like in gp_minimal.m

% Of course we don't see the random numbers nu1 directly when we observe data.
% However, they are known: we can solve for nu1 from the observed values:
nu1_from_obs = L_all(1:N_train, 1:N_train) \ f_all(1:N_train);
assert(max(abs(nu1_from_obs - nu1)) < 1e-9)


% Notice how almost all of the code above is comments, plotting, and tracking
% which data points are which. Little maths is required to sample realizations
% of complex functions given data.