# -*- coding: utf-8 -*-
"""
Created on Tue Aug 16 15:05:18 2022

@author: 102194
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import mysearch as mys

def f(x,y):
    #f = ((4 - 2.1*x**2 + x**4 / 3.) * x**2 + x * y
     #       + (-4 + 4*y**2) * y **2)
    f = (x-1)**2 + (y)**2
    return f

def powellsearch(f,x0,tol):   
    # Function that performs a powell-type 
    # direction set search in 2 dimensions.
    #
    # f: function handle to the function that is
    #    to be minimised
    # x0: initial search point (2D numpy array)
    # tol: tolerance

    def fline(t):
        x = xn + t*pn
    
        fline = f(x[0],x[1])
        return fline
    
    #parameters for search
    NMAX = 100 # max iterations
    NDIM = 2   # dimension (fixed at 2)
    t0 = -5    # minimum search distance along each direction
    t1 = 5     # maximum search distance along each direction
    
    deltaf = np.zeros(NDIM)
    # initial search point
    xn = x0
    xnm = xn
    
    # initial conjugate directions
    pset = np.array([[1,0],[0,1]])
    
    for i in range(0,NMAX):
        #loop over each iteration
        fn = f(xn[0],xn[1]) #store function value at current point
        for j in range(0,NDIM):
            # Do a search along each direction in the set
            pn = pset[j]
            blist = mys.minbracket(fline,t0,t1,50)
            
            if len(blist)>0:
                
                tmin = mys.golden(fline,blist[0][0],blist[0][1],tol)
                xmin = xn + pn*tmin
                fmin = fline(tmin)
                
                deltaf[j] = fn-fmin
                #print(xmin,pn)
                
                xn = xmin
            else:
                print('Minimum lost')
                break
        
            
        #do a final search:
        pn = xn-xnm
        blist = mys.minbracket(fline,t0,t1,50)
        tmin = mys.golden(fline,blist[0][0],blist[0][1],tol)
        xmin = xn + pn*tmin
        
        print(xmin)
        
        #check if within tolerance
        change = np.sqrt((xnm[0]-xn[0])**2+(xnm[1]-xn[1])**2)
        if change<tol:
            print('Minimum found at',xn)
            return xn
        
        #if not, update the direction set:
        jmax = np.argmax(abs(deltaf))  # find the largest change in f  
        
        pset[jmax] = pn #set to current direction
    
        xnm = xn #set new point to old point


    
x = np.linspace(-2,2,100)
y = np.linspace(-1,1,100)

X,Y = np.meshgrid(x,y)
#plt.pcolor(X,Y,f(X,Y))
#plt.contour(X,Y,f(X,Y),100)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, f(X,Y))

# test powell search:
tol = 1e-5
x0 = np.array([-2,-1])

xf = powellsearch(f, x0, tol)




