# -*- coding: utf-8 -*-
"""
Created on Sat Aug  6 12:36:27 2022

@author: Chris

Lab 2 Question 3
"""

import numpy as np
import matplotlib.pyplot as plt

def bisection(f,a,b,tol):
    # Function that performs a bisection zero search
    # for a function f(x) on the interval a and b. Returns
    # the zero when within a tolerance tol, or prints
    # an error if it doesn't find a zero within 1000 iterations
    
    NMAX = 1000     #maximum iteractions for the bisection
    
    xa = a
    xb = b
    
    N = 1
    while N<NMAX:
        print(N,xa,xb)
        
        fa = f(xa)
        fb = f(xb)
        
        if fa*fb<0:
            if abs(xa-xb)<tol:
                xf = 0.5*(xa+xb)
                print("Final root found at",xf)
                return xf
            else:
                xm = 0.5*(xa+xb) #do the bisection
        
                # now test the midpoint and rebracket
                if f(xa)*f(xm)<0:
                    xb = xm
                elif f(xb)*f(xm)<0:
                    xa = xm    
        else:
            print("Warning: bracket lost. Searching anyway...")
            xb = 0.5*(xa+xb)
        N = N+1
        # end of while loop
    
    #If we end up here then we've bisected more
    # than NMAX times without achieving the tolerance
    print("Error in bisection routine: zero not found to required tolerance.")
    xf = float("NaN") # assign "not a number" to the final result
        
    return xf

def golden(f,a,b,tol):
    phi = (1 + 5**0.5) / 2
    
    while np.abs(b-a)>tol:
        
        #pick two new points:
        c = b - (b-a)/phi
        d = a + (b-a)/phi
        
        if f(d)<f(c):
            a = c
            c = d
        else:
            b = d
            
    #print("minimum found at",c)
    return c

def minbracket(f,a,b,N):
    blist = []
    h = (b-a)/N
    for i in range(1,N):
        
        xa = a+(i-1)*h
        xb = a+i*h
        xc = a+(i+1)*h
        
        if f(xb)<f(xa) and f(xb)<f(xc):
            #print('Minimum found between',xa,'and',xc)
            blist.append([xa,xc])
    return blist     
        

