"""
Annotation Server - Updates database with annotation scores
Run: python annotation_server.py
Then open: http://localhost:8001/annotation_tool.html
"""

import http.server
import json
import sqlite3
import csv
import os
from pathlib import Path
from datetime import datetime

PORT = 8001
SCRIPT_DIR = Path(__file__).parent.absolute()
DB_PATH = SCRIPT_DIR.parent / "epstein_analysis.db"
CSV_DIR = SCRIPT_DIR / "annotations_csv"


def ensure_csv_dir():
    """Create the CSV directory if it doesn't exist."""
    CSV_DIR.mkdir(exist_ok=True)


def get_csv_path(run_id):
    """Get the CSV file path for a given run."""
    return CSV_DIR / f"annotations_{run_id}.csv"


def save_annotation_to_csv(run_id, thread_id, annotation, note=None, sample_data=None):
    """Save a single annotation to CSV file, replacing if exists."""
    ensure_csv_dir()
    csv_path = get_csv_path(run_id)

    timestamp = datetime.now().isoformat()

    # Read existing data
    existing_rows = []
    headers = ['timestamp', 'run_id', 'thread_id', 'annotation', 'note']
    if sample_data:
        headers.extend(sample_data.keys())

    if csv_path.exists():
        with open(csv_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            if reader.fieldnames:
                headers = reader.fieldnames
                # Add note column if not present
                if 'note' not in headers:
                    headers = list(headers)
                    headers.insert(4, 'note')
            for row in reader:
                # Skip the row we're replacing
                if row.get('thread_id') != str(thread_id):
                    existing_rows.append(row)

    # Write all data back with the new/updated annotation
    with open(csv_path, 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=headers)
        writer.writeheader()

        # Write existing rows
        for row in existing_rows:
            writer.writerow(row)

        # Write the new/updated annotation
        new_row = {
            'timestamp': timestamp,
            'run_id': run_id,
            'thread_id': thread_id,
            'annotation': annotation,
            'note': note or ''
        }
        if sample_data:
            new_row.update(sample_data)
        writer.writerow(new_row)

    return csv_path


def load_existing_annotations(run_id):
    """Load existing annotations and notes from CSV for a given run."""
    csv_path = get_csv_path(run_id)
    annotations = {}
    notes = {}

    if csv_path.exists():
        with open(csv_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                thread_id = row.get('thread_id')
                if thread_id:
                    annotations[thread_id] = row.get('annotation')
                    notes[thread_id] = row.get('note', '')

    return annotations, notes


class AnnotationHandler(http.server.SimpleHTTPRequestHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, directory=str(SCRIPT_DIR), **kwargs)

    def translate_path(self, path):
        # Intercept database requests to serve from parent directory
        if path == "/epstein_analysis.db" or path.startswith("/epstein_analysis.db?"):
            return str(DB_PATH)
        return super().translate_path(path)

    def do_POST(self):
        if self.path == "/save_annotation":
            # Save single annotation incrementally
            content_length = int(self.headers["Content-Length"])
            post_data = self.rfile.read(content_length)
            data = json.loads(post_data.decode("utf-8"))

            run_id = data.get("run_id")
            thread_id = data.get("thread_id")
            annotation = data.get("annotation")
            note = data.get("note", "")
            sample_data = data.get("sample_data", {})

            csv_path = save_annotation_to_csv(run_id, thread_id, annotation, note, sample_data)

            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.send_header("Access-Control-Allow-Origin", "*")
            self.end_headers()
            self.wfile.write(json.dumps({
                "success": True,
                "run_id": run_id,
                "thread_id": thread_id,
                "annotation": annotation,
                "csv_file": str(csv_path.name),
            }).encode())

        elif self.path == "/load_annotations":
            # Load existing annotations for a run
            content_length = int(self.headers["Content-Length"])
            post_data = self.rfile.read(content_length)
            data = json.loads(post_data.decode("utf-8"))

            run_id = data.get("run_id")
            annotations, notes = load_existing_annotations(run_id)

            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.send_header("Access-Control-Allow-Origin", "*")
            self.end_headers()
            self.wfile.write(json.dumps({
                "success": True,
                "run_id": run_id,
                "annotations": annotations,
                "notes": notes,
            }).encode())

        elif self.path == "/save_annotations":
            content_length = int(self.headers["Content-Length"])
            post_data = self.rfile.read(content_length)
            data = json.loads(post_data.decode("utf-8"))

            run_id = data.get("run_id")
            score = data.get("score")
            annotations = data.get("annotations", [])

            conn = sqlite3.connect(DB_PATH)
            cursor = conn.cursor()

            # Create annotations table if not exists
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS annotation_results (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    run_id TEXT,
                    thread_id TEXT,
                    annotation TEXT,
                    annotated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """)

            # Update score in ai_classification_runs
            cursor.execute(
                "UPDATE ai_classification_runs SET score = ? WHERE run_id = ?",
                (score, run_id),
            )

            # Save individual annotations
            for ann in annotations:
                cursor.execute("""
                    INSERT INTO annotation_results (run_id, thread_id, annotation)
                    VALUES (?, ?, ?)
                """, (run_id, ann.get("thread_id"), ann.get("annotation")))

            conn.commit()
            conn.close()

            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.send_header("Access-Control-Allow-Origin", "*")
            self.end_headers()
            self.wfile.write(json.dumps({
                "success": True,
                "run_id": run_id,
                "score": score,
                "annotations_saved": len(annotations),
            }).encode())
        else:
            self.send_response(404)
            self.end_headers()

    def do_OPTIONS(self):
        self.send_response(200)
        self.send_header("Access-Control-Allow-Origin", "*")
        self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
        self.send_header("Access-Control-Allow-Headers", "Content-Type")
        self.end_headers()


if __name__ == "__main__":
    print(f"Starting annotation server on http://localhost:{PORT}")
    print(f"Database: {DB_PATH}")
    print(f"Open: http://localhost:{PORT}/annotation_tool.html")
    http.server.HTTPServer(("", PORT), AnnotationHandler).serve_forever()
