"""
Data Service Layer

This module handles loading and saving tabular data from MongoDB collections.
It provides the bridge between MongoDB documents and pandas DataFrames for UI editing.

Key responsibilities:
- Load collection documents as normalized DataFrames
- Handle schema normalization (missing fields → None)
- Convert ObjectId to string for UI display
- Track changes between original and edited data
- Apply changes back to MongoDB with user confirmation
"""

import pandas as pd
from typing import Dict, List, Any, Tuple, Optional
from bson import ObjectId
from datetime import datetime
from services.json_flattener import JSONFlattener


class DataService:
    """
    Service for loading and saving collection data in tabular format.
    
    This service normalizes MongoDB documents into pandas DataFrames,
    tracks changes, and applies updates back to MongoDB.
    """
    
    def __init__(self, mongo_service):
        """
        Initialize data service with MongoDB connection.
        
        Args:
            mongo_service: MongoService instance
        """
        self.mongo = mongo_service
        self.flattener = JSONFlattener()
    
    def _normalize_document(self, doc: Dict[str, Any]) -> Dict[str, Any]:
        """
        Normalize a MongoDB document for DataFrame display.
        
        Converts ObjectId to string and handles nested structures.
        
        Args:
            doc: MongoDB document
            
        Returns:
            Normalized document with ObjectId as string
        """
        normalized = {}
        
        for key, value in doc.items():
            if isinstance(value, ObjectId):
                # Convert ObjectId to string for display
                normalized[key] = str(value)
            elif isinstance(value, datetime):
                # Convert datetime to ISO string
                normalized[key] = value.isoformat()
            elif isinstance(value, dict):
                # Keep nested dicts as-is (will be displayed as string in table)
                normalized[key] = value
            elif isinstance(value, list):
                # Keep lists as-is
                normalized[key] = value
            else:
                normalized[key] = value
        
        return normalized
    
    def _get_all_fields(self, documents: List[Dict[str, Any]]) -> List[str]:
        """
        Get union of all fields across documents.
        
        Args:
            documents: List of MongoDB documents
            
        Returns:
            Sorted list of unique field names
        """
        all_fields = set()
        
        for doc in documents:
            all_fields.update(doc.keys())
        
        # Sort fields with _id first
        fields = sorted(all_fields)
        
        # Move _id to front if present
        if '_id' in fields:
            fields.remove('_id')
            fields.insert(0, '_id')
        
        return fields
    
    def load_collection_as_dataframe(
        self,
        collection_name: str,
        limit: int = 500,
        skip: int = 0
    ) -> pd.DataFrame:
        """
        Load MongoDB collection as a normalized pandas DataFrame.
        
        All documents are normalized to have the same columns.
        Missing fields are filled with None.
        ObjectId fields are converted to strings.
        
        Args:
            collection_name: Name of the collection
            limit: Maximum number of documents to load
            skip: Number of documents to skip
            
        Returns:
            DataFrame with normalized documents
        """
        # Load documents from MongoDB
        documents = self.mongo.sample_documents(collection_name, limit, skip)
        
        if not documents:
            # Return empty DataFrame with just _id column
            return pd.DataFrame(columns=['_id'])
        
        # Normalize all documents
        normalized_docs = [self._normalize_document(doc) for doc in documents]
        
        # Get union of all fields
        all_fields = self._get_all_fields(normalized_docs)
        
        # Create DataFrame with all fields
        # Missing fields will be filled with None
        df = pd.DataFrame(normalized_docs, columns=all_fields)
        
        # Fill NaN with None for consistency
        df = df.where(pd.notnull(df), None)
        
        return df
    
    def load_flattened_records(
        self,
        collection_name: str,
        limit: int = 5000,
        skip: int = 0
    ) -> List[Dict[str, Any]]:
        """
        Load MongoDB records and flatten them for comparison with JSON records.
        
        This method:
        1. Loads documents from MongoDB
        2. Flattens each document using JSONFlattener
        3. Converts ObjectId to string
        4. Returns list of flattened dictionaries
        
        Args:
            collection_name: Name of the collection
            limit: Maximum number of documents to load
            skip: Number of documents to skip
            
        Returns:
            List of flattened document dictionaries
        """
        # Load raw documents from MongoDB
        documents = self.mongo.sample_documents(collection_name, limit, skip)
        
        if not documents:
            return []
        
        flattened_records = []
        
        for doc in documents:
            # Convert ObjectId to string before flattening
            if '_id' in doc and isinstance(doc['_id'], ObjectId):
                doc['_id'] = str(doc['_id'])
            
            # Flatten the document
            flattened = self.flattener.flatten(doc)
            
            # Convert any remaining ObjectIds to strings
            for key, value in flattened.items():
                if isinstance(value, ObjectId):
                    flattened[key] = str(value)
                elif isinstance(value, datetime):
                    flattened[key] = value.isoformat()
            
            flattened_records.append(flattened)
        
        return flattened_records
    
    def _identify_changes(
        self,
        original_df: pd.DataFrame,
        edited_df: pd.DataFrame
    ) -> Tuple[List[Dict], List[Dict], List[str]]:
        """
        Identify new, updated, and deleted documents.
        
        Args:
            original_df: Original DataFrame from MongoDB
            edited_df: Edited DataFrame from UI
            
        Returns:
            Tuple of (new_docs, updated_docs, deleted_ids)
        """
        new_docs = []
        updated_docs = []
        deleted_ids = []
        
        # Get original IDs
        original_ids = set()
        if '_id' in original_df.columns:
            original_ids = set(original_df['_id'].dropna().astype(str))
        
        # Get edited IDs
        edited_ids = set()
        if '_id' in edited_df.columns:
            edited_ids = set(edited_df['_id'].dropna().astype(str))
        
        # Find deleted documents (in original but not in edited)
        deleted_ids = list(original_ids - edited_ids)
        
        # Process edited DataFrame
        for idx, row in edited_df.iterrows():
            row_dict = row.to_dict()
            
            # Remove None values for cleaner documents
            row_dict = {k: v for k, v in row_dict.items() if pd.notna(v)}
            
            doc_id = row_dict.get('_id')
            
            if not doc_id or doc_id == '' or pd.isna(doc_id):
                # New document (no _id)
                # Remove _id field if empty
                row_dict.pop('_id', None)
                new_docs.append(row_dict)
            
            elif str(doc_id) not in original_ids:
                # New document (has _id but not in original)
                # This shouldn't happen normally, but handle it
                row_dict.pop('_id', None)
                new_docs.append(row_dict)
            
            else:
                # Potentially updated document
                # Find original row
                original_row = original_df[original_df['_id'] == doc_id]
                
                if not original_row.empty:
                    original_dict = original_row.iloc[0].to_dict()
                    original_dict = {k: v for k, v in original_dict.items() if pd.notna(v)}
                    
                    # Check if anything changed
                    changed = False
                    updates = {}
                    
                    for key, value in row_dict.items():
                        if key == '_id':
                            continue
                        
                        original_value = original_dict.get(key)
                        
                        # Compare values (handle None/NaN)
                        if pd.isna(value) and pd.isna(original_value):
                            continue
                        elif value != original_value:
                            changed = True
                            updates[key] = value
                    
                    # Check for removed fields (in original but not in edited)
                    for key in original_dict.keys():
                        if key != '_id' and key not in row_dict:
                            changed = True
                            updates[key] = None
                    
                    if changed:
                        updated_docs.append({
                            '_id': doc_id,
                            'updates': updates
                        })
        
        return new_docs, updated_docs, deleted_ids
    
    def preview_changes(
        self,
        original_df: pd.DataFrame,
        edited_df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Preview changes between original and edited DataFrames.
        
        Args:
            original_df: Original DataFrame
            edited_df: Edited DataFrame
            
        Returns:
            Dictionary with change summary
        """
        new_docs, updated_docs, deleted_ids = self._identify_changes(
            original_df, edited_df
        )
        
        return {
            'new_count': len(new_docs),
            'updated_count': len(updated_docs),
            'deleted_count': len(deleted_ids),
            'new_docs': new_docs,
            'updated_docs': updated_docs,
            'deleted_ids': deleted_ids
        }
    
    def apply_changes(
        self,
        collection_name: str,
        original_df: pd.DataFrame,
        edited_df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Apply changes from edited DataFrame back to MongoDB.
        
        This method:
        1. Identifies new, updated, and deleted documents
        2. Inserts new documents
        3. Updates existing documents (only changed fields)
        4. Deletes removed documents
        
        Args:
            collection_name: Name of the collection
            original_df: Original DataFrame from MongoDB
            edited_df: Edited DataFrame from UI
            
        Returns:
            Dictionary with operation results
        """
        collection = self.mongo.get_collection(collection_name)
        
        if collection is None:
            return {
                'success': False,
                'error': 'Collection not found or not connected to MongoDB'
            }
        
        # Identify changes
        new_docs, updated_docs, deleted_ids = self._identify_changes(
            original_df, edited_df
        )
        
        results = {
            'success': True,
            'inserted': 0,
            'updated': 0,
            'deleted': 0,
            'errors': []
        }
        
        try:
            # Insert new documents
            if new_docs:
                try:
                    insert_result = collection.insert_many(new_docs)
                    results['inserted'] = len(insert_result.inserted_ids)
                except Exception as e:
                    results['errors'].append(f"Insert error: {str(e)}")
            
            # Update existing documents
            for update_doc in updated_docs:
                try:
                    doc_id = update_doc['_id']
                    updates = update_doc['updates']
                    
                    # Convert string _id back to ObjectId
                    if isinstance(doc_id, str):
                        doc_id = ObjectId(doc_id)
                    
                    # Prepare update operation
                    set_updates = {k: v for k, v in updates.items() if v is not None}
                    unset_updates = {k: "" for k, v in updates.items() if v is None}
                    
                    update_op = {}
                    if set_updates:
                        update_op['$set'] = set_updates
                    if unset_updates:
                        update_op['$unset'] = unset_updates
                    
                    if update_op:
                        collection.update_one(
                            {'_id': doc_id},
                            update_op
                        )
                        results['updated'] += 1
                
                except Exception as e:
                    results['errors'].append(f"Update error for {doc_id}: {str(e)}")
            
            # Delete removed documents
            for doc_id in deleted_ids:
                try:
                    # Convert string _id back to ObjectId
                    if isinstance(doc_id, str):
                        doc_id = ObjectId(doc_id)
                    
                    collection.delete_one({'_id': doc_id})
                    results['deleted'] += 1
                
                except Exception as e:
                    results['errors'].append(f"Delete error for {doc_id}: {str(e)}")
        
        except Exception as e:
            results['success'] = False
            results['errors'].append(f"General error: {str(e)}")
        
        return results
