"""
MongoDB Service Layer

This module provides all MongoDB database access for the IGF Admin application.
All database operations must go through this service to maintain separation of concerns.

Key principles:
- Single source of truth: MongoDB database
- No hardcoded collection names or database names
- All configuration comes from app_config.yaml
- Read/write operations are explicit and controlled
"""

import os
from typing import List, Dict, Any, Optional
from pymongo import MongoClient
from pymongo.database import Database
from pymongo.collection import Collection
import yaml


class MongoService:
    """
    MongoDB service for managing database connections and operations.
    
    This service loads configuration from app_config.yaml and provides
    safe, controlled access to MongoDB collections.
    """
    
    def __init__(self, config_path: str = "config/app_config.yaml"):
        """
        Initialize MongoDB service with configuration.
        
        Args:
            config_path: Path to YAML configuration file
        """
        self.config = self._load_config(config_path)
        self.client: Optional[MongoClient] = None
        self.db: Optional[Database] = None
        self._connection_error: Optional[str] = None
        
    def _load_config(self, config_path: str) -> Dict[str, Any]:
        """
        Load configuration from YAML file.
        
        Args:
            config_path: Path to configuration file
            
        Returns:
            Configuration dictionary
        """
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        # Replace environment variables in config
        uri_env = config.get('mongodb', {}).get('uri_env', 'MONGO_URI')
        env_uri = os.getenv(uri_env)
        
        if env_uri:
            config['mongodb']['uri'] = env_uri
        else:
            # Fallback to default if not set
            config['mongodb']['uri'] = "mongodb://AdminUser:WnE0TjGIafVJutreZL7cTIoWNWU52YyxLVgwCUqIGCGg3YUT@18.184.249.241:27017/?authMechanism=DEFAULT"
        
        return config
    
    def connect(self) -> bool:
        """
        Establish connection to MongoDB.
        
        Returns:
            True if connection successful, False otherwise
        """
        try:
            mongo_uri = self.config['mongodb']['uri']
            database_name = self.config['mongodb']['database']
            connection_timeout = self.config['mongodb'].get('connection_timeout', 5000)
            server_selection_timeout = self.config['mongodb'].get('server_selection_timeout', 5000)
            
            # Create MongoDB client with timeout from config
            self.client = MongoClient(
                mongo_uri,
                serverSelectionTimeoutMS=server_selection_timeout,
                connectTimeoutMS=connection_timeout
            )
            
            # Test connection
            self.client.admin.command('ping')
            
            # Get database
            self.db = self.client[database_name]
            
            self._connection_error = None
            return True
            
        except Exception as e:
            self._connection_error = str(e)
            return False
    
    def is_connected(self) -> bool:
        """
        Check if MongoDB connection is active.
        
        Returns:
            True if connected, False otherwise
        """
        if self.client is None or self.db is None:
            return False
        
        try:
            self.client.admin.command('ping')
            return True
        except:
            return False
    
    def get_connection_error(self) -> Optional[str]:
        """
        Get the last connection error message.
        
        Returns:
            Error message or None if no error
        """
        return self._connection_error
    
    def get_database_name(self) -> str:
        """
        Get the configured database name.
        
        Returns:
            Database name from configuration
        """
        return self.config['mongodb']['database']
    
    def list_collections(self) -> List[str]:
        """
        Get list of enabled collections from configuration.
        
        Returns:
            List of collection names that are enabled
        """
        collections_config = self.config.get('collections', {})
        return [name for name, config in collections_config.items() if config.get('enabled', True)]
    
    def get_collection(self, name: str) -> Optional[Collection]:
        """
        Get a MongoDB collection by name.
        
        Args:
            name: Collection name
            
        Returns:
            Collection object or None if not connected
        """
        if self.db is None:
            return None
        
        return self.db[name]
    
    def count_documents(self, collection_name: str) -> int:
        """
        Count total documents in a collection.
        
        Args:
            collection_name: Name of the collection
            
        Returns:
            Number of documents, or 0 if error
        """
        try:
            collection = self.get_collection(collection_name)
            if collection is None:
                return 0
            return collection.count_documents({})
        except Exception:
            return 0
    
    def sample_documents(
        self, 
        collection_name: str, 
        limit: int = 50,
        skip: int = 0
    ) -> List[Dict[str, Any]]:
        """
        Retrieve sample documents from a collection.
        
        Args:
            collection_name: Name of the collection
            limit: Maximum number of documents to return
            skip: Number of documents to skip
            
        Returns:
            List of documents as dictionaries
        """
        try:
            collection = self.get_collection(collection_name)
            if collection is None:
                return []
            
            # Fetch documents with limit and skip
            cursor = collection.find({}).skip(skip).limit(limit)
            documents = list(cursor)
            
            return documents
            
        except Exception as e:
            print(f"Error sampling documents from {collection_name}: {e}")
            return []
    
    def get_collection_stats(self, collection_name: str) -> Dict[str, Any]:
        """
        Get statistics for a collection.
        
        Args:
            collection_name: Name of the collection
            
        Returns:
            Dictionary with collection statistics
        """
        try:
            collection = self.get_collection(collection_name)
            if collection is None:
                return {}
            
            stats = self.db.command("collStats", collection_name)
            
            return {
                'count': stats.get('count', 0),
                'size': stats.get('size', 0),
                'avgObjSize': stats.get('avgObjSize', 0),
                'storageSize': stats.get('storageSize', 0),
                'indexes': stats.get('nindexes', 0)
            }
            
        except Exception:
            return {
                'count': self.count_documents(collection_name),
                'size': 0,
                'avgObjSize': 0,
                'storageSize': 0,
                'indexes': 0
            }
    
    def list_all_databases(self) -> List[str]:
        """
        List all databases on the MongoDB server.
        
        Returns:
            List of database names
        """
        try:
            if not self.is_connected():
                return []
            
            # Get list of all databases
            db_list = self.client.list_database_names()
            
            # Filter out system databases
            return [db for db in db_list if db not in ['admin', 'local', 'config']]
            
        except Exception as e:
            print(f"Error listing databases: {e}")
            return []
    
    def list_collections_in_database(self, database_name: str) -> List[str]:
        """
        List all collections in a specific database.
        
        Args:
            database_name: Name of the database
            
        Returns:
            List of collection names
        """
        try:
            if not self.is_connected():
                return []
            
            db = self.client[database_name]
            return db.list_collection_names()
            
        except Exception as e:
            print(f"Error listing collections in {database_name}: {e}")
            return []
    
    def export_collection_data(
        self,
        database_name: str,
        collection_name: str,
        limit: Optional[int] = None
    ) -> List[Dict[str, Any]]:
        """
        Export all documents from a collection.
        
        Args:
            database_name: Name of the database
            collection_name: Name of the collection
            limit: Optional limit on number of documents
            
        Returns:
            List of documents
        """
        try:
            if not self.is_connected():
                return []
            
            db = self.client[database_name]
            collection = db[collection_name]
            
            if limit:
                cursor = collection.find({}).limit(limit)
            else:
                cursor = collection.find({})
            
            return list(cursor)
            
        except Exception as e:
            print(f"Error exporting {database_name}.{collection_name}: {e}")
            return []
    
    def get_database_stats(self, database_name: str) -> Dict[str, Any]:
        """
        Get statistics for a database.
        
        Args:
            database_name: Name of the database
            
        Returns:
            Dictionary with database statistics
        """
        try:
            if not self.is_connected():
                return {}
            
            db = self.client[database_name]
            stats = db.command("dbStats")
            
            return {
                'collections': stats.get('collections', 0),
                'dataSize': stats.get('dataSize', 0),
                'storageSize': stats.get('storageSize', 0),
                'indexes': stats.get('indexes', 0),
                'objects': stats.get('objects', 0)
            }
            
        except Exception as e:
            print(f"Error getting stats for {database_name}: {e}")
            return {}
    
    def get_distinct_values(self, collection_name: str, field_name: str) -> List[Any]:
        """
        Get distinct values for a field in a collection.
        
        Args:
            collection_name: Name of the collection
            field_name: Name of the field
        
        Returns:
            List of distinct values
        """
        try:
            collection = self.db[collection_name]
            return collection.distinct(field_name)
        except Exception as e:
            print(f"Error getting distinct values from {collection_name}.{field_name}: {e}")
            return []
    
    def list_indexes(self, collection_name: str) -> List[Dict[str, Any]]:
        """
        List all indexes for a collection.
        
        Args:
            collection_name: Name of the collection
            
        Returns:
            List of index information dictionaries
        """
        try:
            collection = self.db[collection_name]
            indexes = list(collection.list_indexes())
            return indexes
        except Exception as e:
            print(f"Error listing indexes for {collection_name}: {e}")
            return []
    
    def drop_index(self, collection_name: str, index_name: str) -> bool:
        """
        Drop an index from a collection.
        
        Args:
            collection_name: Name of the collection
            index_name: Name of the index to drop
            
        Returns:
            True if successful, False otherwise
        """
        try:
            collection = self.db[collection_name]
            collection.drop_index(index_name)
            print(f"Successfully dropped index '{index_name}' from {collection_name}")
            return True
        except Exception as e:
            print(f"Error dropping index '{index_name}' from {collection_name}: {e}")
            return False
    
    def create_partial_unique_index(self, collection_name: str, field_name: str) -> bool:
        """
        Create a partial unique index that only applies to non-null values.
        This allows multiple null values while maintaining uniqueness for actual values.
        
        Args:
            collection_name: Name of the collection
            field_name: Name of the field to index
            
        Returns:
            True if successful, False otherwise
        """
        try:
            collection = self.db[collection_name]
            # Create partial index that only indexes documents where field exists and is not null
            # Using $exists and $type to properly filter out null/missing values
            collection.create_index(
                [(field_name, 1)],
                unique=True,
                partialFilterExpression={
                    field_name: {
                        "$exists": True,
                        "$type": "string"  # Only index string values (not null, not missing)
                    }
                },
                name=f"{field_name}_1_partial"
            )
            print(f"Successfully created partial unique index on {collection_name}.{field_name}")
            return True
        except Exception as e:
            print(f"Error creating partial unique index on {collection_name}.{field_name}: {e}")
            return False
    
    def close(self):
        """
        Close MongoDB connection.
        """
        if self.client is not None:
            self.client.close()
            self.client = None
            self.db = None
