"""
SeqMaster Runtime - Relay Output Driver
Layer: DRIVERS

Driver for USB relay modules.
"""

import asyncio
from typing import Any, Dict, List, Optional

import structlog

from src.drivers.base import BaseDriver, DriverInfo, DriverResult, DriverStatus

logger = structlog.get_logger(__name__)


class RelayDriver(BaseDriver):
    """
    Driver for USB relay output modules.
    
    Supports various USB relay boards with serial command interface.
    """
    
    def __init__(self, device_id: str, port: str, 
                 num_channels: int = 8,
                 config: Optional[Dict[str, Any]] = None):
        """
        Initialize relay driver.
        
        Args:
            device_id: Unique identifier
            port: Serial port (COM3, /dev/ttyUSB0)
            num_channels: Number of relay channels
            config: Additional configuration
        """
        super().__init__(device_id, config)
        self._port = port
        self._num_channels = num_channels
        self._serial = None
        self._relay_states: List[bool] = [False] * num_channels
    
    @property
    def info(self) -> DriverInfo:
        return DriverInfo(
            name="Relay Output Driver",
            version="1.0.0",
            driver_type="relay",
            description=f"USB Relay module with {self._num_channels} channels",
            capabilities=["set_relay", "get_relay", "set_all", "reset"]
        )
    
    @property
    def num_channels(self) -> int:
        """Get number of relay channels."""
        return self._num_channels
    
    @property
    def relay_states(self) -> List[bool]:
        """Get current relay states."""
        return self._relay_states.copy()
    
    async def connect(self) -> DriverResult:
        """Connect to relay module."""
        try:
            import serial
            
            self._status = DriverStatus.CONNECTING
            
            self._serial = serial.Serial(
                port=self._port,
                baudrate=9600,
                timeout=1
            )
            
            # Reset all relays
            await self.reset()
            
            self._set_connected()
            logger.info("Relay module connected", 
                       device_id=self._device_id,
                       port=self._port,
                       channels=self._num_channels)
            
            return DriverResult(success=True)
            
        except Exception as e:
            self._set_error(str(e))
            logger.error("Relay connection failed", error=str(e))
            return DriverResult(success=False, error=str(e))
    
    async def disconnect(self) -> DriverResult:
        """Disconnect from relay module."""
        try:
            if self._serial:
                await self.reset()  # Turn off all relays
                self._serial.close()
                self._serial = None
            
            self._set_disconnected()
            return DriverResult(success=True)
            
        except Exception as e:
            return DriverResult(success=False, error=str(e))
    
    async def read(self, command: Optional[str] = None, **kwargs) -> DriverResult:
        """Read relay states."""
        return DriverResult(success=True, value=self._relay_states)
    
    async def write(self, command: str, **kwargs) -> DriverResult:
        """Write command to relay module."""
        if not self.is_connected:
            return DriverResult(success=False, error="Not connected")
        
        try:
            self._serial.write(command.encode())
            return DriverResult(success=True)
        except Exception as e:
            return DriverResult(success=False, error=str(e))
    
    async def set_relay(self, channel: int, state: bool) -> DriverResult:
        """
        Set single relay state.
        
        Args:
            channel: Relay channel (1-indexed)
            state: True = ON, False = OFF
        """
        if not self.is_connected:
            return DriverResult(success=False, error="Not connected")
        
        if channel < 1 or channel > self._num_channels:
            return DriverResult(
                success=False, 
                error=f"Invalid channel {channel}, must be 1-{self._num_channels}"
            )
        
        try:
            # Common relay board command format
            cmd = bytes([0xA0, channel, 0x01 if state else 0x00, 
                        0xA0 + channel + (0x01 if state else 0x00)])
            self._serial.write(cmd)
            
            self._relay_states[channel - 1] = state
            
            return DriverResult(
                success=True,
                value={"channel": channel, "state": state}
            )
            
        except Exception as e:
            return DriverResult(success=False, error=str(e))
    
    async def get_relay(self, channel: int) -> DriverResult:
        """
        Get single relay state.
        
        Args:
            channel: Relay channel (1-indexed)
        """
        if channel < 1 or channel > self._num_channels:
            return DriverResult(
                success=False,
                error=f"Invalid channel {channel}"
            )
        
        return DriverResult(
            success=True,
            value=self._relay_states[channel - 1]
        )
    
    async def set_all(self, states: List[bool]) -> DriverResult:
        """
        Set all relay states at once.
        
        Args:
            states: List of states for each channel
        """
        if len(states) != self._num_channels:
            return DriverResult(
                success=False,
                error=f"Expected {self._num_channels} states, got {len(states)}"
            )
        
        for i, state in enumerate(states):
            result = await self.set_relay(i + 1, state)
            if not result.success:
                return result
        
        return DriverResult(success=True, value=states)
    
    async def reset(self) -> DriverResult:
        """Turn off all relays."""
        if not self.is_connected:
            return DriverResult(success=False, error="Not connected")
        
        try:
            for i in range(self._num_channels):
                await self.set_relay(i + 1, False)
            
            return DriverResult(success=True)
            
        except Exception as e:
            return DriverResult(success=False, error=str(e))
    
    async def selftest(self) -> DriverResult:
        """Test relay module by toggling each relay."""
        if not self.is_connected:
            return DriverResult(success=False, error="Not connected")
        
        try:
            for i in range(1, self._num_channels + 1):
                await self.set_relay(i, True)
                await asyncio.sleep(0.1)
                await self.set_relay(i, False)
            
            return DriverResult(success=True, value={"tested_channels": self._num_channels})
            
        except Exception as e:
            return DriverResult(success=False, error=str(e))
