"""
SeqMaster Runtime - Analytics API

Provides endpoints for step-level data analysis, statistics, and Pareto.
Uses existing StepResult table - no separate analytics table needed.

Uses async SQLAlchemy to match runtime architecture.
"""

from typing import List, Optional
from datetime import datetime, timedelta
import csv
import io
import json

from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy import select, func, and_, desc, asc, case
from sqlalchemy.ext.asyncio import AsyncSession

from src.database.connection import get_db_session
from src.database.models import StepResult, TestSession, Sequence
from src.services.statistics import StatisticsCalculator, ParetoCalculator
from src.core.constants import FLOW_CONTROL_TYPES
from src.core.license import get_license_service


async def _require_analytics_feature():
    """Dependency: block analytics endpoints in demo mode."""
    svc = get_license_service()
    if not svc.has_feature("analytics"):
        raise HTTPException(
            status_code=403,
            detail="Analytics kræver Professional licens eller højere."
        )


router = APIRouter(
    prefix="/analytics",
    tags=["Analytics"],
    dependencies=[Depends(_require_analytics_feature)],
)


def _parse_csv_list(value: Optional[str]) -> List[str]:
    if not value:
        return []
    return [v.strip() for v in value.split(',') if v.strip()]


def _build_step_conditions(
    *,
    step_name: Optional[str] = None,
    group_name: Optional[str] = None,
    sequence_id: Optional[str] = None,  # Logical sequence_id (e.g. SEQ-xxx)
    dut_type: Optional[str] = None,
    testers: Optional[List[str]] = None,
    operator: Optional[str] = None,
    passed: Optional[bool] = None,
    since: Optional[datetime] = None,
    until: Optional[datetime] = None,
    include_flow_control: bool = False
) -> List:
    """Build common filters for step-level analytics queries.
    
    Note: sequence_id here is the logical ID (e.g. SEQ-ML47Q353-J2X),
    not the database row ID. Requires join with Sequence table.
    """
    conditions: List = []

    if step_name:
        conditions.append(StepResult.step_name == step_name)
    if not include_flow_control:
        conditions.append(StepResult.step_type.notin_(FLOW_CONTROL_TYPES))
    if group_name:
        conditions.append(StepResult.group_name == group_name)
    if sequence_id:
        # Filter by logical sequence_id - requires join with Sequence
        conditions.append(Sequence.sequence_id == sequence_id)
    if dut_type:
        conditions.append(TestSession.dut_type == dut_type)
    if testers:
        conditions.append(TestSession.tester_id.in_(testers))
    if operator:
        conditions.append(TestSession.operator == operator)
    if passed is not None:
        conditions.append(StepResult.passed == passed)
    if since:
        conditions.append(StepResult.completed_at >= since)
    if until:
        conditions.append(StepResult.completed_at <= until)

    return conditions


# ============================================
# PYDANTIC MODELS
# ============================================

class StepDataPoint(BaseModel):
    """A single step measurement data point for charting."""
    id: int
    step_name: str
    dut_serial: str
    measured_value: Optional[float]
    lower_limit: Optional[float]
    upper_limit: Optional[float]
    unit: Optional[str]
    passed: bool
    tester_id: Optional[str]
    dut_type: Optional[str]
    executed_at: datetime
    error_code: Optional[str]


class StepDataResponse(BaseModel):
    """Response for step data query."""
    data: List[StepDataPoint]
    total_count: int
    page: int
    page_size: int
    has_more: bool


class DistributionBinResponse(BaseModel):
    """Histogram bin."""
    bin_start: float
    bin_end: float
    count: int
    percentage: float


class StatisticsResponse(BaseModel):
    """Statistics calculation result."""
    step_name: str
    count: int
    mean: float
    median: float
    std_dev: float
    min_value: float
    max_value: float
    range_value: float
    cp: float
    cpk: float
    cpk_interpretation: str
    suggested_lower_limit: float
    suggested_upper_limit: float
    current_lower_limit: Optional[float]
    current_upper_limit: Optional[float]
    distribution: List[DistributionBinResponse]


class ParetoItemResponse(BaseModel):
    """Pareto analysis item."""
    step_name: str
    group_name: Optional[str]
    failure_count: int
    failure_percentage: float
    cumulative_percentage: float


class StepListItem(BaseModel):
    """Step summary for step selection list."""
    step_name: str
    group_name: Optional[str]

    total_count: int
    passed_count: int
    failed_count: int
    pass_rate: float
    last_seen: Optional[datetime]


class SequenceOption(BaseModel):
    """Sequence option for dropdown - uses logical sequence_id."""
    id: str  # Logical sequence_id (e.g. SEQ-ML47Q353-J2X)
    name: Optional[str]


class FilterOptions(BaseModel):
    """Available filter options."""
    testers: List[str]
    sequences: List[SequenceOption]
    dut_types: List[str]
    operators: List[str]
    group_names: List[str]
    date_range: dict


# ============================================
# ENDPOINTS
# ============================================

@router.get("/steps", response_model=List[StepListItem])
async def list_steps(
    sequence_id: Optional[str] = None,  # Logical sequence_id (e.g. SEQ-xxx)
    dut_type: Optional[str] = None,
    since: Optional[datetime] = None,
    search: Optional[str] = None,
    testers: Optional[str] = None,
    step_type: Optional[str] = None,  # Optional: filter by step_type (test, action, etc.)
    sort_by: str = Query(default="step_index", pattern="^(step_name|step_index|total_count|pass_rate|last_seen)$"),
    sort_order: str = Query(default="asc", pattern="^(asc|desc)$"),
    limit: int = Query(default=100, le=500),
    db: AsyncSession = Depends(get_db_session)
):
    """
    List available steps with pass/fail summary.
    
    Used for step selection in the analysis UI.
    Flow control steps are automatically filtered out.
    Steps with same name and group are aggregated across all sequence versions.
    
    Args:
        sequence_id: Logical sequence ID (e.g. SEQ-ML47Q353-J2X), not database row ID.
    """
    if not sequence_id:
        return []  # Require sequence selection
    
    # Build base query joining StepResult with TestSession and Sequence
    # Group by step_name, group_name ONLY (aggregates across versions AND positions)
    total_count_expr = func.count(StepResult.id)
    passed_count_expr = func.sum(case((StepResult.passed == True, 1), else_=0))
    last_seen_expr = func.max(StepResult.completed_at)
    pass_rate_expr = (passed_count_expr * 100.0 / func.nullif(total_count_expr, 0))
    min_step_index_expr = func.min(StepResult.step_index)

    testers_list = _parse_csv_list(testers)

    stmt = (
        select(
            StepResult.step_name,
            StepResult.group_name,
            min_step_index_expr.label('min_step_index'),
            total_count_expr.label('total_count'),
            passed_count_expr.label('passed_count'),
            last_seen_expr.label('last_seen')
        )
        .join(TestSession, StepResult.session_id == TestSession.id)
        .join(Sequence, TestSession.sequence_id == Sequence.id)
        .where(StepResult.step_type.notin_(FLOW_CONTROL_TYPES))
        .where(Sequence.sequence_id == sequence_id)  # Filter by logical sequence_id
        .group_by(StepResult.step_name, StepResult.group_name)  # NO step_index!
    )
    
    # Apply step_type filter if specified
    if step_type:
        stmt = stmt.where(StepResult.step_type == step_type)
    
    # Apply filters
    if dut_type:
        stmt = stmt.where(TestSession.dut_type == dut_type)
    if testers_list:
        stmt = stmt.where(TestSession.tester_id.in_(testers_list))
    if since:
        stmt = stmt.where(StepResult.completed_at >= since)
    if search:
        stmt = stmt.where(StepResult.step_name.ilike(f"%{search}%"))
    
    # Apply sorting and limiting in SQL for performance
    sort_map = {
        "step_name": StepResult.step_name,
        "step_index": min_step_index_expr,
        "total_count": total_count_expr,
        "pass_rate": pass_rate_expr,
        "last_seen": last_seen_expr
    }
    order_col = sort_map.get(sort_by, min_step_index_expr)
    order_clause = desc(order_col) if sort_order == "desc" else asc(order_col)
    stmt = stmt.order_by(order_clause).limit(limit)

    # Execute query
    result = await db.execute(stmt)
    rows = result.all()
    
    # Convert to response format and calculate pass rate
    items = []
    for row in rows:
        total = row.total_count or 0
        passed = row.passed_count or 0
        failed = total - passed
        pass_rate = (passed / total * 100) if total > 0 else 0
        
        items.append(StepListItem(
            step_name=row.step_name,
            group_name=row.group_name,
            total_count=total,
            passed_count=passed,
            failed_count=failed,
            pass_rate=pass_rate,
            last_seen=row.last_seen
        ))
    
    return items


@router.get("/step-data/{step_name}", response_model=StepDataResponse)
async def get_step_data(
    step_name: str,
    # Filters
    group_name: Optional[str] = None,
    step_index: Optional[int] = None,
    sequence_id: Optional[str] = None,  # Logical sequence_id
    dut_type: Optional[str] = None,
    dut_serial: Optional[str] = None,
    testers: Optional[str] = None,
    operator: Optional[str] = None,
    passed: Optional[bool] = None,
    since: Optional[datetime] = None,
    until: Optional[datetime] = None,
    # Pagination
    page: int = Query(default=1, ge=1),
    page_size: int = Query(default=1000, le=5000),
    # Sorting
    sort_by: str = Query(default="executed_at", pattern="^(executed_at|measured_value|dut_serial)$"),
    sort_order: str = Query(default="asc", pattern="^(asc|desc)$"),
    db: AsyncSession = Depends(get_db_session)
):
    """
    Get step measurement data for charting.
    
    Returns paginated data points for a specific step across all DUTs.
    Use group_name and step_index to distinguish steps with same name.
    Data from all versions of the same sequence is aggregated.
    """
    if not sequence_id:
        raise HTTPException(status_code=400, detail="sequence_id is required")

    # Build base select
    testers_list = _parse_csv_list(testers)
    base_conditions = _build_step_conditions(
        step_name=step_name,
        group_name=group_name,
        sequence_id=sequence_id,
        dut_type=dut_type,
        testers=testers_list,
        operator=operator,
        passed=passed,
        since=since,
        until=until
    )
    if step_index is not None:
        base_conditions.append(StepResult.step_index == step_index)
    if dut_serial:
        base_conditions.append(TestSession.dut_serial.ilike(f"%{dut_serial}%"))
    
    # Count query - must join Sequence for sequence_id filter
    count_stmt = (
        select(func.count(StepResult.id))
        .select_from(StepResult)
        .join(TestSession, StepResult.session_id == TestSession.id)
        .join(Sequence, TestSession.sequence_id == Sequence.id)
        .where(and_(*base_conditions))
    )
    count_result = await db.execute(count_stmt)
    total_count = count_result.scalar() or 0
    
    # Data query - must join Sequence for sequence_id filter
    data_stmt = (
        select(StepResult, TestSession)
        .join(TestSession, StepResult.session_id == TestSession.id)
        .join(Sequence, TestSession.sequence_id == Sequence.id)
        .where(and_(*base_conditions))
    )
    
    # Sorting
    if sort_by == "executed_at":
        order_col = StepResult.completed_at
    elif sort_by == "measured_value":
        order_col = StepResult.measured_value
    else:
        order_col = TestSession.dut_serial
    
    if sort_order == "desc":
        data_stmt = data_stmt.order_by(desc(order_col))
    else:
        data_stmt = data_stmt.order_by(asc(order_col))
    
    # Pagination
    offset = (page - 1) * page_size
    data_stmt = data_stmt.offset(offset).limit(page_size)
    
    result = await db.execute(data_stmt)
    rows = result.all()
    
    # Convert to response
    data = []
    for step_result, test_session in rows:
        data.append(StepDataPoint(
            id=step_result.id,
            step_name=step_result.step_name,
            dut_serial=test_session.dut_serial or test_session.dut_id,
            measured_value=step_result.measured_value,
            lower_limit=step_result.lower_limit,
            upper_limit=step_result.upper_limit,
            unit=step_result.unit,
            passed=step_result.passed or False,
            tester_id=test_session.tester_id,
            dut_type=test_session.dut_type,
            executed_at=step_result.completed_at or test_session.started_at,
            error_code=step_result.error_code
        ))
    
    return StepDataResponse(
        data=data,
        total_count=total_count,
        page=page,
        page_size=page_size,
        has_more=(offset + len(data)) < total_count
    )


@router.get("/statistics/{step_name}", response_model=Optional[StatisticsResponse])
async def get_step_statistics(
    step_name: str,
    # Filters
    group_name: Optional[str] = None,
    step_index: Optional[int] = None,
    sequence_id: Optional[str] = None,  # Logical sequence_id
    dut_type: Optional[str] = None,
    testers: Optional[str] = None,
    operator: Optional[str] = None,
    since: Optional[datetime] = None,
    until: Optional[datetime] = None,
    bin_count: int = Query(default=20, ge=5, le=100),
    db: AsyncSession = Depends(get_db_session)
):
    """
    Calculate statistics for a step's measurements.
    
    Returns mean, median, std dev, Cp, Cpk, and histogram distribution.
    Returns null if no numeric measurements found (e.g. delay steps).
    Use group_name and step_index to distinguish steps with same name.
    Data from all versions of the same sequence is aggregated.
    """
    if not sequence_id:
        raise HTTPException(status_code=400, detail="sequence_id is required")

    # Build query for numeric values only
    testers_list = _parse_csv_list(testers)
    conditions = _build_step_conditions(
        step_name=step_name,
        group_name=group_name,
        sequence_id=sequence_id,
        dut_type=dut_type,
        testers=testers_list,
        operator=operator,
        since=since,
        until=until
    )
    if step_index is not None:
        conditions.append(StepResult.step_index == step_index)
    conditions.append(StepResult.measured_value.isnot(None))
    
    stmt = (
        select(
            StepResult.measured_value,
            StepResult.lower_limit,
            StepResult.upper_limit
        )
        .select_from(StepResult)
        .join(TestSession, StepResult.session_id == TestSession.id)
        .join(Sequence, TestSession.sequence_id == Sequence.id)
        .where(and_(*conditions))
        .limit(50000)
    )
    
    result = await db.execute(stmt)
    rows = result.all()
    
    if not rows:
        return None  # No data for this step
    
    values = [r.measured_value for r in rows if r.measured_value is not None]
    
    if not values:
        return None  # No numeric measurements (e.g. delay steps)
    
    # Get current limits from most recent record
    current_lower = rows[-1].lower_limit
    current_upper = rows[-1].upper_limit
    
    # Calculate statistics
    stats = StatisticsCalculator.calculate(
        values=values,
        current_lower=current_lower,
        current_upper=current_upper,
        step_name=step_name,
        bin_count=bin_count
    )
    
    if not stats:
        raise HTTPException(status_code=500, detail="Statistics calculation failed")
    
    return StatisticsResponse(
        step_name=stats.step_name,
        count=stats.count,
        mean=stats.mean,
        median=stats.median,
        std_dev=stats.std_dev,
        min_value=stats.min_value,
        max_value=stats.max_value,
        range_value=stats.range_value,
        cp=stats.cp if stats.cp != float('inf') else 999.99,
        cpk=stats.cpk if stats.cpk != float('inf') else 999.99,
        cpk_interpretation=stats.cpk_interpretation,
        suggested_lower_limit=stats.suggested_lower_limit,
        suggested_upper_limit=stats.suggested_upper_limit,
        current_lower_limit=stats.current_lower_limit,
        current_upper_limit=stats.current_upper_limit,
        distribution=[
            DistributionBinResponse(
                bin_start=b.bin_start,
                bin_end=b.bin_end,
                count=b.count,
                percentage=b.percentage
            )
            for b in stats.distribution
        ]
    )


@router.get("/pareto", response_model=List[ParetoItemResponse])
async def get_pareto_analysis(
    # Filters
    sequence_id: Optional[str] = None,  # Logical sequence_id
    dut_type: Optional[str] = None,
    testers: Optional[str] = None,
    since: Optional[datetime] = None,
    until: Optional[datetime] = None,
    top_count: int = Query(default=10, ge=1, le=50),
    db: AsyncSession = Depends(get_db_session)
):
    """
    Get Pareto analysis of failing steps.
    
    Returns top failure modes sorted by count with cumulative percentages.
    Flow control steps are excluded.
    Data from all versions of the same sequence is aggregated.
    """
    if not sequence_id:
        raise HTTPException(status_code=400, detail="sequence_id is required")

    # Build conditions - show all failed steps, exclude flow control
    testers_list = _parse_csv_list(testers)
    conditions = _build_step_conditions(
        sequence_id=sequence_id,
        dut_type=dut_type,
        testers=testers_list,
        since=since,
        until=until
    )
    conditions.append(StepResult.passed == False)
    
    # Group by both step_name and group_name
    stmt = (
        select(
            StepResult.step_name,
            StepResult.group_name,
            func.count(StepResult.id).label('failure_count')
        )
        .select_from(StepResult)
        .join(TestSession, StepResult.session_id == TestSession.id)
        .join(Sequence, TestSession.sequence_id == Sequence.id)
        .where(and_(*conditions))
        .group_by(StepResult.step_name, StepResult.group_name)
        .order_by(desc('failure_count'))
        .limit(top_count)
    )
    
    result = await db.execute(stmt)
    rows = result.all()
    
    # Convert to Pareto format
    failure_data = [
        {'step_name': r.step_name, 'group_name': r.group_name, 'failure_count': r.failure_count}
        for r in rows
    ]
    
    pareto_items = ParetoCalculator.calculate(failure_data, top_count)
    
    return [
        ParetoItemResponse(
            step_name=p.step_name,
            group_name=p.group_name,
            failure_count=p.failure_count,
            failure_percentage=p.failure_percentage,
            cumulative_percentage=p.cumulative_percentage
        )
        for p in pareto_items
    ]


@router.get("/filters")
async def get_filter_options(
    since: Optional[datetime] = None,
    db: AsyncSession = Depends(get_db_session)
) -> FilterOptions:
    """
    Get available filter options based on existing data.
    
    Sequences are returned as unique logical IDs (e.g. SEQ-xxx),
    aggregating all versions of the same sequence.
    """
    base_conditions = []
    if since:
        base_conditions.append(TestSession.started_at >= since)
    
    # Sequences - get unique logical sequence_ids with their names
    seq_stmt = (
        select(
            func.distinct(Sequence.sequence_id),
            Sequence.name
        )
        .join(TestSession, TestSession.sequence_id == Sequence.id)
    )
    if base_conditions:
        seq_stmt = seq_stmt.where(*base_conditions)
    seq_stmt = seq_stmt.group_by(Sequence.sequence_id, Sequence.name)
    
    seq_result = await db.execute(seq_stmt)
    sequences: List[SequenceOption] = []
    seen_seq_ids = set()
    for seq_id, seq_name in seq_result.all():
        if seq_id and seq_id not in seen_seq_ids:
            seen_seq_ids.add(seq_id)
            sequences.append(SequenceOption(id=seq_id, name=seq_name))
    
    # DUT Types
    dut_conditions = [TestSession.dut_type.isnot(None)] + base_conditions
    dut_stmt = select(func.distinct(TestSession.dut_type)).where(and_(*dut_conditions))
    dut_result = await db.execute(dut_stmt)
    dut_types = [r[0] for r in dut_result.all() if r[0]]

    # Testers
    tester_conditions = [TestSession.tester_id.isnot(None)] + base_conditions
    tester_stmt = select(func.distinct(TestSession.tester_id)).where(and_(*tester_conditions))
    tester_result = await db.execute(tester_stmt)
    testers = [r[0] for r in tester_result.all() if r[0]]
    
    # Operators
    op_conditions = [TestSession.operator.isnot(None)] + base_conditions
    op_stmt = select(func.distinct(TestSession.operator)).where(and_(*op_conditions))
    op_result = await db.execute(op_stmt)
    operators = [r[0] for r in op_result.all() if r[0]]

    # Group names
    group_conditions = [
        StepResult.group_name.isnot(None),
        StepResult.step_type.notin_(FLOW_CONTROL_TYPES)
    ]
    if since:
        group_conditions.append(TestSession.started_at >= since)
    group_stmt = (
        select(func.distinct(StepResult.group_name))
        .join(TestSession, StepResult.session_id == TestSession.id)
        .where(and_(*group_conditions))
    )
    group_result = await db.execute(group_stmt)
    group_names = [r[0] for r in group_result.all() if r[0]]
    
    # Date range
    range_stmt = select(
        func.min(TestSession.started_at),
        func.max(TestSession.started_at)
    )
    if base_conditions:
        range_stmt = range_stmt.where(and_(*base_conditions))
    range_result = await db.execute(range_stmt)
    range_row = range_result.one()
    
    date_range = {
        "min": range_row[0].isoformat() if range_row[0] else None,
        "max": range_row[1].isoformat() if range_row[1] else None
    }
    
    return FilterOptions(
        testers=sorted(testers),
        sequences=sorted(sequences, key=lambda s: (s.name or '', s.id)),
        dut_types=sorted(dut_types),
        operators=sorted(operators),
        group_names=sorted(group_names),
        date_range=date_range
    )


@router.get("/export/{step_name}")
async def export_step_data(
    step_name: str,
    format: str = Query(default="csv", pattern="^(csv|json)$"),
    # Filters
    group_name: Optional[str] = None,
    sequence_id: Optional[int] = None,
    dut_type: Optional[str] = None,
    testers: Optional[str] = None,
    operator: Optional[str] = None,
    passed: Optional[bool] = None,
    since: Optional[datetime] = None,
    until: Optional[datetime] = None,
    db: AsyncSession = Depends(get_db_session)
):
    """
    Export step data to CSV or JSON.
    """
    if not sequence_id:
        raise HTTPException(status_code=400, detail="sequence_id is required to avoid mixing sequences")

    # Build conditions
    testers_list = _parse_csv_list(testers)
    conditions = _build_step_conditions(
        step_name=step_name,
        group_name=group_name,
        sequence_id=sequence_id,
        dut_type=dut_type,
        testers=testers_list,
        operator=operator,
        passed=passed,
        since=since,
        until=until
    )
    
    stmt = (
        select(StepResult, TestSession)
        .join(TestSession, StepResult.session_id == TestSession.id)
        .where(and_(*conditions))
        .order_by(asc(StepResult.completed_at))
        .limit(50000)
    )
    
    result = await db.execute(stmt)
    rows = result.all()
    
    if format == "csv":
        output = io.StringIO()
        writer = csv.writer(output)
        
        # Header
        writer.writerow([
            "Step Name", "DUT Serial", "Measured Value", "Lower Limit", "Upper Limit",
            "Unit", "Passed", "DUT Type", "Operator", "Executed At", "Error Code"
        ])
        
        # Data rows
        for step_result, test_session in rows:
            writer.writerow([
                step_result.step_name,
                test_session.dut_serial or test_session.dut_id,
                step_result.measured_value,
                step_result.lower_limit,
                step_result.upper_limit,
                step_result.unit,
                step_result.passed,
                test_session.dut_type,
                test_session.operator,
                step_result.completed_at.isoformat() if step_result.completed_at else "",
                step_result.error_code
            ])
        
        output.seek(0)
        filename = f"{step_name.replace(' ', '_')}_export.csv"
        
        return StreamingResponse(
            iter([output.getvalue()]),
            media_type="text/csv",
            headers={"Content-Disposition": f"attachment; filename={filename}"}
        )
    
    else:  # JSON
        data = [
            {
                "step_name": step_result.step_name,
                "dut_serial": test_session.dut_serial or test_session.dut_id,
                "measured_value": step_result.measured_value,
                "lower_limit": step_result.lower_limit,
                "upper_limit": step_result.upper_limit,
                "unit": step_result.unit,
                "passed": step_result.passed,
                "dut_type": test_session.dut_type,
                "operator": test_session.operator,
                "executed_at": step_result.completed_at.isoformat() if step_result.completed_at else None,
                "error_code": step_result.error_code
            }
            for step_result, test_session in rows
        ]
        
        filename = f"{step_name.replace(' ', '_')}_export.json"
        
        return StreamingResponse(
            iter([json.dumps(data, indent=2)]),
            media_type="application/json",
            headers={"Content-Disposition": f"attachment; filename={filename}"}
        )
