"""
SeqMaster Runtime - Database Repository
Layer: DATABASE

Repository pattern for database operations.
Provides clean interface for CRUD operations.
"""

from datetime import datetime
from typing import List, Optional, Dict, Any
import hashlib
import json

from sqlalchemy import select, update, delete, desc
from sqlalchemy.ext.asyncio import AsyncSession

from src.database.models import (
    User, UserRole, Sequence, PropertySet,
    TestSession, StepResult, TestStatus, StepStatus,
    AuditLog, SystemState
)


class UserRepository:
    """Repository for user operations."""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def create(self, username: str, password_hash: str, 
                     role: UserRole, **kwargs) -> User:
        """Create new user."""
        user = User(
            username=username,
            password_hash=password_hash,
            role=role,
            **kwargs
        )
        self.session.add(user)
        await self.session.flush()
        return user
    
    async def get_by_id(self, user_id: int) -> Optional[User]:
        """Get user by ID."""
        result = await self.session.execute(
            select(User).where(User.id == user_id)
        )
        return result.scalar_one_or_none()
    
    async def get_by_username(self, username: str) -> Optional[User]:
        """Get user by username."""
        result = await self.session.execute(
            select(User).where(User.username == username)
        )
        return result.scalar_one_or_none()
    
    async def get_all_active(self) -> List[User]:
        """Get all active users."""
        result = await self.session.execute(
            select(User).where(User.active == True)
        )
        return list(result.scalars().all())
    
    async def get_all(self) -> List[User]:
        """Get all users."""
        result = await self.session.execute(
            select(User).order_by(User.username)
        )
        return list(result.scalars().all())
    
    async def update_last_login(self, user_id: int) -> None:
        """Update last login timestamp."""
        await self.session.execute(
            update(User)
            .where(User.id == user_id)
            .values(last_login=datetime.utcnow())
        )


class SequenceRepository:
    """Repository for sequence operations."""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def create(self, sequence_id: str, version: str, name: str,
                     content: str, **kwargs) -> Sequence:
        """Create new sequence version."""
        content_hash = hashlib.sha256(content.encode()).hexdigest()
        
        sequence = Sequence(
            sequence_id=sequence_id,
            version=version,
            name=name,
            content=content,
            content_hash=content_hash,
            **kwargs
        )
        self.session.add(sequence)
        await self.session.flush()
        return sequence
    
    async def get_latest(self, sequence_id: str) -> Optional[Sequence]:
        """Get latest version of a sequence."""
        result = await self.session.execute(
            select(Sequence)
            .where(Sequence.sequence_id == sequence_id)
            .where(Sequence.active == True)
            .order_by(desc(Sequence.created_at))
            .limit(1)
        )
        return result.scalar_one_or_none()
    
    async def get_by_version(self, sequence_id: str, version: str) -> Optional[Sequence]:
        """Get specific version of a sequence."""
        result = await self.session.execute(
            select(Sequence)
            .where(Sequence.sequence_id == sequence_id)
            .where(Sequence.version == version)
        )
        return result.scalar_one_or_none()
    
    async def get_all(self) -> List[Sequence]:
        """Get all active sequences (latest versions)."""
        result = await self.session.execute(
            select(Sequence)
            .where(Sequence.active == True)
            .order_by(desc(Sequence.created_at))
        )
        return list(result.scalars().all())
    
    async def get_versions(self, sequence_id: str) -> List[Sequence]:
        """Get all versions of a sequence (excluding drafts)."""
        result = await self.session.execute(
            select(Sequence)
            .where(Sequence.sequence_id == sequence_id)
            .where(Sequence.status == 'published')
            .order_by(desc(Sequence.created_at))
        )
        return list(result.scalars().all())
    
    async def get_draft(self, sequence_id: str) -> Optional[Sequence]:
        """Get current draft for a sequence (if exists)."""
        result = await self.session.execute(
            select(Sequence)
            .where(Sequence.sequence_id == sequence_id)
            .where(Sequence.status == 'draft')
            .limit(1)
        )
        return result.scalar_one_or_none()
    
    async def save_draft(self, sequence_id: str, name: str, content: str, 
                         version: str, **kwargs) -> Sequence:
        """Save or update draft. Overwrites existing draft."""
        content_hash = hashlib.sha256(content.encode()).hexdigest()
        
        # Check for existing draft
        existing_draft = await self.get_draft(sequence_id)
        
        if existing_draft:
            # Update existing draft
            existing_draft.name = name
            existing_draft.content = content
            existing_draft.content_hash = content_hash
            existing_draft.version = version
            for key, value in kwargs.items():
                if hasattr(existing_draft, key):
                    setattr(existing_draft, key, value)
            await self.session.flush()
            return existing_draft
        else:
            # Create new draft
            draft = Sequence(
                sequence_id=sequence_id,
                version=version,
                name=name,
                content=content,
                content_hash=content_hash,
                status='draft',
                **kwargs
            )
            self.session.add(draft)
            await self.session.flush()
            return draft
    
    async def publish_draft(self, sequence_id: str) -> Optional[Sequence]:
        """Publish current draft as new version."""
        draft = await self.get_draft(sequence_id)
        if not draft:
            return None
        
        # Get latest published version to increment
        versions = await self.get_versions(sequence_id)
        if versions:
            # Parse version and increment
            try:
                parts = versions[0].version.split('.')
                new_version = f"{parts[0]}.{int(parts[1]) + 1}"
            except (IndexError, ValueError):
                new_version = "1.0"
        else:
            new_version = "1.0"
        
        # Update draft to published
        draft.status = 'published'
        draft.version = new_version
        draft.published_at = datetime.utcnow()
        await self.session.flush()
        return draft
    
    async def delete_draft(self, sequence_id: str) -> bool:
        """Delete draft, keeping published versions."""
        draft = await self.get_draft(sequence_id)
        if draft:
            await self.session.delete(draft)
            await self.session.flush()
            return True
        return False
    
    async def delete_sequence(self, sequence_id: str) -> bool:
        """Soft delete - only drafts can be fully deleted."""
        # For now, just mark as inactive
        await self.session.execute(
            update(Sequence)
            .where(Sequence.sequence_id == sequence_id)
            .values(active=False)
        )
        return True


class PropertySetRepository:
    """Repository for property set operations."""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def get_all(self) -> List[PropertySet]:
        """Get all unique property sets (latest version of each)."""
        # Subquery to get max created_at for each property_set_id
        from sqlalchemy import func
        subquery = (
            select(
                PropertySet.property_set_id,
                func.max(PropertySet.created_at).label('max_created')
            )
            .where(PropertySet.active == True)
            .group_by(PropertySet.property_set_id)
            .subquery()
        )
        
        # Join to get full records
        result = await self.session.execute(
            select(PropertySet)
            .join(
                subquery,
                (PropertySet.property_set_id == subquery.c.property_set_id) &
                (PropertySet.created_at == subquery.c.max_created)
            )
            .order_by(PropertySet.name)
        )
        return list(result.scalars().all())
    
    async def create(self, property_set_id: str, version: str, name: str,
                     content: str, **kwargs) -> PropertySet:
        """Create new property set version."""
        content_hash = hashlib.sha256(content.encode()).hexdigest()
        
        prop_set = PropertySet(
            property_set_id=property_set_id,
            version=version,
            name=name,
            content=content,
            content_hash=content_hash,
            **kwargs
        )
        self.session.add(prop_set)
        await self.session.flush()
        return prop_set
    
    async def get_latest(self, property_set_id: str) -> Optional[PropertySet]:
        """Get latest version of a property set."""
        result = await self.session.execute(
            select(PropertySet)
            .where(PropertySet.property_set_id == property_set_id)
            .where(PropertySet.active == True)
            .order_by(desc(PropertySet.created_at))
            .limit(1)
        )
        return result.scalar_one_or_none()
    
    async def get_for_dut(self, dut_type: str, dut_revision: Optional[str] = None) -> Optional[PropertySet]:
        """Get property set for a specific DUT type/revision."""
        query = select(PropertySet).where(
            PropertySet.dut_type == dut_type,
            PropertySet.active == True
        )
        if dut_revision:
            query = query.where(PropertySet.dut_revision == dut_revision)
        
        query = query.order_by(desc(PropertySet.created_at)).limit(1)
        
        result = await self.session.execute(query)
        return result.scalar_one_or_none()


class TestSessionRepository:
    """Repository for test session operations."""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def create(self, session_id: str, dut_id: str, **kwargs) -> TestSession:
        """Create new test session."""
        test_session = TestSession(
            session_id=session_id,
            dut_id=dut_id,
            **kwargs
        )
        self.session.add(test_session)
        await self.session.flush()
        return test_session
    
    async def get_by_session_id(self, session_id: str) -> Optional[TestSession]:
        """Get test session by session ID."""
        from sqlalchemy.orm import joinedload
        result = await self.session.execute(
            select(TestSession)
            .options(joinedload(TestSession.sequence))
            .where(TestSession.session_id == session_id)
        )
        return result.scalar_one_or_none()
    
    async def get_by_dut(self, dut_id: str, limit: int = 100) -> List[TestSession]:
        """Get test sessions for a DUT."""
        from sqlalchemy.orm import joinedload
        result = await self.session.execute(
            select(TestSession)
            .options(joinedload(TestSession.sequence))
            .where(TestSession.dut_id == dut_id)
            .order_by(desc(TestSession.started_at))
            .limit(limit)
        )
        return list(result.unique().scalars().all())
    
    async def get_recent(self, limit: int = 50) -> List[TestSession]:
        """Get recent test sessions."""
        from sqlalchemy.orm import joinedload
        result = await self.session.execute(
            select(TestSession)
            .options(joinedload(TestSession.sequence))
            .order_by(desc(TestSession.started_at))
            .limit(limit)
        )
        return list(result.unique().scalars().all())
    
    async def update_status(self, session_id: str, status: TestStatus,
                           **kwargs) -> None:
        """Update test session status."""
        values = {"status": status, **kwargs}
        if status in [TestStatus.PASSED, TestStatus.FAILED, TestStatus.ERROR, TestStatus.ABORTED]:
            values["completed_at"] = datetime.utcnow()
        
        await self.session.execute(
            update(TestSession)
            .where(TestSession.session_id == session_id)
            .values(**values)
        )
    
    async def update_progress(self, session_id: str, current_step_index: int,
                             current_step_id: str, progress_percent: float) -> None:
        """Update test progress."""
        await self.session.execute(
            update(TestSession)
            .where(TestSession.session_id == session_id)
            .values(
                current_step_index=current_step_index,
                current_step_id=current_step_id,
                progress_percent=progress_percent
            )
        )


class StepResultRepository:
    """Repository for step result operations."""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def create(self, session_id: int, step_id: str, 
                     step_index: int, **kwargs) -> StepResult:
        """Create new step result."""
        step_result = StepResult(
            session_id=session_id,
            step_id=step_id,
            step_index=step_index,
            **kwargs
        )
        self.session.add(step_result)
        await self.session.flush()
        return step_result
    
    async def get_by_session(
        self,
        test_session_id: int,
        limit: Optional[int] = None,
        newest_first: bool = False
    ) -> List[StepResult]:
        """Get step results for a test session."""
        stmt = select(StepResult).where(StepResult.session_id == test_session_id)
        stmt = stmt.order_by(desc(StepResult.step_index) if newest_first else StepResult.step_index)
        if limit:
            stmt = stmt.limit(limit)
        result = await self.session.execute(stmt)
        return list(result.scalars().all())
    
    async def update_result(self, step_result_id: int, status: StepStatus,
                           measured_value: Optional[float] = None,
                           **kwargs) -> None:
        """Update step result."""
        values = {
            "status": status,
            "completed_at": datetime.utcnow(),
            **kwargs
        }
        if measured_value is not None:
            values["measured_value"] = measured_value
        
        await self.session.execute(
            update(StepResult)
            .where(StepResult.id == step_result_id)
            .values(**values)
        )


class AuditLogRepository:
    """Repository for audit log operations."""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def log(self, action: str, username: Optional[str] = None,
                  user_id: Optional[int] = None, **kwargs) -> AuditLog:
        """Create audit log entry."""
        log_entry = AuditLog(
            action=action,
            username=username,
            user_id=user_id,
            **kwargs
        )
        self.session.add(log_entry)
        await self.session.flush()
        return log_entry
    
    async def get_recent(self, limit: int = 100) -> List[AuditLog]:
        """Get recent audit log entries."""
        result = await self.session.execute(
            select(AuditLog)
            .order_by(desc(AuditLog.timestamp))
            .limit(limit)
        )
        return list(result.scalars().all())


class SystemStateRepository:
    """Repository for system state operations."""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def get(self, key: str) -> Optional[str]:
        """Get system state value."""
        result = await self.session.execute(
            select(SystemState).where(SystemState.key == key)
        )
        state = result.scalar_one_or_none()
        return state.value if state else None
    
    async def set(self, key: str, value: str) -> None:
        """Set system state value."""
        existing = await self.session.execute(
            select(SystemState).where(SystemState.key == key)
        )
        state = existing.scalar_one_or_none()
        
        if state:
            state.value = value
            state.updated_at = datetime.utcnow()
        else:
            self.session.add(SystemState(key=key, value=value))
        
        await self.session.flush()
    
    async def delete(self, key: str) -> None:
        """Delete system state value."""
        await self.session.execute(
            delete(SystemState).where(SystemState.key == key)
        )
