import numpy as np
import matplotlib.pyplot as plt
import scipy.linalg as la
import scipy.sparse as sp
import scipy.sparse.linalg as sla

from mpl_toolkits.mplot3d import Axes3D

def f(x,y):
    #fout = x*(1+4*y**2)
    f0 = y[1]
    f1 = -4*y[0]
    fout = np.array([f0,f1])
    return fout


def rk4solve(f,x0,y0,xmax,h):
    xvals = np.array([x0])
    yvals = np.array([y0])
    
    x = x0
    y = y0
    while x<xmax:
        
        k1 = h*f(x,y)
        k2 = h*f(x+0.5*h,y+0.5*k1)
        k3 = h*f(x+0.5*h,y+0.5*k2)
        k4 = h*f(x+h,y+k3)
        
        yn = y + k1/6.0 + k2/3.0 +k3/3.0 + k4/6.0
        
        xn = x + h
    
        xvals = np.vstack([xvals,xn])
        yvals = np.vstack([yvals,yn])
        x,y = xn,yn
        #print(xn,yn)
        
    return xvals,yvals
    

def rk2solve(f,x0,y0,xmax,h):
    xvals = np.array([x0])
    yvals = np.array([y0])
    
    x = x0
    y = y0
    while x<xmax:
        
        k1 = h*f(x,y)
        k2 = h*f(x+0.5*h,y+0.5*k1)
        yn = y + k2
        
        xn = x + h
    
        xvals = np.append(xvals,xn)
        yvals = np.append(yvals,yn)
        x,y = xn,yn
        #print(xn,yn)
        
    return xvals,yvals



def rk1solve(f,x0,y0,xmax,h):
    xvals = np.array([x0])
    yvals = np.array([y0])

    x = x0
    y = y0
    while x<xmax:
        yn = y + h*f(x,y)
        xn = x + h

        xvals = np.append(xvals,xn)
        yvals = np.append(yvals,yn)
        x,y = xn,yn
    return xvals,yvals
    



#2d

def f(x,y):

    f = (x-1)**2 + 3*(y)**2
    
    return f

x = np.linspace(-2,2,50)
y = np.linspace(-1,1,50)

X,Y = np.meshgrid(x,y)

F = f(X,Y)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#surf = ax.plot_surface(X, Y, F)
#plt.show()

hx = x[1]-x[0]
hy = y[1]-y[0]
N = len(x)

Xv = np.reshape(X,(N*N,1))
Yv = np.reshape(Y,(N*N,1))

Fv = np.reshape(F,(N*N,1))

D2Y = sp.csc_matrix(1/hy**2*sp.diags([1,-2,1],[-N,0,N],shape=(N*N,N*N)))
D2X = sp.csc_matrix(1/hx**2*sp.diags([1,-2,1],[-1,0,1],shape=(N*N,N*N)))

D2Xv = D2X @ Fv
D2Yv = D2Y @ Fv

A = D2X + D2Y

#now put in boundary conditions
#B1 = (Yv==-1).astype(int).tolist()

# Y min boundary:
A[0:N,:] = 0.
A[0:N,0:N] = sp.eye(N)

# Y max boundary:
A[(N-1)*N:N*N,:] = 0.
A[(N-1)*N:N*N,(N-1)*N:N*N] = sp.eye(N)

# X min boundary:
A[0:N*N:N,:] = 0.
A[0:N*N:N,0:N*N:N] = sp.eye(N)

# X max boundary:
A[N-1:N*N:N,:] = 0.
A[N-1:N*N:N,N-1:N*N:N] = sp.eye(N)


AM = sp.csc_matrix(A)
B = sp.csc_matrix(0*Xv)
B[(N-1)*N:N*N] = 1

#plt.spy(AM)

u = sla.spsolve(AM,B)

U = np.reshape(u,(N,N))

surf = ax.plot_surface(X, Y, U)
plt.show()

#1d
xmin = 0
xmax = 2

x = np.linspace(xmin,xmax,101)
h = x[1]-x[0]
N = len(x)

D1 = sp.diags([-1/(2*h),0,1/(2*h)],[-1,0,1],shape=(N,N))
D2 = 1/h**2*sp.diags([1,-2,1],[-1,0,1],shape=(N,N))

I = sp.eye(N)

#create matrix:
A = sp.csc_matrix(-D2 + 4*D1 + I)

#create RHS:
f = sp.csc_matrix(x*np.exp(-x)).T

# put in boundary condition
A[0,0] = 1; A[0,1]=0
A[-1,-1]=1/h; A[-1,-2] = -1/h
f[0] = 0
f[-1] = -1

u = sla.spsolve(A,f)

plt.plot(x,u)










