"""
SQreamDB MCP Server
A Model Context Protocol server for SQreamDB operations
"""
import sys
import asyncio
import logging
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
import pysqream
import aiohttp
from mcp.server import Server
from mcp.server.models import InitializationOptions
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
import os
import re
import glob
import math
from collections import Counter

# --- SETUP PROPER LOGGING ---
log_file_path = os.path.join(os.path.dirname(__file__), 'sqreamdb_mcp.log')
logging.FileHandler(log_file_path, mode='w'),
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file_path, mode='w'),
        logging.StreamHandler(sys.stderr)
    ]
)
logger = logging.getLogger("sqreamdb-mcp-server")


class DocumentationIndex:
    """Lightweight in-memory documentation index for local HTML / MD / TXT files.

    Features:
    - Recursive directory scan
    - Simple whitespace + symbol tokenization
    - Chunking large files into overlapping windows
    - Term-frequency ranking (cosine similarity over term counts)
    - Returns top-k relevant chunks
    """

    SUPPORTED_EXT = {'.html', '.htm', '.md', '.txt'}

    def __init__(self, root_dir: str, chunk_size: int = 1400, overlap: int = 200):
        self.root_dir = root_dir
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.chunks: List[Tuple[str, str]] = []  # (source_path, text)
        self._token_cache: List[Counter] = []
        self._loaded = False

    def load(self) -> int:
        # Accept both directory and single-file usage
        if os.path.isfile(self.root_dir):
            ext = os.path.splitext(self.root_dir)[1].lower()
            if ext in self.SUPPORTED_EXT:
                try:
                    with open(self.root_dir, 'r', encoding='utf-8', errors='ignore') as f:
                        raw = f.read()
                    text = self._strip_markup(raw)
                    self._chunk_file(self.root_dir, text)
                    self._loaded = True
                    return len(self.chunks)
                except Exception:
                    logging.exception(f"Failed loading doc file: {self.root_dir}")
                    return 0
            return 0
        if not os.path.isdir(self.root_dir):
            return 0
        pattern = os.path.join(self.root_dir, '**', '*')
        files = [p for p in glob.glob(pattern, recursive=True) if os.path.isfile(p) and os.path.splitext(p)[1].lower() in self.SUPPORTED_EXT]
        for path in files:
            try:
                with open(path, 'r', encoding='utf-8', errors='ignore') as f:
                    raw = f.read()
                text = self._strip_markup(raw)
                self._chunk_file(path, text)
            except Exception:
                logging.exception(f"Failed loading doc file: {path}")
        self._loaded = True
        return len(self.chunks)

    def _strip_markup(self, text: str) -> str:
        # Naive removal of HTML tags; keep content
        text = re.sub(r'<script[\s\S]*?</script>', ' ', text, flags=re.IGNORECASE)
        text = re.sub(r'<style[\s\S]*?</style>', ' ', text, flags=re.IGNORECASE)
        text = re.sub(r'<[^>]+>', ' ', text)
        return re.sub(r'\s+', ' ', text).strip()

    def _tokenize(self, text: str) -> Counter:
        tokens = re.findall(r'[A-Za-z0-9_\.]+', text.lower())
        return Counter(tokens)

    def _chunk_file(self, path: str, text: str):
        if len(text) <= self.chunk_size:
            self.chunks.append((path, text))
            self._token_cache.append(self._tokenize(text))
            return
        start = 0
        while start < len(text):
            end = start + self.chunk_size
            chunk = text[start:end]
            self.chunks.append((path, chunk))
            self._token_cache.append(self._tokenize(chunk))
            if end >= len(text):
                break
            start = end - self.overlap

    def is_loaded(self) -> bool:
        return self._loaded

    def search(self, query: str, k: int = 5) -> List[Tuple[str, str, float]]:
        if not self._loaded or not self.chunks:
            return []
        q_tokens = self._tokenize(query)
        if not q_tokens:
            return []
        q_norm = math.sqrt(sum(v * v for v in q_tokens.values())) or 1.0
        scored: List[Tuple[str, str, float]] = []
        for (path, text), doc_tokens in zip(self.chunks, self._token_cache):
            # dot product
            dot = sum(q_tokens[t] * doc_tokens.get(t, 0) for t in q_tokens)
            if dot == 0:
                continue
            d_norm = math.sqrt(sum(v * v for v in doc_tokens.values())) or 1.0
            score = dot / (q_norm * d_norm)
            scored.append((path, text[:1000], score))
        scored.sort(key=lambda x: x[2], reverse=True)
        return scored[:k]


class SQreamDBMCPServer:
    
    # SQreamDB documentation URL - always the same
    SQREAM_DOCS_URL = "https://docs.sqream.com/en/latest/"
    
    def __init__(self, connection_string: str):

        try:
            self.conn_params = {}
            for part in connection_string.split():
                key, value = part.split('=', 1)

                if key.lower() == 'port':
                    self.conn_params[key] = int(value)

                elif key.lower() == 'clustered':
                    if value.lower() in ('True','true','1'):
                        self.conn_params[key] = True
                    else:
                        self.conn_params[key] = False
                else:
                    self.conn_params[key] = value
            
            # Log connection parameters with password masked for security
            safe_params = self.conn_params.copy()
            if 'password' in safe_params:
                safe_params['password'] = '*' * len(str(safe_params['password']))
            logger.info(f"Parsed connection parameters: {safe_params}")
        except ValueError as e:
            logger.critical(f"Could not parse connection string. Ensure it is in 'key=value' format. Error: {e}")
            raise
            
        self.connection = None
        self.documentation = ""  # (legacy single blob)
        self.documentation_loaded = False
        # New: structured index (optional)
        docs_location = os.environ.get('SQREAM_DOCS_DIR')
        if not docs_location:
            # Default to single HTML file next to this script
            potential = os.path.join(os.path.dirname(__file__), 'sqreamdb_complete_documentation.html')
            if os.path.exists(potential):
                docs_location = potential
        self.doc_index: Optional[DocumentationIndex] = None
        if docs_location:
            try:
                self.doc_index = DocumentationIndex(docs_location)
                loaded = self.doc_index.load()
                logger.info(f"Documentation index initialized from '{docs_location}' with {loaded} chunks")
            except Exception:
                logger.exception("Failed to build documentation index; continuing without it.")
        self.server = Server("sqreamdb-mcp-server")
        self.setup_handlers()

    async def _load_documentation(self):
        """Load documentation from SQreamDB docs URL."""
        if self.documentation_loaded:
            return  # Already loaded
            
        try:
            logger.info(f"Fetching SQreamDB documentation from: {self.SQREAM_DOCS_URL}")
            
            async with aiohttp.ClientSession() as session:
                async with session.get(self.SQREAM_DOCS_URL, timeout=30) as response:
                    if response.status == 200:
                        content = await response.text()
                        self.documentation = content
                        self.documentation_loaded = True
                        logger.info(f"Successfully loaded SQreamDB documentation ({len(content)} characters)")
                    else:
                        logger.error(f"Failed to fetch documentation: HTTP {response.status}")
                        raise Exception(f"HTTP {response.status} when fetching documentation")
                        
        except asyncio.TimeoutError:
            logger.error(f"Timeout while fetching documentation from {self.SQREAM_DOCS_URL}")
            raise Exception(f"Timeout while fetching documentation from {self.SQREAM_DOCS_URL}")
        except Exception as e:
            logger.error(f"Error loading documentation: {e}")
            raise

    def setup_handlers(self):
        logger.info("Setting up MCP request handlers.")
        
        @self.server.list_tools()
        async def handle_list_tools() -> List[Tool]:
            logger.info("Request received: list_tools")
            return [
                Tool(
                    name="execute_query",
                    description="Execute a SQL statement against SQreamDB. If documentation was previously primed it will be used for context.",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "query": {"type": "string", "description": "SQL statement to execute"}
                        },
                        "required": ["query"]
                    }
                ),
                Tool(
                    name="get_documentation",
                    description="Search local/indexed SQreamDB documentation. Provide an optional 'query' to retrieve the most relevant snippets.",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "query": {"type": "string", "description": "Search query terms"},
                            "k": {"type": "number", "description": "Max number of snippets", "default": 5}
                        }
                    }
                ),
                Tool(
                    name="prime_documentation",
                    description="Store relevant documentation context for subsequent execute_query calls. Provide either a search 'query' or explicit 'text'.",
                    inputSchema={
                        "type": "object",
                        "properties": {
                            "query": {"type": "string", "description": "Search then store top snippet"},
                            "text": {"type": "string", "description": "Directly provide documentation text to prime"}
                        }
                    }
                ),
            ]

        @self.server.call_tool()
        async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
            logger.info(f"Request received: call_tool '{name}' with args: {arguments}")
            tool_functions = {
                "execute_query": self._execute_query,
                "get_documentation": self._get_documentation,
                "prime_documentation": self._prime_documentation,
            }
            if name in tool_functions:
                return await tool_functions[name](arguments or {})
            logger.warning(f"Unknown tool name received: {name}")
            return [TextContent(type="text", text=f"Unknown tool: {name}")]

    async def _get_documentation(self, arguments: Dict[str, Any]) -> List[TextContent]:
        """Search or return instructions for using documentation.

        If a 'query' is provided and local index exists, return top-k snippets.
        Otherwise if no index, lazily fetch remote homepage (legacy fallback).
        """
        query = arguments.get('query') if arguments else None
        k = int(arguments.get('k', 5)) if arguments else 5
        # Prefer local index
        if query and self.doc_index and self.doc_index.is_loaded():
            results = self.doc_index.search(query, k=k)
            if not results:
                return [TextContent(type="text", text=f"No documentation snippets matched query: {query}")]
            assembled = [
                "DOCUMENTATION SEARCH RESULTS (top {}):".format(len(results))
            ]
            for i, (path, text, score) in enumerate(results, 1):
                assembled.append(f"[{i}] {os.path.relpath(path, self.doc_index.root_dir)} (score={score:.3f})\n{text.strip()}")
            return [TextContent(type="text", text="\n\n".join(assembled))]

        # If no query but we have index, give guidance
        if self.doc_index and self.doc_index.is_loaded() and not query:
            guidance = (
                "Local documentation index loaded. Use get_documentation with a 'query' property to search, e.g.\n"\
                "{\n  'name': 'get_documentation', 'arguments': { 'query': 'create table compression' }\n}.\n"\
                "To prime context for future queries, call prime_documentation with the same query."
            )
            return [TextContent(type="text", text=guidance)]

        # Fallback: remote fetch as before
        try:
            if not self.documentation_loaded:
                await self._load_documentation()
            if query:
                # naive substring filter fallback
                lowered = self.documentation.lower()
                q = query.lower()
                idx = lowered.find(q)
                if idx == -1:
                    return [TextContent(type="text", text=f"Query '{query}' not found in remote homepage content.")]
                window = 900
                start = max(0, idx - 200)
                snippet = self.documentation[start:start+window]
                return [TextContent(type="text", text=f"REMOTE HOMEPAGE SNIPPET for query '{query}':\n{snippet}")]
            return [TextContent(type="text", text=self.documentation)]
        except Exception as e:
            error_msg = f"Unable to fetch SQreamDB documentation remotely: {e}"
            logger.error(error_msg)
            return [TextContent(type="text", text=error_msg)]

    async def _prime_documentation(self, arguments: Dict[str, Any]) -> List[TextContent]:
        """Store chosen documentation snippet(s) for later contextual use."""
        self._primed_docs: List[str] = getattr(self, '_primed_docs', [])
        provided_text = arguments.get('text')
        query = arguments.get('query')
        stored = []
        if provided_text:
            self._primed_docs = [provided_text]
            stored.append('direct-text')
        elif query and self.doc_index and self.doc_index.is_loaded():
            hits = self.doc_index.search(query, k=1)
            if hits:
                self._primed_docs = [hits[0][1]]
                stored.append('index-hit')
        elif query:
            # fallback try remote
            if not self.documentation_loaded:
                try:
                    await self._load_documentation()
                except Exception:
                    pass
            if self.documentation:
                lowered = self.documentation.lower()
                idx = lowered.find(query.lower())
                if idx != -1:
                    start = max(0, idx - 200)
                    snippet = self.documentation[start:start+900]
                    self._primed_docs = [snippet]
                    stored.append('remote-snippet')
        if not stored:
            return [TextContent(type="text", text="No documentation could be primed (no matches or empty input)." )]
        return [TextContent(type="text", text=f"Primed documentation context using: {', '.join(stored)}.")]

    async def _get_connection(self):
        logger.info("Checking for database connection...")
        if self.connection is None or getattr(self.connection, 'closed', True):
            logger.info("Connection is missing or closed. Creating a new one...")
            try:
                # --- THIS IS THE FIX: USE THE PARSED PARAMETERS ---
                # The ** operator unpacks the dictionary into keyword arguments
                # e.g., pysqream.connect(host='...', port=5000, ...)
                self.connection = await asyncio.to_thread(
                    pysqream.connect, **self.conn_params
                )
                logger.info("Database connection established successfully.")
            except Exception as e:
                logger.error(f"Failed to connect to SQreamDB!", exc_info=True)
                # Re-raise the exception so the tool call fails with a clear message
                raise e
        return self.connection

    async def _execute_query(self, arguments: Dict[str, Any]) -> List[TextContent]:
        query = arguments["query"]
        logger.info(f"Executing query: {query}")
        
        conn = await self._get_connection()
        with conn.cursor() as cursor:
            cursor.execute(query)
            if cursor.description is None: # For statements that don't return rows (e.g., INSERT, UPDATE)
                primed = getattr(self, '_primed_docs', [])
                header = "" if not primed else "USING PRIMED DOCUMENTATION CONTEXT ({} snippet)\n---\n{}\n---\n".format(len(primed), primed[0][:500])
                return [TextContent(type="text", text=header + f"Statement executed successfully. {cursor.rowcount} rows affected.")]

            columns = [desc[0] for desc in cursor.description]
            rows = cursor.fetchall()
            
            primed = getattr(self, '_primed_docs', [])
            if primed:
                primed_header = "USING PRIMED DOCUMENTATION CONTEXT ({} snippet)\n---\n{}\n---\n".format(len(primed), primed[0][:500])
            else:
                primed_header = ""
            result = primed_header + f"Query returned {len(rows)} rows.\n\n"
            result += " | ".join(columns) + "\n"
            result += "-" * (len(" | ".join(columns))) + "\n"
            for row in rows:
                result += " | ".join(str(val) for val in row) + "\n"
            
            return [TextContent(type="text", text=result)]

    async def run(self):
        """Run the MCP server's main loop."""
        logger.info("Starting server run loop...")
        async with stdio_server() as (read_stream, write_stream):
            dummy_notification_options = SimpleNamespace(tools_changed=None)
            await self.server.run(
                read_stream,
                write_stream,
                InitializationOptions(
                    server_name="sqreamdb-mcp-server",
                    server_version="1.0.0",
                    capabilities=self.server.get_capabilities(
                        notification_options=dummy_notification_options,
                        experimental_capabilities=None
                    ),
                ),
            )
        logger.info("Server run loop finished.")


def main():
    """Main entry point for the script."""
    if len(sys.argv) < 2:
        print("Usage: sqreamdb_mcp_server.py <connection_string>", file=sys.stderr)
        sys.exit(1)
        
    connection_string = sys.argv[1]
    server = SQreamDBMCPServer(connection_string)
    
    try:
        asyncio.run(server.run())
    except KeyboardInterrupt:
        logger.info("Server shut down by user.")
    except Exception:
        logger.critical("A fatal error caused the server to exit.", exc_info=True)
        sys.exit(1)

if __name__ == "__main__":
    main()
