function blr_demo()

% Bayesian linear regression demo
%
% This demo will use RBF basis functions. We assume the only unknowns
% are the weights and the width of the basis functions. To keep the
% discussion simpler, we assume everything else is known. Even though in
% reality there are many other "hyper-parameters" -- model choices that
% I have plucked out of thin air, and in reality could be questioned.
%
% I've written this quickly as a demonstration to accompany a lecture.
% It is not production code. I have also not checked it in any way. It
% is highly likely to contain mistakes!
%
% Iain Murray, November 2017

VERBOSE = true;
x_grid = (0:0.001:1)';    % inputs used for plotting functions "everywhere"
x_train = (0.1:0.1:0.7)'; % not modelled, assumed fixed and known
sigma_y = 0.1;            % standard deviation of Gaussian observation noise
sigma_w = 1;              % standard deviation of zero-mean Gaussian prior on weights
K = 100;                  % number of RBF basis functions

% Decide which bandwidths to consider, assign them colors for the second plot,
% and pre-compute input features for each bandwidth.
bandwidths = [0.01, 1];
colors = {[0.7 0.7 0.7], [0.7, 0, 0]};
Phi_h = cell(size(bandwidths));
Phi_grid_h = cell(size(bandwidths));
for ii = 1:numel(bandwidths)
    hh = bandwidths(ii);
    Phi_h{ii} = rbfs(x_train, K, hh);
    Phi_grid_h{ii} = rbfs(x_grid, K, hh);
end

% FIRST PLOT
if VERBOSE
    fprintf('Figure 1: each subplot gives a sample function from the\n');
    fprintf('prior model, with some synthetic noisy observations.\n');
    fprintf('This figure is meant to give you some idea of what the\n');
    fprintf('different bandwidths imply.\n\n');
end
figure(1); clf; hold on;
num_rows = 3;
num_cols = 4;
num_subplots = num_rows * num_cols;
for ii = 1:num_subplots;
    h_idx = ceil(rand() * numel(bandwidths));
    hh = bandwidths(h_idx);
    subplot(num_rows, num_cols, ii); hold on;
    title(sprintf('Sample from prior, basis function width %g', hh));
    ww = sigma_w * randn(K, 1);
    f_grid = Phi_grid_h{h_idx} * ww;
    f_train = Phi_h{h_idx} * ww;
    y_train = f_train + sigma_y*randn(size(f_train));
    plot(x_grid, f_grid, '-', 'LineWidth', 2, 'color', 0.7*ones(3,1));
    plot(x_train, y_train, '+', 'MarkerSize', 20, 'LineWidth', 2);
end

% SECOND PLOT
if VERBOSE
    fprintf('Figure 2: each subplot gives some data from one of the\n');
    fprintf('models. In gray are some posterior samples -- plausible\n');
    fprintf('explanations of the data -- assuming the short\n');
    fprintf('bandwidth. In red are some posterior samples assuming\n');
    fprintf('the long bandwidth. The posterior probability of the\n');
    fprintf('short-bandwidth is shown for each dataset.\n\n');
end
    figure(2); clf; hold on;
for ii = 1:num_subplots;
    subplot(num_rows, num_cols, ii); hold on;

    % Generate synthetic data
    h_idx = ceil(rand() * numel(bandwidths));
    w_true = sigma_w * randn(K, 1);
    f_train = Phi_h{h_idx} * w_true;
    y_train = f_train + sigma_y*randn(size(f_train));

    Lmls = zeros(size(bandwidths));
    for h_idx = 1:numel(bandwidths)
        [w_N, chol_inv_V_N, Lml] = blr(Phi_h{h_idx}, y_train, sigma_w, sigma_y);
        Lmls(h_idx) = Lml;
        num_samples = 6;
        w_samples = bsxfun(@plus, w_N, chol_inv_V_N\randn(K, num_samples));
        f_samples = Phi_grid_h{h_idx} * w_samples;
        for ss = 1:num_samples
            plot(x_grid, f_samples(:,ss), '-', 'LineWidth', 1, 'color', colors{h_idx});
        end
    end

    % Plot data last, so not obscured by lines:
    plot(x_train, y_train, '+', 'MarkerSize', 20, 'LineWidth', 2);

    % Infer which bandwidth used:
    if numel(bandwidths) == 2
        % only implemented special case for two bandwidths:
        posterior_h_1 = 1 ./ (1 + exp(-(Lmls(1) - Lmls(2))));
        title(sprintf('p(h = %g) = %0.3e', bandwidths(1), posterior_h_1));
    end
end

if VERBOSE
    fprintf('For data generated with the short bandwidth, the long\n');
    fprintf('bandwidth samples are bunched up despite not matching\n');
    fprintf('the data -- we are too confident. This problem is\n');
    fprintf('partly caused by fixing the noise level sigma_y.\n');
    fprintf('\n');
    fprintf('For data generated with the long bandwidth, the short\n');
    fprintf('bandwidth model ignores the smooth trend, and makes\n');
    fprintf('rapidly fluctuating predictions. Fortunately this model\n');
    fprintf('is deemed to be improbable, so we can ignore it.\n');
    fprintf('\n');
    fprintf('You may have noticed that some of the plots show the\n');
    fprintf('short bandwidth posterior samples reaching out to the\n');
    fprintf('data with a series of spikes. Because the narrow basis\n');
    fprintf('functions don''t overlap as much, larger weights are\n');
    fprintf('needed to get large function values. We could alter the\n');
    fprintf('width of the prior, sigma_w, so that the narrow\n');
    fprintf('bandwidth model is also happy with large function\n');
    fprintf('values.\n');
end


function X = solve_chol(chol_A, B)
% Solve A*X = B for X, or find inv(A)*B, given chol(A) and B.
% As is Matlab/Octave's default, we use upper-triangular Cholesky's.
X = chol_A \ (chol_A' \ B);

function [w_N, chol_inv_V_N, Lml] = blr(X, yy, sigma_w, sigma_y)
% Assume spherical zero-mean prior, with width sigma_w
% Return posterior mean w_N, chol(inv(V_N)) where V_N is posterior
% covariance, and log-marginal-likelihood.
% As is Matlab/Octave's default, we use upper-triangular Cholesky's.
[N, D] = size(X);
inv_V_N = ((sigma_y/sigma_w)^2 * eye(D) + X'*X) / sigma_y^2;
chol_inv_V_N = chol(inv_V_N);
w_N = solve_chol(chol_inv_V_N, X'*yy) / sigma_y^2;
% Evaluate p(w), p(y|X,w), p(w|y,X) at w=0. Hence get p(y|X)
hl2pi = 0.5*log(2*pi);
Lp_w0 = -D*(hl2pi + log(sigma_w));
Lp_y_w0 = -0.5*(yy'*yy)/sigma_y^2 - N*(hl2pi + log(sigma_y));
U_w_N = chol_inv_V_N * w_N;
Lp_w0_yX = -0.5*(U_w_N'*U_w_N) - D*hl2pi + sum(log(diag(chol_inv_V_N)));
Lml = Lp_w0 + Lp_y_w0 - Lp_w0_yX;

function Phi = rbfs(X, K, hh)
%  X Nx1 inputs
%  K 1x1 number of RBFS
% Returns Phi(X), where Phi has K RBFs of width hh, spaced between 0 and 1
assert(size(X, 2) == 1); % Assuming one dimensional inputs.
cc = 0:1/(K-1):1;
Phi = exp(-bsxfun(@minus, X(:), cc).^2 / hh^2);