#!/usr/bin/env python3
"""
CSV Toolkit — analyze, filter, transform CSV/TSV files.
Pure Python stdlib. No external dependencies. Safe — never modifies source files.
"""

import argparse
import csv
import json
import sys
from collections import Counter
from pathlib import Path
from typing import List, Dict, Optional


# ── Loader ────────────────────────────────────────────────────────────────────

def detect_delimiter(path: Path) -> str:
    """Auto-detect CSV vs TSV."""
    suffix = path.suffix.lower()
    if suffix == ".tsv":
        return "\t"
    sample = path.read_bytes()[:2048]
    for enc in ("utf-8", "latin-1"):
        try:
            text = sample.decode(enc)
            break
        except Exception:
            continue
    tabs = text.count("\t")
    commas = text.count(",")
    return "\t" if tabs > commas else ","


def load_csv(path: Path) -> tuple[List[str], List[Dict]]:
    """Load CSV, return (headers, rows). Tries UTF-8 then latin-1."""
    delim = detect_delimiter(path)
    for enc in ("utf-8-sig", "utf-8", "latin-1"):
        try:
            with open(path, newline="", encoding=enc) as f:
                reader = csv.DictReader(f, delimiter=delim)
                rows = list(reader)
                headers = reader.fieldnames or []
                return list(headers), rows
        except UnicodeDecodeError:
            continue
    print("❌ Could not decode file", file=sys.stderr)
    sys.exit(1)


# ── Filters ───────────────────────────────────────────────────────────────────

def apply_filter(rows: List[Dict], expr: str) -> List[Dict]:
    """Apply a single filter expression."""
    # Contains: col~value
    if "~" in expr:
        col, val = expr.split("~", 1)
        col = col.strip()
        val = val.strip().lower()
        return [r for r in rows if val in str(r.get(col, "")).lower()]
    # Greater than
    if ">" in expr:
        col, val = expr.split(">", 1)
        col = col.strip()
        try:
            threshold = float(val.strip())
            result = []
            for r in rows:
                raw = r.get(col, "").replace(",", "").replace("$", "").strip()
                try:
                    if float(raw) > threshold:
                        result.append(r)
                except ValueError:
                    pass
            return result
        except ValueError:
            return rows
    # Less than
    if "<" in expr:
        col, val = expr.split("<", 1)
        col = col.strip()
        try:
            threshold = float(val.strip())
            result = []
            for r in rows:
                raw = r.get(col, "").replace(",", "").replace("$", "").strip()
                try:
                    if float(raw) < threshold:
                        result.append(r)
                except ValueError:
                    pass
            return result
        except ValueError:
            return rows
    # Equality
    if "=" in expr:
        col, val = expr.split("=", 1)
        col = col.strip()
        val = val.strip()
        return [r for r in rows if r.get(col, "").strip() == val]
    return rows


# ── Stats ─────────────────────────────────────────────────────────────────────

def numeric_stats(rows: List[Dict], col: str):
    values = []
    for r in rows:
        raw = r.get(col, "").replace(",", "").replace("$", "").strip()
        try:
            values.append(float(raw))
        except ValueError:
            pass
    if not values:
        print(f"No numeric values found in column '{col}'")
        return
    values.sort()
    n = len(values)
    total = sum(values)
    mean = total / n
    median = values[n // 2] if n % 2 else (values[n // 2 - 1] + values[n // 2]) / 2
    print(f"\n📊 Stats for '{col}':")
    print(f"  Count:   {n:,}")
    print(f"  Sum:     {total:,.2f}")
    print(f"  Mean:    {mean:,.2f}")
    print(f"  Median:  {median:,.2f}")
    print(f"  Min:     {values[0]:,.2f}")
    print(f"  Max:     {values[-1]:,.2f}")


# ── Output ────────────────────────────────────────────────────────────────────

def to_markdown(headers: List[str], rows: List[Dict]) -> str:
    col_widths = {h: max(len(h), max((len(str(r.get(h, ""))) for r in rows), default=0)) for h in headers}
    header_row = "| " + " | ".join(h.ljust(col_widths[h]) for h in headers) + " |"
    sep_row = "| " + " | ".join("-" * col_widths[h] for h in headers) + " |"
    data_rows = ["| " + " | ".join(str(r.get(h, "")).ljust(col_widths[h]) for h in headers) + " |" for r in rows]
    return "\n".join([header_row, sep_row] + data_rows)


def write_csv(headers: List[str], rows: List[Dict], output: Optional[Path]):
    if output:
        with open(output, "w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=headers, extrasaction="ignore")
            writer.writeheader()
            writer.writerows(rows)
        print(f"✅ Saved {len(rows):,} rows → {output}")
    else:
        writer = csv.DictWriter(sys.stdout, fieldnames=headers, extrasaction="ignore")
        writer.writeheader()
        writer.writerows(rows)


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="CSV Toolkit — analyze and transform CSV files")
    parser.add_argument("file", help="CSV or TSV file path")
    parser.add_argument("--head", type=int, metavar="N", help="Show first N rows")
    parser.add_argument("--tail", type=int, metavar="N", help="Show last N rows")
    parser.add_argument("--info", action="store_true", help="Show column names and row count")
    parser.add_argument("--cols", metavar="COL1,COL2", help="Select specific columns")
    parser.add_argument("--filter", action="append", dest="filters", metavar="EXPR",
                        help="Filter: col=val | col>val | col<val | col~val (repeatable, AND logic)")
    parser.add_argument("--sort", metavar="COLUMN", help="Sort by column")
    parser.add_argument("--desc", action="store_true", help="Sort descending")
    parser.add_argument("--stats", metavar="COLUMN", help="Compute numeric stats for column")
    parser.add_argument("--dedup", metavar="COLUMN", help="Deduplicate by column (keep first)")
    parser.add_argument("--search", metavar="TEXT", help="Search all columns for text (case-insensitive)")
    parser.add_argument("--count-by", metavar="COLUMN", help="Count occurrences by column value")
    parser.add_argument("--markdown", action="store_true", help="Output as Markdown table")
    parser.add_argument("--json", action="store_true", dest="as_json", help="Output as JSON")
    parser.add_argument("--output", "-o", metavar="FILE", help="Save output to file")

    args = parser.parse_args()

    path = Path(args.file)
    if not path.exists():
        print(f"❌ File not found: {path}", file=sys.stderr)
        sys.exit(1)

    headers, rows = load_csv(path)

    # Info
    if args.info:
        print(f"\n📄 {path.name}")
        print(f"   Rows:    {len(rows):,}")
        print(f"   Columns: {len(headers)}")
        print(f"\nColumns:")
        for i, h in enumerate(headers, 1):
            print(f"  {i:2}. {h}")
        return

    # Stats (before filtering? — no, after filters)
    # Apply filters
    if args.filters:
        for expr in args.filters:
            rows = apply_filter(rows, expr)

    # Search
    if args.search:
        term = args.search.lower()
        rows = [r for r in rows if any(term in str(v).lower() for v in r.values())]

    # Dedup
    if args.dedup:
        col = args.dedup
        seen = set()
        deduped = []
        for r in rows:
            key = r.get(col, "")
            if key not in seen:
                seen.add(key)
                deduped.append(r)
        print(f"ℹ️  Deduplicated: {len(rows):,} → {len(deduped):,} rows (by '{col}')", file=sys.stderr)
        rows = deduped

    # Sort
    if args.sort:
        col = args.sort
        def sort_key(r):
            v = r.get(col, "").replace(",", "").replace("$", "").strip()
            try:
                return (0, float(v))
            except ValueError:
                return (1, v.lower())
        rows.sort(key=sort_key, reverse=args.desc)

    # Count by
    if args.count_by:
        col = args.count_by
        counter = Counter(r.get(col, "(blank)") for r in rows)
        print(f"\n📊 Count by '{col}' ({len(rows):,} rows):\n")
        for val, count in counter.most_common():
            pct = count / len(rows) * 100 if rows else 0
            print(f"  {count:5,}  {pct:5.1f}%  {val}")
        return

    # Stats
    if args.stats:
        numeric_stats(rows, args.stats)
        return

    # Select columns
    selected_headers = headers
    if args.cols:
        selected_headers = [c.strip() for c in args.cols.split(",")]
        rows = [{h: r.get(h, "") for h in selected_headers} for r in rows]

    # Head / tail
    if args.head:
        rows = rows[:args.head]
    elif args.tail:
        rows = rows[-args.tail:]

    if not rows:
        print("ℹ️  No rows match.", file=sys.stderr)
        return

    output_path = Path(args.output) if args.output else None

    # Output format
    if args.as_json:
        out = json.dumps(rows, indent=2, ensure_ascii=False)
        if output_path:
            output_path.write_text(out, encoding="utf-8")
            print(f"✅ Saved {len(rows):,} rows → {output_path}")
        else:
            print(out)
    elif args.markdown:
        md = to_markdown(selected_headers, rows)
        if output_path:
            output_path.write_text(md, encoding="utf-8")
            print(f"✅ Saved Markdown table → {output_path}")
        else:
            print(md)
    else:
        write_csv(selected_headers, rows, output_path)


if __name__ == "__main__":
    main()
