#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep  3 14:16:55 2024

@author: sam
"""

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-1,2)

xp = np.array([[-1],[0],[1],[2]])
yp = np.array([[2],[-1],[1],[0]])



def lagx(j,x,xp):
    
    #j first element
    if j == 0:
        l = (x-xp[1])/(xp[j]-xp[1]) #creates value l starting with second element
        for i in range(2,len(xp)): #loops from third element to end
            l *= (x-xp[i])/(xp[j]-xp[i]) #multiply l by Lagrange interpolation formula for rest of points in array
    
    #j last element
    elif j == len(xp)-1:
        l = (x-xp[0])/(xp[j]-xp[0]) #creates value l starting with first element
        for i in range(1,j): #loops from second element to 1 from the end
            l *= (x-xp[i])/(xp[j]-xp[i]) #Lagrange interpolation formula
            
    #j somewhere in middle
    else:
        l = (x-xp[0])/(xp[j]-xp[0]) #creates value l starting with first element
        for i in range(1,j): #loops from second element to j (not inclusive)
            l *= (x-xp[i])/(xp[j]-xp[i]) #Lagrange interpolation formula
        for i in range(j+1,len(xp)): #loops from 1 past j to end
            l *= (x-xp[i])/(xp[j]-xp[i]) #Lagrange interpolation formula
    return l



for i in range(len(xp)):
    F = lagx(i,x,xp)
    plt.plot(x,F) #prints all interpolating functions that pass through their point j on one graph

#plt.plot(x,lagx(3,x,xp))