"""
Check Relationships Script

This script checks relationships between people collection and other collections
to verify data integrity and identify missing references.
"""

from services.mongo import MongoService
from collections import defaultdict

def check_relationships():
    """Check relationships between collections."""
    print("=" * 80)
    print("Collection Relationships Checker")
    print("=" * 80)
    
    # Initialize MongoDB service
    print("\n1. Connecting to MongoDB...")
    mongo_service = MongoService()
    
    if not mongo_service.connect():
        print("❌ Failed to connect to MongoDB")
        return False
    
    print("✅ Connected to MongoDB")
    
    # Get people codes (including codes array)
    print("\n2. Loading people collection...")
    people_collection = mongo_service.db['people']
    people_codes = set()
    
    for person in people_collection.find({}, {'code': 1, 'codes': 1}):
        # Check single code field
        if 'code' in person and person['code']:
            code_str = str(int(person['code'])) if isinstance(person['code'], (int, float)) else str(person['code'])
            people_codes.add(code_str)
        
        # Check codes array (for consolidated records)
        if 'codes' in person and isinstance(person['codes'], list):
            for code in person['codes']:
                if code:
                    code_str = str(int(code)) if isinstance(code, (int, float)) else str(code)
                    people_codes.add(code_str)
    
    print(f"✅ Found {len(people_codes)} unique people codes (including all competition codes)")
    
    # Check relationships with other collections
    print("\n3. Checking relationships with other collections...")
    print("=" * 80)
    
    relationships = {
        'participants': {
            'field': 'code',
            'description': 'Participant registrations'
        },
        'odf_cumulative_results': {
            'field': 'competitor.composition.athlete.code',
            'description': 'Cumulative results'
        },
        'odf_rankings': {
            'field': 'competitor.composition.athlete.code',
            'description': 'Rankings'
        },
        'odf_medallists': {
            'field': 'competitor.composition.athlete.code',
            'description': 'Medal winners'
        },
        'odf_results': {
            'field': 'competitor.composition.athlete.code',
            'description': 'Competition results'
        },
        'odf_statistics': {
            'field': 'competitor.composition.athlete.code',
            'description': 'Statistics'
        }
    }
    
    results = {}
    
    for collection_name, config in relationships.items():
        print(f"\n📋 Checking {collection_name} ({config['description']})...")
        
        try:
            collection = mongo_service.db[collection_name]
            total_count = collection.count_documents({})
            
            if total_count == 0:
                print(f"   ℹ️  Collection is empty")
                results[collection_name] = {
                    'total': 0,
                    'valid': 0,
                    'invalid': 0,
                    'missing_codes': []
                }
                continue
            
            field = config['field']
            
            # Get all referenced codes
            referenced_codes = set()
            missing_codes = []
            
            # Handle nested fields (e.g., competitor.composition.athlete.code)
            if '.' in field:
                parts = field.split('.')
                for doc in collection.find():
                    value = doc
                    for part in parts:
                        value = value.get(part, {}) if isinstance(value, dict) else None
                        if value is None:
                            break
                    # Convert to string for comparison (codes can be int or str)
                    if value is not None:
                        code_str = str(int(value)) if isinstance(value, (int, float)) else str(value)
                        referenced_codes.add(code_str)
                        if code_str not in people_codes:
                            missing_codes.append(code_str)
            else:
                for doc in collection.find({field: {"$exists": True}}):
                    code = doc.get(field)
                    if code:
                        code_str = str(int(code)) if isinstance(code, (int, float)) else str(code)
                        referenced_codes.add(code_str)
                        if code_str not in people_codes:
                            missing_codes.append(code_str)
            
            valid_count = len(referenced_codes) - len(set(missing_codes))
            invalid_count = len(set(missing_codes))
            
            results[collection_name] = {
                'total': total_count,
                'referenced_codes': len(referenced_codes),
                'valid': valid_count,
                'invalid': invalid_count,
                'missing_codes': list(set(missing_codes))[:10]  # Show first 10
            }
            
            print(f"   📊 Total records: {total_count}")
            print(f"   📊 Unique codes referenced: {len(referenced_codes)}")
            print(f"   ✅ Valid references: {valid_count}")
            print(f"   ❌ Invalid references: {invalid_count}")
            
            if invalid_count > 0:
                print(f"   ⚠️  Missing codes (first 10): {', '.join(results[collection_name]['missing_codes'][:10])}")
        
        except Exception as e:
            print(f"   ❌ Error checking {collection_name}: {e}")
            results[collection_name] = {'error': str(e)}
    
    # Summary
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    
    total_invalid = sum(r.get('invalid', 0) for r in results.values())
    total_valid = sum(r.get('valid', 0) for r in results.values())
    
    print(f"\n✅ Total valid references: {total_valid}")
    print(f"❌ Total invalid references: {total_invalid}")
    
    if total_invalid > 0:
        print("\n⚠️  WARNING: Some collections reference people codes that don't exist!")
        print("   This may indicate:")
        print("   1. People data is incomplete")
        print("   2. Different code formats are used")
        print("   3. Data needs to be synchronized")
    else:
        print("\n🎉 All relationships are valid!")
    
    # Check organisations relationship
    print("\n" + "=" * 80)
    print("Checking organisations relationship...")
    print("=" * 80)
    
    try:
        orgs_collection = mongo_service.db['organisations']
        org_count = orgs_collection.count_documents({})
        print(f"📊 Total organisations: {org_count}")
        
        # Check if people reference organisations
        people_with_org = people_collection.count_documents({'organisation': {'$exists': True, '$ne': None}})
        print(f"📊 People with organisation field: {people_with_org}")
        
        if people_with_org > 0:
            # Get unique organisation codes from people
            org_codes_in_people = set()
            for person in people_collection.find({'organisation': {'$exists': True}}, {'organisation': 1}):
                org_code = person.get('organisation')
                if org_code:
                    org_codes_in_people.add(org_code)
            
            print(f"📊 Unique organisation codes in people: {len(org_codes_in_people)}")
            
            # Check which exist in organisations collection
            org_codes_in_db = set()
            for org in orgs_collection.find({}, {'code': 1}):
                if 'code' in org:
                    org_codes_in_db.add(org['code'])
            
            missing_orgs = org_codes_in_people - org_codes_in_db
            print(f"✅ Valid organisation references: {len(org_codes_in_people) - len(missing_orgs)}")
            print(f"❌ Missing organisations: {len(missing_orgs)}")
            
            if missing_orgs:
                print(f"   Missing org codes (first 10): {', '.join(list(missing_orgs)[:10])}")
    
    except Exception as e:
        print(f"❌ Error checking organisations: {e}")
    
    print("\n" + "=" * 80)
    return True

if __name__ == "__main__":
    try:
        check_relationships()
    except Exception as e:
        print(f"\n❌ Error: {e}")
        import traceback
        traceback.print_exc()
