#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  3 14:17:06 2024

@author: sam
"""

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-1,2)

xp = np.array([[-1],[0],[1],[2]])
yp = np.array([[2],[-1],[1],[0]])



def lagx(j,x,xp):
    
    #j first element
    if j == 0:
        l = (x-xp[1])/(xp[j]-xp[1])
        for i in range(2,len(xp)):
            l *= (x-xp[i])/(xp[j]-xp[i])
    
    #j last element
    elif j == len(xp)-1:
        l = (x-xp[0])/(xp[j]-xp[0])
        for i in range(1,j):
            l *= (x-xp[i])/(xp[j]-xp[i])
            
    #j somewhere in middle
    else:
        l = (x-xp[0])/(xp[j]-xp[0])
        for i in range(1,j):
            l *= (x-xp[i])/(xp[j]-xp[i])
        for i in range(j+1,len(xp)):
            l *= (x-xp[i])/(xp[j]-xp[i])
    return l

def L(x,xp,yp):
    L = 0
    for i in range(len(xp)): #loop through length of list of points
        L += lagx(i,x,xp)*yp[i] #adds each weighted Lagrange interpolation function
    return L



plt.plot(x,L(x,xp,yp)) #plot total interpolation function