#!/usr/bin/env python3
"""Apple Health Export — Flexible Query Tool
Usage:
  python3 health_query.py --type StepCount --days 30
  python3 health_query.py --type HeartRate --days 7 --output /tmp/hr.csv
  python3 health_query.py --list-types
"""

import argparse
import csv
import os
import sys
import xml.etree.ElementTree as ET
from collections import defaultdict
from datetime import datetime, timedelta
from pathlib import Path

EXPORT_PATH = Path.home() / "Downloads/HealthExport/apple_health_export/export.xml"

# Short names → HKIdentifier mapping
SHORT_NAMES = {
    "StepCount": "HKQuantityTypeIdentifierStepCount",
    "HeartRate": "HKQuantityTypeIdentifierHeartRate",
    "BodyMass": "HKQuantityTypeIdentifierBodyMass",
    "ActiveEnergyBurned": "HKQuantityTypeIdentifierActiveEnergyBurned",
    "DistanceWalkingRunning": "HKQuantityTypeIdentifierDistanceWalkingRunning",
    "RestingHeartRate": "HKQuantityTypeIdentifierRestingHeartRate",
    "BloodPressureSystolic": "HKQuantityTypeIdentifierBloodPressureSystolic",
    "BloodPressureDiastolic": "HKQuantityTypeIdentifierBloodPressureDiastolic",
    "SleepAnalysis": "HKCategoryTypeIdentifierSleepAnalysis",
    "BodyFatPercentage": "HKQuantityTypeIdentifierBodyFatPercentage",
}

def find_export():
    candidates = [
        EXPORT_PATH,
        Path.home() / "Desktop/apple_health_export/export.xml",
        Path("/tmp/apple_health_export/export.xml"),
    ]
    for p in candidates:
        if p.exists():
            return p
    return None

def resolve_type(type_arg):
    if type_arg in SHORT_NAMES:
        return SHORT_NAMES[type_arg]
    return type_arg  # assume full HK identifier

def query(export_path, hk_type, days):
    cutoff = datetime.now() - timedelta(days=days)
    records = []

    tree = ET.parse(export_path)
    root = tree.getroot()

    for record in root.iter("Record"):
        rtype = record.get("type", "")
        if rtype != hk_type:
            continue
        date_str = record.get("startDate", "")[:19]
        try:
            dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S")
        except ValueError:
            continue
        if dt < cutoff:
            continue
        val = record.get("value", record.get("value", "N/A"))
        unit = record.get("unit", "")
        records.append((dt, val, unit))

    return sorted(records, key=lambda x: x[0])

def list_types(export_path):
    """List all record types in the export."""
    types = defaultdict(int)
    tree = ET.parse(export_path)
    root = tree.getroot()
    for record in root.iter("Record"):
        rtype = record.get("type", "unknown")
        # Shorten HKQuantityTypeIdentifier → just the metric name
        short = rtype.replace("HKQuantityTypeIdentifier", "").replace("HKCategoryTypeIdentifier", "")
        types[short] += 1

    print("\nAvailable health types in your export:\n")
    for t, count in sorted(types.items(), key=lambda x: -x[1]):
        print(f"  {t:45s} {count:>6,} records")
    print()

def main():
    parser = argparse.ArgumentParser(description="Query Apple Health export")
    parser.add_argument("--type", help="Metric type (e.g. StepCount, HeartRate)")
    parser.add_argument("--days", type=int, default=30, help="Days back to query (default: 30)")
    parser.add_argument("--output", help="CSV output path")
    parser.add_argument("--list-types", action="store_true", help="List all available types")
    args = parser.parse_args()

    export = find_export()
    if not export:
        print(f"ERROR: Health export not found at {EXPORT_PATH}")
        print("  Export: iPhone → Health → Profile icon → Export All Health Data")
        print(f"  Unzip to: {EXPORT_PATH.parent}")
        sys.exit(1)

    if args.list_types:
        list_types(export)
        return

    if not args.type:
        parser.print_help()
        sys.exit(1)

    hk_type = resolve_type(args.type)
    print(f"Querying: {hk_type} | Last {args.days} days")
    records = query(export, hk_type, args.days)

    if not records:
        print("No records found.")
        return

    if args.output:
        with open(args.output, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["datetime", "value", "unit"])
            for dt, val, unit in records:
                writer.writerow([dt.strftime("%Y-%m-%d %H:%M:%S"), val, unit])
        print(f"Exported {len(records)} records → {args.output}")
    else:
        # Print table
        print(f"\n{'Date':20s} {'Value':>12} {'Unit':10s}")
        print("-" * 45)
        for dt, val, unit in records[-50:]:  # last 50 to avoid wall of text
            print(f"  {dt.strftime('%Y-%m-%d %H:%M'):18s} {str(val):>12} {unit}")
        if len(records) > 50:
            print(f"  ... (showing last 50 of {len(records)} records)")
        print(f"\n  Total records: {len(records)}")
        try:
            vals = [float(v) for _, v, _ in records]
            print(f"  Avg: {sum(vals)/len(vals):.1f}  Min: {min(vals):.1f}  Max: {max(vals):.1f}")
        except (ValueError, ZeroDivisionError):
            pass

if __name__ == "__main__":
    main()
