#!/usr/bin/python3
import numpy as np
#import matplotlib.pyplot as plt
#import matplotlib as mpl
#from matplotlib import cm
import math as math
import time
#from mpl_toolkits.mplot3d import Axes3D  
from scipy.special import p_roots #for pulling zeroes of Legendre orthogonal polynomials


def diff(f,h,z): #returns the numerical derivative of a function
    return (-f(z+2*h)+8*f(z+h)-8*f(z-h)+f(z-2*h))/(12*h) #5-point derivative formula (O(h^4) error)


def trapquad(f,a,b,tol): # Trapezoial quadrature
    maxn=100
    h=b-a
    I0=(h/2)*(f(a)+f(b))
    for p in range(maxn):
        x = np.array([a+h/2+h*j for j in range(2**p)])
        I= I0/2 + (h/2)*sum(f(x))
        #print(p,I)
        if p>=1 and abs(I0-I)<tol:
            return I
        h=h/2
        I0=I
    print("Error occured: tolerance unreached")
    return None

def glquad(f,a,b,tol): # Gauss-Legendre quadrature
    maxn=100
    I0=0
    for n in range(1,maxn):
        x,w=p_roots(n)
        I=((b-a)/2)*sum(w*(f((b-a)*x/2 + (b+a)/2)))
        if abs(I0-I)<tol:
            return I
        I0=I
    print("Error occured: tolerance unreached")
    return None

def contint(f,eff,dom,tol): #Integrates a function along the boundary of complex rectangles
    match eff: #'Efficiency mode'; trapezoidal rule is used for bracketing procedure and Gauss-Legendre used for residue calculation
        case True:
            intfunc=trapquad
        case _:
            intfunc=glquad
    cyc = [dom[0][0]+dom[1][0]*1j,dom[0][1]+dom[1][0]*1j,dom[0][1]+dom[1][1]*1j,dom[0][0]+dom[1][1]*1j]
    I = 0
    for k in range(0,4): #complete each of the 4 integrals for the 4 sides of the boundary
        b=cyc[(k+1)%4]
        a=cyc[(k)%4]
        I+=intfunc(f,a,b,tol)
    return I


def bracketpoles(f,bregion,n): #bracketing procedure
    brac=[] #contains the rectangular brackets
    x = np.linspace(bregion[0][0],bregion[0][1],n)
    y = np.linspace(bregion[1][0],bregion[1][1],n)
    X,Y = np.meshgrid(x,y,indexing='ij') #Create a complex meshgrid
    def cpaf(t): #cpaf; Cauchy principle argument function
        return diff(f,1e-15,t)/(2*np.pi*1j*f(t))
    for i in range(n-1):
        for k in range(n-1):
            I=np.round(contint(cpaf,True,[[X[i][k],X[i+1][k]],[Y[i][k],Y[i][k+1]]],1e-1),0) #Calculate Cauchy principle argument; high accuracy of this integral is not important)
            if I < 0:
                brac.append([[X[i][k],X[i+1][k]],[Y[i][k],Y[i][k+1]]]) #Save backet if pole detected
    return brac


def res(f,dom,n,tol): #Main function; the residue calculator
    residues = [] #List of residues
    brac = bracketpoles(f,dom,n) #Bracket the poles
    for b in brac:
        p0 = 0
        p=(b[0][0]+b[1][0]*1j+b[0][1]+b[1][1]*1j)/2 #Start pole convergence with midpoint of the complex rectangle
        it=0
        while abs(p0-p) > tol and it < 100: #Apply Newton-Raphson method
            p0 = p
            p = p0 + f(p0)/diff(f,1e-15,p0) #Noting that (1/f)/(1/f)'=+f/f', we swap the typical - in the method to a + to detect poles
            it += 1
            if abs(p0-p) < tol and it > 3:
                residues.append([p,contint(f,False,b,tol)/(2*np.pi*1j)]) #Once pole converged, calculate residue
                break
    return residues 

