#!/usr/bin/python3
import numpy as np
import math
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as Axes3D


def lagx(j,x,xp):
    num = 1
    den = 1
    for i in range(0,len(xp)):
        if i!=j:
            num*=x-xp[i]
            den*=xp[j]-xp[i]
    return num/den

def lagint(x,xp,yp):
    ans = 0
    for i in range(0,len(xp)):
        ans+=lagx(i,x,xp)*yp[i]
    return ans



xp = np.array([[-1],[0],[1],[2]])
yp = np.array([[2],[-1],[1],[0]])

x = np.linspace(-1,2,100)
plt.plot(x, lagint(x,xp,yp))
plt.show()

