"""Bulk operations router for mass machine operations.""" import asyncio import socket import time from datetime import datetime, timezone from types import SimpleNamespace from typing import Dict, List from uuid import UUID import structlog from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from core.middleware import get_current_user from core.models import ( BulkHealthCheckRequest, BulkHealthCheckResponse, BulkHealthCheckResult, BulkSSHCommandRequest, BulkSSHCommandResponse, BulkSSHCommandResult, UserRole, ) from core.permissions import PermissionChecker from core.saved_machines_db import saved_machines_db from core.audit_logger import immutable_audit_logger logger = structlog.get_logger(__name__) security = HTTPBearer() bulk_router = APIRouter(prefix="/api/bulk", tags=["Bulk Operations"]) ROLE_HEALTH_CHECK_LIMITS = { UserRole.GUEST: 10, UserRole.USER: 50, UserRole.ADMIN: 200, UserRole.SUPER_ADMIN: 200, } ROLE_SSH_COMMAND_LIMITS = { UserRole.GUEST: 0, UserRole.USER: 20, UserRole.ADMIN: 100, UserRole.SUPER_ADMIN: 100, } async def check_host_availability( hostname: str, port: int = 22, timeout: int = 5 ) -> tuple[bool, float | None, str | None]: """Check if host is available via TCP connection.""" start_time = time.time() try: reader, writer = await asyncio.wait_for( asyncio.open_connection(hostname, port), timeout=timeout ) writer.close() await writer.wait_closed() response_time = (time.time() - start_time) * 1000 return True, response_time, None except asyncio.TimeoutError: return False, None, "Connection timeout" except socket.gaierror: return False, None, "DNS resolution failed" except ConnectionRefusedError: return False, None, "Connection refused" except Exception as e: return False, None, f"Connection error: {str(e)}" @bulk_router.post( "/health-check", response_model=BulkHealthCheckResponse, summary="Bulk health check", description="Check availability of multiple machines in parallel" ) async def bulk_health_check( request_data: BulkHealthCheckRequest, request: Request, credentials: HTTPAuthorizationCredentials = Depends(security), ): """Bulk machine availability check with role-based limits.""" user_info = get_current_user(request) if not user_info: raise HTTPException(status_code=401, detail="Authentication required") username = user_info["username"] user_role = UserRole(user_info["role"]) client_ip = request.client.host if request.client else "unknown" max_machines = ROLE_HEALTH_CHECK_LIMITS.get(user_role, 10) machine_count = len(request_data.machine_ids) if machine_count > max_machines: logger.warning( "Bulk health check limit exceeded", username=username, role=user_role.value, requested=machine_count, limit=max_machines, ) raise HTTPException( status_code=403, detail=f"Role {user_role.value} can check max {max_machines} machines at once", ) logger.info( "Bulk health check started", username=username, machine_count=machine_count, timeout=request_data.timeout, ) started_at = datetime.now(timezone.utc) start_time = time.time() machines = [] for machine_id in request_data.machine_ids: # Try to get from saved machines first (UUID format) try: UUID(machine_id) machine_dict = saved_machines_db.get_machine_by_id(machine_id, username) if machine_dict: # Convert dict to object with attributes for uniform access machine = SimpleNamespace( id=machine_dict['id'], name=machine_dict['name'], ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')), hostname=machine_dict.get('hostname', 'unknown'), ) machines.append(machine) continue except (ValueError, AttributeError): # Not a UUID pass logger.warning( "Machine not found or invalid UUID", username=username, machine_id=machine_id, ) async def check_machine(machine): checked_at = datetime.now(timezone.utc).isoformat() try: available, response_time, error = await check_host_availability( machine.ip, timeout=request_data.timeout ) return BulkHealthCheckResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="success" if available else "failed", available=available, response_time_ms=int(response_time) if response_time else None, error=error, checked_at=checked_at, ) except Exception as e: logger.error( "Health check error", machine_id=str(machine.id), error=str(e) ) return BulkHealthCheckResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="failed", available=False, error=str(e), checked_at=checked_at, ) results = await asyncio.gather(*[check_machine(m) for m in machines]) completed_at = datetime.now(timezone.utc) execution_time_ms = int((time.time() - start_time) * 1000) success_count = sum(1 for r in results if r.status == "success") failed_count = len(results) - success_count available_count = sum(1 for r in results if r.available) unavailable_count = len(results) - available_count immutable_audit_logger.log_security_event( event_type="bulk_health_check", client_ip=client_ip, user_agent=request.headers.get("user-agent", "unknown"), details={ "machine_count": len(results), "available": available_count, "unavailable": unavailable_count, "execution_time_ms": execution_time_ms, }, severity="info", username=username, ) logger.info( "Bulk health check completed", username=username, total=len(results), available=available_count, execution_time_ms=execution_time_ms, ) return BulkHealthCheckResponse( total=len(results), success=success_count, failed=failed_count, available=available_count, unavailable=unavailable_count, results=results, execution_time_ms=execution_time_ms, started_at=started_at.isoformat(), completed_at=completed_at.isoformat(), ) @bulk_router.post( "/ssh-command", response_model=BulkSSHCommandResponse, summary="Bulk SSH command", description="Execute SSH commands on multiple machines in parallel" ) async def bulk_ssh_command( request_data: BulkSSHCommandRequest, request: Request, credentials: HTTPAuthorizationCredentials = Depends(security), ): """Bulk SSH command execution with role-based limits.""" user_info = get_current_user(request) if not user_info: raise HTTPException(status_code=401, detail="Authentication required") username = user_info["username"] user_role = UserRole(user_info["role"]) client_ip = request.client.host if request.client else "unknown" if user_role == UserRole.GUEST: raise HTTPException( status_code=403, detail="GUEST role cannot execute SSH commands" ) max_machines = ROLE_SSH_COMMAND_LIMITS.get(user_role, 0) machine_count = len(request_data.machine_ids) if machine_count > max_machines: logger.warning( "Bulk SSH command limit exceeded", username=username, role=user_role.value, requested=machine_count, limit=max_machines, ) raise HTTPException( status_code=403, detail=f"Role {user_role.value} can execute commands on max {max_machines} machines at once", ) logger.info( "Bulk SSH command started", username=username, machine_count=machine_count, command=request_data.command[:50], mode=request_data.credentials_mode, ) started_at = datetime.now(timezone.utc) start_time = time.time() machines = [] for machine_id in request_data.machine_ids: # Try to get from saved machines first (UUID format) try: UUID(machine_id) machine_dict = saved_machines_db.get_machine_by_id(machine_id, username) if machine_dict: # Convert dict to object with attributes for uniform access machine = SimpleNamespace( id=machine_dict['id'], name=machine_dict['name'], ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')), hostname=machine_dict.get('hostname', 'unknown'), ) machines.append(machine) continue except (ValueError, AttributeError): # Not a UUID, check if hostname provided pass # Check if hostname provided for non-saved machine (mock machines) if request_data.machine_hostnames and machine_id in request_data.machine_hostnames: hostname = request_data.machine_hostnames[machine_id] # Create mock machine object for non-saved machines mock_machine = SimpleNamespace( id=machine_id, name=f'Mock-{machine_id}', ip=hostname, hostname=hostname, ) machines.append(mock_machine) logger.info( "Using non-saved machine (mock)", username=username, machine_id=machine_id, hostname=hostname, ) continue logger.warning( "Machine not found and no hostname provided", username=username, machine_id=machine_id, ) semaphore = asyncio.Semaphore(10) async def execute_command(machine): async with semaphore: executed_at = datetime.now(timezone.utc).isoformat() cmd_start = time.time() try: ssh_username = None ssh_password = None if request_data.credentials_mode == "global": if not request_data.global_credentials: return BulkSSHCommandResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="no_credentials", error="Global credentials not provided", executed_at=executed_at, ) ssh_username = request_data.global_credentials.username ssh_password = request_data.global_credentials.password else: # custom mode if not request_data.machine_credentials or str( machine.id ) not in request_data.machine_credentials: return BulkSSHCommandResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="no_credentials", error="Custom credentials not provided for this machine", executed_at=executed_at, ) creds = request_data.machine_credentials[str(machine.id)] ssh_username = creds.username ssh_password = creds.password if not ssh_username or not ssh_password: return BulkSSHCommandResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="no_credentials", error="Credentials missing", executed_at=executed_at, ) import paramiko ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) await asyncio.wait_for( asyncio.get_event_loop().run_in_executor( None, lambda: ssh.connect( machine.ip, username=ssh_username, password=ssh_password, timeout=request_data.timeout, ), ), timeout=request_data.timeout, ) stdin, stdout, stderr = ssh.exec_command(request_data.command) stdout_text = stdout.read().decode("utf-8", errors="ignore") stderr_text = stderr.read().decode("utf-8", errors="ignore") exit_code = stdout.channel.recv_exit_status() ssh.close() execution_time = int((time.time() - cmd_start) * 1000) return BulkSSHCommandResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="success" if exit_code == 0 else "failed", exit_code=exit_code, stdout=stdout_text[:5000], stderr=stderr_text[:5000], execution_time_ms=execution_time, executed_at=executed_at, ) except asyncio.TimeoutError: return BulkSSHCommandResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="timeout", error="Command execution timeout", executed_at=executed_at, ) except Exception as e: logger.error( "SSH command error", machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, error=str(e), error_type=type(e).__name__ ) return BulkSSHCommandResult( machine_id=str(machine.id), machine_name=machine.name, hostname=machine.ip, status="failed", error=str(e)[:500], executed_at=executed_at, ) results = await asyncio.gather(*[execute_command(m) for m in machines]) completed_at = datetime.now(timezone.utc) execution_time_ms = int((time.time() - start_time) * 1000) success_count = sum(1 for r in results if r.status == "success") failed_count = len(results) - success_count immutable_audit_logger.log_security_event( event_type="bulk_ssh_command", client_ip=client_ip, user_agent=request.headers.get("user-agent", "unknown"), details={ "machine_count": len(results), "command": request_data.command[:100], "credentials_mode": request_data.credentials_mode, "success": success_count, "failed": failed_count, "execution_time_ms": execution_time_ms, }, severity="high", username=username, ) logger.info( "Bulk SSH command completed", username=username, total=len(results), success=success_count, failed=failed_count, execution_time_ms=execution_time_ms, ) return BulkSSHCommandResponse( total=len(results), success=success_count, failed=failed_count, results=results, execution_time_ms=execution_time_ms, command=request_data.command, started_at=started_at.isoformat(), completed_at=completed_at.isoformat(), )