#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 15 15:42:57 2024

@author: sam
"""

import numpy as np
import matplotlib.pyplot as plt



def f1(x,y):
    return y[1]

def f2(x,y):
    return -4*y[0]

def f(x,y):
    return np.array([f1(x,y),f2(x,y)])

def rk4solve(f,x0,y0,xmax,h):
    x = np.array([x0])
    y = np.array([y0])
    while x0 < xmax:
        k1 = f(x0,y0)
        k2 = f(x0+(h/2),y0+(k1*h)/2)
        k3 = f(x0+(h/2),y0+(k2*h)/2)
        k4 = f(x0+h,y0+k3*h)
        
        x = np.vstack([x,x0])
        y = np.vstack([y,y0])
        x0 += h
        y0 += (h/6)*(k1+(2*k2)+(2*k3)+k4)
    return x,y



x,y = rk4solve(f,0,[1,0],5,0.01)
plt.plot(x,y)