#!/usr/bin/env python3
"""
Opus Audio Decoder

Converts raw Opus-encoded audio from Omi BLE packets to WAV format.
Also supports PCM and Mu-law decoding.

Usage:
    python3 opus_decoder.py <input.raw> --codec opus --sample-rate 16000
    python3 opus_decoder.py <input.raw> --codec pcm --sample-rate 16000
"""

import logging
import subprocess
import sys
from pathlib import Path
from typing import Optional

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')


class AudioDecoder:
    """Decode audio from various formats to WAV."""
    
    SUPPORTED_CODECS = ["opus", "pcm", "mulaw"]
    
    def __init__(self, codec: str = "opus", sample_rate: int = 16000, channels: int = 1):
        if codec.lower() not in self.SUPPORTED_CODECS:
            raise ValueError(f"Unsupported codec: {codec}. Supported: {self.SUPPORTED_CODECS}")
        
        self.codec = codec.lower()
        self.sample_rate = sample_rate
        self.channels = channels
    
    def decode_opus_to_wav(self, input_file: Path, output_file: Path) -> bool:
        """Decode Opus to WAV using ffmpeg."""
        try:
            # ffmpeg command for Opus decoding
            cmd = [
                "ffmpeg",
                "-f", "opus",
                "-i", str(input_file),
                "-acodec", "pcm_s16le",
                "-ar", str(self.sample_rate),
                "-ac", str(self.channels),
                "-y",  # Overwrite output
                str(output_file)
            ]
            
            logger.info(f"Decoding Opus: {input_file.name} -> {output_file.name}")
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
            
            if result.returncode != 0:
                logger.error(f"ffmpeg error: {result.stderr}")
                return False
            
            logger.info(f"✅ Decoded to WAV: {output_file}")
            return True
            
        except subprocess.TimeoutExpired:
            logger.error("Decoding timed out")
            return False
        except Exception as e:
            logger.error(f"Decoding failed: {e}")
            return False
    
    def decode_pcm_to_wav(self, input_file: Path, output_file: Path, bits: int = 16) -> bool:
        """Decode PCM to WAV using ffmpeg."""
        try:
            # FFmpeg format for raw PCM
            pcm_format = f"s{bits}le"  # signed 16-bit little-endian
            
            cmd = [
                "ffmpeg",
                "-f", pcm_format,
                "-ar", str(self.sample_rate),
                "-ac", str(self.channels),
                "-i", str(input_file),
                "-acodec", "pcm_s16le",
                "-y",
                str(output_file)
            ]
            
            logger.info(f"Decoding PCM ({bits}-bit): {input_file.name} -> {output_file.name}")
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
            
            if result.returncode != 0:
                logger.error(f"ffmpeg error: {result.stderr}")
                return False
            
            logger.info(f"✅ Decoded to WAV: {output_file}")
            return True
            
        except Exception as e:
            logger.error(f"Decoding failed: {e}")
            return False
    
    def decode_mulaw_to_wav(self, input_file: Path, output_file: Path) -> bool:
        """Decode Mu-law to WAV using ffmpeg."""
        try:
            cmd = [
                "ffmpeg",
                "-f", "mulaw",
                "-ar", str(self.sample_rate),
                "-ac", str(self.channels),
                "-i", str(input_file),
                "-acodec", "pcm_s16le",
                "-y",
                str(output_file)
            ]
            
            logger.info(f"Decoding Mu-law: {input_file.name} -> {output_file.name}")
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
            
            if result.returncode != 0:
                logger.error(f"ffmpeg error: {result.stderr}")
                return False
            
            logger.info(f"✅ Decoded to WAV: {output_file}")
            return True
            
        except Exception as e:
            logger.error(f"Decoding failed: {e}")
            return False
    
    def decode(self, input_file: Path, output_file: Optional[Path] = None) -> Optional[Path]:
        """
        Decode audio file based on codec.
        
        Args:
            input_file: Raw audio input file
            output_file: Output WAV file (auto-generated if None)
        
        Returns:
            Path to output WAV file, or None if failed
        """
        input_file = Path(input_file)
        
        if not input_file.exists():
            logger.error(f"Input file not found: {input_file}")
            return None
        
        # Auto-generate output filename
        if output_file is None:
            output_file = input_file.parent / f"{input_file.stem}_decoded.wav"
        else:
            output_file = Path(output_file)
        
        # Create parent directory
        output_file.parent.mkdir(parents=True, exist_ok=True)
        
        # Decode based on codec
        success = False
        
        if self.codec == "opus":
            success = self.decode_opus_to_wav(input_file, output_file)
        elif self.codec == "pcm":
            success = self.decode_pcm_to_wav(input_file, output_file)
        elif self.codec == "mulaw":
            success = self.decode_mulaw_to_wav(input_file, output_file)
        
        return output_file if success else None


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Decode Omi audio files to WAV")
    parser.add_argument("input", help="Input audio file (raw)")
    parser.add_argument("--output", help="Output WAV file (auto-generated if not specified)")
    parser.add_argument("--codec", choices=AudioDecoder.SUPPORTED_CODECS, default="opus",
                        help="Audio codec (default: opus)")
    parser.add_argument("--sample-rate", type=int, default=16000, help="Sample rate in Hz")
    parser.add_argument("--channels", type=int, default=1, help="Number of channels")
    
    args = parser.parse_args()
    
    decoder = AudioDecoder(
        codec=args.codec,
        sample_rate=args.sample_rate,
        channels=args.channels
    )
    
    output_file = decoder.decode(
        Path(args.input),
        Path(args.output) if args.output else None
    )
    
    if output_file:
        print(f"\n✅ Successfully decoded to: {output_file}")
        sys.exit(0)
    else:
        print(f"\n❌ Decoding failed")
        sys.exit(1)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        logger.info("Interrupted by user")
    except Exception as e:
        logger.error(f"Fatal error: {e}")
        sys.exit(1)
