#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>

// Define memory layout and ranges
#define INSTRUCTION_MEMORY_START 0x0000
#define INSTRUCTION_MEMORY_END   0x03FF
#define DATA_MEMORY_START        0x0400
#define DATA_MEMORY_END          0x07FF
#define VIRTUAL_ROUTINES_START   0x0800
#define HEAP_BANKS_START         0xB700
#define HEAP_BANKS_END           0xD6FF

// Define constants
#define MAX_ALLOCATIONS          1000
#define NUM_REGISTERS            32

// Define a block struct for heap memory management
typedef struct {
    uint8_t allocated;          // Flag to indicate if the block is allocated (1) or free (0)
    uint32_t start_addr;        // Starting address of the memory block
    uint64_t mem64[8];          // 64-bit memory cells in the block
} block;

// Define a Machine struct representing the RISKVII machine
typedef struct {
    uint8_t memory[0x1000];           // Array representing the machine's memory
    int32_t registers[NUM_REGISTERS]; // Array representing the machine's registers
    block heap[MAX_ALLOCATIONS];      // Array representing the heap memory allocations
    int num_allocations;              // Number of allocations currently in the heap
    uint32_t pc;                      // Program counter
} Machine;

// Function prototype for executing an instruction
void execute_instruction(Machine* machine, uint32_t instruction);

// Function prototype for handling virtual routines
void virtual_routine(Machine* machine, uint32_t address, bool is_store, int32_t* value);


int main(int argc, char *argv[]) {
    // Check the number of command-line arguments
    if (argc != 2) {
        printf("Usage: %s <binary_file>\n", argv[0]);
        return 1;
    }

    // Initialize the machine struct with zero values
    Machine machine = {0};

    // Open the binary file for reading
    FILE *file = fopen(argv[1], "rb");
    if (!file) {
        printf("Error opening file: %s\n", argv[1]);
        return 1;
    }

    // Determine the size of the binary file
    fseek(file, 0, SEEK_END); //moves the file position indicator to the end of the file.
    size_t file_size = ftell(file); // retrieves the current position of the file position indicator (size of the file in bytes)
    fseek(file, 0, SEEK_SET); // resets the file position indicator back to the beginning of the file

    // Check if the binary file has the correct size
    if (file_size != 2048) {
        printf("Error: binary file must be exactly 2048 bytes\n");
        fclose(file);
        return 1;
    }

    // Read the binary file into the machine's memory
    fread(machine.memory + INSTRUCTION_MEMORY_START, 1, 1024, file);
    fread(machine.memory + DATA_MEMORY_START, 1, 1024, file);
    fclose(file);

    // Initialize the program counter
    machine.pc = 0;

    // Initialize heap memory blocks
    for (int i = 0; i < 128; i++) {
        machine.heap[i].allocated = 0;
        memset(machine.heap[i].mem64, 0, sizeof(uint64_t) * 8 );
        machine.heap[i].start_addr = 0xb700 + (i*64);
    }

    // Execute instructions until the program counter reaches the end of the instruction memory
    while (machine.pc <= INSTRUCTION_MEMORY_END) {
        uint32_t instruction = *((uint32_t *)(machine.memory + machine.pc));
        execute_instruction(&machine, instruction);
    }

    return 0;
}

// Execute a single RISC-XVII instruction
void execute_instruction(Machine* machine, uint32_t instruction) {
    // Extract fields from the instruction
    uint32_t opcode = instruction & 0x7F;
    uint32_t rd = (instruction >> 7) & 0x1F;
    uint32_t func3 = (instruction >> 12) & 0x07;
    uint32_t rs1 = (instruction >> 15) & 0x1F;
    uint32_t rs2 = (instruction >> 20) & 0x1F;
    uint32_t func7 = (instruction >> 25) & 0x7F;

    int32_t imm = 0;
    int32_t address = 0;
    int32_t value = 0;

    // Implement RISK-XVII instruction set
    switch (opcode) {
        case 0x13: // I-Type (addi, andi, ori, xori, slti, sltiu, slli, srli, srai)
            // Sign extend the immediate value if the 31st bit is set
            if ((instruction >> 31) & 0x1) {
                imm = 0xFFFFFFFF; 
            }
            // Extract the immediate value from the instruction
            imm = (imm << 11) | ((instruction >> 20) & 0x7FF);

            // Process I-Type instructions based on the func3 field
            if (func3 == 0x0) { // addi
                machine->registers[rd] = machine->registers[rs1] + imm;
                machine->pc += 4;
                // Ensure register 0 always has the value 0
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
            } else if (func3 == 0x1) { // slli
                machine->registers[rd] = machine->registers[rs1] << (imm & 0x1F);
                machine->pc += 4;
                // Ensure register 0 always has the value 0
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }

				//printf("done slli\n");
            } else if (func3 == 0x2) { // slti
                machine->registers[rd] = (machine->registers[rs1] < imm) ? 1 : 0;
                machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
				//printf("done slti\n");
            } else if (func3 == 0x3) { // sltiu
                machine->registers[rd] = (((uint32_t)machine->registers[rs1]) < ((uint32_t)imm)) ? 1 : 0;
                machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
				//printf("done sltiu\n");
            } else if (func3 == 0x4) { // xori
                machine->registers[rd] = machine->registers[rs1] ^ imm;
                machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
				//printf("done xori\n");
            } else if (func3 == 0x5) { // srli, srai
                if (func7 == 0x00) { // srli
                    machine->registers[rd] = ((uint32_t)machine->registers[rs1]) >> (imm & 0x1F);
                    machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
					//printf("done srli\n");
                } else if (func7 == 0x20) { // srai
                    machine->registers[rd] = machine->registers[rs1] >> (imm & 0x1F);
                    machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
					//printf("done srai\n");
                }
            } else if (func3 == 0x6) { // ori
                machine->registers[rd] = machine->registers[rs1] | imm;
                machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
				//printf("done ori\n");
            } else if (func3 == 0x7) { // andi
                machine->registers[rd] = machine->registers[rs1] & imm;
                machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
				//printf("done andi\n");
            }
            break;
        case 0x33: // R-Type (add, sub, and, or, xor, slt, sltu, sll, srl, sra)
            switch (func3) {
            case 0x0:
                if (func7 == 0x00) { // add
                    machine->registers[rd] = machine->registers[rs1] + machine->registers[rs2];
                    machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("%d\n", machine->registers[rs1]);
                    //printf("done add\n");
                } else if (func7 == 0x20) { // sub
                    machine->registers[rd] = machine->registers[rs1] - machine->registers[rs2];
                    machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("done sub\n");
                }
                break;
            case 0x1: // sll
                machine->registers[rd] = machine->registers[rs1] << (machine->registers[rs2] & 0x1F);
                machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
				//printf("done sll\n");
                break;
            case 0x2: // slt
                machine->registers[rd] = (machine->registers[rs1] < machine->registers[rs2]) ? 1 : 0;
                machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
				//printf("done slt\n");
                break;
            case 0x3: // sltu
                machine->registers[rd] = (((uint32_t)machine->registers[rs1]) < ((uint32_t)machine->registers[rs2])) ? 1 : 0;
				machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
                //printf("done sltu\n");
                break;
            case 0x4: // xor
                machine->registers[rd] = machine->registers[rs1] ^ machine->registers[rs2];
				machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
                //printf("done xor\n");
                break;
            case 0x5:
                if (func7 == 0x00) { // srl
                    machine->registers[rd] = ((uint32_t)machine->registers[rs1]) >> (machine->registers[rs2] & 0x1F);
					machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("done srl\n");
                } else if (func7 == 0x20) { // sra
                    machine->registers[rd] = machine->registers[rs1] >> (machine->registers[rs2] & 0x1F);
					machine->pc += 4;
                    //printf("done sra\n");
                }
                break;
            case 0x6: // or
                machine->registers[rd] = machine->registers[rs1] | machine->registers[rs2];
				machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
                //printf("done or\n");
                break;
            case 0x7: // and
                machine->registers[rd] = machine->registers[rs1] & machine->registers[rs2];
				machine->pc += 4;
                if (rd == 0) {
                    machine->registers[rd] = 0;
                }
                //printf("done and\n");
                break;
            }
            break;
            
        // Load instructions (lb, lh, lw, lbu, lhu)
        case 0x03:
            if ((instruction >> 31) & 0x1) {
                imm = 0xFFFFFFFF; 
            }
            imm = (imm << 11) | ((instruction >> 20) & 0x7FF);
            // imm = (int32_t)((instruction >> 20) & 0xFFF);
            address = machine->registers[rs1] + imm;
            switch (func3) {
                case 0x0: // lb
                    if ((machine->memory[address] >> 7) & 0x1) {
                    rd = 0xFFFFFFFF;
                    } 
                    machine->registers[rd] = machine->registers[rd] << 8;
                    machine->registers[rd] = machine->registers[rd] | machine->memory[address];
                    //machine->registers[rd] = (int32_t)(int8_t)machine->memory[address];
					//machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("done lb\n");
                    break;
                case 0x1: // lh
                   if ((machine->memory[address + 1] >> 7) & 0x1) {
                    rd = 0xFFFFFFFF;
                    } 
                    machine->registers[rd] = machine->registers[rd] << 8;
                    machine->registers[rd] = (machine->registers[rd] | machine->memory[address + 1]) << 8 | machine->memory[address];
                    //machine->registers[rd] = (int32_t)(int16_t)(machine->memory[address] | (machine->memory[address + 1] << 8));
					//machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("done lh\n");
                    break;
                case 0x2: // lw
                    machine->registers[rd] = (machine->memory[address + 3] << 8 | (machine->memory[address + 2] << 8) | (machine->memory[address + 1] << 8) | (machine->memory[address]));
                    //machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("imm-%d\n", imm);
                    //printf("done lw\n");
					break;
                case 0x4: // lbu
                    machine->registers[rd] = (uint32_t)machine->memory[address];
                    //machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("imm-%d\n", imm);
                    //printf("done lbu\n");
					break;
                case 0x5: // lhu
                    machine->registers[rd] = (uint32_t)(machine->memory[address] | (machine->memory[address + 1] << 8));
					//machine->pc += 4;
                    if (rd == 0) {
                        machine->registers[rd] = 0;
                    }
                    //printf("done lhu\n");
                    break;
                }
                if (address >= VIRTUAL_ROUTINES_START && address <= HEAP_BANKS_START) {
                    virtual_routine(machine, address, (opcode == 0x23), &value);
                    machine->registers[rd] = value;
                }
                machine->pc += 4;
                break;

        // Store instructions (sb, sh, sw) S instructions
        case 0x23:
            if ((instruction >> 31) & 0x1) {
                    imm = 0xFFFFFFFF; 
                }
            uint32_t imm11_to_5 = (instruction >> 25) & 0x7F;
            uint32_t imm4_to_0 = (instruction >> 7) & 0x1F;
            imm = (int32_t)((((instruction >> 31) & 0x1) << 11) | (imm11_to_5 << 5) | imm4_to_0);
            imm = (imm << 20) >> 20; // Sign-extend the 12-bit immediate value
            //imm = (int32_t)(((instruction >> 25) << 5) | ((instruction >> 7) & 0x1F));
            address = machine->registers[rs1] + imm;
            // printf("func-%x\n", address);
            switch (func3) {
                case 0x0: // sb
                    machine->memory[address] = (uint8_t)(machine->registers[rs2] & 0xFF);
                    value = machine->memory[address];
                    //printf("done sb\n");
                    break;
                case 0x1: // sh
                    machine->memory[address] = (uint8_t)(machine->registers[rs2] & 0xFF);
                    machine->memory[address + 1] = (uint8_t)((machine->registers[rs2] >> 8) & 0xFF);
                    value = ((machine->memory[address + 1] << 8) | machine->memory[address]);
                    //printf("done sh\n");
                    break;
                case 0x2: // sw
                    //machine->memory[address] = (uint8_t)(machine->registers[rs2]);
                    machine->memory[address] = (uint8_t)(machine->registers[rs2] & 0xFF);
                    machine->memory[address + 1] = (uint8_t)((machine->registers[rs2] >> 8) & 0xFF);
                    machine->memory[address + 2] = (uint8_t)((machine->registers[rs2] >> 16) & 0xFF);
                    machine->memory[address + 3] = (uint8_t)((machine->registers[rs2] >> 24) & 0xFF);
                    // value = (machine->memory[address + 3] << 24 |
                    //         (machine->memory[address + 2] << 16) |
                    //         (machine->memory[address + 1] << 8) |
                    //         (machine->memory[address]));
                    value = machine->registers[rs2];

                    // value = (machine->memory[address] |
                    //          (machine->memory[address + 1] << 8) |
                    //          (machine->memory[address + 2] << 16) |
                    //          (machine->memory[address + 3] << 24));
                    
                    //printf("done sw\n");
                    break;
                }
                if (address >= VIRTUAL_ROUTINES_START && address < HEAP_BANKS_START) {
                    virtual_routine(machine, address, (opcode == 0x23), &value);
                }
                // else if (address >= HEAP_BANKS_START && address <= HEAP_BANKS_END) {

                //     virtual_routine(machine, address, (opcode == 0x23), &value);
                // }
                    machine->pc += 4;
                    break;
        
        
        // Branch instructions (beq, bne, blt, bge, bltu, bgeu) SB instructions
        case 0x63:
            //imm = (int32_t)((((instruction >> 31) & 0x1) << 12) | (((instruction >> 7) & 0x1E) << 5) | (((instruction >> 25) & 0x3F) << 5) | (((instruction >> 8) & 0xF) << 1));
            uint32_t imm12 = ((instruction >> 31) & 0x1);
            uint32_t imm11 = (instruction >> 7) & 0x1;
            uint32_t imm10_to_5 = (instruction >> 25) & 0x3F;
            uint32_t imm4_to_1 = (instruction >> 8) & 0x1F;
            if (imm12 & 0x1) {
                imm = 0xFFFFFFFF;
            }
            imm = (imm << 1) | (imm12);
            imm = (imm << 1) | (imm11);
            imm = (imm << 6) | imm10_to_5;
            imm = (imm << 4) | imm4_to_1;
            imm = imm << 1;
            //printf("checking sb\n");
            switch (func3) {
                case 0x0: // beq
                    //printf("1-%u 2-%u 3-%d\n", machine->registers[rs1], machine->registers[rs2], imm);
                    if (machine->registers[rs1] == machine->registers[rs2]) {
                        machine->pc += imm;
                        //printf("done beq\n");
                    }
                    else {
                        //printf("imm = %d\n", imm);
                        machine->pc += 4;
                        //printf("done beq2\n");
                    }
                    // machine->pc += imm;
                    // printf("done beq\n");
                    break;
                case 0x1: // bne
                    if (machine->registers[rs1] != machine->registers[rs2]) {
                        machine->pc += imm;
                        //printf("%d\n", imm);
						//printf("done bne\n");
                    }
                    else {
                        machine->pc += 4;
                    }
                    break;
                case 0x4: // blt
                    if (machine->registers[rs1] < machine->registers[rs2]) {
                        machine->pc += imm;
						//printf("done blt\n");
                    } else {
                        machine->pc += 4;
                    }
                    break;
                case 0x5: // bge
                    if (machine->registers[rs1] >= machine->registers[rs2]) {
                        machine->pc += imm;
						//printf("done bge\n");
                    } else {
                        machine->pc += 4;
                    }
                    break;
                case 0x6: // bltu
                    if (((uint32_t)machine->registers[rs1]) < ((uint32_t)machine->registers[rs2])) {
                        machine->pc += imm;
						//printf("done bltu\n");
                    } else {
                        machine->pc += 4;
                    }
                    break;
                case 0x7: // bgeu
                    if (((uint32_t)machine->registers[rs1]) >= ((uint32_t)machine->registers[rs2])) {
                        machine->pc += imm;
						//printf("done bgeu\n");
                    } else {
                        machine->pc += 4;
                    }
                    break;
                }
                break;
            
        // JAL instruction (0x6F)
        case 0x6F:
            
            uint32_t imm20 = (instruction >> 31) & 0x1;
            uint32_t imm19_to_12 = (instruction >> 12) & 0xff;
            uint32_t imm11__ = (instruction >> 20) & 0x1;
            uint32_t imm10_to_1 = (instruction >> 21) & 0x3ff;
            imm |= (imm10_to_1);
            imm |= (imm11__ << 10);
            imm |= (imm19_to_12 << 11);
            imm |= (imm20 << 19);
            imm = (int32_t)((((instruction >> 31) & 0x1) << 20) | (((instruction >> 12) & 0xFF) << 12) | (((instruction >> 20) & 0x1) << 11) | (((instruction >> 21) & 0x3FF) << 1));

            if ((imm >> 20  & 0x1) == 1) {
                imm |= 0xffe00000;
            }
            machine->registers[rd] = machine->pc + 4; // Save the return address
            machine->pc = machine->pc + imm; // Jump to the target address
            if (rd == 0) {
                machine->registers[rd] = 0;
            }
            // printf("imm-%d\n", imm);
            // printf("done JAL\n");
            break;

        // JALR instruction (0x67) I type
        case 0x67:
            if ((instruction >> 31) & 0x1) {
                imm = 0xFFFFFFFF; 
            }
            imm = (imm << 11) | ((instruction >> 20) & 0x7FF);
            // imm = (int32_t)((instruction >> 20) & 0xFFF);
            uint32_t temp = machine->registers[rs1] + imm;
            machine->registers[rd] = machine->pc + 4; // Save the return address
            if (rd == 0) {
                machine->registers[rd] = 0;
            }
            machine->pc = temp; // Jump to the target address (force the least significant bit to 0)
            //printf("done JALR\n");
            break;

        case 0x37: // U-Type (lui)
            imm = (int32_t)((instruction >> 12) & 0xFFFFF);
            machine->registers[rd] = (imm << 12);
            machine->pc += 4;
            // printf("imm-%d\n", imm);
            // printf("done lui\n");
            break;
        

        default:
        // Handle undefined opcodes or other unimplemented instructions
            printf("Instruction Not Implemented: 0x%x\n", instruction);
            printf("PC = 0x%08x;\n", machine->pc);
            for(int i = 0; i < 32; i++) {
                printf("R[%d] = 0x%08x;\n", i, machine->registers[i]);
            }
            //printf("Error: Unimplemented or undefined opcode (0x%X) encountered.\n", opcode);
            exit(1);
    }
    // Check if the instruction accesses a virtual routine
    // if (address >= VIRTUAL_ROUTINES_START) {
    //     virtual_routine(machine, address, (opcode == 0x23), &value);
    // }
}

void virtual_routine(Machine* machine, uint32_t address, bool is_store, int32_t* value) {
    switch (address) {
        case 0x0800: //Console Write Character
            if (is_store) {
                printf("%c", (char)*value);
                //machine->pc += 4;
            }
            break;
        case 0x0804: //Console Write Signed Integer
            if (is_store) {
                printf("%d", *value);
                //machine->pc += 4;
            }
            break;
        case 0x0808: //Console Write Unsigned Integer
            if (is_store) {
                printf("%x", (uint32_t)*value);
                //machine->pc += 4;
            }
            break;
        case 0x080C: //Halt
            if (is_store) {
                printf("CPU Halt Requested\n");
                //machine->pc += 4;
                exit(0);
            }
            break;

        case 0x0812: //Console Read Character
            if (!is_store) {
                scanf("%d", value);
            }
            break;
        case 0x0816: //Console Read Signed Integer
            if (!is_store) {
                scanf("%d", value);
                //machine->pc += 4;
            }
            break;
        case 0x0820: //Dump PC
            if (is_store) {
                printf("PC: %08x\n", machine->pc);
                //machine->pc += 4;
            }
            break;
        case 0x0824: //Dump Register Banks
            if (is_store) {
                printf("Register Dump:\n");
                for (int i = 0; i < NUM_REGISTERS; i++) {
                    printf("R%d: %08x\n", i, (uint32_t)machine->registers[i]);
                    //machine->pc += 4;
                }
            }
            break;
        case 0x0828: //Dump Memory Word
        if (is_store) {
            printf("Memory Dump (%08x): %08x\n", (uint32_t)*value, *((uint32_t *)(machine->memory + *value)));
            //machine->pc += 4;
        }
        break;

        // Heap banks

    case 0x0830: //malloc
        if (is_store) {
            for (int i = 0; i < 128; i++) {
                if (machine->heap[i].allocated == 0) {
                    //machine->heap[i].start_addr = 0xb700 + i;
                    machine->registers[28] = machine->heap[i].start_addr;
                    machine->heap[i].allocated = 1;
                    return;
                }
            }
        }
        break;
    case 0x0834: // free
        if (is_store) {
            // Memory deallocation request, with address to free stored in *value
            for (int i = 0; i < 128; i++) {
                if (machine->heap[i].start_addr == *value) {
                    machine->heap[i].start_addr = 0;
                    // machine->registers[28] = 0xb700 + i;
                    machine->heap[i].allocated = 0;
                    return;
                }
            }
            
        }
        break;
        default:
            printf("Error: Unimplemented or undefined virtual routine (0x%X) encountered.\n", address);
            exit(1);
            break;
        }
    }