#!/usr/bin/env python3
"""
Omi Direct BLE Client - Batch Sync Mode (Phase 1)

Connects to Omi pendant via BLE and pulls stored audio.
Handles packet reassembly and saves raw audio for decoding.

Usage:
    python3 omi_ble_client.py --scan              # Scan for devices
    python3 omi_ble_client.py --connect           # Connect and sync audio
    python3 omi_ble_client.py --sync-interval 30  # Sync every 30 minutes
"""

import asyncio
import json
import logging
import struct
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional, List

try:
    from bleak import BleakClient, BleakScanner
except ImportError:
    print("❌ bleak not installed. Install with: pip3 install --break-system-packages bleak")
    sys.exit(1)

# Configuration
OMI_DEVICE_NAME = "Omi"
OMI_DEVICE_ADDRESS = "BD3FF870-B161-ACF9-84B5-2A6B0963C2BF"
OMI_AUDIO_SERVICE = "19b10000-e8f2-537e-4f6c-d104768a1214"
OMI_AUDIO_CHAR = "19b10001-e8f2-537e-4f6c-d104768a1214"
OMI_CODEC_CHAR = "19b10002-e8f2-537e-4f6c-d104768a1214"

DATA_DIR = Path.home() / ".openclaw/workspace/projects/active/omi-direct-pipeline/omi-data"
AUDIO_DIR = DATA_DIR / "audio"
LOG_FILE = Path.home() / ".openclaw/workspace/projects/active/omi-direct-pipeline/logs/ble-client.log"

# Setup logging
LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.FileHandler(LOG_FILE),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)


class OmiAudioBuffer:
    """Buffer audio packets and handle reassembly."""
    
    def __init__(self, max_packet_number=65535):
        self.packets = {}
        self.max_packet_number = max_packet_number
        self.expected_next = 0
    
    def add_packet(self, packet_num: int, data: bytes) -> bool:
        """Add a packet. Returns True if we have a contiguous sequence ready."""
        self.packets[packet_num] = data
        
        # Check if we can assemble from expected_next
        if self.expected_next in self.packets:
            return True
        return False
    
    def get_contiguous_audio(self) -> Optional[bytes]:
        """Return all contiguous audio from expected_next, update expected_next."""
        result = b''
        
        while self.expected_next in self.packets:
            result += self.packets.pop(self.expected_next)
            self.expected_next += 1
            
            # Wrap around
            if self.expected_next > self.max_packet_number:
                self.expected_next = 0
        
        return result if result else None
    
    def clear(self):
        """Clear buffer."""
        self.packets.clear()
        self.expected_next = 0


class OmiBLEClient:
    """Direct BLE client for Omi pendant."""
    
    def __init__(self, device_address: str = OMI_DEVICE_ADDRESS):
        self.device_address = device_address
        self.client = None
        self.audio_buffer = OmiAudioBuffer()
        self.total_packets = 0
        self.session_start = None
    
    async def scan_devices(self):
        """Scan for available BLE devices."""
        logger.info("Scanning for BLE devices...")
        try:
            devices = await BleakScanner.discover()
            logger.info(f"Found {len(devices)} BLE devices")
            
            for device in devices:
                print(f"  {device.name:20s} | {device.address:20s} | RSSI: {device.rssi}")
                if device.name == OMI_DEVICE_NAME:
                    logger.info(f"✅ Found Omi device: {device.address}")
            
            return devices
        except Exception as e:
            logger.error(f"Scan failed: {e}")
            return []
    
    async def get_codec_type(self) -> str:
        """Read codec type from device."""
        try:
            if not self.client or not self.client.is_connected:
                return "unknown"
            
            codec_data = await self.client.read_gatt_char(OMI_CODEC_CHAR)
            codec_byte = codec_data[0] if codec_data else 0
            
            codec_map = {
                0: "PCM 16kHz",
                1: "PCM 8kHz",
                2: "Opus 16kHz",
                3: "Mu-law 16kHz",
                4: "Mu-law 8kHz",
            }
            
            codec = codec_map.get(codec_byte, f"Unknown ({codec_byte})")
            logger.info(f"Device codec: {codec}")
            return codec
        except Exception as e:
            logger.warning(f"Could not read codec: {e}")
            return "unknown"
    
    def notification_handler(self, sender, data: bytearray):
        """Handle incoming audio notifications."""
        if len(data) < 3:
            logger.warning(f"Invalid packet (too short): {len(data)} bytes")
            return
        
        # Parse packet header: 2-byte packet number (little-endian) + 1-byte index
        packet_num = struct.unpack('<H', data[0:2])[0]
        index = data[2]
        audio_data = bytes(data[3:])
        
        self.audio_buffer.add_packet(packet_num, audio_data)
        self.total_packets += 1
        
        # Log progress
        if self.total_packets % 100 == 0:
            logger.info(f"Received {self.total_packets} packets ({len(self.audio_buffer.packets)} buffered)")
    
    async def connect_and_sync(self, duration_seconds: int = 60) -> bool:
        """Connect to device and sync audio for specified duration."""
        try:
            logger.info(f"Connecting to {self.device_address}...")
            
            async with BleakClient(self.device_address) as client:
                self.client = client
                
                if not client.is_connected:
                    logger.error("Failed to connect")
                    return False
                
                logger.info("✅ Connected to Omi pendant")
                
                # Get codec info
                await self.get_codec_type()
                
                # Subscribe to audio notifications
                logger.info(f"Subscribing to audio characteristic...")
                await client.start_notify(OMI_AUDIO_CHAR, self.notification_handler)
                
                # Sync for duration
                self.session_start = datetime.now()
                logger.info(f"Syncing audio for {duration_seconds} seconds...")
                await asyncio.sleep(duration_seconds)
                
                # Save audio data
                await client.stop_notify(OMI_AUDIO_CHAR)
                logger.info(f"Stopped notifications. Total packets: {self.total_packets}")
                
                # Get final contiguous audio
                audio_data = self.audio_buffer.get_contiguous_audio()
                if audio_data:
                    filename = self.save_audio_data(audio_data)
                    logger.info(f"Saved audio: {filename}")
                    return True
                else:
                    logger.warning("No audio data to save")
                    return False
                
        except Exception as e:
            logger.error(f"Connection failed: {e}")
            return False
    
    def save_audio_data(self, data: bytes) -> str:
        """Save raw audio data to file."""
        AUDIO_DIR.mkdir(parents=True, exist_ok=True)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = AUDIO_DIR / f"omi_audio_{timestamp}.raw"
        
        with open(filename, 'wb') as f:
            f.write(data)
        
        logger.info(f"Saved {len(data)} bytes to {filename}")
        return str(filename)


async def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Omi Direct BLE Client")
    parser.add_argument("--scan", action="store_true", help="Scan for BLE devices")
    parser.add_argument("--connect", action="store_true", help="Connect and sync audio")
    parser.add_argument("--duration", type=int, default=30, help="Sync duration in seconds")
    parser.add_argument("--device", default=OMI_DEVICE_ADDRESS, help="Device BLE address")
    
    args = parser.parse_args()
    
    client = OmiBLEClient(args.device)
    
    if args.scan:
        await client.scan_devices()
    elif args.connect:
        success = await client.connect_and_sync(args.duration)
        sys.exit(0 if success else 1)
    else:
        parser.print_help()


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        logger.info("Interrupted by user")
    except Exception as e:
        logger.error(f"Fatal error: {e}")
        sys.exit(1)
