# -*- coding: utf-8 -*-
"""
Created on Sat Aug  6 12:36:27 2022

@author: Chris

Lab 2 Question 3
"""

import numpy as np
import matplotlib.pyplot as plt

def bisection(f,a,b,tol):
    # Function that performs a bisection zero search
    # for a function f(x) on the interval a and b. Returns
    # the zero when within a tolerance tol, or prints
    # an error if it doesn't find a zero within 1000 iterations
    
    NMAX = 1000     #maximum iteractions for the bisection
    
    xa = a
    xb = b
    
    N = 1
    while N<NMAX:
        print(N,xa,xb)
        
        fa = f(xa)
        fb = f(xb)
        
        if fa*fb<0:
            if abs(xa-xb)<tol:
                xf = 0.5*(xa+xb)
                print("Final root found at",xf)
                return xf
            else:
                xm = 0.5*(xa+xb) #do the bisection
        
                # now test the midpoint and rebracket
                if f(xa)*f(xm)<0:
                    xb = xm
                elif f(xb)*f(xm)<0:
                    xa = xm    
        else:
            print("Warning: bracket lost. Searching anyway...")
            xb = 0.5*(xa+xb)
        N = N+1
        # end of while loop
    
    #If we end up here then we've bisected more
    # than NMAX times without achieving the tolerance
    print("Error in bisection routine: zero not found to required tolerance.")
    xf = float("NaN") # assign "not a number" to the final result
        
    return xf

def golden(f,a,b,tol):
    phi = (1 + 5**0.5) / 2
    
    while np.abs(b-a)>tol:
        
        #pick two new points:
        c = b - (b-a)/phi
        d = a + (b-a)/phi
        
        if f(d)<f(c):
            a = c
            c = d
        else:
            b = d
            
    #print("minimum found at",c)
    return c

def minbracket(f,a,b,N):
    blist = []
    h = (b-a)/N
    for i in range(1,N):
        
        xa = a+(i-1)*h
        xb = a+i*h
        xc = a+(i+1)*h
        
        if f(xb)<f(xa) and f(xb)<f(xc):
            #print('Minimum found between',xa,'and',xc)
            blist.append([xa,xc])
    return blist     
        








# -*- coding: utf-8 -*-
"""
Created on Tue Aug 16 15:05:18 2022

@author: 102194
"""
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import mysearch as mys

def f(x,y):
    #f = ((4 - 2.1*x**2 + x**4 / 3.) * x**2 + x * y
     #       + (-4 + 4*y**2) * y **2)
    f = (x-1)**2 + (y)**2
    return f

def powellsearch(f,x0,tol):   
    def fline(t):
        x = xn + t*pn
        fline = f(x[0],x[1])
        return fline
    
    #parameters for search
    NMAX = 100 # max iterations
    NDIM = 2   # dimension (fixed at 2)
    t0 = -5    # minimum search distance along each direction
    t1 = 5     # maximum search distance along each direction
    
    deltaf = np.zeros(NDIM)
    # initial search point
    xn = x0
    xnm = xn
    
    # initial conjugate directions
    pset = np.array([[1,0],[0,1]])
    
    for i in range(0,NMAX):
        #loop over each iteration
        fn = f(xn[0],xn[1]) #store function value at current point
        for j in range(0,NDIM):
            # Do a search along each direction in the set
            pn = pset[j]
            blist = mys.minbracket(fline,t0,t1,50)
            
            if len(blist)>0:
                
                tmin = mys.golden(fline,blist[0][0],blist[0][1],tol)
                xmin = xn + pn*tmin
                fmin = fline(tmin)
                
                deltaf[j] = fn-fmin
                #print(xmin,pn)
                
                xn = xmin
            else:
                print('Minimum lost')
                break
        
            
        #do a final search:
        pn = xn-xnm
        blist = mys.minbracket(fline,t0,t1,50)
        tmin = mys.golden(fline,blist[0][0],blist[0][1],tol)
        xmin = xn + pn*tmin
        
        print(xmin)
        
        #check if within tolerance
        change = np.sqrt((xnm[0]-xn[0])**2+(xnm[1]-xn[1])**2)
        if change<tol:
            print('Minimum found at',xn)
            return xn
        
        #if not, update the direction set:
        jmax = np.argmax(abs(deltaf))  # find the largest change in f  
        
        pset[jmax] = pn #set to current direction
    
        xnm = xn #set new point to old point


