# -*- coding: utf-8 -*-
"""
Created on Fri Oct 14 16:39:44 2022

@author: 102194
"""
import numpy as np
import matplotlib.pyplot as plt
import scipy.linalg as la
import scipy.sparse as sp
import scipy.sparse.linalg as sla

from mpl_toolkits.mplot3d import Axes3D


def f(x,y):

    f = (x-1)**2 + 3*(y)**2
    
    return f

x = np.linspace(-2,2,50)
y = np.linspace(-1,1,50)

X,Y = np.meshgrid(x,y)

F = f(X,Y)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#surf = ax.plot_surface(X, Y, F)
#plt.show()

hx = x[1]-x[0]
hy = y[1]-y[0]
N = len(x)

Xv = np.reshape(X,(N*N,1))
Yv = np.reshape(Y,(N*N,1))

Fv = np.reshape(F,(N*N,1))

D2Y = sp.csc_matrix(1/hy**2*sp.diags([1,-2,1],[-N,0,N],shape=(N*N,N*N)))
D2X = sp.csc_matrix(1/hx**2*sp.diags([1,-2,1],[-1,0,1],shape=(N*N,N*N)))

D2Xv = D2X @ Fv
D2Yv = D2Y @ Fv

A = D2X + D2Y

#now put in boundary conditions
#B1 = (Yv==-1).astype(int).tolist()

# Y min boundary:
A[0:N,:] = 0.
A[0:N,0:N] = sp.eye(N)

# Y max boundary:
A[(N-1)*N:N*N,:] = 0.
A[(N-1)*N:N*N,(N-1)*N:N*N] = sp.eye(N)

# X min boundary:
A[0:N*N:N,:] = 0.
A[0:N*N:N,0:N*N:N] = sp.eye(N)

# X max boundary:
A[N-1:N*N:N,:] = 0.
A[N-1:N*N:N,N-1:N*N:N] = sp.eye(N)


AM = sp.csc_matrix(A)
B = sp.csc_matrix(0*Xv)
B[(N-1)*N:N*N] = 1

#plt.spy(AM)

u = sla.spsolve(AM,B)

U = np.reshape(u,(N,N))

surf = ax.plot_surface(X, Y, U)
plt.show()



