function l9_least_squares()

% Data points
n = 25;

x = -1 + 2*rand(n,1);
m = 0.5;
b = 2;
Y = m*x+b;
X = [x ones(n,1)];

noise = -0.25 + 0.5*rand(n,1);
Y = Y + noise;

% Bad solution
f = figure;
f.Position = [100 100 600 900];
subplot(2,1,1)
plot(x,Y, 'k.', 'markersize', 20);
theta = [0.75; 1.9];
hold on
xs = -1:0.01:1;
plot(xs, theta(1)*xs + theta(2), 'b-', 'linewidth', 2.5)
plot_err(x,Y,theta)
grid on
title(sprintf("theta(1) = %.2f, theta(2) = %.2f", theta(1), theta(2)))

% Analytic optimal solution
subplot(2,1,2)
plot(x,Y, 'k.', 'markersize', 20);
hold on
theta = X'*X\X'*Y;
plot(xs, theta(1)*xs + theta(2), 'b-', 'linewidth', 2.5)
plot_err(x,Y,theta)
grid on
title(sprintf("theta(1) = %.2f, theta(2) = %.2f", theta(1), theta(2)))

% Gradient descent
f2 = figure;
f2.Position = [100 100 1200 500];
theta_last = [1; -2];
thetas = theta_last;
dtheta = inf;
maxIter = 25;

Ms = linspace(-5, 5, 100);
Bs = linspace(-5, 5,100);

[ms, bs] = ndgrid(Ms, Bs);

for i = 1:length(Ms)
    for j = 1:length(Bs)
        err(i,j) = norm(Y - X*[Ms(i); Bs(j)], 2)^2;
    end
end

theta_opt = X'*X\X'*Y;

for k = 1:maxIter
    if (norm(dtheta) <= 0.0001)
        break;
    end
    
%     alpha = 0.05/k^2;
    alpha = 0.05/k;
    theta = theta_last - alpha*(2*X'*X*theta_last - 2*X'*Y);
    dtheta = theta_last - theta;
    theta_last = theta;
    
    subplot(1,2,1)
    plot(x,Y, 'k.', 'markersize', 20);
    hold on
    plot(xs, theta(1)*xs + theta(2), 'b-', 'linewidth', 2.5);
    plot_err(x,Y,theta);
    
    title(sprintf("theta(1) = %.2f, theta(2) = %.2f", theta(1), theta(2)))
    grid on
    hold off
    
    subplot(1,2,2)
    thetas = [thetas theta];
    plot(thetas(1,:), thetas(2,:), 'k.-', 'markersize', 15)
    hold on
    plot(theta_opt(1), theta_opt(2), 'rx', 'markersize', 10)
    contour(ms, bs, err, logspace(-3, 3, 25))
    grid on
    xlabel('m')
    ylabel('b')
    xlim([-5 5])
    ylim([-5 5])
    drawnow;
    
    hold off
%     export_fig(sprintf("LS_demo_fig/%d", k), '-png')
end

end

function he = plot_err(x,Y,theta)
n = length(x);
for i = 1:n
    plot([x(i) x(i)], [Y(i) theta(1)*x(i) + theta(2)], 'r-')
end

end