"""
SeqMaster Runtime - Expression Resolver
Layer: EXECUTOR

Safe expression evaluation for variable references and calculations.
Uses simpleeval library for security (no arbitrary code execution).

Supports:
- Variable references: Locals.voltage, Station.fixture_id
- Math operations: Locals.raw * 1.5 + 0.05
- Comparisons: Locals.voltage > 11.8 and Locals.voltage < 12.2
- Built-in functions: Len(), Sum(), Avg(), Min(), Max(), Abs()
- Array indexing: Locals.readings[0], Locals.readings[RunState.LoopIndex]
"""

import re
from typing import Any, Callable, Dict, List, Optional, Union

import structlog

# Try to import simpleeval, fall back to basic eval with restrictions
try:
    from simpleeval import SimpleEval, DEFAULT_FUNCTIONS, DEFAULT_OPERATORS
    HAS_SIMPLEEVAL = True
except ImportError:
    HAS_SIMPLEEVAL = False
    SimpleEval = None

logger = structlog.get_logger(__name__)


# Built-in functions available in expressions
def _len(obj: Any) -> int:
    """Return length of a collection."""
    return len(obj) if hasattr(obj, '__len__') else 0


def _sum(values: List[Union[int, float]]) -> float:
    """Return sum of numeric values."""
    return sum(values) if values else 0.0


def _avg(values: List[Union[int, float]]) -> float:
    """Return average of numeric values."""
    if not values:
        return 0.0
    return sum(values) / len(values)


def _min_val(values: List[Union[int, float]]) -> float:
    """Return minimum value."""
    return min(values) if values else 0.0


def _max_val(values: List[Union[int, float]]) -> float:
    """Return maximum value."""
    return max(values) if values else 0.0


def _abs_val(value: Union[int, float]) -> float:
    """Return absolute value."""
    return abs(value)


def _round_val(value: Union[int, float], digits: int = 0) -> float:
    """Round to specified digits."""
    return round(value, digits)


def _str_val(value: Any) -> str:
    """Convert to string."""
    return str(value)


def _int_val(value: Any) -> int:
    """Convert to integer."""
    return int(value)


def _float_val(value: Any) -> float:
    """Convert to float."""
    return float(value)


def _bool_val(value: Any) -> bool:
    """Convert to boolean."""
    return bool(value)


# Available functions in expressions
EXPRESSION_FUNCTIONS: Dict[str, Callable] = {
    "Len": _len,
    "Sum": _sum,
    "Avg": _avg,
    "Min": _min_val,
    "Max": _max_val,
    "Abs": _abs_val,
    "Round": _round_val,
    "Str": _str_val,
    "Int": _int_val,
    "Float": _float_val,
    "Bool": _bool_val,
    # Also support lowercase versions
    "len": _len,
    "sum": _sum,
    "avg": _avg,
    "min": _min_val,
    "max": _max_val,
    "abs": _abs_val,
    "round": _round_val,
    "str": _str_val,
    "int": _int_val,
    "float": _float_val,
    "bool": _bool_val,
}


class ExpressionResolver:
    """
    Safe expression evaluator for sequence expressions.
    
    Usage:
        resolver = ExpressionResolver()
        
        # With variable store
        namespace = store.get_expression_namespace()
        
        # Simple reference
        value = resolver.evaluate("Locals.voltage", namespace)
        
        # Math expression
        adjusted = resolver.evaluate(
            "Locals.raw * Locals.calibration.gain + Locals.calibration.offset",
            namespace
        )
        
        # Condition
        passed = resolver.evaluate(
            "Locals.voltage > 11.8 and Locals.voltage < 12.2",
            namespace
        )
        
        # With functions
        avg = resolver.evaluate("Avg(Locals.readings)", namespace)
        count = resolver.evaluate("Len(Locals.readings)", namespace)
    """
    
    # Class-level cache for parsed expressions
    _expression_type_cache: Dict[str, bool] = {}  # expression -> is_literal
    _literal_cache: Dict[str, Any] = {}  # expression -> parsed literal value
    _compiled_cache: Dict[str, Any] = {}  # expression -> compiled code object
    _cache_max_size: int = 1000
    
    def __init__(self):
        """Initialize expression resolver."""
        self._evaluator: Optional[SimpleEval] = None
        
        if HAS_SIMPLEEVAL:
            self._evaluator = SimpleEval()
            self._evaluator.functions.update(EXPRESSION_FUNCTIONS)
            logger.debug("Expression resolver initialized with simpleeval")
        else:
            logger.warning(
                "simpleeval not installed, using restricted eval. "
                "Install simpleeval for safer expression evaluation."
            )
    
    def evaluate(
        self, 
        expression: str, 
        namespace: Dict[str, Any],
        default: Any = None
    ) -> Any:
        """
        Evaluate an expression with given namespace.
        
        Args:
            expression: Expression string
            namespace: Variable namespace (from VariableStore.get_expression_namespace())
            default: Default value if evaluation fails
            
        Returns:
            Evaluated result or default on error
        """
        if not expression:
            return default
        
        # Check literal cache first (fast path)
        if expression in self._literal_cache:
            return self._literal_cache[expression]
        
        # Check if literal (with caching)
        is_lit = self._expression_type_cache.get(expression)
        if is_lit is None:
            is_lit = self._is_literal(expression)
            if len(self._expression_type_cache) < self._cache_max_size:
                self._expression_type_cache[expression] = is_lit
        
        if is_lit:
            result = self._parse_literal(expression)
            if len(self._literal_cache) < self._cache_max_size:
                self._literal_cache[expression] = result
            return result
        
        try:
            if HAS_SIMPLEEVAL and self._evaluator:
                self._evaluator.names = namespace
                result = self._evaluator.eval(expression)
            else:
                result = self._fallback_eval(expression, namespace)
            
            return result
            
        except Exception as e:
            logger.warning("Expression evaluation failed",
                          expression=expression,
                          error=str(e))
            return default
    
    # Pre-compiled regex for operator detection
    _OPERATOR_PATTERN = re.compile(r'[+\-*/()<>=!]|\band\b|\bor\b|\bnot\b')
    _SCOPE_PREFIXES = ("Locals.", "Parameters.", "Station.", "RunState.", "Step.", "Scanned.")
    
    def is_expression(self, value: Any) -> bool:
        """
        Check if a value is an expression that needs evaluation.
        
        Expressions start with a scope name (Locals, Parameters, etc.)
        or contain operators.
        """
        if not isinstance(value, str):
            return False
        
        value = value.strip()
        
        # Check for scope prefix (tuple is faster than list comprehension)
        if value.startswith(self._SCOPE_PREFIXES):
            return True
        
        # Check for operators or function calls (use pre-compiled regex)
        if self._OPERATOR_PATTERN.search(value):
            return True
        
        return False
    
    def resolve_inputs(
        self, 
        inputs: Dict[str, Any], 
        namespace: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Resolve all expressions in an inputs dictionary.
        
        Args:
            inputs: Dictionary of input values (may contain expressions)
            namespace: Variable namespace
            
        Returns:
            Dictionary with all expressions evaluated
        """
        resolved = {}
        
        for key, value in inputs.items():
            if self.is_expression(value):
                resolved[key] = self.evaluate(value, namespace, default=value)
            else:
                resolved[key] = value
        
        return resolved
    
    def _is_literal(self, expression: str) -> bool:
        """Check if expression is a simple literal value."""
        expression = expression.strip()
        
        # Boolean
        if expression.lower() in ["true", "false"]:
            return True
        
        # Number
        try:
            float(expression)
            return True
        except ValueError:
            pass
        
        # Quoted string
        if (expression.startswith('"') and expression.endswith('"')) or \
           (expression.startswith("'") and expression.endswith("'")):
            return True
        
        return False
    
    def _parse_literal(self, expression: str) -> Any:
        """Parse a literal value."""
        expression = expression.strip()
        
        # Boolean
        if expression.lower() == "true":
            return True
        if expression.lower() == "false":
            return False
        
        # Number
        try:
            if "." in expression:
                return float(expression)
            return int(expression)
        except ValueError:
            pass
        
        # Quoted string
        if (expression.startswith('"') and expression.endswith('"')) or \
           (expression.startswith("'") and expression.endswith("'")):
            return expression[1:-1]
        
        return expression
    
    def _fallback_eval(self, expression: str, namespace: Dict[str, Any]) -> Any:
        """
        Fallback evaluation when simpleeval is not available.
        
        WARNING: This is less safe than simpleeval. Only basic operations allowed.
        """
        # Create restricted globals
        safe_globals = {
            "__builtins__": {},
        }
        safe_globals.update(EXPRESSION_FUNCTIONS)
        
        # Create locals from namespace
        safe_locals = {}
        for scope_name, scope_values in namespace.items():
            safe_locals[scope_name] = scope_values
        
        return eval(expression, safe_globals, safe_locals)
    
    def validate_expression(self, expression: str) -> Optional[str]:
        """
        Validate an expression syntax without evaluating.
        
        Returns:
            None if valid, error message if invalid
        """
        if not expression:
            return None
        
        if self._is_literal(expression):
            return None
        
        try:
            # Try to compile to check syntax
            compile(expression, "<expression>", "eval")
            return None
        except SyntaxError as e:
            return f"Syntax error: {e.msg}"


class ExpressionBuilder:
    """
    Helper for building expressions programmatically.
    
    Usage:
        expr = ExpressionBuilder()
        
        # Build: Locals.voltage * Locals.gain + Locals.offset
        result = (expr.ref("Locals.voltage") * expr.ref("Locals.gain") 
                  + expr.ref("Locals.offset"))
        print(result.build())  # "Locals.voltage * Locals.gain + Locals.offset"
    """
    
    def __init__(self, expr: str = ""):
        self._expr = expr
    
    def ref(self, reference: str) -> "ExpressionBuilder":
        """Create a variable reference."""
        return ExpressionBuilder(reference)
    
    def __add__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} + {self._format(other)})")
    
    def __sub__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} - {self._format(other)})")
    
    def __mul__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} * {self._format(other)})")
    
    def __truediv__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} / {self._format(other)})")
    
    def __gt__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} > {self._format(other)})")
    
    def __lt__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} < {self._format(other)})")
    
    def __ge__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} >= {self._format(other)})")
    
    def __le__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} <= {self._format(other)})")
    
    def __eq__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} == {self._format(other)})")
    
    def __ne__(self, other: Any) -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} != {self._format(other)})")
    
    def __and__(self, other: "ExpressionBuilder") -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} and {other._expr})")
    
    def __or__(self, other: "ExpressionBuilder") -> "ExpressionBuilder":
        return ExpressionBuilder(f"({self._expr} or {other._expr})")
    
    def _format(self, value: Any) -> str:
        if isinstance(value, ExpressionBuilder):
            return value._expr
        if isinstance(value, str):
            return f'"{value}"'
        return str(value)
    
    def build(self) -> str:
        """Return the built expression string."""
        return self._expr
    
    def __str__(self) -> str:
        return self._expr
