#!/usr/bin/env python3
"""
Math Calculator with Step-by-Step Solutions
Uses SymPy for symbolic mathematics
"""

import sys
import argparse
import re as regex_module
from sympy import *
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
import matplotlib.pyplot as plt
import numpy as np

# Restore re module (sympy's star import shadows it)
re = regex_module

def preprocess_expr(text):
    """Convert common math notation to SymPy format"""
    # Convert ^ to ** for exponentiation
    text = text.replace('^', '**')
    return text

def parse_problem(text):
    """Extract mathematical expression from natural language"""
    text = text.lower().strip()
    
    # Detect problem type
    if 'solve' in text:
        # Extract equation
        eq_match = re.search(r'solve\s+(.+?)(?:\s+for\s+(\w+))?$', text)
        if eq_match:
            expr = eq_match.group(1)
            var = eq_match.group(2) if eq_match.group(2) else 'x'
            return {'type': 'solve', 'expr': expr, 'var': var}
    
    elif 'derivative' in text or 'differentiate' in text:
        match = re.search(r'(?:derivative|differentiate)\s+(?:of\s+)?(.+?)(?:\s+with respect to\s+(\w+))?$', text)
        if match:
            expr = match.group(1)
            var = match.group(2) if match.group(2) else 'x'
            return {'type': 'derivative', 'expr': expr, 'var': var}
    
    elif 'integrate' in text or 'integral' in text:
        match = re.search(r'integrate\s+(.+?)(?:\s+from\s+(.+?)\s+to\s+(.+?))?$', text)
        if match:
            expr = match.group(1)
            lower = match.group(2) if match.group(2) else None
            upper = match.group(3) if match.group(3) else None
            return {'type': 'integrate', 'expr': expr, 'lower': lower, 'upper': upper}
    
    elif 'simplify' in text:
        match = re.search(r'simplify\s+(.+)$', text)
        if match:
            return {'type': 'simplify', 'expr': match.group(1)}
    
    elif 'factor' in text:
        match = re.search(r'factor\s+(.+)$', text)
        if match:
            return {'type': 'factor', 'expr': match.group(1)}
    
    elif 'graph' in text or 'plot' in text:
        match = re.search(r'(?:graph|plot)\s+(?:y\s*=\s*)?(.+?)(?:\s+from\s+x\s*=\s*(.+?)\s+to\s+x\s*=\s*(.+?))?$', text)
        if match:
            expr = match.group(1)
            x_min = match.group(2) if match.group(2) else '-10'
            x_max = match.group(3) if match.group(3) else '10'
            return {'type': 'graph', 'expr': expr, 'x_min': x_min, 'x_max': x_max}
    
    elif 'calculate' in text or 'compute' in text:
        match = re.search(r'(?:calculate|compute)\s+(.+)$', text)
        if match:
            return {'type': 'evaluate', 'expr': match.group(1)}
    
    # Default: try to evaluate as expression
    return {'type': 'evaluate', 'expr': text}

def solve_equation(expr_text, var_name='x', show_steps=True):
    """Solve an equation and show steps"""
    try:
        var = symbols(var_name)
        expr_text = preprocess_expr(expr_text)
        
        # Parse equation (handle = sign)
        if '=' in expr_text:
            left, right = expr_text.split('=')
            expr = parse_expr(left.strip(), transformations='all') - parse_expr(right.strip(), transformations='all')
        else:
            expr = parse_expr(expr_text.strip(), transformations='all')
        
        print(f"\nProblem: Solve {expr} = 0\n")
        
        if show_steps:
            # Analyze equation type
            degree = degree_list(Poly(expr, var))[0] if expr.is_polynomial(var) else None
            
            if degree == 1:
                print("Step 1: This is a linear equation")
                print(f"        Simplify: {expr} = 0")
                
            elif degree == 2:
                # Extract coefficients
                poly = Poly(expr, var)
                coeffs = poly.all_coeffs()
                if len(coeffs) == 3:
                    a, b, c = coeffs
                    print(f"Step 1: This is a quadratic equation in the form ax² + bx + c = 0")
                    print(f"        where a={a}, b={b}, c={c}\n")
                    
                    # Try to factor
                    factored = factor(expr)
                    if factored != expr:
                        print(f"Step 2: Factor the expression")
                        print(f"        {expr} = {factored}\n")
                        print(f"Step 3: Set each factor to zero")
        
        # Solve
        solutions = solve(expr, var)
        
        if show_steps and solutions:
            for i, sol in enumerate(solutions, 1):
                print(f"        {var} = {sol}")
        
        print(f"\nSolutions: {var} = {', '.join(str(s) for s in solutions)}\n")
        
        # Verify
        if show_steps:
            print("Verification:")
            for sol in solutions:
                result = expr.subs(var, sol)
                check = "✓" if result == 0 else "✗"
                print(f"  When {var}={sol}: {expr.subs(var, sol)} = {result} {check}")
        
        return solutions
        
    except Exception as e:
        print(f"Error solving equation: {e}")
        return None

def find_derivative(expr_text, var_name='x', show_steps=True):
    """Find derivative with step-by-step explanation"""
    try:
        var = symbols(var_name)
        expr = parse_expr(preprocess_expr(expr_text.strip()))
        
        print(f"\nProblem: Find the derivative of {expr}\n")
        
        if show_steps:
            print("Step 1: Apply the power rule: d/dx[x^n] = n·x^(n-1)\n")
            print("Step 2: Differentiate each term")
            
            # Break down if polynomial
            if expr.is_polynomial(var):
                terms = Add.make_args(expr)
                for term in terms:
                    deriv = diff(term, var)
                    print(f"        d/dx[{term}] = {deriv}")
            
            print("\nStep 3: Combine results")
        
        result = diff(expr, var)
        print(f"        f'({var}) = {result}\n")
        print(f"Answer: {result}")
        
        return result
        
    except Exception as e:
        print(f"Error finding derivative: {e}")
        return None

def integrate_expr(expr_text, lower=None, upper=None, show_steps=True):
    """Integrate with optional bounds"""
    try:
        var = symbols('x')
        expr = parse_expr(preprocess_expr(expr_text.strip()))
        
        if lower and upper:
            print(f"\nProblem: Integrate {expr} from {lower} to {upper}\n")
        else:
            print(f"\nProblem: Find the indefinite integral of {expr}\n")
        
        if show_steps:
            print("Step 1: Apply integration rules")
        
        result = integrate(expr, var)
        
        if lower and upper:
            lower_val = parse_expr(lower)
            upper_val = parse_expr(upper)
            definite = integrate(expr, (var, lower_val, upper_val))
            
            if show_steps:
                print(f"Step 2: Evaluate at bounds")
                print(f"        F(x) = {result}")
                print(f"        F({upper}) - F({lower}) = {definite}")
            
            print(f"\nAnswer: {definite}")
            return definite
        else:
            print(f"\nAnswer: {result} + C")
            return result
        
    except Exception as e:
        print(f"Error integrating: {e}")
        return None

def simplify_expr(expr_text, show_steps=True):
    """Simplify an expression"""
    try:
        expr = parse_expr(preprocess_expr(expr_text.strip()))
        print(f"\nProblem: Simplify {expr}\n")
        
        result = simplify(expr)
        
        if show_steps and result != expr:
            print(f"Simplified: {result}")
        elif result == expr:
            print("Expression is already in simplest form")
        
        return result
        
    except Exception as e:
        print(f"Error simplifying: {e}")
        return None

def factor_expr(expr_text, show_steps=True):
    """Factor an expression"""
    try:
        expr = parse_expr(preprocess_expr(expr_text.strip()))
        print(f"\nProblem: Factor {expr}\n")
        
        result = factor(expr)
        
        if result != expr:
            print(f"Factored: {result}")
        else:
            print("Expression cannot be factored further")
        
        return result
        
    except Exception as e:
        print(f"Error factoring: {e}")
        return None

def evaluate_expr(expr_text, show_steps=True):
    """Evaluate a numerical expression"""
    try:
        expr = parse_expr(preprocess_expr(expr_text.strip()))
        print(f"\nProblem: Calculate {expr}\n")
        
        result = expr.evalf()
        print(f"Answer: {result}")
        
        return result
        
    except Exception as e:
        print(f"Error evaluating: {e}")
        return None

def graph_function(expr_text, x_min='-10', x_max='10', output_path='/tmp/math-plot.png'):
    """Graph a function"""
    try:
        var = symbols('x')
        expr = parse_expr(preprocess_expr(expr_text.strip()))
        
        x_min_val = float(parse_expr(x_min))
        x_max_val = float(parse_expr(x_max))
        
        # Create numpy function
        f = lambdify(var, expr, 'numpy')
        x_vals = np.linspace(x_min_val, x_max_val, 1000)
        y_vals = f(x_vals)
        
        # Plot
        plt.figure(figsize=(10, 6))
        plt.plot(x_vals, y_vals, 'b-', linewidth=2)
        plt.axhline(y=0, color='k', linestyle='-', linewidth=0.5)
        plt.axvline(x=0, color='k', linestyle='-', linewidth=0.5)
        plt.grid(True, alpha=0.3)
        plt.xlabel('x', fontsize=12)
        plt.ylabel('y', fontsize=12)
        plt.title(f'y = {expr}', fontsize=14)
        plt.tight_layout()
        plt.savefig(output_path, dpi=150)
        
        print(f"\nGraph saved to: {output_path}")
        print(f"Function: y = {expr}")
        print(f"Domain: x ∈ [{x_min_val}, {x_max_val}]")
        
        return output_path
        
    except Exception as e:
        print(f"Error graphing: {e}")
        return None

def interactive_mode():
    """Start interactive REPL"""
    print("\n=== Math Calculator - Interactive Mode ===")
    print("Type 'help' for usage, 'quit' to exit\n")
    
    while True:
        try:
            problem = input("math> ").strip()
            
            if not problem:
                continue
            
            if problem.lower() in ['quit', 'exit', 'q']:
                print("Goodbye!")
                break
            
            if problem.lower() == 'help':
                print("\nExamples:")
                print("  solve x^2 - 5x + 6 = 0")
                print("  derivative of x^3 + 2x^2")
                print("  integrate sin(x) from 0 to pi")
                print("  simplify (x^2 - 4) / (x - 2)")
                print("  calculate 234 * 567")
                print()
                continue
            
            process_problem(problem)
            print()
            
        except KeyboardInterrupt:
            print("\nGoodbye!")
            break
        except Exception as e:
            print(f"Error: {e}\n")

def process_problem(text, show_steps=True, output_path='/tmp/math-plot.png'):
    """Process a math problem based on parsed type"""
    parsed = parse_problem(text)
    
    if parsed['type'] == 'solve':
        solve_equation(parsed['expr'], parsed.get('var', 'x'), show_steps)
    elif parsed['type'] == 'derivative':
        find_derivative(parsed['expr'], parsed.get('var', 'x'), show_steps)
    elif parsed['type'] == 'integrate':
        integrate_expr(parsed['expr'], parsed.get('lower'), parsed.get('upper'), show_steps)
    elif parsed['type'] == 'simplify':
        simplify_expr(parsed['expr'], show_steps)
    elif parsed['type'] == 'factor':
        factor_expr(parsed['expr'], show_steps)
    elif parsed['type'] == 'graph':
        graph_function(parsed['expr'], parsed.get('x_min', '-10'), parsed.get('x_max', '10'), output_path)
    elif parsed['type'] == 'evaluate':
        evaluate_expr(parsed['expr'], show_steps)
    else:
        print(f"Unknown problem type: {parsed['type']}")

def main():
    parser = argparse.ArgumentParser(description='Math Calculator with Step-by-Step Solutions')
    parser.add_argument('problem', nargs='*', help='Math problem to solve')
    parser.add_argument('-i', '--interactive', action='store_true', help='Start interactive mode')
    parser.add_argument('--no-steps', action='store_true', help='Hide step-by-step solutions')
    parser.add_argument('--output', default='/tmp/math-plot.png', help='Output path for graphs')
    
    args = parser.parse_args()
    
    if args.interactive:
        interactive_mode()
    elif args.problem:
        problem_text = ' '.join(args.problem)
        show_steps = not args.no_steps
        process_problem(problem_text, show_steps, args.output)
    else:
        parser.print_help()

if __name__ == '__main__':
    main()
