# -*- coding: utf-8 -*-
"""
Created on Sun Sep  4 17:35:57 2022

@author: Chris
Lab 6, Q4
"""

import numpy as np
import matplotlib.pyplot as plt

def f(x):
    f = x**3+1
    return f


def qtrapn(f,a,b,N):
    # perform trapezoidal rule on a 1D function f, 
    # on the interval [a,b], with N subintervals
    
    h = (b-a)/N
    
    xj = np.array([a+h*j for j in range(N+1)])

    I = h/2*(f(xj[0]) + 2*sum(f(xj[1:-1])) + f(xj[-1]))
    
    return I

def qtrapz0(f,a,b,tol):
    
    maxp = 100
    Iold = 0
    
    for p in range(0,maxp):
        n = 2**p
        #"slow" refined trapezoidal rule
        
        I = qtrapn(f,a,b,n)

        print(n,I)
        
        if p>3 and abs(I-Iold)<tol:
            return I
        
        Iold = I

    print('qtrapz0: quadrature did not achieve specified tolerance.')
    return float('nan')

def qtrapz(f,a,b,tol):
    # "fast" refined trapedzoidal rule
    maxp = 100
    
    
    h = (b-a)
    
    I0 = h/2*(f(a) + f(b)) # first trapezoidal step
    
    print(I0)
    for p in range(maxp):
           
        xj = np.array([a+h/2+h*j for j in range(2**p)])
        
        I = I0/2 + h/2*sum(f(xj))

        print(2**p,I)
        
        if p>3 and abs(I-I0)<tol:
            return I
        h = h/2    
        I0 = I
        
    print('qtrapz: quadrature did not achieve specified tolerance.')
    return float('nan')

a = 0
b = 2

result = qtrapz0(f,a,b,1e-8)
print('Final result:',result)

