"""
SeqMaster Runtime - WebSocket API
Layer: API

WebSocket endpoints for real-time communication with HMI.
Provides live test status updates and operator interaction.
"""

import asyncio
import json
from datetime import datetime
from typing import Dict, Set, Optional, Any

from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
from pydantic import BaseModel
import structlog

from src.core.service_registry import ServiceRegistry
from src.core.constants import STATUS_UPDATE_MS
from src.security.auth import verify_token_ws

logger = structlog.get_logger(__name__)

websocket_router = APIRouter()


class ConnectionManager:
    """
    Manages WebSocket connections for live updates.
    
    Supports multiple clients and message broadcasting.
    """
    
    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}
        self.subscriptions: Dict[str, Set[str]] = {}  # topic -> client_ids
    
    async def connect(self, websocket: WebSocket, client_id: str) -> None:
        """Accept and register a new connection."""
        await websocket.accept()
        self.active_connections[client_id] = websocket
        logger.info("WebSocket client connected", client_id=client_id)
    
    def disconnect(self, client_id: str) -> None:
        """Remove a connection."""
        if client_id in self.active_connections:
            del self.active_connections[client_id]
        
        # Remove from all subscriptions
        for topic in self.subscriptions.values():
            topic.discard(client_id)
        
        logger.info("WebSocket client disconnected", client_id=client_id)
    
    def subscribe(self, client_id: str, topic: str) -> None:
        """Subscribe a client to a topic."""
        if topic not in self.subscriptions:
            self.subscriptions[topic] = set()
        self.subscriptions[topic].add(client_id)
    
    def unsubscribe(self, client_id: str, topic: str) -> None:
        """Unsubscribe a client from a topic."""
        if topic in self.subscriptions:
            self.subscriptions[topic].discard(client_id)
    
    async def send_personal(self, client_id: str, message: dict) -> None:
        """Send message to a specific client."""
        if client_id in self.active_connections:
            websocket = self.active_connections[client_id]
            try:
                await websocket.send_json(message)
            except Exception as e:
                logger.error("Failed to send message", client_id=client_id, error=str(e))
                self.disconnect(client_id)
    
    async def broadcast(self, message: dict, topic: Optional[str] = None) -> None:
        """Broadcast message to all clients or topic subscribers."""
        if topic and topic in self.subscriptions:
            client_ids = list(self.subscriptions[topic])
        else:
            client_ids = list(self.active_connections.keys())
        
        disconnected = []
        for client_id in client_ids:
            if client_id in self.active_connections:
                try:
                    await self.active_connections[client_id].send_json(message)
                except Exception:
                    disconnected.append(client_id)
        
        for client_id in disconnected:
            self.disconnect(client_id)


# Global connection manager
manager = ConnectionManager()


# Store in service registry for access from executor
def get_connection_manager() -> ConnectionManager:
    return manager


async def _send_ack(client_id: str, ack_type: str, **extra_data) -> None:
    """Send acknowledgment message to client. Helper to reduce duplication."""
    await manager.send_personal(client_id, {
        "type": ack_type,
        **extra_data
    })


@websocket_router.websocket("/live")
async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = None):
    """
    Main WebSocket endpoint for live updates.
    
    Protocol:
    - Client sends: {"type": "subscribe", "topics": ["status", "steps"]}
    - Client sends: {"type": "control", "action": "pause"}
    - Server sends: {"type": "status", "data": {...}}
    - Server sends: {"type": "step_complete", "data": {...}}
    - Server sends: {"type": "operator_input_required", "data": {...}}
    """
    # Generate client ID
    import uuid
    client_id = str(uuid.uuid4())[:8]
    
    # Optionally verify token
    if token:
        try:
            user = await verify_token_ws(token)
            client_id = f"{user.username}-{client_id}"
        except Exception:
            await websocket.close(code=4001, reason="Invalid token")
            return
    
    await manager.connect(websocket, client_id)
    
    try:
        # Send initial status
        await send_current_status(client_id)
        
        # Start status update task
        status_task = asyncio.create_task(status_update_loop(client_id))
        
        while True:
            data = await websocket.receive_json()
            await handle_message(client_id, data)
            
    except WebSocketDisconnect:
        manager.disconnect(client_id)
        status_task.cancel()
    except Exception as e:
        logger.error("WebSocket error", client_id=client_id, error=str(e))
        manager.disconnect(client_id)


async def handle_message(client_id: str, message: dict) -> None:
    """Handle incoming WebSocket message."""
    msg_type = message.get("type")
    
    if msg_type == "subscribe":
        topics = message.get("topics", [])
        for topic in topics:
            manager.subscribe(client_id, topic)
        await _send_ack(client_id, "subscribed", topics=topics)
    
    elif msg_type == "unsubscribe":
        topics = message.get("topics", [])
        for topic in topics:
            manager.unsubscribe(client_id, topic)
    
    elif msg_type == "control":
        action = message.get("action")
        executor = ServiceRegistry.get("executor")
        
        if executor:
            if action == "pause":
                await executor.pause()
            elif action == "resume":
                await executor.resume()
            elif action == "abort":
                await executor.abort()
        
        await _send_ack(client_id, "control_ack", action=action, success=True)
    
    elif msg_type == "operator_input":
        # Handle operator input response from frontend
        data = message.get("data", {})
        step_id = data.get("step_id") or message.get("session_id")
        input_value = data.get("response") or message.get("value")
        
        logger.info("Received operator input", step_id=step_id, value=input_value)
        
        # Store input for executor to pick up (keyed by step_id)
        pending_inputs = ServiceRegistry.get("pending_operator_inputs") or {}
        pending_inputs[step_id] = input_value
        ServiceRegistry.set("pending_operator_inputs", pending_inputs)
        
        await _send_ack(client_id, "operator_input_ack", success=True)
    
    elif msg_type == "step_failure_response":
        # Handle step failure response from frontend (Retry, End, End & Print)
        data = message.get("data", {})
        step_id = data.get("step_id")
        action = data.get("action")  # 'retry', 'end', 'end_print'
        
        logger.info("Received step failure response", step_id=step_id, action=action)
        
        # Store response for executor to pick up
        pending_inputs = ServiceRegistry.get("pending_operator_inputs") or {}
        
        if action == "retry":
            # Store retry decision - executor checks step_id and step_id_retry
            pending_inputs[step_id] = "retry"
            pending_inputs[f"{step_id}_retry"] = "retry"
        elif action == "continue":
            # Continue to next step (only valid if group allows continue)
            pending_inputs[step_id] = "continue"
            pending_inputs[f"{step_id}_retry"] = "continue"
        else:
            # End or End & Print - abort execution
            pending_inputs[step_id] = "abort"
            pending_inputs[f"{step_id}_retry"] = "abort"
        
        ServiceRegistry.set("pending_operator_inputs", pending_inputs)
        
        await _send_ack(client_id, "step_failure_response_ack", success=True, action=action)
    
    elif msg_type == "ping":
        await _send_ack(client_id, "pong", timestamp=datetime.utcnow().isoformat())


async def send_current_status(client_id: str) -> None:
    """Send current execution status to a client."""
    executor = ServiceRegistry.get("executor")
    
    if executor and executor.context:
        status = executor._get_status()
    else:
        status = {"state": "idle"}
    
    await manager.send_personal(client_id, {
        "type": "status",
        "data": status
    })


async def status_update_loop(client_id: str) -> None:
    """Periodically send status updates to client."""
    update_interval = STATUS_UPDATE_MS / 1000  # Convert to seconds
    while client_id in manager.active_connections:
        try:
            await asyncio.sleep(update_interval)
            await send_current_status(client_id)
        except asyncio.CancelledError:
            break
        except Exception:
            break


# ============================================
# EXECUTOR CALLBACKS
# ============================================

async def broadcast_status_update(status: dict) -> None:
    """Broadcast status update to all subscribed clients."""
    await manager.broadcast({
        "type": "status",
        "data": status,
        "timestamp": datetime.utcnow().isoformat()
    }, topic="status")


async def broadcast_step_complete(step_result: Any) -> None:
    """Broadcast step completion to all subscribed clients."""
    await manager.broadcast({
        "type": "step_complete",
        "data": {
            "step_id": step_result.step_id,
            "step_name": step_result.step_name,
            "step_type": step_result.step_type,
            "status": step_result.status.value,
            "passed": step_result.passed,
            "group_id": step_result.group_id,
            "group_name": step_result.group_name,
            "measured_value": step_result.measured_value,
            "expected_value": step_result.expected_value,
            "lower_limit": step_result.lower_limit,
            "upper_limit": step_result.upper_limit,
            "duration_ms": step_result.duration_ms,
            "error_message": step_result.error_message,
            "error_code": step_result.error_code,
            "allow_retry": step_result.allow_retry,
            "applicable_error_codes": step_result.applicable_error_codes,
            "retry_count": step_result.retry_count,
            "loop_iteration": step_result.loop_iteration,
            "loop_total": step_result.loop_total
        },
        "timestamp": datetime.utcnow().isoformat()
    }, topic="steps")


async def broadcast_operator_input_required(step: Any) -> None:
    """Broadcast operator input request to all clients."""
    # Support both operator_prompt (standard steps) and message_text (MESSAGE steps)
    prompt = getattr(step, 'operator_prompt', None) or getattr(step, 'message_text', None) or step.description or "Operator action required"
    buttons = getattr(step, 'message_buttons', None) or ['OK']
    step_type = step.type.value if hasattr(step.type, 'value') else str(step.type)
    
    await manager.broadcast({
        "type": "operator_input_required",
        "data": {
            "step_id": step.id,
            "step_name": step.name,
            "step_type": step_type,
            "prompt": prompt,
            "buttons": buttons
        },
        "timestamp": datetime.utcnow().isoformat()
    })


# Register callbacks with service registry
ServiceRegistry.set("ws_broadcast_status", broadcast_status_update)
ServiceRegistry.set("ws_broadcast_step", broadcast_step_complete)
ServiceRegistry.set("ws_broadcast_operator_input", broadcast_operator_input_required)
