#!/usr/bin/env python3
"""
Rolling Conversation Buffer for OpenClaw
Maintains recent conversation context across restarts and model switches.

Usage:
  ./conversation-buffer.py append --role user --model opus --message "..." [--important]
  ./conversation-buffer.py read [--hours 2] [--limit 50]
  ./conversation-buffer.py search "query"
  ./conversation-buffer.py prune [--hours 2] [--limit 50] [--model opus]
  ./conversation-buffer.py summary
  ./conversation-buffer.py startup  # Formatted context for session start
  ./conversation-buffer.py promote  # Find entries worth adding to MEMORY.md
"""

import json
import sys
import os
import re
from datetime import datetime, timedelta, timezone
from pathlib import Path
import argparse

BUFFER_PATH = Path.home() / ".openclaw/workspace/memory/recent-conversation.jsonl"
DEFAULT_HOURS = 2
DEFAULT_LIMIT = 50

# Model tiers - future-proof with pattern matching
TIER_1_PATTERNS = ["opus", "claude-opus"]  # Best brain - can always prune
TIER_2_PATTERNS = ["sonnet", "claude-sonnet"]  # Second best - can prune
# Everything else (haiku, gpt, local, etc.) = Tier 3+ = read-only, no prune

# Keywords that suggest importance
IMPORTANCE_KEYWORDS = [
    "decision", "decided", "approved", "rejected", "confirmed",
    "config", "setting", "changed", "updated", "fixed", "implemented",
    "remember", "important", "critical", "urgent", "priority",
    "password", "key", "token", "secret", "credential",
    "error", "bug", "issue", "problem", "broken",
    "schedule", "reminder", "deadline", "appointment",
    "todo", "action item", "next step", "follow up"
]

def now_utc():
    """Get current UTC time."""
    return datetime.now(timezone.utc)

def get_model_tier(model):
    """Determine model tier. Lower = more capable."""
    if not model:
        return 3
    model_lower = model.lower()
    for pattern in TIER_1_PATTERNS:
        if pattern in model_lower:
            return 1
    for pattern in TIER_2_PATTERNS:
        if pattern in model_lower:
            return 2
    return 3

def can_prune(model):
    """Only Tier 1 and Tier 2 models can prune."""
    return get_model_tier(model) <= 2

def is_primary_model(model):
    """Check if model is Tier 1 or 2."""
    return get_model_tier(model) <= 2

def ensure_buffer():
    """Ensure buffer file exists."""
    BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True)
    if not BUFFER_PATH.exists():
        BUFFER_PATH.touch()

def read_entries():
    """Read all entries from buffer."""
    ensure_buffer()
    entries = []
    with open(BUFFER_PATH, 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    entries.append(json.loads(line))
                except json.JSONDecodeError:
                    continue
    return entries

def write_entries(entries):
    """Write entries back to buffer."""
    ensure_buffer()
    with open(BUFFER_PATH, 'w') as f:
        for entry in entries:
            f.write(json.dumps(entry) + '\n')

def compress_message(message, max_len=2000):
    """
    Compress long messages intelligently.
    Instead of hard truncation, summarize what was cut.
    """
    if len(message) <= max_len:
        return message
    
    # Keep first part and add summary of what was cut
    keep_len = max_len - 100  # Reserve space for summary
    kept = message[:keep_len]
    cut_len = len(message) - keep_len
    
    # Try to break at a sentence or word boundary
    last_period = kept.rfind('.')
    last_space = kept.rfind(' ')
    
    if last_period > keep_len - 200:
        kept = kept[:last_period + 1]
    elif last_space > keep_len - 50:
        kept = kept[:last_space]
    
    return f"{kept} [...{cut_len} more chars truncated]"

def detect_importance(message):
    """
    Detect if a message contains important content.
    Returns importance score (0-10) and detected keywords.
    """
    message_lower = message.lower()
    found_keywords = []
    
    for keyword in IMPORTANCE_KEYWORDS:
        if keyword in message_lower:
            found_keywords.append(keyword)
    
    # Score based on keyword count and type
    score = min(len(found_keywords) * 2, 10)
    
    return score, found_keywords

def append_entry(role, model, message, important=False):
    """Append a new entry to the buffer."""
    compressed = compress_message(message)
    importance_score, keywords = detect_importance(message)
    
    # Mark as important if flagged or auto-detected
    is_important = important or importance_score >= 4
    
    entry = {
        "ts": now_utc().isoformat(),
        "model": model,
        "role": role,
        "message": compressed,
        "important": is_important,
        "importance_score": importance_score,
        "keywords": keywords[:5] if keywords else []
    }
    
    ensure_buffer()
    with open(BUFFER_PATH, 'a') as f:
        f.write(json.dumps(entry) + '\n')
    return entry

def search_entries(query, case_sensitive=False):
    """Search entries for a query string."""
    entries = read_entries()
    results = []
    
    if not case_sensitive:
        query = query.lower()
    
    for e in entries:
        msg = e.get('message', '')
        if not case_sensitive:
            msg = msg.lower()
        
        if query in msg:
            results.append(e)
    
    return results

def prune_entries(hours=DEFAULT_HOURS, limit=DEFAULT_LIMIT, preserve_fallback=True, current_model=None):
    """Prune old entries, respecting tier permissions and importance."""
    entries = read_entries()
    if not entries:
        return [], "empty"
    
    # Check if current model can prune
    if current_model and not can_prune(current_model):
        return entries, f"blocked: {current_model} is Tier {get_model_tier(current_model)}, cannot prune"
    
    # Check fallback situation
    if preserve_fallback:
        has_fallback = any(not is_primary_model(e.get('model', '')) for e in entries)
        has_primary = any(is_primary_model(e.get('model', '')) for e in entries)
        if has_fallback and has_primary:
            if len(entries) <= limit * 2:
                return entries, "preserved: fallback entries present"
    
    # Time-based filtering
    cutoff = now_utc() - timedelta(hours=hours)
    time_filtered = []
    important_preserved = 0
    
    for e in entries:
        try:
            ts_str = e['ts'].replace('Z', '+00:00')
            if '+' not in ts_str and '-' not in ts_str[10:]:
                ts_str += '+00:00'
            ts = datetime.fromisoformat(ts_str)
            
            # Keep if recent OR important
            if ts >= cutoff:
                time_filtered.append(e)
            elif e.get('important', False):
                time_filtered.append(e)
                important_preserved += 1
        except (KeyError, ValueError):
            time_filtered.append(e)
    
    # Limit-based pruning (keep most recent, but always keep important)
    if len(time_filtered) > limit:
        # Separate important and regular entries
        important = [e for e in time_filtered if e.get('important', False)]
        regular = [e for e in time_filtered if not e.get('important', False)]
        
        # Keep all important + most recent regular up to limit
        keep_regular = limit - len(important)
        if keep_regular > 0:
            time_filtered = important + regular[-keep_regular:]
        else:
            time_filtered = important[-limit:]
    
    write_entries(time_filtered)
    pruned = len(entries) - len(time_filtered)
    status = f"pruned {pruned} entries"
    if important_preserved:
        status += f", preserved {important_preserved} important"
    return time_filtered, status

def get_summary():
    """Get buffer state summary."""
    entries = read_entries()
    if not entries:
        return {"total": 0, "models": {}, "oldest": None, "newest": None}
    
    models = {}
    tiers = {}
    important_count = 0
    
    for e in entries:
        m = e.get('model', 'unknown')
        models[m] = models.get(m, 0) + 1
        tier = get_model_tier(m)
        tier_name = f"tier_{tier}"
        tiers[tier_name] = tiers.get(tier_name, 0) + 1
        if e.get('important', False):
            important_count += 1
    
    return {
        "total": len(entries),
        "important": important_count,
        "models": models,
        "tiers": tiers,
        "oldest": entries[0].get('ts') if entries else None,
        "newest": entries[-1].get('ts') if entries else None,
        "has_fallback": any(not is_primary_model(e.get('model', '')) for e in entries)
    }

def format_for_context(entries, max_chars=8000):
    """Format entries for injection into context."""
    lines = []
    total_chars = 0
    
    for e in reversed(entries):
        role = e.get('role', '?')
        model = e.get('model', '?')
        msg = e.get('message', '')[:500]
        ts = e.get('ts', '')[:19]
        important = "⭐" if e.get('important', False) else ""
        
        model_tag = f" [{model}]" if not is_primary_model(model) else ""
        line = f"{important}[{ts}] {role}{model_tag}: {msg}"
        
        if total_chars + len(line) > max_chars:
            break
        lines.append(line)
        total_chars += len(line) + 1
    
    lines.reverse()
    return '\n'.join(lines)

def generate_startup_context():
    """Generate a concise startup context summary."""
    entries = read_entries()
    if not entries:
        return "No recent conversation history."
    
    # Filter to last 2 hours
    cutoff = now_utc() - timedelta(hours=2)
    recent = []
    for e in entries:
        try:
            ts_str = e['ts'].replace('Z', '+00:00')
            if '+' not in ts_str and '-' not in ts_str[10:]:
                ts_str += '+00:00'
            ts = datetime.fromisoformat(ts_str)
            if ts >= cutoff:
                recent.append(e)
        except:
            continue
    
    if not recent:
        return "No conversation in the last 2 hours."
    
    # Check for fallback entries
    fallback_entries = [e for e in recent if not is_primary_model(e.get('model', ''))]
    important_entries = [e for e in recent if e.get('important', False)]
    
    output = []
    output.append(f"📋 CONVERSATION BUFFER: {len(recent)} entries in last 2 hours")
    
    if fallback_entries:
        models = set(e.get('model', 'unknown') for e in fallback_entries)
        output.append(f"⚠️  FALLBACK DETECTED: {len(fallback_entries)} entries from {', '.join(models)}")
        output.append("   Review these for potential errors/corrections needed.")
    
    if important_entries:
        output.append(f"⭐ {len(important_entries)} important entries flagged")
    
    output.append("\n--- Recent Context ---")
    output.append(format_for_context(recent[-20:], max_chars=4000))
    
    return '\n'.join(output)

def generate_handoff_summary(current_channel):
    """Generate a summary of recent activity from OTHER channels for handoff."""
    entries = read_entries()
    if not entries:
        return "No recent conversation history."
    
    # Filter to last 4 hours
    cutoff = now_utc() - timedelta(hours=4)
    recent = []
    for e in entries:
        try:
            ts_str = e['ts'].replace('Z', '+00:00')
            if '+' not in ts_str and '-' not in ts_str[10:]:
                ts_str += '+00:00'
            ts = datetime.fromisoformat(ts_str)
            if ts >= cutoff:
                recent.append(e)
        except:
            continue
    
    if not recent:
        return "No conversation in the last 4 hours."
    
    # Try to identify channel from message patterns or metadata
    # For now, just show all recent activity for context
    output = []
    output.append(f"🔄 HANDOFF SUMMARY (last 4 hours, {len(recent)} entries)")
    output.append(f"   You're now on: {current_channel}")
    output.append("")
    
    # Get the last 15 entries for context
    context_entries = recent[-15:]
    
    # Find key topics/decisions
    important = [e for e in recent if e.get('important', False)]
    if important:
        output.append("⭐ Important items:")
        for e in important[-5:]:
            msg = e.get('message', '')[:100]
            output.append(f"   • {msg}")
        output.append("")
    
    output.append("--- Recent Activity ---")
    output.append(format_for_context(context_entries, max_chars=3000))
    
    return '\n'.join(output)

def find_promotable_entries():
    """Find entries that might be worth promoting to MEMORY.md."""
    entries = read_entries()
    promotable = []
    
    for e in entries:
        score = e.get('importance_score', 0)
        keywords = e.get('keywords', [])
        
        # High importance or specific keywords suggest promotion
        if score >= 6 or any(k in ['decision', 'remember', 'important', 'config', 'implemented'] for k in keywords):
            promotable.append({
                "ts": e.get('ts'),
                "role": e.get('role'),
                "message": e.get('message', '')[:200],
                "score": score,
                "keywords": keywords,
                "reason": "High importance score" if score >= 6 else f"Keywords: {', '.join(keywords)}"
            })
    
    return promotable

def main():
    parser = argparse.ArgumentParser(description='Rolling Conversation Buffer')
    subparsers = parser.add_subparsers(dest='command', help='Commands')
    
    # Append
    append_p = subparsers.add_parser('append', help='Append an entry')
    append_p.add_argument('--role', required=True, choices=['user', 'assistant'])
    append_p.add_argument('--model', required=True)
    append_p.add_argument('--message', required=True)
    append_p.add_argument('--important', action='store_true', help='Flag as important')
    
    # Read
    read_p = subparsers.add_parser('read', help='Read recent entries')
    read_p.add_argument('--hours', type=float, default=DEFAULT_HOURS)
    read_p.add_argument('--limit', type=int, default=DEFAULT_LIMIT)
    read_p.add_argument('--format', choices=['json', 'context'], default='context')
    
    # Search
    search_p = subparsers.add_parser('search', help='Search entries')
    search_p.add_argument('query', help='Search query')
    search_p.add_argument('--case-sensitive', action='store_true')
    
    # Prune
    prune_p = subparsers.add_parser('prune', help='Prune old entries')
    prune_p.add_argument('--hours', type=float, default=DEFAULT_HOURS)
    prune_p.add_argument('--limit', type=int, default=DEFAULT_LIMIT)
    prune_p.add_argument('--preserve-fallback', action='store_true', default=True)
    prune_p.add_argument('--force', action='store_true')
    prune_p.add_argument('--model', help='Current model for tier check')
    
    # Summary
    subparsers.add_parser('summary', help='Get buffer summary')
    
    # Startup
    subparsers.add_parser('startup', help='Generate startup context')
    
    # Promote
    subparsers.add_parser('promote', help='Find entries to promote to MEMORY.md')
    
    # Handoff
    handoff_p = subparsers.add_parser('handoff', help='Get handoff summary when switching channels')
    handoff_p.add_argument('--channel', default='unknown', help='Current channel (telegram/webchat)')
    
    args = parser.parse_args()
    
    if args.command == 'append':
        entry = append_entry(args.role, args.model, args.message, args.important)
        print(json.dumps(entry))
        
    elif args.command == 'read':
        entries = read_entries()
        cutoff = now_utc() - timedelta(hours=args.hours)
        filtered = []
        for e in entries:
            try:
                ts_str = e['ts'].replace('Z', '+00:00')
                if '+' not in ts_str and '-' not in ts_str[10:]:
                    ts_str += '+00:00'
                ts = datetime.fromisoformat(ts_str)
                if ts >= cutoff:
                    filtered.append(e)
            except:
                filtered.append(e)
        
        if len(filtered) > args.limit:
            filtered = filtered[-args.limit:]
        
        if args.format == 'json':
            print(json.dumps(filtered, indent=2))
        else:
            print(format_for_context(filtered))
            
    elif args.command == 'search':
        results = search_entries(args.query, args.case_sensitive)
        if results:
            print(f"Found {len(results)} matches:\n")
            print(format_for_context(results))
        else:
            print("No matches found.")
            
    elif args.command == 'prune':
        preserve = args.preserve_fallback and not args.force
        entries, status = prune_entries(args.hours, args.limit, preserve, args.model)
        print(f"Buffer: {len(entries)} entries ({status})")
        
    elif args.command == 'summary':
        print(json.dumps(get_summary(), indent=2))
        
    elif args.command == 'startup':
        print(generate_startup_context())
        
    elif args.command == 'promote':
        promotable = find_promotable_entries()
        if promotable:
            print(f"Found {len(promotable)} entries worth considering for MEMORY.md:\n")
            for p in promotable:
                print(f"[{p['ts'][:19]}] {p['role']}: {p['message']}")
                print(f"  → {p['reason']}\n")
        else:
            print("No entries flagged for promotion.")
    
    elif args.command == 'handoff':
        print(generate_handoff_summary(args.channel))
            
    else:
        parser.print_help()

if __name__ == '__main__':
    main()

# TONY-APPROVED: 2026-03-01 | sha:9d77731f
