#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 22 13:53:43 2024

@author: sam
"""

import numpy as np
import matplotlib.pyplot as plt



# by letting y0 = y
# y1 = y0' = y'
# y2 = y1' = y''
# y3 = y2' = y'''

# we can rewrite y''' + e^x*y' - y = 0 in terms of derivatives of y0, y1 and y2
#     [y0]   [    y1     ]
# d/dx[y1] = [    y2     ]
#     [y2]   [y0 - e^x*y1]

def f(x,y):
    return np.array([y[1],y[2],y[0]-np.exp(x)*y[1]])
    # function with all 3 equations in an array

def rk4solve(f,x0,y0,xmax,h):
# rk4solve function that uses 4th-order Runge-Kutta method over a function f,
# initial conditions x0 and y0, a max range xmax, and step size h
    
    x = np.array([x0])
    # creates a numpy array of x points
    
    y = np.array([y0])
    # creates a numpy array of y points
    
    while x0 < xmax:
    # while x at each step (xn) does not exceed the maximum x range
        
        k1 = f(x0,y0)
        # k1 is the slope at the beginning of the interval (across the step size)
        
        k2 = f(x0+(h/2),y0+(k1*h)/2)
        # k2 is the slope at the midpoint, using y and k1
        
        k3 = f(x0+(h/2),y0+(k2*h)/2)
        # k3 is the slope at the midpoint, using y and k2
        
        k4 = f(x0+h,y0+k3*h)
        # k4 is the slope at the endpoint, using y and k3
        
        x = np.vstack([x,x0])
        # add each steps xn to the numpy array x
        
        y = np.vstack([y,y0])
        # add each steps yn to the numpy array y
        
        x0 += h
        # increase xn at each step by step size h
        
        y0 += (h/6)*(k1+(2*k2)+(2*k3)+k4)
        # increase yn at each step by a weighted average slope
        
    return x,y
    # outputs the two arrays x,y



x,y = rk4solve(f,0,[2,0,0],10,0.001)
# runs rk4solve over the three equations in function f
# with initial conditions y(0) = 2, y'(0) = 0, y''(0) = 0
# maximum x value of 10 to limit the x range to [0,10]
# step size h = 0.001

plt.plot(x,y[:,0])
# plots x against the first element of each y point from array y, so as to only plot y(x)