import numpy as np

def print_tableau(tableau, basic_vars, non_basic_vars, step):
    print(f"\n=== Tableau at Iteration {step} ===")
    rows, cols = tableau.shape
    header = ["BV"] + [f"x{j+1}" for j in non_basic_vars] + [f"s{j+1}" for j in basic_vars] + ["RHS"]
    print(" | ".join(f"{h:^8}" for h in header))
    print("-" * (12 * len(header)))

    for i, bv in enumerate(basic_vars):
        row_values = [f"s{bv+1}"] + [f"{tableau[i,j]:8.3f}" for j in range(cols)]
        print(" | ".join(row_values))
    print("-" * (12 * len(header)))
    print(" Z | " + " | ".join(f"{tableau[-1,j]:8.3f}" for j in range(cols)))


def is_optimal(tableau,ismin):
    if (np.all(tableau[-1, :-1] >= 0) and not ismin) or (np.all(tableau[-1, :-1] <= 0) and ismin):
        print("\nOptimal solution found!")
        return True
    return False


def simplex_method(c, A, b, ismin):
    A = np.array(A, dtype=float)
    b = np.array(b, dtype=float)
    c = np.array(c, dtype=float)

    m, n = A.shape
    tableau = np.zeros((m + 1, n + m + 1))

    tableau[:m, :n] = A
    tableau[:m, n:n + m] = np.eye(m)
    tableau[:m, -1] = b
    tableau[-1, :n] = -c

    basic_vars = list(range(n, n + m))
    non_basic_vars = list(range(n))

    step = 0
    print_tableau(tableau, basic_vars, non_basic_vars, step)

    while not is_optimal(tableau,ismin):
        step += 1

        pivot_col = np.argmin(tableau[-1, :-1])
        if (np.all(tableau[:-1, pivot_col] <= 0)):
            print("\nUnbounded solution!")
            return

        ratios = []
        for i in range(m): #Ratio test
            if tableau[i, pivot_col] > 0:
                ratios.append(tableau[i, -1] / tableau[i, pivot_col])
            else:
                ratios.append(np.inf)

        pivot_row = np.argmin(ratios)
        pivot = tableau[pivot_row, pivot_col]
        tableau[pivot_row, :] /= pivot 

        for i in range(m + 1):
            if i != pivot_row:
                tableau[i, :] -= tableau[i, pivot_col] * tableau[pivot_row, :]

        basic_vars[pivot_row] = pivot_col
        print_tableau(tableau, basic_vars, non_basic_vars, step)

    # Extract solution
    solution = np.zeros(n)
    for i, bv in enumerate(basic_vars):
        if bv < n:
            solution[bv] = tableau[i, -1]
            Z = tableau[-1, -1]

    print("\nOptimal solution:")
    print("x =", solution)
    print("Z =", Z)
    return solution, Z

# Example usage
if __name__ == "__main__":
    # Maximize Z = 3x1 + 5x2
    # subject to:
    # 2x1 + 3x2 <= 8
    # 2x1 + x2 <= 6
    # x1, x2 >= 0
    c = [3, 5]
    A = [[2, 3],[2, 1]]
    b = [8, 6]

simplex_method(c, A, b,False)
