# management/commands/backfill_field_cache.py
from django.core.management.base import BaseCommand
from django.db import transaction
from apps.inventory.models import InventoryItem, FieldValueCache
from collections import defaultdict
import difflib
import re

class Command(BaseCommand):
    help = 'Backfill field value cache from existing inventory with smart deduplication'

    def add_arguments(self, parser):
        parser.add_argument(
            '--business-id',
            type=str,
            help='Backfill for specific business only',
        )
        parser.add_argument(
            '--clear',
            action='store_true',
            help='Clear existing cache before backfilling',
        )
        parser.add_argument(
            '--no-merge',
            action='store_true',
            help='Disable similarity merging (exact matches only)',
        )
        parser.add_argument(
            '--dry-run',
            action='store_true',
            help='Show what would be done without actually doing it',
        )

    def handle(self, *args, **options):
        business_id = options.get('business_id')
        clear_cache = options.get('clear')
        no_merge = options.get('no_merge')
        dry_run = options.get('dry_run')
        
        if dry_run:
            self.stdout.write(self.style.WARNING('DRY RUN MODE - No changes will be made'))
        
        # Clear cache if requested
        if clear_cache and not dry_run:
            if business_id:
                deleted = FieldValueCache.objects.filter(business_id=business_id).delete()
                self.stdout.write(f'Cleared {deleted[0]} cache entries for business {business_id}')
            else:
                deleted = FieldValueCache.objects.all().delete()
                self.stdout.write(f'Cleared {deleted[0]} cache entries')
        
        # Build queryset
        queryset = InventoryItem.objects.select_related('brand', 'category', 'business')
        if business_id:
            queryset = queryset.filter(business_id=business_id)
        
        total = queryset.count()
        self.stdout.write(f'Processing {total} inventory items...')
        
        # Aggregate field values per business
        aggregated_data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: {
            'value': '',
            'brand_id': None,
            'category_id': None,
            'count': 0,
            'variations': []
        })))
        
        # Process all items and aggregate
        for i, item in enumerate(queryset.iterator(chunk_size=1000), 1):
            business_id_str = str(item.business_id)
            brand_id = str(item.brand_id) if item.brand_id else None
            category_id = item.category_id
            
            # Core fields to cache
            core_fields = {
                'model': item.model,
                'processor': item.processor,
                'ram': item.ram,
                'storage': item.storage,
                'graphics_card': item.graphics_card,
                'screen_size': item.screen_size,
                'operating_system': item.operating_system,
            }
            
            # Process core fields
            for field_name, value in core_fields.items():
                if value and isinstance(value, str) and len(value.strip()) >= 2:
                    self._aggregate_value(
                        aggregated_data[business_id_str][field_name],
                        value,
                        brand_id,
                        category_id
                    )
            
            # Process custom fields
            if item.custom_fields:
                for field_name, value in item.custom_fields.items():
                    if value and isinstance(value, str) and len(str(value).strip()) >= 2:
                        self._aggregate_value(
                            aggregated_data[business_id_str][field_name],
                            str(value),
                            brand_id,
                            category_id
                        )
            
            if i % 1000 == 0:
                self.stdout.write(f'Aggregated {i}/{total} items...')
        
        self.stdout.write(f'Aggregation complete.')
        
        # Merge similar values if enabled
        if not no_merge:
            self.stdout.write('Detecting similar values...')
            merged_data = self._merge_similar_values(aggregated_data)
        else:
            merged_data = aggregated_data
            self.stdout.write('Skipping similarity merging (--no-merge flag)')
        
        if dry_run:
            self._show_dry_run_results(merged_data, aggregated_data)
            return
        
        # Bulk create/update cache entries
        self._save_to_cache(merged_data, total)
        
        # Show statistics
        self.show_statistics(options.get('business_id'))
    
    def _aggregate_value(self, field_dict, value, brand_id, category_id):
        """Aggregate a single value"""
        normalized = value.lower().strip()
        
        field_data = field_dict[normalized]
        field_data['count'] += 1
        field_data['variations'].append(value)
        
        # Keep the best-cased version (prefer title case or uppercase)
        if not field_data['value'] or value.istitle() or value.isupper():
            field_data['value'] = value
        
        # Track brand/category associations (use first occurrence)
        if not field_data['brand_id']:
            field_data['brand_id'] = brand_id
        if not field_data['category_id']:
            field_data['category_id'] = category_id
    
    def _is_similar_enough(self, str1, str2):
        """
        Check if two strings are similar enough to merge
        Uses conservative rules to avoid bad merges
        """
        # Must have high character similarity (95%)
        similarity = difflib.SequenceMatcher(None, str1, str2).ratio()
        if similarity < 0.95:
            return False
        
        # Extract key parts (numbers, important words)
        def extract_key_parts(s):
            # Extract all numbers
            numbers = re.findall(r'\d+', s)
            # Extract important technical terms
            words = s.split()
            return set(numbers + [w.lower() for w in words if len(w) > 2])
        
        parts1 = extract_key_parts(str1)
        parts2 = extract_key_parts(str2)
        
        # Key parts must be identical or very similar
        if parts1 and parts2:
            overlap = len(parts1 & parts2) / max(len(parts1), len(parts2))
            if overlap < 0.90:
                return False
        
        return True
    
    def _merge_similar_values(self, aggregated_data):
        """Merge similar values using conservative fuzzy matching"""
        merged_data = defaultdict(lambda: defaultdict(dict))
        merge_count = 0
        
        for business_id, fields in aggregated_data.items():
            for field_name, values in fields.items():
                # Get list of normalized values sorted by count (desc)
                sorted_values = sorted(
                    values.items(),
                    key=lambda x: x[1]['count'],
                    reverse=True
                )
                
                merged_values = {}
                skip_values = set()
                
                for normalized, data in sorted_values:
                    if normalized in skip_values:
                        continue
                    
                    # Check for similar values
                    canonical = normalized
                    canonical_data = data.copy()
                    
                    for other_normalized, other_data in sorted_values:
                        if other_normalized == normalized or other_normalized in skip_values:
                            continue
                        
                        # Use conservative similarity check
                        if self._is_similar_enough(normalized, other_normalized):
                            # Merge into canonical
                            canonical_data['count'] += other_data['count']
                            canonical_data['variations'].extend(other_data['variations'])
                            
                            # Keep the better value
                            if (other_data['value'].istitle() or other_data['value'].isupper()) and \
                               not (canonical_data['value'].istitle() or canonical_data['value'].isupper()):
                                canonical_data['value'] = other_data['value']
                            
                            skip_values.add(other_normalized)
                            merge_count += 1
                            
                            self.stdout.write(
                                self.style.WARNING(
                                    f"  Merging: '{other_normalized}' -> '{canonical}'"
                                )
                            )
                    
                    merged_values[canonical] = canonical_data
                
                merged_data[business_id][field_name] = merged_values
        
        if merge_count > 0:
            self.stdout.write(f'Merged {merge_count} similar values')
        else:
            self.stdout.write('No similar values found to merge')
        
        return merged_data
    
    def _show_dry_run_results(self, merged_data, original_data):
        """Show what would be done in dry run mode"""
        self.stdout.write('\n' + '='*50)
        self.stdout.write('DRY RUN RESULTS')
        self.stdout.write('='*50)
        
        total_original = 0
        total_merged = 0
        
        for business_id, fields in merged_data.items():
            for field_name, values in fields.items():
                original_count = len(original_data[business_id][field_name])
                merged_count = len(values)
                
                total_original += original_count
                total_merged += merged_count
                
                if original_count != merged_count:
                    reduction = original_count - merged_count
                    self.stdout.write(
                        f"{field_name}: {original_count} -> {merged_count} "
                        f"({reduction} merged)"
                    )
        
        self.stdout.write(f"\nTotal reduction: {total_original} -> {total_merged}")
        self.stdout.write('='*50 + '\n')
    
    def _save_to_cache(self, merged_data, total_items):
        """Save merged data to cache"""
        self.stdout.write('Building cache entries...')
        
        cache_entries = []
        total_entries = 0
        update_entries = []
        
        for business_id, fields in merged_data.items():
            for field_name, values in fields.items():
                for normalized, data in values.items():
                    # Check if entry already exists
                    existing = FieldValueCache.objects.filter(
                        business_id=business_id,
                        field_name=field_name,
                        normalized=normalized
                    ).first()
                    
                    if existing:
                        # Update existing entry
                        existing.count = data['count']
                        existing.value = data['value']
                        existing.brand_id = str(data['brand_id']) if data['brand_id'] else None
                        existing.category_id = data['category_id']
                        update_entries.append(existing)
                    else:
                        # Create new entry
                        cache_entries.append(
                            FieldValueCache(
                                business_id=business_id,
                                field_name=field_name,
                                value=data['value'],
                                normalized=normalized,
                                count=data['count'],
                                brand_id=str(data['brand_id']) if data['brand_id'] else None,
                                category_id=data['category_id']
                            )
                        )
                    
                    total_entries += 1
                    
                    # Batch operations
                    if len(cache_entries) >= 1000:
                        with transaction.atomic():
                            FieldValueCache.objects.bulk_create(
                                cache_entries,
                                ignore_conflicts=True
                            )
                        self.stdout.write(f'Inserted {len(cache_entries)} new entries...')
                        cache_entries = []
                    
                    if len(update_entries) >= 1000:
                        with transaction.atomic():
                            FieldValueCache.objects.bulk_update(
                                update_entries,
                                ['count', 'value', 'brand_id', 'category_id']
                            )
                        self.stdout.write(f'Updated {len(update_entries)} existing entries...')
                        update_entries = []
        
        # Save remaining entries
        if cache_entries:
            with transaction.atomic():
                FieldValueCache.objects.bulk_create(
                    cache_entries,
                    ignore_conflicts=True
                )
            self.stdout.write(f'Inserted {len(cache_entries)} new entries')
        
        if update_entries:
            with transaction.atomic():
                FieldValueCache.objects.bulk_update(
                    update_entries,
                    ['count', 'value', 'brand_id', 'category_id']
                )
            self.stdout.write(f'Updated {len(update_entries)} existing entries')
        
        self.stdout.write(self.style.SUCCESS(
            f'Successfully processed {total_entries} cache entries from {total_items} inventory items'
        ))
    
    def show_statistics(self, business_id=None):
        """Show cache statistics"""
        queryset = FieldValueCache.objects.all()
        if business_id:
            queryset = queryset.filter(business_id=business_id)
        
        from django.db.models import Count, Sum, Max
        
        stats = queryset.aggregate(
            total_entries=Count('id'),
            total_uses=Sum('count'),
            max_uses=Max('count')
        )
        
        # Field breakdown
        field_breakdown = queryset.values('field_name').annotate(
            entry_count=Count('id'),
            total_uses=Sum('count')
        ).order_by('-entry_count')[:10]
        
        self.stdout.write('\n' + '='*50)
        self.stdout.write('CACHE STATISTICS')
        self.stdout.write('='*50)
        self.stdout.write(f"Total unique values: {stats['total_entries']}")
        self.stdout.write(f"Total usage count: {stats['total_uses']}")
        self.stdout.write(f"Max usage for single value: {stats['max_uses']}")
        
        self.stdout.write('\nTop Fields:')
        for field in field_breakdown:
            self.stdout.write(
                f"  {field['field_name']}: "
                f"{field['entry_count']} unique values, "
                f"{field['total_uses']} total uses"
            )
        self.stdout.write('='*50 + '\n')