"""
Universal JSON Flattener

This module provides recursive JSON flattening functionality to convert
deeply nested JSON structures into flat key-value pairs suitable for
tabular display and editing.

Key principles:
- NO data loss - ALL fields at ANY depth are preserved
- NO hardcoded field names - works with ANY JSON structure
- Recursive parsing of nested objects and arrays
- Dot notation for nested paths (e.g., "parent.child.0.field")
- Type preservation for all values

CRITICAL: This is a data integrity critical component.
Every field from source JSON MUST appear in flattened output.
"""

from typing import Dict, Any, List, Union, Optional
import json


class JSONFlattener:
    """
    Universal JSON flattener that converts nested structures to flat key-value pairs.
    
    Handles:
    - Nested dictionaries (unlimited depth)
    - Arrays of primitives
    - Arrays of objects
    - Mixed nested structures
    - Null values
    - All primitive types
    """
    
    def __init__(self, separator: str = "."):
        """
        Initialize flattener.
        
        Args:
            separator: Character to use for path separation (default: ".")
        """
        self.separator = separator
    
    def flatten(
        self,
        data: Union[Dict[str, Any], List[Any]],
        parent_key: str = "",
        preserve_arrays: bool = False
    ) -> Dict[str, Any]:
        """
        Recursively flatten a nested JSON structure.
        
        Args:
            data: Input data (dict or list)
            parent_key: Current path prefix
            preserve_arrays: If True, keep primitive arrays as-is instead of flattening
            
        Returns:
            Flattened dictionary with dot-notation keys
            
        Examples:
            Input: {"a": {"b": {"c": 1}}}
            Output: {"a.b.c": 1}
            
            Input: {"items": [{"id": 1}, {"id": 2}]}
            Output: {"items.0.id": 1, "items.1.id": 2}
            
            Input: {"tags": ["red", "blue"]}
            Output: {"tags.0": "red", "tags.1": "blue"}
            (or {"tags": ["red", "blue"]} if preserve_arrays=True)
        """
        flattened = {}
        
        if isinstance(data, dict):
            # Handle dictionary
            for key, value in data.items():
                new_key = f"{parent_key}{self.separator}{key}" if parent_key else key
                
                if isinstance(value, dict):
                    # Recurse into nested dict
                    flattened.update(self.flatten(value, new_key, preserve_arrays))
                elif isinstance(value, list):
                    # Handle list
                    flattened.update(self._flatten_list(value, new_key, preserve_arrays))
                else:
                    # Store primitive value
                    flattened[new_key] = value
        
        elif isinstance(data, list):
            # Handle top-level list
            flattened.update(self._flatten_list(data, parent_key, preserve_arrays))
        
        else:
            # Primitive value at root
            if parent_key:
                flattened[parent_key] = data
            else:
                flattened["_value"] = data
        
        return flattened
    
    def _flatten_list(
        self,
        items: List[Any],
        parent_key: str,
        preserve_arrays: bool
    ) -> Dict[str, Any]:
        """
        Flatten a list into indexed keys.
        
        Args:
            items: List to flatten
            parent_key: Current path prefix
            preserve_arrays: If True, keep primitive arrays as-is
            
        Returns:
            Flattened dictionary
        """
        flattened = {}
        
        # Check if this is a list of primitives
        if preserve_arrays and items and all(not isinstance(item, (dict, list)) for item in items):
            # Keep primitive arrays intact
            flattened[parent_key] = items
            return flattened
        
        # Flatten list with indices
        for index, item in enumerate(items):
            new_key = f"{parent_key}{self.separator}{index}"
            
            if isinstance(item, dict):
                # Recurse into dict in array
                flattened.update(self.flatten(item, new_key, preserve_arrays))
            elif isinstance(item, list):
                # Recurse into nested list
                flattened.update(self._flatten_list(item, new_key, preserve_arrays))
            else:
                # Store primitive value
                flattened[new_key] = item
        
        return flattened
    
    def unflatten(self, flat_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Reconstruct nested structure from flattened data.
        
        Args:
            flat_data: Flattened dictionary with dot-notation keys
            
        Returns:
            Nested dictionary/list structure
            
        Note: This is used when writing back to MongoDB or JSON.
        """
        result = {}
        
        for flat_key, value in flat_data.items():
            # Skip metadata fields
            if flat_key.startswith('_') and flat_key in ['_id', '_source_file', '_source_path', '_source_competition']:
                result[flat_key] = value
                continue
            
            # Split the key into parts
            parts = flat_key.split(self.separator)
            
            # Navigate/create the nested structure
            current = result
            for i, part in enumerate(parts[:-1]):
                # Check if this part is an array index
                if part.isdigit():
                    # Convert parent to list if needed
                    parent_key = parts[i-1] if i > 0 else None
                    if parent_key and not isinstance(current, list):
                        # This shouldn't happen in well-formed data
                        pass
                    
                    index = int(part)
                    
                    # Ensure list exists and is large enough
                    if not isinstance(current, list):
                        current = []
                    
                    while len(current) <= index:
                        current.append({})
                    
                    current = current[index]
                else:
                    # Regular dict key
                    if part not in current:
                        # Peek ahead to see if next part is a digit (array)
                        next_part = parts[i+1] if i+1 < len(parts) else None
                        if next_part and next_part.isdigit():
                            current[part] = []
                        else:
                            current[part] = {}
                    
                    current = current[part]
            
            # Set the final value
            final_key = parts[-1]
            if final_key.isdigit():
                # Array index
                index = int(final_key)
                if not isinstance(current, list):
                    current = []
                while len(current) <= index:
                    current.append(None)
                current[index] = value
            else:
                # Dict key
                if isinstance(current, dict):
                    current[final_key] = value
        
        return result
    
    def get_all_keys(self, records: List[Dict[str, Any]]) -> List[str]:
        """
        Extract all unique keys from a list of flattened records.
        
        Args:
            records: List of flattened dictionaries
            
        Returns:
            Sorted list of all unique keys
        """
        all_keys = set()
        
        for record in records:
            all_keys.update(record.keys())
        
        # Sort keys for consistent ordering
        # Metadata keys first, then alphabetical
        metadata_keys = [k for k in all_keys if k.startswith('_')]
        data_keys = sorted([k for k in all_keys if not k.startswith('_')])
        
        return metadata_keys + data_keys
    
    def normalize_records(
        self,
        records: List[Dict[str, Any]],
        fill_missing: bool = True
    ) -> List[Dict[str, Any]]:
        """
        Ensure all records have the same keys (columns).
        
        Args:
            records: List of flattened dictionaries
            fill_missing: If True, add missing keys with None values
            
        Returns:
            List of normalized records with consistent keys
        """
        if not records:
            return []
        
        # Get all unique keys
        all_keys = self.get_all_keys(records)
        
        if not fill_missing:
            return records
        
        # Normalize each record
        normalized = []
        for record in records:
            normalized_record = {}
            for key in all_keys:
                normalized_record[key] = record.get(key, None)
            normalized.append(normalized_record)
        
        return normalized
    
    def flatten_json_file(
        self,
        json_data: Dict[str, Any],
        extract_arrays: bool = True
    ) -> List[Dict[str, Any]]:
        """
        Flatten a JSON file that may contain arrays of records.
        
        This is specifically designed for ODF JSON files where the actual
        records are nested inside arrays within the structure.
        
        Args:
            json_data: Parsed JSON data
            extract_arrays: If True, extract array items as separate records
            
        Returns:
            List of flattened records
            
        Example:
            Input: {
                "odf_body": {
                    "competition": {
                        "participant": [
                            {"code": 1, "name": "A"},
                            {"code": 2, "name": "B"}
                        ]
                    }
                }
            }
            
            Output: [
                {
                    "odf_body.competition.participant.0.code": 1,
                    "odf_body.competition.participant.0.name": "A"
                },
                {
                    "odf_body.competition.participant.1.code": 2,
                    "odf_body.competition.participant.1.name": "B"
                }
            ]
        """
        if not extract_arrays:
            # Simple flatten - entire JSON becomes one record
            return [self.flatten(json_data)]
        
        # Find arrays in the structure and extract as separate records
        records = []
        
        # Flatten the entire structure first
        flat = self.flatten(json_data)
        
        # Group by array indices
        # Find all keys that contain array indices
        array_paths = {}
        for key in flat.keys():
            parts = key.split(self.separator)
            
            # Find the first array index in the path
            for i, part in enumerate(parts):
                if part.isdigit():
                    # This is an array index
                    array_base = self.separator.join(parts[:i])
                    array_index = int(part)
                    array_suffix = self.separator.join(parts[i+1:]) if i+1 < len(parts) else ""
                    
                    if array_base not in array_paths:
                        array_paths[array_base] = {}
                    
                    if array_index not in array_paths[array_base]:
                        array_paths[array_base][array_index] = {}
                    
                    # Store the value with the suffix as key
                    if array_suffix:
                        array_paths[array_base][array_index][array_suffix] = flat[key]
                    else:
                        array_paths[array_base][array_index]['_value'] = flat[key]
                    
                    break
        
        # If we found arrays, extract them as separate records
        if array_paths:
            # Find the deepest/most specific array (usually the data array)
            # This is typically something like "odf_body.competition.participant"
            deepest_array = max(array_paths.keys(), key=lambda x: x.count(self.separator))
            
            # Extract records from this array
            for index in sorted(array_paths[deepest_array].keys()):
                record = array_paths[deepest_array][index].copy()
                
                # Add non-array fields from the parent structure
                for key, value in flat.items():
                    # Skip keys that are part of the array we're extracting
                    if not key.startswith(deepest_array + self.separator):
                        record[key] = value
                
                records.append(record)
        else:
            # No arrays found, return single record
            records = [flat]
        
        return records


def flatten_json(
    data: Union[Dict[str, Any], List[Any]],
    separator: str = "."
) -> Dict[str, Any]:
    """
    Convenience function to flatten JSON data.
    
    Args:
        data: Input JSON data
        separator: Path separator (default: ".")
        
    Returns:
        Flattened dictionary
    """
    flattener = JSONFlattener(separator)
    return flattener.flatten(data)


def unflatten_json(
    flat_data: Dict[str, Any],
    separator: str = "."
) -> Dict[str, Any]:
    """
    Convenience function to unflatten data.
    
    Args:
        flat_data: Flattened dictionary
        separator: Path separator (default: ".")
        
    Returns:
        Nested structure
    """
    flattener = JSONFlattener(separator)
    return flattener.unflatten(flat_data)
