% Adapted from an implementation of direct single shooting by J. Andersson, 2016

import casadi.*

%% Preliminaries
T = 10; % Time horizon
N = 100; % number of control intervals

% Declare model variables
x1 = SX.sym('x1');
x2 = SX.sym('x2');
x = [x1; x2];
u = SX.sym('u');

% Model equations
xdot = [(1-x2^2)*x1 - x2 + u; x1];

% Objective term
Jt = x1^2 + x2^2 + u^2;

%% Integrate dynamics
% Forward Euler
dt = T/N;

f = Function('f', {x, u}, {xdot, Jt}); % f has two inputs and two outputs

X0 = MX.sym('X0', 2);
U = MX.sym('U');

[Xdot, Jk] = f(X0, U);
X = X0 + dt*Xdot;
Q = Jk*dt;

% F has two inputs and two outputs, with the corresponding names
F = Function('F', {X0, U}, {X, Q}, {'x0','p'}, {'xf', 'qf'});

% Evaluate at a test point
Fk = F('x0',[0.2; 0.3],'p',0.4);
disp(Fk.xf)
disp(Fk.qf)

%% Formulate the NLP
% Start with empty NLP
w={};
w0 = [];
lbw = [];
ubw = [];
J = 0;
g={};
lbg = [];
ubg = [];

% Iterate through time to obtain constraints at every time step
Xk = [0; 1];
for k=0:N-1
  % New NLP variable for the control
  Uk = MX.sym(['U_' num2str(k)]);
  w = {w{:}, Uk};
  lbw = [lbw, -1]; % Control is bounded between -1 and 1
  ubw = [ubw,  1];
  w0 = [w0,  0];   % Initial guess for the control is all zeros
  
  % Integrate till the end of the interval
  Fk = F('x0',Xk,'p', Uk);
  Xk = Fk.xf;
  J = J+Fk.qf;
  
  % Add inequality constraint: x1 is bounded between -0.25 and infinity
  g = {g{:}, Xk(1)};
  lbg = [lbg; -.25];
  ubg = [ubg;  inf];
end

%% Solve the NLP
% Create an NLP solver
prob = struct( ...        % Problem struct
  'f', J, ...             % Objective function
  'x', vertcat(w{:}), ... % Decision variables (control)
  'g', vertcat(g{:}));    % Constraints

% use IPOPT solver (interior point optimizer)
solver = nlpsol('solver', 'ipopt', prob); 

% Call the solver
sol = solver( ...
  'x0', w0, ...    % Initial guess
  'lbx', lbw, ...  % Bounds for the decision variables (control)
  'ubx', ubw, ...
  'lbg', lbg, ...  % Bounds for the constraint
  'ubg', ubg);

w_opt = full(sol.x); 

%% Plot the solution
u_opt = w_opt;
x_opt = [0;1];
for k=0:N-1
  Fk = F('x0', x_opt(:,end), 'p', u_opt(k+1));
  x_opt = [x_opt, full(Fk.xf)];
end
x1_opt = x_opt(1,:);
x2_opt = x_opt(2,:);
tgrid = linspace(0, T, N+1);
clf;
hold on
plot(tgrid, x1_opt, '--')
plot(tgrid, x2_opt, '-')
stairs(tgrid, [u_opt; nan], '-.')
xlabel('t')
legend('x1','x2','u')
