#!/usr/bin/env python3
"""
Safe file deduplication tool for OpenClaw
Uses SHA-256 hashing and macOS trash for recoverable removal
"""

import os
import sys
import hashlib
import json
import argparse
import shutil
import subprocess
from pathlib import Path
from datetime import datetime
from collections import defaultdict
from typing import Dict, List, Set, Tuple

# Default exclusion patterns
DEFAULT_EXCLUDES = [
    '.git', '.svn', '.hg', 
    'node_modules', '__pycache__', '.pytest_cache',
    '.DS_Store', 'Thumbs.db', 'desktop.ini',
    '*.tmp', '*.temp', '*.cache'
]

class DuplicateFinder:
    def __init__(self, method='content', min_size=0, max_size=None, excludes=None):
        self.method = method
        self.min_size = min_size
        self.max_size = max_size
        self.excludes = excludes or DEFAULT_EXCLUDES
        self.scan_results = {}
        self.file_groups = defaultdict(list)
        
    def should_exclude(self, path: Path) -> bool:
        """Check if file should be excluded based on patterns"""
        path_str = str(path)
        name = path.name
        
        # Skip hidden files
        if name.startswith('.') and name not in ['.gitkeep']:
            return True
            
        # Check exclusion patterns
        for pattern in self.excludes:
            if pattern.startswith('*.'):
                if name.endswith(pattern[1:]):
                    return True
            elif pattern in path_str:
                return True
                
        return False
    
    def hash_file(self, path: Path, partial=False) -> str:
        """Generate SHA-256 hash of file content"""
        hasher = hashlib.sha256()
        
        try:
            with open(path, 'rb') as f:
                if partial:
                    # Read first 8KB for fast mode
                    chunk = f.read(8192)
                    hasher.update(chunk)
                else:
                    # Read entire file
                    while chunk := f.read(65536):
                        hasher.update(chunk)
            return hasher.hexdigest()
        except (IOError, OSError) as e:
            print(f"⚠️  Error reading {path}: {e}", file=sys.stderr)
            return None
    
    def scan_directories(self, directories: List[str], progress=True) -> Dict:
        """Scan directories for duplicate files"""
        all_files = []
        
        # Collect all files
        print("📂 Scanning directories...")
        for directory in directories:
            dir_path = Path(directory).resolve()
            if not dir_path.exists():
                print(f"⚠️  Directory not found: {directory}", file=sys.stderr)
                continue
                
            for root, dirs, files in os.walk(dir_path):
                # Filter out excluded directories
                dirs[:] = [d for d in dirs if not self.should_exclude(Path(root) / d)]
                
                for filename in files:
                    file_path = Path(root) / filename
                    
                    if self.should_exclude(file_path):
                        continue
                        
                    if not file_path.is_file():
                        continue
                        
                    try:
                        size = file_path.stat().st_size
                        
                        # Apply size filters
                        if size < self.min_size:
                            continue
                        if self.max_size and size > self.max_size:
                            continue
                            
                        all_files.append(file_path)
                        
                    except (OSError, IOError):
                        continue
        
        print(f"✓ Found {len(all_files)} files to analyze\n")
        
        # Group files by detection method
        if self.method == 'content':
            self._group_by_content_hash(all_files, progress)
        elif self.method == 'size':
            self._group_by_size_and_partial_hash(all_files, progress)
        elif self.method == 'name':
            self._group_by_name(all_files, progress)
        
        # Generate results
        return self._generate_results()
    
    def _group_by_content_hash(self, files: List[Path], progress: bool):
        """Group files by full content hash"""
        print("🔍 Computing SHA-256 hashes...")
        total = len(files)
        
        for i, file_path in enumerate(files, 1):
            if progress and i % 100 == 0:
                print(f"   {i}/{total} files processed...")
                
            file_hash = self.hash_file(file_path)
            if file_hash:
                self.file_groups[file_hash].append(file_path)
    
    def _group_by_size_and_partial_hash(self, files: List[Path], progress: bool):
        """Group files by size + partial hash (fast mode)"""
        print("🔍 Computing sizes and partial hashes...")
        size_groups = defaultdict(list)
        
        # First group by size
        for file_path in files:
            try:
                size = file_path.stat().st_size
                size_groups[size].append(file_path)
            except (OSError, IOError):
                continue
        
        # Then hash files with matching sizes
        total = sum(len(group) for group in size_groups.values() if len(group) > 1)
        processed = 0
        
        for size, file_list in size_groups.items():
            if len(file_list) > 1:
                for file_path in file_list:
                    processed += 1
                    if progress and processed % 100 == 0:
                        print(f"   {processed}/{total} potential duplicates checked...")
                        
                    file_hash = f"{size}:{self.hash_file(file_path, partial=True)}"
                    if file_hash:
                        self.file_groups[file_hash].append(file_path)
    
    def _group_by_name(self, files: List[Path], progress: bool):
        """Group files by exact name match"""
        print("🔍 Grouping by filename...")
        for file_path in files:
            self.file_groups[file_path.name].append(file_path)
    
    def _generate_results(self) -> Dict:
        """Generate scan results"""
        duplicate_groups = []
        total_duplicates = 0
        total_wasted_space = 0
        
        for key, file_list in self.file_groups.items():
            if len(file_list) > 1:
                files_info = []
                for file_path in file_list:
                    try:
                        stat = file_path.stat()
                        files_info.append({
                            'path': str(file_path),
                            'size': stat.st_size,
                            'modified': datetime.fromtimestamp(stat.st_mtime).isoformat(),
                            'created': datetime.fromtimestamp(stat.st_ctime).isoformat()
                        })
                    except (OSError, IOError):
                        continue
                
                if len(files_info) > 1:
                    # Calculate wasted space (size of all but one copy)
                    size = files_info[0]['size']
                    wasted = size * (len(files_info) - 1)
                    total_wasted_space += wasted
                    total_duplicates += len(files_info) - 1
                    
                    duplicate_groups.append({
                        'key': key,
                        'count': len(files_info),
                        'size': size,
                        'wasted_space': wasted,
                        'files': files_info
                    })
        
        return {
            'scan_time': datetime.now().isoformat(),
            'method': self.method,
            'total_groups': len(duplicate_groups),
            'total_duplicates': total_duplicates,
            'total_wasted_space': total_wasted_space,
            'groups': duplicate_groups
        }

class DuplicateCleaner:
    def __init__(self, keep_strategy='newest'):
        self.keep_strategy = keep_strategy
        self.trash_command = self._find_trash_command()
        
    def _find_trash_command(self) -> str:
        """Find available trash command"""
        if shutil.which('trash'):
            return 'trash'
        elif shutil.which('trash-put'):
            return 'trash-put'
        else:
            return None
    
    def check_trash_available(self) -> bool:
        """Check if trash command is available"""
        if not self.trash_command:
            print("❌ ERROR: No trash command found!", file=sys.stderr)
            print("   Install with: brew install trash", file=sys.stderr)
            print("   or: brew install trash-cli", file=sys.stderr)
            return False
        return True
    
    def select_files_to_keep(self, group: Dict) -> Tuple[Dict, List[Dict]]:
        """Select which file to keep and which to remove"""
        files = group['files']
        
        if self.keep_strategy == 'newest':
            files_sorted = sorted(files, key=lambda f: f['modified'], reverse=True)
        elif self.keep_strategy == 'oldest':
            files_sorted = sorted(files, key=lambda f: f['modified'])
        elif self.keep_strategy == 'largest':
            files_sorted = sorted(files, key=lambda f: f['size'], reverse=True)
        elif self.keep_strategy == 'smallest':
            files_sorted = sorted(files, key=lambda f: f['size'])
        else:  # first-found
            files_sorted = files
        
        keep = files_sorted[0]
        remove = files_sorted[1:]
        
        return keep, remove
    
    def clean_duplicates(self, scan_results: Dict, action='trash', archive_path=None, dry_run=True, confirm=False):
        """Remove duplicate files"""
        
        if action == 'trash' and not self.check_trash_available():
            return False
        
        if action == 'archive' and archive_path:
            archive_dir = Path(archive_path)
            archive_dir.mkdir(parents=True, exist_ok=True)
        
        # Build removal plan
        files_to_keep = []
        files_to_remove = []
        total_space_to_free = 0
        
        for group in scan_results['groups']:
            keep, remove = self.select_files_to_keep(group)
            files_to_keep.append(keep)
            
            for file_info in remove:
                files_to_remove.append(file_info)
                total_space_to_free += file_info['size']
        
        # Show summary
        print(f"\n{'='*70}")
        print(f"DUPLICATE REMOVAL SUMMARY")
        print(f"{'='*70}")
        print(f"Keep strategy: {self.keep_strategy}")
        print(f"Action: {action}")
        print(f"Files to keep: {len(files_to_keep)}")
        print(f"Files to remove: {len(files_to_remove)}")
        print(f"Space to free: {self._format_size(total_space_to_free)}")
        print(f"{'='*70}\n")
        
        # Show first few examples
        print("Examples of files that will be removed:")
        for file_info in files_to_remove[:5]:
            print(f"  ❌ {file_info['path']}")
        if len(files_to_remove) > 5:
            print(f"  ... and {len(files_to_remove) - 5} more\n")
        
        if dry_run:
            print("🔒 DRY-RUN MODE: No files will be removed")
            print("   Add --confirm to actually remove files\n")
            return True
        
        # Final confirmation
        if confirm:
            response = input(f"\n⚠️  Really {action} {len(files_to_remove)} files? [y/N]: ")
            if response.lower() != 'y':
                print("❌ Cancelled by user")
                return False
        
        # Execute removal
        log_file = Path.home() / '.openclaw' / 'workspace' / 'logs' / f"dedup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
        log_file.parent.mkdir(parents=True, exist_ok=True)
        
        success_count = 0
        error_count = 0
        
        print(f"\n🗑️  {'Archiving' if action == 'archive' else 'Trashing'} duplicates...")
        
        with open(log_file, 'w') as log:
            log.write(f"Deduplication log - {datetime.now().isoformat()}\n")
            log.write(f"Method: {scan_results['method']}\n")
            log.write(f"Keep strategy: {self.keep_strategy}\n")
            log.write(f"Action: {action}\n\n")
            
            for file_info in files_to_remove:
                file_path = Path(file_info['path'])
                
                try:
                    if action == 'archive':
                        # Move to archive
                        rel_path = file_path.relative_to(file_path.anchor)
                        dest = archive_dir / rel_path
                        dest.parent.mkdir(parents=True, exist_ok=True)
                        shutil.move(str(file_path), str(dest))
                        log.write(f"ARCHIVED: {file_path} -> {dest}\n")
                    else:
                        # Use trash command (subprocess avoids shell quoting issues with apostrophes)
                        result = subprocess.run([self.trash_command, str(file_path)], capture_output=True)
                        if result.returncode == 0:
                            log.write(f"TRASHED: {file_path}\n")
                        else:
                            raise Exception(f"trash command returned {result.returncode}")
                    
                    success_count += 1
                    
                except Exception as e:
                    error_count += 1
                    log.write(f"ERROR: {file_path} - {e}\n")
                    print(f"   ⚠️  Failed: {file_path}", file=sys.stderr)
        
        print(f"\n✓ Complete!")
        print(f"  Success: {success_count}")
        print(f"  Errors: {error_count}")
        print(f"  Log: {log_file}\n")
        
        return True
    
    def _format_size(self, bytes: int) -> str:
        """Format bytes as human-readable size"""
        for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
            if bytes < 1024:
                return f"{bytes:.2f} {unit}"
            bytes /= 1024
        return f"{bytes:.2f} PB"

def cmd_scan(args):
    """Scan for duplicates"""
    finder = DuplicateFinder(
        method=args.method,
        min_size=args.min_size,
        max_size=args.max_size,
        excludes=args.exclude if args.exclude else DEFAULT_EXCLUDES
    )
    
    results = finder.scan_directories(args.directories, progress=True)
    
    if args.json:
        print(json.dumps(results, indent=2))
    else:
        print("\n" + "="*70)
        print("DUPLICATE SCAN RESULTS")
        print("="*70)
        print(f"Method: {results['method']}")
        print(f"Duplicate groups: {results['total_groups']}")
        print(f"Total duplicate files: {results['total_duplicates']}")
        print(f"Wasted space: {DuplicateCleaner()._format_size(results['total_wasted_space'])}")
        print("="*70 + "\n")
        
        # Show first few groups
        for i, group in enumerate(results['groups'][:3], 1):
            print(f"Group {i}: {group['count']} copies, {DuplicateCleaner()._format_size(group['size'])} each")
            for file_info in group['files']:
                print(f"  • {file_info['path']}")
                print(f"    Modified: {file_info['modified']}")
            print()
        
        if len(results['groups']) > 3:
            print(f"... and {len(results['groups']) - 3} more groups\n")
    
    # Save results
    output_file = Path.home() / '.openclaw' / 'workspace' / 'logs' / f"scan_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    output_file.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"💾 Results saved to: {output_file}")
    return results

def cmd_clean(args):
    """Clean duplicate files"""
    # First scan
    finder = DuplicateFinder(
        method=args.method,
        min_size=args.min_size,
        max_size=args.max_size,
        excludes=args.exclude if args.exclude else DEFAULT_EXCLUDES
    )
    
    results = finder.scan_directories(args.directories, progress=True)
    
    if results['total_duplicates'] == 0:
        print("\n✓ No duplicates found!")
        return
    
    # Then clean
    cleaner = DuplicateCleaner(keep_strategy=args.keep)
    
    dry_run = not args.confirm
    
    cleaner.clean_duplicates(
        results,
        action=args.action,
        archive_path=args.archive_path,
        dry_run=dry_run,
        confirm=args.confirm
    )

def cmd_report(args):
    """Display a previous scan result"""
    report_path = Path(args.report_file)
    
    if not report_path.exists():
        print(f"❌ Report file not found: {args.report_file}", file=sys.stderr)
        sys.exit(1)
    
    with open(report_path) as f:
        results = json.load(f)
    
    print("\n" + "="*70)
    print("DUPLICATE SCAN REPORT")
    print("="*70)
    print(f"Scan time: {results['scan_time']}")
    print(f"Method: {results['method']}")
    print(f"Duplicate groups: {results['total_groups']}")
    print(f"Total duplicate files: {results['total_duplicates']}")
    print(f"Wasted space: {DuplicateCleaner()._format_size(results['total_wasted_space'])}")
    print("="*70 + "\n")
    
    for i, group in enumerate(results['groups'], 1):
        print(f"Group {i}: {group['count']} copies, {DuplicateCleaner()._format_size(group['size'])} each")
        for file_info in group['files']:
            print(f"  • {file_info['path']}")
            print(f"    Modified: {file_info['modified']}")
        print()

def main():
    parser = argparse.ArgumentParser(
        description='Safe file deduplication tool - Uses SHA-256 and trash for safety'
    )
    subparsers = parser.add_subparsers(dest='command', help='Command to run')
    
    # Scan command
    scan_parser = subparsers.add_parser('scan', help='Scan for duplicate files')
    scan_parser.add_argument('directories', nargs='+', help='Directories to scan')
    scan_parser.add_argument('--method', choices=['content', 'size', 'name'], default='content',
                           help='Detection method (default: content)')
    scan_parser.add_argument('--min-size', type=int, default=0, help='Minimum file size in bytes')
    scan_parser.add_argument('--max-size', type=int, help='Maximum file size in bytes')
    scan_parser.add_argument('--exclude', action='append', help='Exclusion patterns')
    scan_parser.add_argument('--json', action='store_true', help='Output as JSON')
    
    # Clean command
    clean_parser = subparsers.add_parser('clean', help='Remove duplicate files')
    clean_parser.add_argument('directories', nargs='+', help='Directories to clean')
    clean_parser.add_argument('--method', choices=['content', 'size', 'name'], default='content',
                            help='Detection method (default: content)')
    clean_parser.add_argument('--keep', choices=['newest', 'oldest', 'largest', 'smallest', 'first-found'],
                            default='newest', help='Which file to keep (default: newest)')
    clean_parser.add_argument('--action', choices=['trash', 'archive'], default='trash',
                            help='Action to take (default: trash)')
    clean_parser.add_argument('--archive-path', help='Archive directory (required for archive action)')
    clean_parser.add_argument('--min-size', type=int, default=0, help='Minimum file size in bytes')
    clean_parser.add_argument('--max-size', type=int, help='Maximum file size in bytes')
    clean_parser.add_argument('--exclude', action='append', help='Exclusion patterns')
    clean_parser.add_argument('--confirm', action='store_true', 
                            help='Actually remove files (default is dry-run)')
    
    # Report command
    report_parser = subparsers.add_parser('report', help='Display a previous scan result')
    report_parser.add_argument('report_file', help='Path to scan result JSON file')
    
    args = parser.parse_args()
    
    if not args.command:
        parser.print_help()
        sys.exit(1)
    
    if args.command == 'scan':
        cmd_scan(args)
    elif args.command == 'clean':
        if args.action == 'archive' and not args.archive_path:
            print("❌ --archive-path is required when using --action archive", file=sys.stderr)
            sys.exit(1)
        cmd_clean(args)
    elif args.command == 'report':
        cmd_report(args)

if __name__ == '__main__':
    main()
