# -*- coding: utf-8 -*-
"""
Created on Fri Oct 14 16:39:44 2022

@author: 102194
"""
import numpy as np
import matplotlib.pyplot as plt
import scipy.linalg as la
import scipy.sparse as sp
import scipy.sparse.linalg as sla

from mpl_toolkits.mplot3d import Axes3D


def f(x,y):

    f = (x-1)**2 + 3*(y)**2
    
    return f

x = np.linspace(-2,2,10)
y = np.linspace(-1,1,10)

X,Y = np.meshgrid(x,y)

F = f(X,Y)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#surf = ax.plot_surface(X, Y, F)
#plt.show()

hx = x[1]-x[0]
hy = y[1]-y[0]
N = len(x)

Xv = np.reshape(X,(N*N,1))
Yv = np.reshape(Y,(N*N,1))

Fv = np.reshape(F,(N*N,1))

D2Y = sp.csc_matrix(1/hy**2*sp.diags([1,-2,1],[-N,0,N],shape=(N*N,N*N)))
D2X = sp.csc_matrix(1/hx**2*sp.diags([1,-2,1],[-1,0,1],shape=(N*N,N*N)))

D2Xv = D2X @ Fv

D2Yv = D2Y @ Fv

d2fdx = np.reshape(D2Xv,(N,N))
d2fdy = np.reshape(D2Yv,(N,N))
surf = ax.plot_surface(X, Y, d2fdy)



plt.show()



