#!/usr/bin/python3
import math
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import p_roots #for pulling zeroes of Legendre orthogonal polynomials
from scipy.stats import qmc


def qtrapn(f,a,b,N):
    dx=(b-a)/N
    xj = np.array([a+dx*i for i in range(N+1)])
    return dx/2*(f(xj[0])+f(xj[-1])+sum(2*f(xj[1:-1])))




def qsimpn(f,a,b,N):
    dx=(b-a)/N
    xi = np.array([a+dx*i for i in range(N+1)])
    return (dx/3)*(f(xi[0])+f(xi[-1])+sum(4*f(xi[1:-1:2]))+sum(2*f(xi[2:-1:2])))




def qtrapn(f,a,b,N):
    dx=(b-a)/N
    xj = np.array([a+dx*i for i in range(N+1)])
    return dx/2*(f(xj[0])+f(xj[-1])+sum(2*f(xj[1:-1])))



def qtrapz(f,a,b,tol): #fast
    maxp=100
    h=(b-a)
    I0=h/2*(f(a)+f(b))
    I=0
    for p in range(maxp):
        xj = np.array([a+h/2+h*i for i in range(2**p)])
        I=I0/2 + h/2*sum(f(xj))
        print(p,I)
        if p>3 and abs(I-I0)<tol:
            return I
        h*=1/2
        I0=I
    print("tol unreached")
    return None


def qtrapz0(f,a,b,tol): #slow
    maxp=100
    I0=0
    for p in range(maxp):
        I=qtrapn(f,a,b,2**p)
        print(p,I)
        if p>3 and abs(I-I0)<tol:
            return I
        I0=I
    print("tol unreached")
    return None





def quadz(f,a,b,tol):
    maxn=100
    I0=0
    for n in range(1,maxn):
        x,w=p_roots(n)
        I=(b-a)/2*sum(w*f((b-a)/2*x+(a+b)/2))
        print(n,I)
        if abs(I0-I)<tol:
            return I
        I0=I
    print("tol unreached")
    return None



def monty2d(f,region,a,b,M):
    rng = qmc.Sobol(d=2)
    samp = a+(b-a)*rng.random_base2(M)
    xp,yp = samp[:,0],samp[;,1]
    fr= f(xp,yp)*region(xp,yp)
    return ((b-a)**2)*/(2**M)*sum(fr)



def monty3d(f,region,a,b,M):
    rng = qmc.Sobol(d=3)
    samp = a+(b-a)*rng.random_base2(M)
    xp,yp,zp = samp[:,0],samp[:,1],samp[:,2]
    fr= f(xp,yp,zp)*region(xp,yp,zp)
    return ((b-a)**3)*/(2**M)*sum(fr)


def nestquad(f,a,b,c,d):
    def yint(f,x):
        f1 = lambda y : f(x,y)
        return  quadz(f1,c,d,1e-15)
    f2 = lambda x : yint(f,x)
    return quadz(f2,a,b,1e-15)
