"""
Foreign Key Validator Service

This service validates foreign key relationships between collections.
It checks if referenced records exist in target collections.
"""

from typing import Dict, Any, List, Tuple, Set
from services.mongo import MongoService


class ForeignKeyValidator:
    """
    Service for validating foreign key relationships.
    
    This service:
    - Checks if referenced records exist in target collections
    - Reports missing references
    - Validates data integrity across collections
    """
    
    def __init__(self, mongo_service: MongoService):
        """
        Initialize foreign key validator.
        
        Args:
            mongo_service: MongoDB service instance
        """
        self.mongo = mongo_service
    
    def validate_foreign_keys(
        self,
        collection_name: str,
        records: List[Dict[str, Any]],
        foreign_keys: List[Dict[str, str]]
    ) -> Dict[str, Any]:
        """
        Validate foreign key relationships for a set of records.
        
        Args:
            collection_name: Name of the collection being validated
            records: List of records to validate
            foreign_keys: List of foreign key definitions, each with:
                - field: Field name in this collection
                - references_collection: Target collection name
                - references_field: Field name in target collection
        
        Returns:
            Dictionary with validation results:
            - valid: List of valid records
            - invalid: List of invalid records with errors
            - summary: Summary statistics
        """
        if not foreign_keys:
            # No foreign keys to validate
            return {
                'valid': records,
                'invalid': [],
                'summary': {
                    'total_records': len(records),
                    'valid_count': len(records),
                    'invalid_count': 0,
                    'foreign_keys_checked': 0
                }
            }
        
        valid_records = []
        invalid_records = []
        
        # Build cache of existing values for each foreign key
        fk_cache = self._build_foreign_key_cache(foreign_keys)
        
        # Validate each record
        for record in records:
            errors = []
            
            for fk in foreign_keys:
                field = fk['field']
                ref_collection = fk['references_collection']
                ref_field = fk['references_field']
                
                # Get value from record (handle nested fields with dot notation)
                value = self._get_nested_value(record, field)
                
                if value is None or value == '':
                    # Null/empty values are allowed (not enforcing NOT NULL)
                    continue
                
                # Check if value exists in target collection
                cache_key = f"{ref_collection}.{ref_field}"
                if cache_key not in fk_cache:
                    errors.append({
                        'field': field,
                        'value': value,
                        'error': f"Target collection '{ref_collection}' not found or empty"
                    })
                    continue
                
                if value not in fk_cache[cache_key]:
                    errors.append({
                        'field': field,
                        'value': value,
                        'error': f"Referenced value '{value}' not found in {ref_collection}.{ref_field}"
                    })
            
            if errors:
                invalid_records.append({
                    'record': record,
                    'errors': errors
                })
            else:
                valid_records.append(record)
        
        return {
            'valid': valid_records,
            'invalid': invalid_records,
            'summary': {
                'total_records': len(records),
                'valid_count': len(valid_records),
                'invalid_count': len(invalid_records),
                'foreign_keys_checked': len(foreign_keys)
            }
        }
    
    def _build_foreign_key_cache(
        self,
        foreign_keys: List[Dict[str, str]]
    ) -> Dict[str, Set[Any]]:
        """
        Build a cache of existing values for foreign key validation.
        
        Args:
            foreign_keys: List of foreign key definitions
        
        Returns:
            Dictionary mapping "collection.field" to set of existing values
        """
        cache = {}
        
        for fk in foreign_keys:
            ref_collection = fk['references_collection']
            ref_field = fk['references_field']
            cache_key = f"{ref_collection}.{ref_field}"
            
            if cache_key in cache:
                # Already cached
                continue
            
            try:
                # Load all distinct values from target collection
                # Use MongoDB distinct() for efficiency
                values = self.mongo.get_distinct_values(ref_collection, ref_field)
                cache[cache_key] = set(values)
            except Exception as e:
                # Collection might not exist or field might not exist
                print(f"Warning: Could not load values for {cache_key}: {e}")
                cache[cache_key] = set()
        
        return cache
    
    def _get_nested_value(self, record: Dict[str, Any], field_path: str) -> Any:
        """
        Get value from nested dictionary using dot notation.
        
        Args:
            record: Dictionary to extract value from
            field_path: Field path (e.g., "person.name.given")
        
        Returns:
            Value at the specified path, or None if not found
        """
        keys = field_path.split('.')
        value = record
        
        for key in keys:
            if isinstance(value, dict) and key in value:
                value = value[key]
            else:
                return None
        
        return value
    
    def get_validation_report(
        self,
        validation_results: Dict[str, Any]
    ) -> str:
        """
        Generate a human-readable validation report.
        
        Args:
            validation_results: Results from validate_foreign_keys()
        
        Returns:
            Formatted report string
        """
        summary = validation_results['summary']
        invalid = validation_results['invalid']
        
        report = []
        report.append("=" * 60)
        report.append("FOREIGN KEY VALIDATION REPORT")
        report.append("=" * 60)
        report.append(f"Total Records: {summary['total_records']}")
        report.append(f"Valid Records: {summary['valid_count']}")
        report.append(f"Invalid Records: {summary['invalid_count']}")
        report.append(f"Foreign Keys Checked: {summary['foreign_keys_checked']}")
        report.append("")
        
        if invalid:
            report.append("INVALID RECORDS:")
            report.append("-" * 60)
            
            for idx, item in enumerate(invalid, 1):
                record = item['record']
                errors = item['errors']
                
                # Try to identify the record
                record_id = record.get('_id', record.get('code', record.get('id', 'Unknown')))
                report.append(f"\n{idx}. Record: {record_id}")
                
                for error in errors:
                    report.append(f"   ❌ Field '{error['field']}' = '{error['value']}'")
                    report.append(f"      Error: {error['error']}")
        else:
            report.append("✅ All records passed foreign key validation!")
        
        report.append("")
        report.append("=" * 60)
        
        return "\n".join(report)
