# -*- coding: utf-8 -*-
"""
Created on Sun Oct  9 11:00:22 2022

@author: Chris
Lab 10 Q2
"""

import numpy as np
import matplotlib.pyplot as plt

def f(x,y):
    fout = x*(1+4*y**2)
    #f0 = y[0]**2 + x*y[1]-1
    #f1 = y[0]*y[1] - x*y[1]
    #fout = np.array([f0,f1])
    return fout


def rk4solve(f,x0,y0,xmax,h):
    xvals = np.array([x0])
    yvals = np.array([y0])
    
    x = x0
    y = y0
    while x<xmax:
        
        k1 = h*f(x,y)
        k2 = h*f(x+0.5*h,y+0.5*k1)
        k3 = h*f(x+0.5*h,y+0.5*k2)
        k4 = h*f(x+h,y+k3)
        
        yn = y + k1/6.0 + k2/3.0 +k3/3.0 + k4/6.0
        
        xn = x + h
    
        xvals = np.vstack([xvals,xn])
        yvals = np.vstack([yvals,yn])
        x,y = xn,yn
        #print(xn,yn)
        
    return xvals,yvals
    
    
x0 = 0
y0 = 0 #np.array([1,2])

h = 0.01

xmax = 1.0

xvals,yvals = rk4solve(f,x0,y0,xmax,h)
plt.plot(xvals,yvals)

    
    


