"""
Weather Data Provider Utilities

This module contains utility functions for processing weather data, including:
- Data parsing and validation
- Time series operations
- Data smoothing and normalization
- Visualization helpers

These utilities support the main weather data provider by handling common
data processing tasks and mathematical operations.
"""

from __future__ import annotations

import base64
import io
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, TypedDict, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from numpy.typing import NDArray

matplotlib.use('Agg')  # Use non-interactive backend to avoid GUI warnings


# ---- Typed return payloads ---------------------------------------------------

class SignalRecord(TypedDict):
    data: NDArray[Any]                # e.g., float, int; stays generic
    data_name: str
    times: NDArray[Any]               # typically datetime64[ns, UTC]
    errored: bool

class FetchDataResult(TypedDict):
    id: str
    signals: List[SignalRecord]
    errored_signals: List[str]


def parse_value(value: Any) -> Union[datetime, float, int, str, None]:
    """
    Parse a value to its most appropriate data type (datetime, float, int, or original).
    
    This function attempts to convert input values in the following priority order:
    1. Datetime (with automatic unit detection for timestamps)
    2. Float
    3. Integer
    4. Original value (if all conversions fail)
    
    Args:
        value: Input value to parse (can be string, number, or datetime)
        
    Returns:
        Parsed value in most appropriate type, or None for empty strings
        
    Timestamp Unit Detection:
        - 10 digits: seconds since epoch
        - 13 digits: milliseconds since epoch  
        - 16 digits: microseconds since epoch
        - 19 digits: nanoseconds since epoch
        - Default: milliseconds if uncertain
    """
    # First, try to parse as datetime
    try:
        # Determine unit based on timestamp length
        if isinstance(value, (int, float)):
            timestamp_str = str(value)
            length = len(timestamp_str)

            if length == 10:
                unit = "s"  # Seconds
            elif length == 13:
                unit = "ms"  # Milliseconds
            elif length == 16:
                unit = "us"  # Microseconds
            elif length == 19:
                unit = "ns"  # Nanoseconds
            else:
                # Default to milliseconds if unsure
                unit = "ms"

            # Use UTC timezone to avoid timezone conversion issues
            dt = pd.to_datetime(value, unit=unit, utc=True)
        else:
            dt = pd.to_datetime(value, utc=True)  # For strings or already datetime objects

        if pd.isna(dt):  # Check if the result is NaT
            raise ValueError
        return dt.to_pydatetime()
    except (ValueError, TypeError):
        pass

    # If not int, try to convert to float
    try:
        return float(value)
    except (ValueError, TypeError):
        pass

    # If not datetime, try to convert to int
    try:
        return int(value)
    except (ValueError, TypeError):
        pass

    # If all conversions fail, return the original value
    if isinstance(value, str) and len(value) == 0:
        return None
    return value


def find_nearest_datetime_index(
    arr: NDArray[Any],
    target_datetime: Optional[datetime]
) -> Optional[int]:
    """
    Find the index of the nearest datetime in an array to a target datetime.
    
    Handles both timezone-aware and naive datetime objects consistently by
    normalizing timezone information between the array and target.
    
    Args:
        arr (array-like): Array of datetime values to search
        target_datetime (datetime): Target datetime to find nearest match for
        
    Returns:
        int or None: Index of nearest datetime, or None if array is empty or error occurs
        
    Timezone Handling:
        - If target has timezone but array doesn't: removes target timezone
        - If target lacks timezone but array has it: adds UTC timezone to target
        - Ensures consistent comparison by aligning timezone information
    """
    if target_datetime is None:
        return None

    # Convert array to numpy array if it's not already
    arr = np.array(arr)

    if len(arr) == 0:
        return None

    # Check if the array contains timezone-aware datetimes
    try:
        sample_time = pd.Timestamp(arr[0])
        arr_has_timezone = sample_time.tzinfo is not None
    except:
        arr_has_timezone = False

    # Check if target datetime has timezone
    target_has_timezone = hasattr(target_datetime, "tzinfo") and target_datetime.tzinfo is not None

    # Make timezone consistent between array and target
    if target_has_timezone and not arr_has_timezone:
        # Remove timezone from target to match array
        target_datetime = target_datetime.replace(tzinfo=None)
    elif not target_has_timezone and arr_has_timezone:
        # Add UTC timezone to target to match array
        target_datetime = target_datetime.replace(tzinfo=timezone.utc)

    # Convert target_datetime to nanoseconds since epoch
    try:
        target_ns = int(target_datetime.timestamp() * 1e9)
    except (AttributeError, TypeError, ValueError) as e:
        print(f"Error converting target datetime to timestamp: {e}")
        return None

    try:
        # Convert datetime64 array to nanoseconds since epoch
        arr_ns = arr.astype("datetime64[ns]").astype(np.int64)

        # Find the index of the nearest value
        nearest_index = np.abs(arr_ns - target_ns).argmin()
        return nearest_index
    except Exception as e:
        print(f"Error finding nearest datetime index: {e}")
        return None


def simple_moving_average(
    _record_name: str,
    _signal_name: str,
    raw_signal: NDArray[Any],
    _times: NDArray[Any],
    parameters: Dict[str, Any]
) -> NDArray[Any]:
    """
    Apply simple moving average smoothing to a 1D array.
    
    This function smooths data using a uniform window and pads edges to maintain
    the original array length. Window size is validated and constrained to
    reasonable bounds.
    
    Args:
        _record_name (str): The name of the record being analyzed (unused)
        _signal_name (str): The name of the signal (unused)
        raw_signal (array-like): Input data to smooth
        _times (array-like): Time array of the record (unused)
        parameters (dict): Dictionary containing smoothing parameters
        
    Parameters Dictionary Keys:
        - "simple_moving_average_moving_average_window_size" or "moving_average_window_size": int
          Window size for moving average (default: 1, min: 1, max: data length)
    
    Returns:
        numpy.array: Smoothed data with same length as input
        
    Edge Handling:
        - Uses 'valid' convolution then pads edges with first/last smoothed values
        - Left padding: first smoothed value repeated
        - Right padding: last smoothed value repeated
    """
    # Convert input to numpy array
    data = np.array(raw_signal)
    # Try both parameter naming conventions for compatibility
    window_size = parameters.get("simple_moving_average_moving_average_window_size", 
                                parameters.get("moving_average_window_size", 1))
    window_size = int(window_size)

    if window_size < 1:
        print("Window size must be >= 1")
        window_size = 1

    if window_size > data.shape[0]:
        print("Window size must be less than total length")
        window_size = data.shape[0]
        print(f"Using window size equal to full data length {window_size}")

    # Create weights for simple moving average
    weights = np.ones(window_size) / window_size

    # Apply convolution
    smoothed = np.convolve(data, weights, mode="valid")

    # Pad the edges to maintain original length
    padding = window_size - 1
    left_pad = np.full(padding // 2, smoothed[0])
    right_pad = np.full(padding - padding // 2, smoothed[-1])

    return np.concatenate([left_pad, smoothed, right_pad])


def robust_scaling(
    record_name: str,
    signal_name: str,
    raw_signal: NDArray[Any],
    times: NDArray[Any],
    parameters: Dict[str, Any]
) -> NDArray[Any]:
    """
    Apply robust scaling to a 1D array with comprehensive edge case handling.
    
    Robust scaling uses median and interquartile range (IQR) instead of mean
    and standard deviation, making it less sensitive to outliers.
    
    Formula: (x - median) / IQR
    
    Args:
        record_name (str): The name of the record being analyzed (unused)
        signal_name (str): The name of the signal (unused)
        raw_signal (array-like): Input data to scale
        times (array-like): Time array of the record (unused)
        parameters (dict): The dictionary of available parameters (unused)

    Returns:
        numpy.array: Robustly scaled data
        
    Edge Case Handling:
        - Empty input: returns empty array
        - All identical values: returns zeros
        - Zero IQR with non-constant data: falls back to standard deviation
        - Single element: returns zero
        - Extreme values: clipped to ±1e6 to prevent infinity
    """
    data = np.asarray(raw_signal)

    # Handle empty input
    if data.size == 0:
        return np.array([])

    # Calculate robust statistics
    median = np.median(data)
    q1, q3 = np.percentile(data, [25, 75])
    iqr = q3 - q1

    # Handle edge cases
    if iqr == 0:
        # Case 1: All values identical
        if np.all(data == data[0]):
            return np.zeros_like(data)
        # Case 2: Use std dev as fallback for non-constant data with zero IQR
        iqr = np.std(data)
        # Final fallback if std dev is also zero
        if iqr == 0:
            iqr = 1.0

    # Handle single-element edge case
    if data.size == 1:
        return np.array([0.0])  # Single value becomes zero after scaling

    # Apply scaling with numerical stability
    scaled_data = (data - median) / iqr

    # Clip extreme values to prevent +/- infinity (optional)
    scaled_data = np.clip(scaled_data, -1e6, 1e6)

    return scaled_data


def matplotlib_to_plotly_base64(fig_mpl: matplotlib.figure.Figure) -> str:
    """
    Convert matplotlib figure to base64-encoded PNG for embedding in Plotly.
    
    This utility enables complex seaborn plots that can't be easily replicated
    in Plotly to be embedded as images within Plotly figures.
    
    Args:
        fig_mpl (matplotlib.figure.Figure): Matplotlib figure to convert
        
    Returns:
        str: Base64-encoded PNG data URL ready for use in Plotly layout images
        
    Memory Management:
        - Automatically closes matplotlib figure after conversion
        - Uses BytesIO buffer for efficient memory usage
        - High DPI (150) for crisp image quality
    """
    buf = io.BytesIO()
    fig_mpl.savefig(buf, format='png', bbox_inches='tight', dpi=150)
    buf.seek(0)
    img_base64 = base64.b64encode(buf.read()).decode('utf-8')
    buf.close()
    plt.close(fig_mpl)  # Important: close the matplotlib figure to free memory
    return f"data:image/png;base64,{img_base64}"


def trim_data(
    record_data: FetchDataResult,
    trim_1: Optional[str] = None,
    trim_2: Optional[str] = None
) -> FetchDataResult:
    """
    Trim record data based on time range specifications.
    
    This function handles timezone-aware trimming by:
    1. Detecting timezone format in existing data
    2. Parsing trim values with timezone consistency
    3. Finding nearest time indices for trimming
    4. Applying trim to all signals in the record
    
    Args:
        record_data (FetchDataResult): Record data dictionary containing signals
        trim_1 (str, optional): Start time for trimming (various formats accepted)
        trim_2 (str, optional): End time for trimming (various formats accepted)
        
    Returns:
        FetchDataResult: Modified record_data with trimmed signals
        
    Record Data Structure:
        Expected: {"signals": [list of signal dictionaries]}
        Signal: {"data": array, "times": array, ...}
        
    Trimming Logic:
        - trim_1 only: trim from start time to end of data
        - trim_2 only: trim from beginning of data to end time  
        - Both: trim to specified time range
        - Neither: return data unchanged
    """
    signals = []
    for signal in record_data["signals"]:
        data = signal["data"]
        times = signal["times"]
        index1 = None
        index2 = None

        # Check if times array has timezone info
        # Try to determine if the numpy array contains timezone-aware datetimes
        has_timezone = False
        if len(times) > 0:
            try:
                # Convert back to pandas datetime temporarily to check timezone
                sample_time = pd.Timestamp(times[0])
                has_timezone = sample_time.tzinfo is not None
            except:
                has_timezone = False

        if trim_1 is not None and len(trim_1) != 0:
            # Use parse_value to handle different timestamp formats
            trim1 = parse_value(trim_1)
            # Make timezone consistent with times array
            if hasattr(trim1, "tzinfo") and trim1.tzinfo is not None:
                if not has_timezone:
                    trim1 = trim1.replace(tzinfo=None)
            else:
                if has_timezone:
                    trim1 = trim1.replace(tzinfo=timezone.utc)

            index1 = find_nearest_datetime_index(times, trim1)

        if trim_2 is not None and len(trim_2) != 0:
            # Use parse_value to handle different timestamp formats
            trim2 = parse_value(trim_2)
            # Make timezone consistent with times array
            if hasattr(trim2, "tzinfo") and trim2.tzinfo is not None:
                if not has_timezone:
                    trim2 = trim2.replace(tzinfo=None)
            else:
                if has_timezone:
                    trim2 = trim2.replace(tzinfo=timezone.utc)

            index2 = find_nearest_datetime_index(times, trim2)

        if index1 is not None and index2 is not None:
            data = data[index1:index2]
            times = times[index1:index2]
        if index1 is not None and index2 is None:
            data = data[index1:]
            times = times[index1:]
        if index1 is None and index2 is not None:
            data = data[:index2]
            times = times[:index2]

        signal["data"] = data
        signal["times"] = times
        signals.append(signal)

    record_data["signals"] = signals
    return record_data 