#!/usr/bin/env python

# Run logistic regression training for different learning rates with stochastic gradient descent.

import numpy as np
import scipy.special as sps
import matplotlib.pyplot as plt
import assignment2 as a2

#execfile("logistic_regression_mod.py")

# Maximum number of iterations.  Continue until this limit, or when error change is below tol.
max_iter = 500
tol = 0.00001

# Step size for gradient descent.
etas = [0.5, 0.3, 0.1, 0.05, 0.01]
#etas = [0.1, 0.05, 0.01]



# Load data.
data = np.genfromtxt('data.txt')

# Randomly permute the training data.
data = np.random.permutation(data)

# Data matrix, with column of ones at end.
X = data[:,0:3]
# Target values, 0 for class 1, 1 for class 2.
t = data[:,3]
# For plotting data
class1 = np.where(t==0)
X1 = X[class1]
class2 = np.where(t==1)
X2 = X[class2]

n_train = t.size

# Error values over all iterations.
all_errors = dict()

for eta in etas:
    # Initialize w.
    w = np.array([0.1, 0, 0])
    e_all = []

    for iter in range (0, max_iter):
        for n in range (0, n_train):
            # Compute output using current w on sample x_n.
            y = sps.expit(np.dot(X[n,:],w))

            # Gradient of the error, using Assignment result
            grad_e = (y - t[n])*X[n,:]

            # Update w, *subtracting* a step in the error derivative since we're minimizing
            # w = fill this in
  

        # Compute error over all examples, add this error to the end of error vector.
        # Compute output using current w on all data X.
        y = sps.expit(np.dot(X,w))
  
        # e is the error, negative log-likelihood (Eqn 4.90)
        e = -np.mean(np.multiply(t,np.log(y)) + np.multiply((1-t),np.log(1-y)))
        e_all.append(e)

        # Print some information.
        print 'eta={0}, epoch {1:d}, negative log-likelihood {2:.4f}, w={3}'.format(eta, iter, e, w.T)
  
        # Stop iterating if error doesn't change more than tol.
        if iter>0:
            if np.absolute(e-e_all[iter-1]) < tol:
                break
            
    all_errors[eta] = e_all

    
# Plot error over iterations for all etas
plt.figure(10)
plt.rcParams.update({'font.size': 15})
for eta in sorted(all_errors):
    plt.plot(all_errors[eta], label='sgd eta={}'.format(eta))
    
plt.ylabel('Negative log likelihood')
plt.title('Training logistic regression with SGD')
plt.xlabel('Epoch')
plt.axis([0, max_iter, 0.2, 0.7])
plt.legend()
plt.show()
