# -*- coding: utf-8 -*-
"""
Created on Tue Aug 16 15:05:18 2022

@author: 102194
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import mysearch as mys

def f(x,y):
    #f = ((4 - 2.1*x**2 + x**4 / 3.) * x**2 + x * y
     #       + (-4 + 4*y**2) * y **2)
    f = (x-1)**2 + (y)**2
    return f

def fline(t):

    x = xn + t*pn
    
    fline = f(x[0],x[1])
    return fline
    
x = np.linspace(-2,2,100)
y = np.linspace(-1,1,100)

X,Y = np.meshgrid(x,y)
#plt.pcolor(X,Y,f(X,Y))
#plt.contour(X,Y,f(X,Y),100)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, f(X,Y))

NMAX = 10
NDIM = 2

t0 = -4
t1 = 4
tol = 1e-5
       
       
deltaf = np.zeros(NDIM)
# initial search point
xn = np.array([-2,-1])
xnm = xn

# initial conjugate directions
pset = np.array([[1,0],[0,1]])

for i in range(0,NMAX):
    #loop over each iteration
    fn = f(xn[0],xn[1]) #store function value at current point
    for j in range(0,NDIM):
        # Do a search along each direction in the set
        pn = pset[j]
        blist = mys.minbracket(fline,t0,t1,50)
        
        if len(blist)>0:
            
            tmin = mys.golden(fline,blist[0][0],blist[0][1],tol)
            xmin = xn + pn*tmin
            fmin = fline(tmin)
            
            deltaf[j] = fn-fmin
            print(xmin,pn)
            
            xn = xmin
        else:
            print('Minimum lost')
            break
        input()
        
    #do a final search:
    pn = xn-xnm
    blist = mys.minbracket(fline,t0,t1,50)
    tmin = mys.golden(fline,blist[0][0],blist[0][1],tol)
    xmin = xn + pn*tmin
    
    print(xmin)
    input()
    
    #check if within tolerance
    change = np.sqrt((xnm[0]-xn[0])**2+(xnm[1]-xn[1])**2)
    if change<tol:
        print('Minimum found at',xn)
        break
    
    #if not, update the direction set:
    jmax = np.argmax(abs(deltaf))  # find the largest change in f  
    
    pset[jmax] = pn #set to current direction

    xnm = xn #set new point to old point





