#!/usr/bin/env python3
"""
YouTube Transcript Fetcher
Safely fetch YouTube video transcripts and subtitles without downloading video content.
"""

import sys
import json
import re
import argparse
from typing import Optional, List, Dict
from urllib.parse import urlparse, parse_qs

try:
    from youtube_transcript_api import YouTubeTranscriptApi
    TRANSCRIPT_API_AVAILABLE = True
except ImportError:
    TRANSCRIPT_API_AVAILABLE = False
    print("Warning: youtube-transcript-api not installed. Install with: pip3 install --user --break-system-packages youtube-transcript-api", file=sys.stderr)

import subprocess


def validate_youtube_url(url: str) -> bool:
    """Validate that URL is from YouTube"""
    parsed = urlparse(url)
    allowed_domains = ['youtube.com', 'www.youtube.com', 'youtu.be', 'm.youtube.com']
    return parsed.netloc in allowed_domains


def extract_video_id(url: str) -> Optional[str]:
    """Extract video ID from YouTube URL"""
    parsed = urlparse(url)
    
    # youtu.be/VIDEO_ID format
    if parsed.netloc in ['youtu.be']:
        return parsed.path.lstrip('/')
    
    # youtube.com/watch?v=VIDEO_ID format
    if 'youtube.com' in parsed.netloc:
        query_params = parse_qs(parsed.query)
        if 'v' in query_params:
            return query_params['v'][0]
    
    # Direct video ID (11 characters, alphanumeric + - and _)
    if re.match(r'^[a-zA-Z0-9_-]{11}$', url):
        return url
    
    return None


def get_transcript_api(video_id: str, lang: str = 'en') -> Optional[List[Dict]]:
    """Fetch transcript using youtube-transcript-api"""
    if not TRANSCRIPT_API_AVAILABLE:
        return None
    
    try:
        api = YouTubeTranscriptApi()
        
        # Try to get transcript in specified language
        # Try primary language first, fallback to English if different
        languages = [lang]
        if lang != 'en':
            languages.append('en')
        
        transcript_data = api.fetch(video_id, languages=languages)
        
        # Convert to dict format
        result = []
        for snippet in transcript_data:
            result.append({
                'text': snippet.text.replace('\n', ' '),
                'start': snippet.start,
                'duration': snippet.duration
            })
        
        return result
    
    except Exception as e:
        print(f"Error fetching transcript: {e}", file=sys.stderr)
        return None


def get_transcript_ytdlp(video_id: str, lang: str = 'en') -> Optional[List[Dict]]:
    """Fetch transcript using yt-dlp as fallback"""
    try:
        # Construct URL
        url = f"https://www.youtube.com/watch?v={video_id}"
        
        # Use yt-dlp to download subtitles only
        cmd = [
            'yt-dlp',
            '--write-auto-sub',
            '--write-sub',
            '--sub-lang', lang,
            '--skip-download',
            '--sub-format', 'vtt',
            '-o', f'/tmp/yt_transcript_{video_id}',
            url
        ]
        
        result = subprocess.run(cmd, capture_output=True, text=True)
        
        if result.returncode != 0:
            print(f"yt-dlp error: {result.stderr}", file=sys.stderr)
            return None
        
        # Parse VTT file
        vtt_file = f'/tmp/yt_transcript_{video_id}.{lang}.vtt'
        try:
            with open(vtt_file, 'r', encoding='utf-8') as f:
                content = f.read()
            
            # Basic VTT parsing
            transcript = []
            lines = content.split('\n')
            i = 0
            while i < len(lines):
                line = lines[i].strip()
                # Look for timestamp lines
                if '-->' in line:
                    timestamps = line.split('-->')
                    start = parse_vtt_time(timestamps[0].strip())
                    i += 1
                    text_lines = []
                    while i < len(lines) and lines[i].strip() and '-->' not in lines[i]:
                        text_lines.append(lines[i].strip())
                        i += 1
                    if text_lines:
                        transcript.append({
                            'text': ' '.join(text_lines),
                            'start': start,
                            'duration': 0  # VTT doesn't provide duration directly
                        })
                i += 1
            
            # Clean up
            import os
            if os.path.exists(vtt_file):
                os.remove(vtt_file)
            
            return transcript if transcript else None
            
        except FileNotFoundError:
            print(f"Subtitle file not found: {vtt_file}", file=sys.stderr)
            return None
    
    except Exception as e:
        print(f"Error using yt-dlp: {e}", file=sys.stderr)
        return None


def parse_vtt_time(time_str: str) -> float:
    """Convert VTT timestamp to seconds"""
    time_str = time_str.strip()
    # Format: HH:MM:SS.mmm or MM:SS.mmm
    parts = time_str.split(':')
    if len(parts) == 3:
        h, m, s = parts
        return int(h) * 3600 + int(m) * 60 + float(s)
    elif len(parts) == 2:
        m, s = parts
        return int(m) * 60 + float(s)
    return 0.0


def get_metadata_ytdlp(video_id: str) -> Optional[Dict]:
    """Get video metadata using yt-dlp"""
    try:
        url = f"https://www.youtube.com/watch?v={video_id}"
        cmd = [
            'yt-dlp',
            '--skip-download',
            '--print', '%(title)s|||%(duration)s|||%(channel)s|||%(upload_date)s',
            url
        ]
        
        result = subprocess.run(cmd, capture_output=True, text=True)
        
        if result.returncode == 0:
            output = result.stdout.strip()
            parts = output.split('|||')
            if len(parts) >= 3:
                return {
                    'title': parts[0],
                    'duration': int(parts[1]) if parts[1] != 'NA' else 0,
                    'channel': parts[2],
                    'upload_date': parts[3] if len(parts) > 3 else None
                }
        return None
    except Exception as e:
        print(f"Error getting metadata: {e}", file=sys.stderr)
        return None


def list_available_languages(video_id: str) -> List[str]:
    """List available subtitle languages"""
    if not TRANSCRIPT_API_AVAILABLE:
        print("youtube-transcript-api required for listing languages", file=sys.stderr)
        return []
    
    try:
        api = YouTubeTranscriptApi()
        # The list() method returns a formatted string, so we'll just print it
        # and return a simple list
        transcript_info = str(api.list(video_id))
        print(transcript_info, file=sys.stderr)
        return []  # The info is already printed
    except Exception as e:
        print(f"Error listing languages: {e}", file=sys.stderr)
        return []


def format_transcript(transcript: List[Dict], format_type: str = 'text') -> str:
    """Format transcript in various output formats"""
    if format_type == 'json':
        return json.dumps(transcript, indent=2, ensure_ascii=False)
    
    elif format_type == 'srt':
        # SRT format
        srt_output = []
        for i, entry in enumerate(transcript, 1):
            start_time = format_srt_time(entry['start'])
            end_time = format_srt_time(entry['start'] + entry.get('duration', 0))
            srt_output.append(f"{i}\n{start_time} --> {end_time}\n{entry['text']}\n")
        return '\n'.join(srt_output)
    
    elif format_type == 'timestamped':
        # Timestamped text
        output = []
        for entry in transcript:
            timestamp = format_timestamp(entry['start'])
            output.append(f"[{timestamp}] {entry['text']}")
        return '\n'.join(output)
    
    else:  # 'text'
        # Plain text
        return '\n'.join([entry['text'] for entry in transcript])


def format_srt_time(seconds: float) -> str:
    """Format seconds as SRT timestamp (HH:MM:SS,mmm)"""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    millis = int((seconds % 1) * 1000)
    return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"


def format_timestamp(seconds: float) -> str:
    """Format seconds as readable timestamp (MM:SS)"""
    minutes = int(seconds // 60)
    secs = int(seconds % 60)
    return f"{minutes:02d}:{secs:02d}"


def process_video(url: str, args) -> bool:
    """Process a single video"""
    # Validate URL
    if not validate_youtube_url(url):
        print(f"Error: Invalid YouTube URL: {url}", file=sys.stderr)
        return False
    
    # Extract video ID
    video_id = extract_video_id(url)
    if not video_id:
        print(f"Error: Could not extract video ID from: {url}", file=sys.stderr)
        return False
    
    print(f"Processing video ID: {video_id}", file=sys.stderr)
    
    # List languages if requested
    if args.list_langs:
        print(f"\nAvailable subtitle languages for {video_id}:")
        languages = list_available_languages(video_id)
        for lang in languages:
            print(f"  - {lang}")
        return True
    
    # Get transcript
    transcript = None
    if TRANSCRIPT_API_AVAILABLE:
        transcript = get_transcript_api(video_id, args.lang)
    
    if not transcript:
        print("Falling back to yt-dlp...", file=sys.stderr)
        transcript = get_transcript_ytdlp(video_id, args.lang)
    
    if not transcript:
        print(f"Error: Could not fetch transcript for {video_id}", file=sys.stderr)
        return False
    
    # Get metadata if requested
    metadata = None
    if args.metadata:
        metadata = get_metadata_ytdlp(video_id)
        if metadata:
            print(f"\n=== Metadata ===", file=sys.stderr)
            print(f"Title: {metadata['title']}", file=sys.stderr)
            print(f"Channel: {metadata['channel']}", file=sys.stderr)
            print(f"Duration: {metadata['duration']}s", file=sys.stderr)
            if metadata.get('upload_date'):
                print(f"Uploaded: {metadata['upload_date']}", file=sys.stderr)
            print("", file=sys.stderr)
    
    # Format and output transcript
    output = format_transcript(transcript, args.format)
    print(output)
    
    return True


def main():
    parser = argparse.ArgumentParser(
        description='Fetch YouTube video transcripts safely (no video download)'
    )
    parser.add_argument('url', nargs='?', help='YouTube URL or video ID')
    parser.add_argument('--lang', default='en', help='Subtitle language code (default: en)')
    parser.add_argument('--format', choices=['text', 'timestamped', 'srt', 'json'],
                       default='text', help='Output format (default: text)')
    parser.add_argument('--metadata', action='store_true',
                       help='Include video metadata')
    parser.add_argument('--list-langs', action='store_true',
                       help='List available subtitle languages')
    parser.add_argument('--batch', help='Process multiple URLs from a file')
    
    args = parser.parse_args()
    
    # Batch processing
    if args.batch:
        try:
            with open(args.batch, 'r') as f:
                urls = [line.strip() for line in f if line.strip()]
            
            success_count = 0
            for url in urls:
                print(f"\n{'='*60}", file=sys.stderr)
                print(f"Processing: {url}", file=sys.stderr)
                print('='*60, file=sys.stderr)
                if process_video(url, args):
                    success_count += 1
            
            print(f"\n\nProcessed {success_count}/{len(urls)} videos successfully", file=sys.stderr)
            return 0 if success_count > 0 else 1
            
        except FileNotFoundError:
            print(f"Error: Batch file not found: {args.batch}", file=sys.stderr)
            return 1
    
    # Single video processing
    if not args.url:
        parser.print_help()
        return 1
    
    return 0 if process_video(args.url, args) else 1


if __name__ == '__main__':
    sys.exit(main())
