#!/usr/bin/env python3
"""
Network Scanner — macOS/Linux compatible local network device discovery.
Uses nmap for scanning. Only scans private RFC 1918 networks (safety enforced).

Usage:
    scan.py                        # auto-detect current network
    scan.py 192.168.1.0/24         # specific CIDR
    scan.py home                   # named network from config
    scan.py --json                 # JSON output
    scan.py 192.168.1.0/24 --json  # CIDR + JSON
"""

import subprocess
import json
import sys
import os
import argparse
import ipaddress
import platform
import time
from datetime import datetime
from pathlib import Path

CONFIG_PATH = Path.home() / ".config" / "network-scanner" / "networks.json"

# Common MAC vendor prefixes
MAC_VENDORS = {
    "b8:27:eb": "Raspberry Pi", "dc:a6:32": "Raspberry Pi", "e4:5f:01": "Raspberry Pi",
    "d8:3a:dd": "Raspberry Pi", "00:17:88": "Philips Hue",
    "94:9f:3e": "Sonos", "b8:e9:37": "Sonos", "78:28:ca": "Sonos",
    "f0:99:bf": "Apple", "3c:22:fb": "Apple", "a4:83:e7": "Apple",
    "a4:5e:60": "Apple", "14:98:77": "Apple", "64:a2:00": "Apple",
    "fc:ec:da": "Ubiquiti", "78:8a:20": "Ubiquiti", "74:ac:b9": "Ubiquiti",
    "9c:05:d6": "Synology", "00:11:32": "Synology",
    "00:0c:29": "VMware", "00:50:56": "VMware",
    "70:ee:50": "Netatmo", "24:5a:4c": "Ubiquiti",
}


def is_private_network(cidr: str) -> tuple[bool, str]:
    """Verify CIDR is a private RFC 1918 network. Returns (is_safe, reason)."""
    try:
        net = ipaddress.ip_network(cidr, strict=False)
        if not net.is_private:
            return False, f"❌ {cidr} is a PUBLIC network — scanning blocked to prevent abuse"
        return True, "ok"
    except ValueError as e:
        return False, f"❌ Invalid CIDR: {e}"


def get_default_gateway_macos() -> str | None:
    """Get default gateway IP on macOS."""
    try:
        result = subprocess.run(
            ["route", "-n", "get", "default"],
            capture_output=True, text=True, timeout=5
        )
        for line in result.stdout.splitlines():
            if "gateway:" in line:
                return line.split("gateway:")[-1].strip()
    except Exception:
        pass
    return None


def get_default_gateway_linux() -> str | None:
    """Get default gateway IP on Linux."""
    try:
        result = subprocess.run(
            ["ip", "route", "show", "default"],
            capture_output=True, text=True, timeout=5
        )
        parts = result.stdout.split()
        if "via" in parts:
            return parts[parts.index("via") + 1]
    except Exception:
        pass
    return None


def auto_detect_network() -> str | None:
    """Detect the current local network CIDR."""
    is_mac = platform.system() == "Darwin"

    # Get local IP via connection trick
    try:
        import socket
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.connect(("8.8.8.8", 80))
        local_ip = s.getsockname()[0]
        s.close()

        # Assume /24
        parts = local_ip.rsplit(".", 1)
        return f"{parts[0]}.0/24"
    except Exception:
        pass

    return None


def load_named_networks() -> dict:
    """Load named networks from config file."""
    if CONFIG_PATH.exists():
        try:
            with open(CONFIG_PATH) as f:
                data = json.load(f)
                return data.get("networks", {})
        except Exception:
            pass
    return {}


def lookup_vendor(mac: str) -> str:
    """Lookup vendor from MAC prefix."""
    if not mac or mac == "N/A":
        return ""
    prefix = mac[:8].lower()
    return MAC_VENDORS.get(prefix, "")


def run_nmap(cidr: str) -> list[dict]:
    """Run nmap scan and parse results."""
    devices = []

    # Build nmap command
    # -sn: ping scan (no port scan, faster)
    # -T4: aggressive timing
    # --host-timeout: don't hang on slow hosts
    cmd = ["nmap", "-sn", "-T4", "--host-timeout", "5s", cidr]

    # Try sudo for MAC address discovery; fall back without
    use_sudo = os.geteuid() == 0  # Already root
    if not use_sudo:
        # Check if sudo is available without password (for nmap -sn it's needed for ARP)
        try:
            test = subprocess.run(
                ["sudo", "-n", "nmap", "--version"],
                capture_output=True, timeout=3
            )
            use_sudo = test.returncode == 0
        except Exception:
            use_sudo = False

    if use_sudo and os.geteuid() != 0:
        cmd = ["sudo"] + cmd

    try:
        result = subprocess.run(
            cmd, capture_output=True, text=True, timeout=120
        )
        output = result.stdout
    except subprocess.TimeoutExpired:
        print("⚠️  Scan timed out (2 min limit)", file=sys.stderr)
        return []
    except FileNotFoundError:
        print("❌ nmap not found. Install with: brew install nmap", file=sys.stderr)
        sys.exit(1)

    # Parse nmap output
    current_device = {}
    for line in output.splitlines():
        line = line.strip()
        if line.startswith("Nmap scan report for"):
            # Save previous
            if current_device.get("ip"):
                devices.append(current_device)

            # Parse: "Nmap scan report for hostname (192.168.1.1)" or just IP
            parts = line.replace("Nmap scan report for ", "")
            if "(" in parts:
                hostname = parts.split("(")[0].strip()
                ip = parts.split("(")[1].rstrip(")")
            else:
                hostname = ""
                ip = parts.strip()

            current_device = {"ip": ip, "hostname": hostname, "mac": "N/A", "vendor": ""}

        elif line.startswith("MAC Address:") and current_device:
            # "MAC Address: AA:BB:CC:DD:EE:FF (Vendor Name)"
            mac_part = line.replace("MAC Address: ", "")
            if "(" in mac_part:
                mac = mac_part.split("(")[0].strip()
                vendor = mac_part.split("(")[1].rstrip(")")
            else:
                mac = mac_part.strip()
                vendor = lookup_vendor(mac)
            current_device["mac"] = mac
            current_device["vendor"] = vendor or lookup_vendor(mac)

    # Don't forget last device
    if current_device.get("ip"):
        devices.append(current_device)

    return devices


def format_table(devices: list[dict], cidr: str, duration: float) -> str:
    """Format devices as a readable table."""
    lines = [
        f"🔍 Network Scan: {cidr}",
        f"Scanned: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | Duration: {duration:.1f}s | Devices: {len(devices)}",
        "",
        f"{'IP':<18} {'Hostname':<28} {'MAC':<20} {'Vendor'}",
        f"{'─'*18} {'─'*28} {'─'*20} {'─'*20}",
    ]
    for d in sorted(devices, key=lambda x: [int(p) for p in x["ip"].split(".")]):
        hostname = d.get("hostname", "") or ""
        mac = d.get("mac", "N/A")
        vendor = d.get("vendor", "")
        lines.append(f"{d['ip']:<18} {hostname:<28} {mac:<20} {vendor}")
    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description="Scan local network for devices")
    parser.add_argument("target", nargs="?", help="CIDR, named network, or 'auto'")
    parser.add_argument("--json", action="store_true", help="Output as JSON")
    parser.add_argument("--config", default=str(CONFIG_PATH), help="Config file path")
    args = parser.parse_args()

    # Resolve target
    if not args.target or args.target == "auto":
        cidr = auto_detect_network()
        if not cidr:
            print("❌ Could not auto-detect network. Specify a CIDR.", file=sys.stderr)
            sys.exit(1)
        print(f"ℹ️  Auto-detected network: {cidr}", file=sys.stderr)
    elif "/" in args.target:
        cidr = args.target
    else:
        # Named network lookup
        networks = load_named_networks()
        if args.target not in networks:
            print(f"❌ Unknown network '{args.target}'. Available: {list(networks.keys())}", file=sys.stderr)
            sys.exit(1)
        cidr = networks[args.target]["cidr"]
        print(f"ℹ️  Using named network '{args.target}': {cidr}", file=sys.stderr)

    # Safety check
    safe, reason = is_private_network(cidr)
    if not safe:
        print(reason, file=sys.stderr)
        sys.exit(1)

    # Scan
    start = time.time()
    devices = run_nmap(cidr)
    duration = time.time() - start

    if args.json:
        output = {
            "cidr": cidr,
            "scanned_at": datetime.now().isoformat(),
            "duration_seconds": round(duration, 2),
            "device_count": len(devices),
            "devices": devices
        }
        print(json.dumps(output, indent=2))
    else:
        print(format_table(devices, cidr, duration))


if __name__ == "__main__":
    main()
