"""
Additive Manufacturing (AM) Data Provider

This provider ingests layer CSVs for 4 parts (part01–part04), each containing
rows of measured points and 40 feature columns (per the attached spec).

Design choices:
- Treat each layer CSV as a single "record"; record_id encodes part and layer
  as "partXX_LYYYY" (e.g., part01_L0001). This is stable, readable, and
  unique across the dataset.
- Since CSV rows are ordered scan points, create a numeric time axis based
  on the row index (0..N-1). Set is_date=False to use numeric mode.
- Provide trim_data that trims by numeric index (consistent with GA offline).
- Provide two custom graphers:
  1) layer_line_plot: plot selected features within a single layer/part
  2) layer_trend_by_part: aggregate a chosen feature across layers for each
     part and plot the trend across layer index (1..250)

File structure expected:
apps/additive_manufacturing/am_data/
  ├── part01/ L0001.csv ... L0250.csv
  ├── part02/ ...
  ├── part03/ ...
  └── part04/ ...

CSV assumption:
- No header row. Each CSV has 40 columns. We assign canonical names from a
  fixed schema (see AM_COLUMN_SCHEMA below). Time is synthesized from row index.
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import plotly.graph_objects as go

# Reuse robust utilities from GA offline (numeric trimming and smoothing/normalizing)
from am_utilities import (
    trim_data,
    simple_moving_average,
    robust_scaling,
)

from template_layouts import layout_options
from Labeler.plotly_themes import get_theme_template


PART_DIR_NAMES = ["part01", "part02", "part03", "part04"]

# Canonical 40-column schema (aligned with the provided tables)
AM_COLUMN_SCHEMA = [
    # 1–10: Command/real machine signals
    "part_number",
    "build_time_us",
    "cmd_laser_pos_x_mm",
    "cmd_laser_pos_y_mm",
    "cmd_laser_power_W",
    "cmd_scan_speed_mm_per_s",
    "real_laser_pos_x_mm",
    "real_laser_pos_y_mm",
    "real_laser_power_W",
    "real_scan_speed_mm_per_s",
    # 11–19: Melt pool metrics at thresholds 80/100/120
    "melt_pool_length_thr80_mm",
    "melt_pool_width_thr80_mm",
    "melt_pool_area_thr80_mm2",
    "melt_pool_length_thr100_mm",
    "melt_pool_width_thr100_mm",
    "melt_pool_area_thr100_mm2",
    "melt_pool_length_thr120_mm",
    "melt_pool_width_thr120_mm",
    "melt_pool_area_thr120_mm2",
    # 20–37: LWI pixel values (powder/exposure, LED A/B/C, original/3x3/5x5)
    "lwi_powder_ledA_orig",
    "lwi_powder_ledA_mean3",
    "lwi_powder_ledA_mean5",
    "lwi_powder_ledB_orig",
    "lwi_powder_ledB_mean3",
    "lwi_powder_ledB_mean5",
    "lwi_powder_ledC_orig",
    "lwi_powder_ledC_mean3",
    "lwi_powder_ledC_mean5",
    "lwi_exposure_ledA_orig",
    "lwi_exposure_ledA_mean3",
    "lwi_exposure_ledA_mean5",
    "lwi_exposure_ledB_orig",
    "lwi_exposure_ledB_mean3",
    "lwi_exposure_ledB_mean5",
    "lwi_exposure_ledC_orig",
    "lwi_exposure_ledC_mean3",
    "lwi_exposure_ledC_mean5",
    # 38–40: XCT voxel values
    "xct_voxel_orig",
    "xct_voxel_mean3",
    "xct_voxel_mean5",
]


def _list_layer_files(data_root: Path, part_dir_name: str) -> List[Path]:
    part_dir = data_root / part_dir_name
    if not part_dir.exists():
        return []
    return sorted(part_dir.glob("L*.csv"))


def _format_record_id(part_dir_name: str, csv_path: Path) -> str:
    # L0001.csv -> L0001
    layer_stem = csv_path.stem
    return f"{part_dir_name}_{layer_stem}"


def _parse_record_id(record_id: str) -> Tuple[str, str]:
    """Return (part_dir_name, layer_stem) from formatted record_id."""
    # Expected format: partXX_LYYYY
    try:
        part_dir_name, layer_stem = record_id.split("_", 1)
    except ValueError:
        # Fallback: assume part01 if malformed
        part_dir_name = PART_DIR_NAMES[0]
        layer_stem = record_id
    return part_dir_name, layer_stem


def _load_layer_csv(data_root: Path, part_dir_name: str, layer_stem: str) -> pd.DataFrame:
    csv_path = data_root / part_dir_name / f"{layer_stem}.csv"
    # No header present → read as raw and assign schema
    df = pd.read_csv(csv_path, header=None)
    # If column count matches expected, assign canonical names
    if df.shape[1] == len(AM_COLUMN_SCHEMA):
        df.columns = AM_COLUMN_SCHEMA
    else:
        # Fallback: generate generic names feature_01..feature_N
        df.columns = [f"feature_{i:02d}" for i in range(1, df.shape[1] + 1)]
    return df


def _build_signal_records_from_df(df: pd.DataFrame, signals: List[str]) -> Tuple[List[Dict], List[str]]:
    """Convert specified columns to signal dictionaries using numeric index as time."""
    times = np.arange(len(df))
    data_array: List[Dict] = []
    errored: List[str] = []

    for signal_name in signals:
        if signal_name in df.columns:
            values = df[signal_name].to_numpy()
            record = {
                "data": values,
                "data_name": signal_name,
                "times": times,
                "units": {},
                "dims": ("times",),
            }
            data_array.append(record)
        else:
            errored.append(signal_name)
    return data_array, errored


def _extract_all_signals(data_root: Path) -> List[str]:
    """Get ordered feature list from the first available CSV across all parts."""
    for part in PART_DIR_NAMES:
        files = _list_layer_files(data_root, part)
        if files:
            # Use loader to ensure schema assignment
            df = _load_layer_csv(data_root, part, files[0].stem)
            return df.columns.to_list()
    return []


def laser_path_map(app_control_parameters: Dict, parameters: Dict):
    """2D path map of scan points for the current layer, colored by a feature."""
    theme_value = app_control_parameters["theme_value"]
    record_id = app_control_parameters["record_id"]
    data_coordinator = app_control_parameters["data_coordinator"]

    # Params
    source = parameters.get("laser_path_map_position_source", "real")  # real or command
    color_feature = parameters.get("laser_path_map_color_feature", None)
    point_size = parameters.get("laser_path_map_point_size", 6)

    part_dir_name, layer_stem = _parse_record_id(record_id)
    df = _load_layer_csv(data_coordinator.data_folder, part_dir_name, layer_stem)

    x_col = "real_laser_pos_x_mm" if source == "real" else "cmd_laser_pos_x_mm"
    y_col = "real_laser_pos_y_mm" if source == "real" else "cmd_laser_pos_y_mm"

    # Choose default color feature if none provided
    if color_feature is None:
        default_candidates = [
            "real_scan_speed_mm_per_s",
            "real_laser_power_W",
            "melt_pool_area_thr100_mm2",
        ]
        for cand in default_candidates:
            if cand in df.columns:
                color_feature = cand
                break
        if color_feature is None and len(df.columns) > 0:
            color_feature = df.columns[0]

    fig = go.Figure(
        data=[
            go.Scattergl(
                x=df.get(x_col, []),
                y=df.get(y_col, []),
                mode="markers+lines",
                line=dict(width=1, color="rgba(150,150,150,0.5)"),
                marker=dict(
                    color=df.get(color_feature, []),
                    colorscale="Viridis",
                    showscale=True,
                    size=point_size,
                ),
                name=f"{source} path",
            )
        ]
    )
    fig.update_layout(
        title=f"Laser Path Map - {record_id} (color: {color_feature})",
        xaxis_title=f"{x_col}",
        yaxis_title=f"{y_col}",
        template=get_theme_template(theme_value),
        yaxis=dict(scaleanchor="x", scaleratio=1),
    )
    return fig


def layer_trend_by_part(app_control_parameters: Dict, parameters: Dict):
    """
    Aggregate a chosen feature across layers for each part and plot trends.

    - aggregation: mean by default (per layer)
    - feature_name: column to aggregate
    - optional layer_range: [start_layer, end_layer] (1-indexed inclusive); if not
      provided, use all layers found in the folder.
    """
    data_coordinator = app_control_parameters["data_coordinator"]
    data_root = Path(data_coordinator.data_folder)

    feature_name = parameters.get("layer_trend_by_part_feature_name", None)
    if feature_name is None:
        return "Select a feature to visualize layer trend."

    start_layer = parameters.get("layer_trend_by_part_layer_start", 1)
    end_layer = parameters.get("layer_trend_by_part_layer_end", None)

    fig = go.Figure()
    theme_value = app_control_parameters["theme_value"]

    # Build trends per part
    for part in PART_DIR_NAMES:
        files = _list_layer_files(data_root, part)
        if not files:
            continue

        # Derive default end if absent
        if end_layer is None:
            end_layer_use = len(files)
        else:
            end_layer_use = int(end_layer)

        x_layers: List[int] = []
        y_values: List[float] = []

        for layer_idx in range(int(start_layer), end_layer_use + 1):
            layer_stem = f"L{layer_idx:04d}"
            csv_path = data_root / part / f"{layer_stem}.csv"
            if not csv_path.exists():
                continue
            try:
                df = _load_layer_csv(data_root, part, layer_stem)
                if feature_name in df.columns:
                    # Mean value for the layer
                    y_values.append(float(pd.to_numeric(df[feature_name], errors="coerce").mean()))
                    x_layers.append(layer_idx)
            except Exception:
                continue

        if x_layers and y_values:
            fig.add_trace(
                go.Scatter(x=x_layers, y=y_values, mode="lines+markers", name=part)
            )

    fig.update_layout(
        title=f"Layer Trend by Part: {feature_name}",
        xaxis_title="Layer Index",
        yaxis_title=feature_name,
        template=get_theme_template(theme_value),
    )
    return fig


def melt_pool_threshold_scatter(app_control_parameters: Dict, parameters: Dict):
    """Scatter of melt pool width vs length colored by area for a chosen threshold."""
    theme_value = app_control_parameters["theme_value"]
    record_id = app_control_parameters["record_id"]
    data_coordinator = app_control_parameters["data_coordinator"]

    # Robust parameter retrieval (UI may pass None or strings)
    thr_param = parameters.get("melt_pool_threshold_scatter_threshold", None)
    threshold = str(thr_param) if thr_param not in (None, "") else "100"
    size_param = parameters.get("melt_pool_threshold_scatter_marker_size_scale", 30)
    try:
        size_scale = float(size_param) if size_param not in (None, "") else 30.0
    except Exception:
        size_scale = 30.0

    part_dir_name, layer_stem = _parse_record_id(record_id)
    df = _load_layer_csv(data_coordinator.data_folder, part_dir_name, layer_stem)

    len_col = f"melt_pool_length_thr{threshold}_mm"
    wid_col = f"melt_pool_width_thr{threshold}_mm"
    area_col = f"melt_pool_area_thr{threshold}_mm2"

    if not all(col in df.columns for col in [len_col, wid_col, area_col]):
        return f"Selected threshold columns not found: {threshold}"

    length = pd.to_numeric(df[len_col], errors="coerce")
    width = pd.to_numeric(df[wid_col], errors="coerce")
    area = pd.to_numeric(df[area_col], errors="coerce")

    # Normalize area to [0,1] safely for marker sizing
    a_min = np.nanmin(area.to_numpy()) if area.size else np.nan
    a_max = np.nanmax(area.to_numpy()) if area.size else np.nan
    if not np.isfinite(a_min) or not np.isfinite(a_max) or a_max - a_min <= 0:
        size = np.full(len(area), size_scale)
    else:
        norm = (area - a_min) / (a_max - a_min)
        size = norm * size_scale + 2

    fig = go.Figure(
        data=[
            go.Scatter(
                x=length,
                y=width,
                mode="markers",
                marker=dict(size=size, color=area, colorscale="Inferno", showscale=True),
                name=f"thr {threshold}",
            )
        ]
    )
    fig.update_layout(
        title=f"Melt Pool (thr {threshold}) Length vs Width (size/color=Area)",
        xaxis_title=len_col,
        yaxis_title=wid_col,
        template=get_theme_template(theme_value),
    )
    return fig


def layer_histogram_heatmap(app_control_parameters: Dict, parameters: Dict):
    """Histogram heatmap across layers for a selected part and feature."""
    theme_value = app_control_parameters["theme_value"]
    data_coordinator = app_control_parameters["data_coordinator"]
    data_root = Path(data_coordinator.data_folder)

    # Robust parameter parsing
    part = parameters.get("layer_histogram_heatmap_part_id", None) or PART_DIR_NAMES[0]
    feature = parameters.get("layer_histogram_heatmap_feature_name", None)
    
    bins_param = parameters.get("layer_histogram_heatmap_bins", 30)
    try:
        n_bins = int(bins_param) if bins_param not in (None, "") else 30
    except Exception:
        n_bins = 30
        
    start_param = parameters.get("layer_histogram_heatmap_layer_start", 1)
    try:
        start_layer = int(start_param) if start_param not in (None, "") else 1
    except Exception:
        start_layer = 1
        
    end_param = parameters.get("layer_histogram_heatmap_layer_end", 250)
    try:
        end_layer = int(end_param) if end_param not in (None, "") else 250
    except Exception:
        end_layer = 250

    if feature is None:
        return "Select a feature for histogram heatmap."

    # Determine global min/max over selected layers for consistent bins
    vals_all = []
    for layer_idx in range(start_layer, end_layer + 1):
        layer_stem = f"L{layer_idx:04d}"
        csv_path = data_root / part / f"{layer_stem}.csv"
        if not csv_path.exists():
            continue
        df = _load_layer_csv(data_root, part, layer_stem)
        if feature in df.columns:
            vals = pd.to_numeric(df[feature], errors="coerce").dropna().values
            if vals.size:
                vals_all.append(vals)

    if not vals_all:
        return "No data available for selected range/feature."

    all_concat = np.concatenate(vals_all)
    vmin, vmax = np.nanmin(all_concat), np.nanmax(all_concat)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
        vmin, vmax = float(np.nanmin(all_concat)), float(np.nanmax(all_concat) + 1e-6)
    bins = np.linspace(vmin, vmax, n_bins + 1)

    # Build counts per layer
    heat = []
    layer_ticks = []
    for layer_idx in range(start_layer, end_layer + 1):
        layer_stem = f"L{layer_idx:04d}"
        csv_path = data_root / part / f"{layer_stem}.csv"
        if not csv_path.exists():
            continue
        df = _load_layer_csv(data_root, part, layer_stem)
        if feature not in df.columns:
            continue
        vals = pd.to_numeric(df[feature], errors="coerce").dropna().values
        hist, _ = np.histogram(vals, bins=bins)
        heat.append(hist)
        layer_ticks.append(layer_idx)

    if not heat:
        return "No histograms computed."

    heat = np.array(heat)  # shape: layers x bins
    # Center bin positions for x-axis
    bin_centers = (bins[:-1] + bins[1:]) / 2

    fig = go.Figure(
        data=[
            go.Heatmap(
                z=heat,
                x=bin_centers,
                y=layer_ticks,
                colorscale="Viridis",
                colorbar=dict(title="Count"),
            )
        ]
    )
    fig.update_layout(
        title=f"Histogram Heatmap • {part} • {feature}",
        xaxis_title=f"{feature} (bins)",
        yaxis_title="Layer",
        template=get_theme_template(theme_value),
    )
    return fig


def get_provider(_):
    """
    Assemble the AM data provider configuration.
    """
    data_folder = Path(__file__).parent / "am_data"

    def fetch_record_ids_for_dataset_id(data_folder: Path, _dataset_id=None) -> List[str]:
        record_ids: List[str] = []
        for part in PART_DIR_NAMES:
            for csv_path in _list_layer_files(data_folder, part):
                record_ids.append(_format_record_id(part, csv_path))
        return record_ids

    # Determine all_possible_signals from the first available CSV
    all_possible_signals = _extract_all_signals(data_folder)

    def fetch_data(
        data_folder: Path,
        _dataset_id,
        record_id: str,
        signals: List[str],
        _global_data_params,
        data_trim_1=None,
        data_trim_2=None,
    ) -> Dict:
        if record_id is None:
            return {"id": None, "signals": [], "errored_signals": []}
        # Default to first column if none provided
        if signals is None or len(signals) == 0:
            signals = [all_possible_signals[0]] if all_possible_signals else []

        part_dir_name, layer_stem = _parse_record_id(record_id)
        df = _load_layer_csv(Path(data_folder), part_dir_name, layer_stem)
        data_array, errored = _build_signal_records_from_df(df, signals)
        record_dictionary = {"id": record_id, "signals": data_array, "errored_signals": errored}
        return record_dictionary

    # Smoothing / normalizing options
    custom_smoothing_options = {
        "simple_moving_average": {
            "display_name": "Simple Moving Average",
            "parameters": {
                "moving_average_window_size": {
                    "default": 5,
                    "min": 1,
                    "max": None,
                    "display_name": "Window Size",
                }
            },
            "function": simple_moving_average,
        }
    }

    custom_normalizing_options = {
        "robust_scaling": {
            "display_name": "Robust Scaling",
            "parameters": None,
            "function": robust_scaling,
        }
    }

    # Custom graphers for AM
    # Parameter schemas follow the pattern used elsewhere in the repo
    custom_grapher_dictionary = {
        "layer_trend_by_part": {
            "display_name": "Layer Trend by Part (aggregate)",
            "parameters": {
                "feature_name": {
                    "default": all_possible_signals[0] if all_possible_signals else None,
                    "options": {name: name for name in all_possible_signals},
                    "display_name": "Feature",
                },
                "layer_start": {"default": 1, "min": 1, "max": 250, "display_name": "Start Layer"},
                "layer_end": {"default": 250, "min": 1, "max": 250, "display_name": "End Layer"},
            },
            "function": layer_trend_by_part,
        },
        "laser_path_map": {
            "display_name": "Laser Path Map (XY, color by feature)",
            "parameters": {
                "position_source": {
                    "default": "real",
                    "options": {"real": "real", "command": "command"},
                    "display_name": "Position Source",
                },
                "color_feature": {
                    "default": (all_possible_signals[0] if all_possible_signals else None),
                    "options": {name: name for name in all_possible_signals},
                    "display_name": "Color Feature",
                },
                "point_size": {"default": 6, "min": 1, "max": 20, "display_name": "Point Size"},
            },
            "function": laser_path_map,
        },
        "melt_pool_threshold_scatter": {
            "display_name": "Melt Pool Threshold Scatter",
            "parameters": {
                "threshold": {
                    "default": "100",
                    "options": {"80": "80", "100": "100", "120": "120"},
                    "display_name": "Threshold",
                },
                "marker_size_scale": {"default": 30, "min": 5, "max": 80, "display_name": "Marker Size Scale"},
            },
            "function": melt_pool_threshold_scatter,
        },
        "layer_histogram_heatmap": {
            "display_name": "Histogram Heatmap by Layer",
            "parameters": {
                "part_id": {"default": PART_DIR_NAMES[0], "options": {p: p for p in PART_DIR_NAMES}, "display_name": "Part"},
                "feature_name": {
                    "default": all_possible_signals[0] if all_possible_signals else None,
                    "options": {name: name for name in all_possible_signals},
                    "display_name": "Feature",
                },
                "bins": {"default": 30, "min": 5, "max": 200, "display_name": "Bins"},
                "layer_start": {"default": 1, "min": 1, "max": 250, "display_name": "Start Layer"},
                "layer_end": {"default": 250, "min": 1, "max": 250, "display_name": "End Layer"},
            },
            "function": layer_histogram_heatmap,
        },
    }

    data_coordinator_info: Dict = {
        "fetch_data": fetch_data,
        "dataset_id": "additive_manufacturing",
        "fetch_record_ids_for_dataset_id": fetch_record_ids_for_dataset_id,
        "all_possible_signals": all_possible_signals,
        "custom_smoothing_options": custom_smoothing_options,
        "custom_normalizing_options": custom_normalizing_options,
        "spline_path": "spline_parameters.csv",
        "auto_label_function_dictionary": {},
        "all_labels": ["Defect", "Anomaly", "Keyhole", "Lack of Fusion"],
        "custom_grapher_dictionary": custom_grapher_dictionary,
        "is_date": False,  # numeric time axis (row index)
        "trim_data": trim_data,
        "data_folder": data_folder,
        "layout_options": layout_options,
    }

    return data_coordinator_info


