# -*- coding: utf-8 -*-
"""
Created on Tue Aug 16 15:05:18 2022

@author: 102194
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import mysearch as mys

def f(x,y):
    #f = ((4 - 2.1*x**2 + x**4 / 3.) * x**2 + x * y
     #       + (-4 + 4*y**2) * y **2)
    f = (x-1)**2 + (y)**2
    return f

def fline(t):

    x = xn + t*pn
    
    fline = f(x[0],x[1])

    return fline
    
x = np.linspace(-2,2,100)
y = np.linspace(-1,1,100)

X,Y = np.meshgrid(x,y)
#plt.pcolor(X,Y,f(X,Y))
#plt.contour(X,Y,f(X,Y),100)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, f(X,Y))

NMAX = 10
NDIM = 2
# initial search point
xn = np.array([-2,-1])
xnm = xn

# initial conjugate directions
pset = np.array([[1,0],[0,1]])

for i in range(0,NMAX):
    #loop over each iteration
    for ni in range(0,NDIM):
        # Do a search along each direction in the set
        pn = pset[ni]
        
        
        t0 = -4
        t1 = 4
        tol = 1e-5
        
        trange = np.linspace(t0,t1,20)
        
        #frange = [fline(t) for t in trange]
        #plt.plot(trange,frange)
    
        blist = mys.minbracket(fline,t0,t1,20)
        
        
        if len(blist)>0:
            imin = 0 #take the first root found
            tmin = mys.golden(fline,blist[imin][0],blist[imin][1],tol)
            xmin = xn + pn*tmin
            print(xmin)
            
            xn = xmin
        else:
            print('Minimum lost')
            break
    
    change = np.sqrt((xnm[0]-xn[0])**2+(xnm[1]-xn[1])**2)
    if change>tol:
            #min not yet found, keep going  
        xnm = xn
    else:
        print('Minimum found at',xn)
        break
    




