function blr_demo() % Bayesian linear regression demo % % This demo will use RBF basis functions. We assume the only unknowns % are the weights and the width of the basis functions. To keep the % discussion simpler, we assume everything else is known. Even though in % reality there are many other "hyper-parameters" -- model choices that % I have plucked out of thin air, and in reality could be questioned. % % I've written this quickly as a demonstration to accompany a lecture. % It is not production code. I have also not checked it in any way. It % is highly likely to contain mistakes! % % Iain Murray, November 2017 VERBOSE = true; x_grid = (0:0.001:1)'; % inputs used for plotting functions "everywhere" x_train = (0.1:0.1:0.7)'; % not modelled, assumed fixed and known sigma_y = 0.1; % standard deviation of Gaussian observation noise sigma_w = 1; % standard deviation of zero-mean Gaussian prior on weights K = 100; % number of RBF basis functions % Decide which bandwidths to consider, assign them colors for the second plot, % and pre-compute input features for each bandwidth. bandwidths = [0.01, 1]; colors = {[0.7 0.7 0.7], [0.7, 0, 0]}; Phi_h = cell(size(bandwidths)); Phi_grid_h = cell(size(bandwidths)); for ii = 1:numel(bandwidths) hh = bandwidths(ii); Phi_h{ii} = rbfs(x_train, K, hh); Phi_grid_h{ii} = rbfs(x_grid, K, hh); end % FIRST PLOT if VERBOSE fprintf('Figure 1: each subplot gives a sample function from the\n'); fprintf('prior model, with some synthetic noisy observations.\n'); fprintf('This figure is meant to give you some idea of what the\n'); fprintf('different bandwidths imply.\n\n'); end figure(1); clf; hold on; num_rows = 3; num_cols = 4; num_subplots = num_rows * num_cols; for ii = 1:num_subplots; h_idx = ceil(rand() * numel(bandwidths)); hh = bandwidths(h_idx); subplot(num_rows, num_cols, ii); hold on; title(sprintf('Sample from prior, basis function width %g', hh)); ww = sigma_w * randn(K, 1); f_grid = Phi_grid_h{h_idx} * ww; f_train = Phi_h{h_idx} * ww; y_train = f_train + sigma_y*randn(size(f_train)); plot(x_grid, f_grid, '-', 'LineWidth', 2, 'color', 0.7*ones(3,1)); plot(x_train, y_train, '+', 'MarkerSize', 20, 'LineWidth', 2); end % SECOND PLOT if VERBOSE fprintf('Figure 2: each subplot gives some data from one of the\n'); fprintf('models. In gray are some posterior samples -- plausible\n'); fprintf('explanations of the data -- assuming the short\n'); fprintf('bandwidth. In red are some posterior samples assuming\n'); fprintf('the long bandwidth. The posterior probability of the\n'); fprintf('short-bandwidth is shown for each dataset.\n\n'); end figure(2); clf; hold on; for ii = 1:num_subplots; subplot(num_rows, num_cols, ii); hold on; % Generate synthetic data h_idx = ceil(rand() * numel(bandwidths)); w_true = sigma_w * randn(K, 1); f_train = Phi_h{h_idx} * w_true; y_train = f_train + sigma_y*randn(size(f_train)); Lmls = zeros(size(bandwidths)); for h_idx = 1:numel(bandwidths) [w_N, chol_inv_V_N, Lml] = blr(Phi_h{h_idx}, y_train, sigma_w, sigma_y); Lmls(h_idx) = Lml; num_samples = 6; w_samples = bsxfun(@plus, w_N, chol_inv_V_N\randn(K, num_samples)); f_samples = Phi_grid_h{h_idx} * w_samples; for ss = 1:num_samples plot(x_grid, f_samples(:,ss), '-', 'LineWidth', 1, 'color', colors{h_idx}); end end % Plot data last, so not obscured by lines: plot(x_train, y_train, '+', 'MarkerSize', 20, 'LineWidth', 2); % Infer which bandwidth used: if numel(bandwidths) == 2 % only implemented special case for two bandwidths: posterior_h_1 = 1 ./ (1 + exp(-(Lmls(1) - Lmls(2)))); title(sprintf('p(h = %g) = %0.3e', bandwidths(1), posterior_h_1)); end end if VERBOSE fprintf('For data generated with the short bandwidth, the long\n'); fprintf('bandwidth samples are bunched up despite not matching\n'); fprintf('the data -- we are too confident. This problem is\n'); fprintf('partly caused by fixing the noise level sigma_y.\n'); fprintf('\n'); fprintf('For data generated with the long bandwidth, the short\n'); fprintf('bandwidth model ignores the smooth trend, and makes\n'); fprintf('rapidly fluctuating predictions. Fortunately this model\n'); fprintf('is deemed to be improbable, so we can ignore it.\n'); fprintf('\n'); fprintf('You may have noticed that some of the plots show the\n'); fprintf('short bandwidth posterior samples reaching out to the\n'); fprintf('data with a series of spikes. Because the narrow basis\n'); fprintf('functions don''t overlap as much, larger weights are\n'); fprintf('needed to get large function values. We could alter the\n'); fprintf('width of the prior, sigma_w, so that the narrow\n'); fprintf('bandwidth model is also happy with large function\n'); fprintf('values.\n'); end function X = solve_chol(chol_A, B) % Solve A*X = B for X, or find inv(A)*B, given chol(A) and B. % As is Matlab/Octave's default, we use upper-triangular Cholesky's. X = chol_A \ (chol_A' \ B); function [w_N, chol_inv_V_N, Lml] = blr(X, yy, sigma_w, sigma_y) % Assume spherical zero-mean prior, with width sigma_w % Return posterior mean w_N, chol(inv(V_N)) where V_N is posterior % covariance, and log-marginal-likelihood. % As is Matlab/Octave's default, we use upper-triangular Cholesky's. [N, D] = size(X); inv_V_N = ((sigma_y/sigma_w)^2 * eye(D) + X'*X) / sigma_y^2; chol_inv_V_N = chol(inv_V_N); w_N = solve_chol(chol_inv_V_N, X'*yy) / sigma_y^2; % Evaluate p(w), p(y|X,w), p(w|y,X) at w=0. Hence get p(y|X) hl2pi = 0.5*log(2*pi); Lp_w0 = -D*(hl2pi + log(sigma_w)); Lp_y_w0 = -0.5*(yy'*yy)/sigma_y^2 - N*(hl2pi + log(sigma_y)); U_w_N = chol_inv_V_N * w_N; Lp_w0_yX = -0.5*(U_w_N'*U_w_N) - D*hl2pi + sum(log(diag(chol_inv_V_N))); Lml = Lp_w0 + Lp_y_w0 - Lp_w0_yX; function Phi = rbfs(X, K, hh) % X Nx1 inputs % K 1x1 number of RBFS % Returns Phi(X), where Phi has K RBFs of width hh, spaced between 0 and 1 assert(size(X, 2) == 1); % Assuming one dimensional inputs. cc = 0:1/(K-1):1; Phi = exp(-bsxfun(@minus, X(:), cc).^2 / hh^2);