#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Performance Optimization Utilities
Provides parallel processing, caching, batch operations, and adaptive timeouts
"""

import time
import hashlib
import logging
from typing import Callable, Any, Optional, Dict, List, Tuple
from functools import wraps
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock


class SimpleCache:
    """Simple in-memory cache with TTL."""
    
    def __init__(self, ttl: int = 3600):
        """Initialize cache.
        
        Args:
            ttl: Time to live in seconds
        """
        self.cache: Dict[str, Tuple[Any, datetime]] = {}
        self.ttl = ttl
        self.lock = Lock()
    
    def _generate_key(self, *args, **kwargs) -> str:
        """Generate cache key from arguments.
        
        Args:
            *args: Positional arguments
            **kwargs: Keyword arguments
            
        Returns:
            Cache key string
        """
        key_str = str(args) + str(sorted(kwargs.items()))
        return hashlib.md5(key_str.encode()).hexdigest()
    
    def get(self, key: str) -> Optional[Any]:
        """Get value from cache.
        
        Args:
            key: Cache key
            
        Returns:
            Cached value or None if not found or expired
        """
        with self.lock:
            if key in self.cache:
                value, timestamp = self.cache[key]
                if datetime.now() - timestamp < timedelta(seconds=self.ttl):
                    return value
                else:
                    del self.cache[key]
            return None
    
    def set(self, key: str, value: Any):
        """Set value in cache.
        
        Args:
            key: Cache key
            value: Value to cache
        """
        with self.lock:
            self.cache[key] = (value, datetime.now())
    
    def clear(self):
        """Clear all cache entries."""
        with self.lock:
            self.cache.clear()
    
    def cleanup_expired(self):
        """Remove expired entries."""
        with self.lock:
            now = datetime.now()
            expired_keys = [
                key for key, (_, timestamp) in self.cache.items()
                if now - timestamp >= timedelta(seconds=self.ttl)
            ]
            for key in expired_keys:
                del self.cache[key]


class AdaptiveTimeout:
    """Adaptive timeout based on historical performance."""
    
    def __init__(self, min_timeout: float = 5.0, max_timeout: float = 60.0, 
                 initial_timeout: float = 30.0, factor: float = 1.5):
        """Initialize adaptive timeout.
        
        Args:
            min_timeout: Minimum timeout in seconds
            max_timeout: Maximum timeout in seconds
            initial_timeout: Initial timeout value
            factor: Multiplier for timeout adjustment
        """
        self.min_timeout = min_timeout
        self.max_timeout = max_timeout
        self.current_timeout = initial_timeout
        self.factor = factor
        self.history: List[float] = []
        self.lock = Lock()
    
    def get_timeout(self) -> float:
        """Get current timeout value.
        
        Returns:
            Timeout in seconds
        """
        with self.lock:
            return self.current_timeout
    
    def record_duration(self, duration: float):
        """Record operation duration and adjust timeout.
        
        Args:
            duration: Operation duration in seconds
        """
        with self.lock:
            self.history.append(duration)
            # Keep only last 10 measurements
            if len(self.history) > 10:
                self.history.pop(0)
            
            # Calculate average
            if self.history:
                avg_duration = sum(self.history) / len(self.history)
                # Set timeout to 2x average, but within bounds
                self.current_timeout = max(
                    self.min_timeout,
                    min(self.max_timeout, avg_duration * 2)
                )


def cached(ttl: int = 3600, cache_instance: Optional[SimpleCache] = None):
    """Decorator for caching function results.
    
    Args:
        ttl: Time to live in seconds
        cache_instance: Optional cache instance to use
        
    Returns:
        Decorated function
    """
    cache = cache_instance or SimpleCache(ttl=ttl)
    
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            key = cache._generate_key(*args, **kwargs)
            cached_value = cache.get(key)
            if cached_value is not None:
                return cached_value
            
            result = func(*args, **kwargs)
            cache.set(key, result)
            return result
        return wrapper
    return decorator


def parallel_process(items: List[Any], func: Callable, max_workers: int = 2, 
                    timeout: Optional[float] = None) -> List[Tuple[int, Any, Optional[Exception]]]:
    """Process items in parallel.
    
    Args:
        items: List of items to process
        func: Function to process each item (takes item and index)
        max_workers: Maximum number of worker threads
        timeout: Optional timeout per item
        
    Returns:
        List of tuples (index, result, error)
    """
    results = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(func, item, idx): idx for idx, item in enumerate(items)}
        
        for future in as_completed(futures, timeout=timeout):
            idx = futures[future]
            try:
                result = future.result(timeout=timeout)
                results.append((idx, result, None))
            except Exception as e:
                results.append((idx, None, e))
    
    # Sort by index to maintain order
    results.sort(key=lambda x: x[0])
    return results


def batch_process(items: List[Any], func: Callable, batch_size: int = 10) -> List[Any]:
    """Process items in batches.
    
    Args:
        items: List of items to process
        func: Function to process each batch (takes list of items)
        batch_size: Number of items per batch
        
    Returns:
        List of results
    """
    results = []
    
    for i in range(0, len(items), batch_size):
        batch = items[i:i + batch_size]
        try:
            batch_result = func(batch)
            if isinstance(batch_result, list):
                results.extend(batch_result)
            else:
                results.append(batch_result)
        except Exception as e:
            logging.error(f"Error processing batch {i // batch_size + 1}: {e}")
            # Add None for each item in failed batch
            results.extend([None] * len(batch))
    
    return results


def measure_time(func: Callable) -> Callable:
    """Decorator to measure function execution time.
    
    Returns:
        Decorated function that returns (result, duration)
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        duration = time.time() - start
        return result, duration
    return wrapper


def retry_with_backoff(max_retries: int = 3, base_delay: float = 1.0, 
                      max_delay: float = 10.0, exponential_base: float = 2.0):
    """Decorator for retrying with exponential backoff.
    
    Args:
        max_retries: Maximum number of retries
        base_delay: Base delay in seconds
        max_delay: Maximum delay in seconds
        exponential_base: Base for exponential backoff
        
    Returns:
        Decorated function
    """
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_error = None
            for attempt in range(1, max_retries + 1):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    last_error = e
                    if attempt < max_retries:
                        delay = min(base_delay * (exponential_base ** (attempt - 1)), max_delay)
                        time.sleep(delay)
                    else:
                        raise last_error
            return None
        return wrapper
    return decorator

