# -*- coding: utf-8 -*-
"""
Created on Fri Oct 14 16:39:44 2022

@author: 102194
"""
import numpy as np
import matplotlib.pyplot as plt
import scipy.linalg as la
import scipy.sparse as sp
import scipy.sparse.linalg as sla

from mpl_toolkits.mplot3d import Axes3D


def f(x,y):

    f = (x-1)**2 + (y)**2
    
    return f

x = np.linspace(-2,2,10)
y = np.linspace(-1,1,10)

X,Y = np.meshgrid(x,y)

F = f(X,Y)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y, F)

plt.show()

hx = x[1]-x[0]
hy = y[1]-y[0]
N = len(x)

Xv = np.reshape(X,(N*N,1))
Yv = np.reshape(Y,(N*N,1))

Fv = np.reshape(F,(N*N,1))

plt.plot(Xv)
plt.plot(Yv)

D2Y = 1/hx**2*sp.diags([1,-2,1],[-N,0,N],shape=(N*N,N*N))
D2X = 1/hy**2*sp.diags([1,-2,1],[-1,0,1],shape=(N*N,N*N))




