"""
Import Preview Engine

This module builds editable preview tables for import decisions.
It tracks user decisions per row (import/skip) without performing any writes.

Key principles:
- NO database writes
- User decisions tracked in memory
- Field-level editing supported
- All changes reversible until apply
"""

from typing import List, Dict, Any, Optional
import pandas as pd
from datetime import datetime
from services.json_flattener import JSONFlattener


class ImportPreviewEngine:
    """
    Service for building and managing import preview data.
    
    Allows users to review, edit, and make decisions about records
    before importing them to MongoDB.
    """
    
    def __init__(self):
        """Initialize import preview engine."""
        self.flattener = JSONFlattener()
    
    def build_preview_dataframe(
        self,
        records: List[Dict[str, Any]],
        record_type: str = 'new'
    ) -> pd.DataFrame:
        """
        Build a DataFrame for preview and editing.
        
        Args:
            records: List of records to preview
            record_type: Type of records ('new', 'duplicate', 'exact')
            
        Returns:
            DataFrame with preview data
        """
        if not records:
            return pd.DataFrame()
        
        # Extract record data based on type
        preview_data = []
        
        for i, record in enumerate(records):
            if record_type == 'new':
                # New records - just the JSON record
                row_data = record.get('json_record', {}).copy()
            elif record_type == 'duplicate':
                # Duplicates - merge JSON and DB records, showing differences
                json_rec = record.get('json_record', {})
                db_rec = record.get('db_record', {})
                comparison = record.get('comparison', {})
                
                # Start with JSON record
                row_data = json_rec.copy()
                
                # Add metadata about match
                row_data['_match_type'] = record.get('match_type', 'unknown')
                row_data['_db_id'] = db_rec.get('_id', '')
                row_data['_difference_count'] = comparison.get('difference_count', 0)
            elif record_type == 'exact':
                # Exact matches - records are already plain dictionaries from DiffEngine
                # Check if it's wrapped or plain
                if 'json_record' in record:
                    row_data = record['json_record'].copy()
                    row_data['_db_id'] = record.get('db_record', {}).get('_id', '')
                else:
                    # Plain record (most common case)
                    row_data = record.copy()
            else:
                row_data = {}
            
            # Add row index for tracking
            row_data['_row_index'] = i
            
            # Add decision tracking field
            row_data['_import_decision'] = 'pending'
            
            preview_data.append(row_data)
        
        # Create DataFrame
        df = pd.DataFrame(preview_data)
        
        # Fill NaN with None
        df = df.where(pd.notnull(df), None)
        
        return df
    
    def build_duplicate_comparison_dataframe(
        self,
        duplicates: List[Dict[str, Any]]
    ) -> pd.DataFrame:
        """
        Build a side-by-side comparison DataFrame for duplicates.
        
        Args:
            duplicates: List of duplicate records with comparison data
            
        Returns:
            DataFrame with JSON vs DB comparison
        """
        if not duplicates:
            return pd.DataFrame()
        
        comparison_data = []
        
        for i, dup in enumerate(duplicates):
            json_rec = dup.get('json_record', {})
            db_rec = dup.get('db_record', {})
            comparison = dup.get('comparison', {})
            differences = comparison.get('differences', {})
            
            # Get all fields
            all_fields = set(json_rec.keys()) | set(db_rec.keys())
            
            for field in all_fields:
                if field == '_id':
                    continue
                
                json_value = json_rec.get(field)
                db_value = db_rec.get(field)
                
                # Determine status
                if field in differences:
                    status = differences[field]['status']
                else:
                    status = 'match'
                
                # Convert values to strings to avoid Arrow serialization errors
                json_value_str = str(json_value) if json_value is not None else None
                db_value_str = str(db_value) if db_value is not None else None
                
                comparison_data.append({
                    '_row_index': i,
                    '_match_type': dup.get('match_type', 'unknown'),
                    '_db_id': str(db_rec.get('_id', '')),
                    'field': field,
                    'json_value': json_value_str,
                    'db_value': db_value_str,
                    'status': status,
                    '_import_decision': 'pending'
                })
        
        df = pd.DataFrame(comparison_data)
        df = df.where(pd.notnull(df), None)
        
        return df
    
    def extract_import_decisions(
        self,
        preview_df: pd.DataFrame
    ) -> Dict[str, List[int]]:
        """
        Extract user import decisions from preview DataFrame.
        
        Args:
            preview_df: DataFrame with user decisions
            
        Returns:
            Dictionary with lists of row indices by decision
        """
        if preview_df.empty or '_import_decision' not in preview_df.columns:
            return {
                'import': [],
                'skip': [],
                'pending': []
            }
        
        decisions = {
            'import': [],
            'skip': [],
            'pending': []
        }
        
        for idx, row in preview_df.iterrows():
            decision = row.get('_import_decision', 'pending')
            row_index = row.get('_row_index', idx)
            
            if decision in decisions:
                decisions[decision].append(int(row_index))
            else:
                decisions['pending'].append(int(row_index))
        
        return decisions
    
    def prepare_records_for_import(
        self,
        preview_df: pd.DataFrame,
        original_records: List[Dict[str, Any]],
        import_indices: List[int],
        unflatten: bool = False,
        collection_name: str = None,
        mongo_service = None
    ) -> List[Dict[str, Any]]:
        """
        Prepare records for import based on user decisions and edits.
        For people collection, automatically consolidates duplicates by name+birthdate.
        
        Args:
            preview_df: DataFrame with user edits (flattened columns)
            original_records: Original record list
            import_indices: List of row indices to import
            unflatten: If True, convert flattened records back to nested structure
            collection_name: Name of target collection (for people consolidation)
            mongo_service: MongoDB service for checking existing records
            
        Returns:
            List of records ready for import (flattened or nested)
        """
        records_to_import = []
        
        for idx in import_indices:
            if idx >= len(preview_df):
                continue
            
            # Get edited row from DataFrame
            row = preview_df.iloc[idx]
            
            # Convert to dictionary, removing metadata fields
            record = row.to_dict()
            
            # Remove metadata fields (all fields starting with _)
            metadata_fields = [
                '_row_index',
                '_import_decision',
                '_match_type',
                '_db_id',
                '_difference_count',
                '_source_file',
                '_source_path',
                '_source_competition',
                '_target_collection',
                '_validation_passed',
                '_document_type',
                '_expected_types'
            ]
            
            for field in metadata_fields:
                record.pop(field, None)
            
            # Also remove any other fields starting with _ (catch-all for metadata)
            record = {k: v for k, v in record.items() if not k.startswith('_')}
            
            # Remove None values
            record = {k: v for k, v in record.items() if pd.notna(v)}
            
            records_to_import.append(record)
        
        # Auto-consolidate for people collection BEFORE unflatten
        if collection_name == 'people' and mongo_service:
            records_to_import = self._consolidate_people_records(
                records_to_import, 
                mongo_service
            )
        
        # Unflatten after consolidation
        if unflatten:
            records_to_import = [
                self.flattener.unflatten(rec) for rec in records_to_import
            ]
        
        return records_to_import
    
    def _consolidate_people_records(
        self,
        new_records: List[Dict[str, Any]],
        mongo_service
    ) -> List[Dict[str, Any]]:
        """
        Consolidate people records by merging duplicates (same name + birthdate).
        Stores all codes in 'codes' array.
        
        Args:
            new_records: Records to import
            mongo_service: MongoDB service
            
        Returns:
            Consolidated records with codes arrays
        """
        from collections import defaultdict
        
        print(f"\n🔄 Auto-consolidation: Processing {len(new_records)} people records...")
        
        people_collection = mongo_service.db['people']
        
        # Group new records by identity
        identity_groups = defaultdict(list)
        
        for record in new_records:
            given_name = str(record.get('given_name', '')).strip().lower()
            family_name = str(record.get('family_name', '')).strip().lower()
            birth_date = str(record.get('birth_date', '')).strip()
            
            if not given_name or not family_name:
                continue
            
            identity_key = (given_name, family_name, birth_date)
            identity_groups[identity_key].append(record)
        
        # Process each identity group
        consolidated_records = []
        updated_count = 0
        new_count = 0
        
        for identity_key, records in identity_groups.items():
            given_name, family_name, birth_date = identity_key
            
            # Collect all codes from new records
            all_codes = set()
            for record in records:
                code = record.get('code')
                if code:
                    code_str = str(int(code)) if isinstance(code, (int, float)) else str(code)
                    all_codes.add(code_str)
            
            # Check if person already exists in DB
            existing_person = people_collection.find_one({
                'given_name': {'$regex': f'^{given_name}$', '$options': 'i'},
                'family_name': {'$regex': f'^{family_name}$', '$options': 'i'},
                'birth_date': birth_date
            })
            
            if existing_person:
                # Person exists - update with new codes and missing fields
                existing_codes = set(existing_person.get('codes', []))
                if 'code' in existing_person and existing_person['code']:
                    existing_codes.add(str(existing_person['code']))
                
                # Merge codes
                all_codes.update(existing_codes)
                codes_list = sorted(list(all_codes))
                
                # Collect all fields from new records
                # Use $setOnInsert for fields that don't exist yet
                update_fields = {
                    'codes': codes_list,
                    'code': codes_list[0] if codes_list else None,
                    'updated_at': datetime.utcnow().isoformat()
                }
                
                # Add missing fields from new records (even if empty)
                for record in records:
                    for field, value in record.items():
                        # Skip internal fields and fields that already exist
                        if field.startswith('_') or field in ['codes', 'code']:
                            continue
                        
                        # Only add field if it doesn't exist in existing person
                        if field not in existing_person:
                            # Add field even if empty/None
                            update_fields[field] = value if value else ''
                
                # Update existing record
                people_collection.update_one(
                    {'_id': existing_person['_id']},
                    {'$set': update_fields}
                )
                updated_count += 1
                # Don't add to import list (already updated)
            else:
                # New person - create with codes array
                base_record = records[0]  # Use first record as base
                
                # Merge all fields from duplicate records
                merged_record = {}
                for record in records:
                    for field, value in record.items():
                        if field not in merged_record or not merged_record[field]:
                            merged_record[field] = value
                
                codes_list = sorted(list(all_codes))
                merged_record['codes'] = codes_list
                merged_record['code'] = codes_list[0] if codes_list else None
                merged_record['created_at'] = datetime.utcnow().isoformat()
                
                consolidated_records.append(merged_record)
                new_count += 1
        
        print(f"   ✅ Updated {updated_count} existing people with new codes")
        print(f"   ✅ Prepared {new_count} new people with codes array")
        print(f"   📊 Total unique identities: {len(identity_groups)}")
        
        return consolidated_records
    
    def build_import_summary(
        self,
        decisions: Dict[str, List[int]],
        record_type: str
    ) -> Dict[str, Any]:
        """
        Build a summary of import decisions.
        
        Args:
            decisions: Dictionary with decision lists
            record_type: Type of records
            
        Returns:
            Summary dictionary
        """
        return {
            'record_type': record_type,
            'total_records': sum(len(v) for v in decisions.values()),
            'to_import': len(decisions.get('import', [])),
            'to_skip': len(decisions.get('skip', [])),
            'pending': len(decisions.get('pending', [])),
            'decisions': decisions
        }
    
    def mark_all_for_import(
        self,
        preview_df: pd.DataFrame
    ) -> pd.DataFrame:
        """
        Mark all records in preview for import.
        
        Args:
            preview_df: Preview DataFrame
            
        Returns:
            Updated DataFrame
        """
        if '_import_decision' in preview_df.columns:
            preview_df['_import_decision'] = 'import'
        
        return preview_df
    
    def mark_all_for_skip(
        self,
        preview_df: pd.DataFrame
    ) -> pd.DataFrame:
        """
        Mark all records in preview to skip.
        
        Args:
            preview_df: Preview DataFrame
            
        Returns:
            Updated DataFrame
        """
        if '_import_decision' in preview_df.columns:
            preview_df['_import_decision'] = 'skip'
        
        return preview_df
    
    def get_field_statistics(
        self,
        preview_df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Get statistics about fields in preview data.
        
        Args:
            preview_df: Preview DataFrame
            
        Returns:
            Dictionary with field statistics
        """
        if preview_df.empty:
            return {}
        
        # Exclude metadata fields
        metadata_fields = [
            '_row_index',
            '_import_decision',
            '_match_type',
            '_db_id',
            '_difference_count'
        ]
        
        data_fields = [col for col in preview_df.columns if col not in metadata_fields]
        
        field_stats = {}
        
        for field in data_fields:
            non_null_count = preview_df[field].notna().sum()
            null_count = preview_df[field].isna().sum()
            
            field_stats[field] = {
                'non_null_count': int(non_null_count),
                'null_count': int(null_count),
                'coverage_percentage': round((non_null_count / len(preview_df) * 100), 2) if len(preview_df) > 0 else 0
            }
        
        return field_stats
