#!/usr/bin/env python3
"""
Whisper Transcription Pipeline

Batch transcribes WAV audio files using OpenAI Whisper CLI.
Saves transcripts as JSON with metadata.

Usage:
    python3 whisper_pipeline.py --input-dir ./audio --output-dir ./transcripts
    python3 whisper_pipeline.py --model base --language en
"""

import json
import logging
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Optional

logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s'
)


class WhisperPipeline:
    """Batch transcription with Whisper CLI."""
    
    WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
    DEFAULT_MODEL = "base"
    DEFAULT_LANGUAGE = "en"
    
    def __init__(
        self,
        model: str = DEFAULT_MODEL,
        language: str = DEFAULT_LANGUAGE,
        output_format: str = "json",
        device: str = "cpu"
    ):
        self.model = model
        self.language = language
        self.output_format = output_format
        self.device = device
        self.whisper_path = "/opt/homebrew/bin/whisper"
        
        if not Path(self.whisper_path).exists():
            raise RuntimeError(f"Whisper not found at {self.whisper_path}")
        
        logger.info(f"Using Whisper model: {model}, language: {language}")
    
    def transcribe_file(self, audio_file: Path) -> Optional[Dict]:
        """Transcribe single audio file with Whisper."""
        try:
            audio_file = Path(audio_file)
            
            if not audio_file.exists():
                logger.error(f"Audio file not found: {audio_file}")
                return None
            
            logger.info(f"Transcribing: {audio_file.name}")
            
            # Run Whisper
            cmd = [
                self.whisper_path,
                str(audio_file),
                "--model", self.model,
                "--language", self.language,
                "--output_format", "json",
                "--device", self.device,
                "--verbose", "False"
            ]
            
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
            
            if result.returncode != 0:
                logger.error(f"Whisper error: {result.stderr}")
                return None
            
            # Whisper saves JSON with same name as audio file
            json_output = audio_file.parent / f"{audio_file.stem}.json"
            
            if not json_output.exists():
                logger.error(f"Whisper output not found: {json_output}")
                return None
            
            # Read and enhance transcript
            with open(json_output) as f:
                transcript_data = json.load(f)
            
            # Add metadata
            transcript_data["source"] = "omi_ble_direct"
            transcript_data["processed_at"] = datetime.now().isoformat()
            transcript_data["audio_file"] = str(audio_file)
            transcript_data["model"] = self.model
            
            logger.info(f"✅ Transcribed: {audio_file.name}")
            logger.info(f"   Duration: {transcript_data.get('duration', 'unknown')}s")
            logger.info(f"   Language: {transcript_data.get('language', 'unknown')}")
            
            return transcript_data
            
        except subprocess.TimeoutExpired:
            logger.error(f"Transcription timed out: {audio_file}")
            return None
        except Exception as e:
            logger.error(f"Transcription failed: {e}")
            return None
    
    def process_directory(
        self,
        input_dir: Path,
        output_dir: Path,
        pattern: str = "*.wav",
        delete_after_transcript: bool = False
    ) -> List[Path]:
        """Process all audio files in directory."""
        input_dir = Path(input_dir)
        output_dir = Path(output_dir)
        
        if not input_dir.exists():
            logger.error(f"Input directory not found: {input_dir}")
            return []
        
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Find audio files
        audio_files = sorted(input_dir.glob(pattern))
        logger.info(f"Found {len(audio_files)} audio files to process")
        
        if not audio_files:
            logger.warning("No audio files found")
            return []
        
        processed_transcripts = []
        
        for i, audio_file in enumerate(audio_files, 1):
            logger.info(f"\n[{i}/{len(audio_files)}] Processing {audio_file.name}")
            
            # Transcribe
            transcript_data = self.transcribe_file(audio_file)
            
            if not transcript_data:
                logger.warning(f"Skipped: {audio_file.name}")
                continue
            
            # Save transcript to date-organized folder
            date_str = datetime.now().strftime("%Y-%m-%d")
            date_folder = output_dir / date_str
            date_folder.mkdir(parents=True, exist_ok=True)
            
            # Save with timestamp
            timestamp = datetime.now().strftime("%H%M%S")
            transcript_file = date_folder / f"transcript_{timestamp}_{audio_file.stem}.json"
            
            with open(transcript_file, 'w') as f:
                json.dump(transcript_data, f, indent=2)
            
            logger.info(f"Saved transcript: {transcript_file.name}")
            processed_transcripts.append(transcript_file)
            
            # Clean up decoded audio if requested
            if delete_after_transcript:
                try:
                    audio_file.unlink()
                    logger.info(f"Deleted: {audio_file.name}")
                except Exception as e:
                    logger.warning(f"Could not delete {audio_file.name}: {e}")
        
        logger.info(f"\n=== Summary ===")
        logger.info(f"Processed: {len(processed_transcripts)}/{len(audio_files)} files")
        
        return processed_transcripts


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Whisper Transcription Pipeline")
    parser.add_argument("--input-dir", required=True, help="Input audio directory")
    parser.add_argument("--output-dir", required=True, help="Output transcripts directory")
    parser.add_argument("--model", choices=WhisperPipeline.WHISPER_MODELS,
                        default=WhisperPipeline.DEFAULT_MODEL,
                        help="Whisper model size")
    parser.add_argument("--language", default="en", help="Language code")
    parser.add_argument("--delete-after", action="store_true",
                        help="Delete audio files after successful transcription")
    
    args = parser.parse_args()
    
    pipeline = WhisperPipeline(model=args.model, language=args.language)
    
    transcripts = pipeline.process_directory(
        Path(args.input_dir),
        Path(args.output_dir),
        delete_after_transcript=args.delete_after
    )
    
    if transcripts:
        print(f"\n✅ Successfully transcribed {len(transcripts)} files")
        sys.exit(0)
    else:
        print(f"\n❌ No files were transcribed")
        sys.exit(1)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        logger.info("Interrupted by user")
    except Exception as e:
        logger.error(f"Fatal error: {e}")
        sys.exit(1)
