478 lines
16 KiB
Python
Executable File
478 lines
16 KiB
Python
Executable File
"""Bulk operations router for mass machine operations."""
|
|
|
|
import asyncio
|
|
import socket
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from types import SimpleNamespace
|
|
from typing import Dict, List
|
|
from uuid import UUID
|
|
|
|
import structlog
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
from core.middleware import get_current_user
|
|
from core.models import (
|
|
BulkHealthCheckRequest,
|
|
BulkHealthCheckResponse,
|
|
BulkHealthCheckResult,
|
|
BulkSSHCommandRequest,
|
|
BulkSSHCommandResponse,
|
|
BulkSSHCommandResult,
|
|
UserRole,
|
|
)
|
|
from core.permissions import PermissionChecker
|
|
from core.saved_machines_db import saved_machines_db
|
|
from core.audit_logger import immutable_audit_logger
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
security = HTTPBearer()
|
|
|
|
bulk_router = APIRouter(prefix="/api/bulk", tags=["Bulk Operations"])
|
|
|
|
|
|
ROLE_HEALTH_CHECK_LIMITS = {
|
|
UserRole.GUEST: 10,
|
|
UserRole.USER: 50,
|
|
UserRole.ADMIN: 200,
|
|
UserRole.SUPER_ADMIN: 200,
|
|
}
|
|
|
|
ROLE_SSH_COMMAND_LIMITS = {
|
|
UserRole.GUEST: 0,
|
|
UserRole.USER: 20,
|
|
UserRole.ADMIN: 100,
|
|
UserRole.SUPER_ADMIN: 100,
|
|
}
|
|
|
|
|
|
async def check_host_availability(
|
|
hostname: str, port: int = 22, timeout: int = 5
|
|
) -> tuple[bool, float | None, str | None]:
|
|
"""Check if host is available via TCP connection."""
|
|
start_time = time.time()
|
|
try:
|
|
reader, writer = await asyncio.wait_for(
|
|
asyncio.open_connection(hostname, port), timeout=timeout
|
|
)
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
response_time = (time.time() - start_time) * 1000
|
|
return True, response_time, None
|
|
except asyncio.TimeoutError:
|
|
return False, None, "Connection timeout"
|
|
except socket.gaierror:
|
|
return False, None, "DNS resolution failed"
|
|
except ConnectionRefusedError:
|
|
return False, None, "Connection refused"
|
|
except Exception as e:
|
|
return False, None, f"Connection error: {str(e)}"
|
|
|
|
|
|
@bulk_router.post(
|
|
"/health-check",
|
|
response_model=BulkHealthCheckResponse,
|
|
summary="Bulk health check",
|
|
description="Check availability of multiple machines in parallel"
|
|
)
|
|
async def bulk_health_check(
|
|
request_data: BulkHealthCheckRequest,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
"""Bulk machine availability check with role-based limits."""
|
|
user_info = get_current_user(request)
|
|
if not user_info:
|
|
raise HTTPException(status_code=401, detail="Authentication required")
|
|
|
|
username = user_info["username"]
|
|
user_role = UserRole(user_info["role"])
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
max_machines = ROLE_HEALTH_CHECK_LIMITS.get(user_role, 10)
|
|
machine_count = len(request_data.machine_ids)
|
|
|
|
if machine_count > max_machines:
|
|
logger.warning(
|
|
"Bulk health check limit exceeded",
|
|
username=username,
|
|
role=user_role.value,
|
|
requested=machine_count,
|
|
limit=max_machines,
|
|
)
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Role {user_role.value} can check max {max_machines} machines at once",
|
|
)
|
|
|
|
logger.info(
|
|
"Bulk health check started",
|
|
username=username,
|
|
machine_count=machine_count,
|
|
timeout=request_data.timeout,
|
|
)
|
|
|
|
started_at = datetime.now(timezone.utc)
|
|
start_time = time.time()
|
|
|
|
machines = []
|
|
for machine_id in request_data.machine_ids:
|
|
# Try to get from saved machines first (UUID format)
|
|
try:
|
|
UUID(machine_id)
|
|
machine_dict = saved_machines_db.get_machine_by_id(machine_id, username)
|
|
if machine_dict:
|
|
# Convert dict to object with attributes for uniform access
|
|
machine = SimpleNamespace(
|
|
id=machine_dict['id'],
|
|
name=machine_dict['name'],
|
|
ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')),
|
|
hostname=machine_dict.get('hostname', 'unknown'),
|
|
)
|
|
machines.append(machine)
|
|
continue
|
|
except (ValueError, AttributeError):
|
|
# Not a UUID
|
|
pass
|
|
|
|
logger.warning(
|
|
"Machine not found or invalid UUID",
|
|
username=username,
|
|
machine_id=machine_id,
|
|
)
|
|
|
|
async def check_machine(machine):
|
|
checked_at = datetime.now(timezone.utc).isoformat()
|
|
try:
|
|
available, response_time, error = await check_host_availability(
|
|
machine.ip, timeout=request_data.timeout
|
|
)
|
|
|
|
return BulkHealthCheckResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="success" if available else "failed",
|
|
available=available,
|
|
response_time_ms=int(response_time) if response_time else None,
|
|
error=error,
|
|
checked_at=checked_at,
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Health check error", machine_id=str(machine.id), error=str(e)
|
|
)
|
|
return BulkHealthCheckResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="failed",
|
|
available=False,
|
|
error=str(e),
|
|
checked_at=checked_at,
|
|
)
|
|
|
|
results = await asyncio.gather(*[check_machine(m) for m in machines])
|
|
|
|
completed_at = datetime.now(timezone.utc)
|
|
execution_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
success_count = sum(1 for r in results if r.status == "success")
|
|
failed_count = len(results) - success_count
|
|
available_count = sum(1 for r in results if r.available)
|
|
unavailable_count = len(results) - available_count
|
|
|
|
immutable_audit_logger.log_security_event(
|
|
event_type="bulk_health_check",
|
|
client_ip=client_ip,
|
|
user_agent=request.headers.get("user-agent", "unknown"),
|
|
details={
|
|
"machine_count": len(results),
|
|
"available": available_count,
|
|
"unavailable": unavailable_count,
|
|
"execution_time_ms": execution_time_ms,
|
|
},
|
|
severity="info",
|
|
username=username,
|
|
)
|
|
|
|
logger.info(
|
|
"Bulk health check completed",
|
|
username=username,
|
|
total=len(results),
|
|
available=available_count,
|
|
execution_time_ms=execution_time_ms,
|
|
)
|
|
|
|
return BulkHealthCheckResponse(
|
|
total=len(results),
|
|
success=success_count,
|
|
failed=failed_count,
|
|
available=available_count,
|
|
unavailable=unavailable_count,
|
|
results=results,
|
|
execution_time_ms=execution_time_ms,
|
|
started_at=started_at.isoformat(),
|
|
completed_at=completed_at.isoformat(),
|
|
)
|
|
|
|
|
|
@bulk_router.post(
|
|
"/ssh-command",
|
|
response_model=BulkSSHCommandResponse,
|
|
summary="Bulk SSH command",
|
|
description="Execute SSH commands on multiple machines in parallel"
|
|
)
|
|
async def bulk_ssh_command(
|
|
request_data: BulkSSHCommandRequest,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
"""Bulk SSH command execution with role-based limits."""
|
|
user_info = get_current_user(request)
|
|
if not user_info:
|
|
raise HTTPException(status_code=401, detail="Authentication required")
|
|
|
|
username = user_info["username"]
|
|
user_role = UserRole(user_info["role"])
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
if user_role == UserRole.GUEST:
|
|
raise HTTPException(
|
|
status_code=403, detail="GUEST role cannot execute SSH commands"
|
|
)
|
|
|
|
max_machines = ROLE_SSH_COMMAND_LIMITS.get(user_role, 0)
|
|
machine_count = len(request_data.machine_ids)
|
|
|
|
if machine_count > max_machines:
|
|
logger.warning(
|
|
"Bulk SSH command limit exceeded",
|
|
username=username,
|
|
role=user_role.value,
|
|
requested=machine_count,
|
|
limit=max_machines,
|
|
)
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Role {user_role.value} can execute commands on max {max_machines} machines at once",
|
|
)
|
|
|
|
logger.info(
|
|
"Bulk SSH command started",
|
|
username=username,
|
|
machine_count=machine_count,
|
|
command=request_data.command[:50],
|
|
mode=request_data.credentials_mode,
|
|
)
|
|
|
|
started_at = datetime.now(timezone.utc)
|
|
start_time = time.time()
|
|
|
|
machines = []
|
|
for machine_id in request_data.machine_ids:
|
|
# Try to get from saved machines first (UUID format)
|
|
try:
|
|
UUID(machine_id)
|
|
machine_dict = saved_machines_db.get_machine_by_id(machine_id, username)
|
|
if machine_dict:
|
|
# Convert dict to object with attributes for uniform access
|
|
machine = SimpleNamespace(
|
|
id=machine_dict['id'],
|
|
name=machine_dict['name'],
|
|
ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')),
|
|
hostname=machine_dict.get('hostname', 'unknown'),
|
|
)
|
|
machines.append(machine)
|
|
continue
|
|
except (ValueError, AttributeError):
|
|
# Not a UUID, check if hostname provided
|
|
pass
|
|
|
|
# Check if hostname provided for non-saved machine (mock machines)
|
|
if request_data.machine_hostnames and machine_id in request_data.machine_hostnames:
|
|
hostname = request_data.machine_hostnames[machine_id]
|
|
# Create mock machine object for non-saved machines
|
|
mock_machine = SimpleNamespace(
|
|
id=machine_id,
|
|
name=f'Mock-{machine_id}',
|
|
ip=hostname,
|
|
hostname=hostname,
|
|
)
|
|
machines.append(mock_machine)
|
|
logger.info(
|
|
"Using non-saved machine (mock)",
|
|
username=username,
|
|
machine_id=machine_id,
|
|
hostname=hostname,
|
|
)
|
|
continue
|
|
|
|
logger.warning(
|
|
"Machine not found and no hostname provided",
|
|
username=username,
|
|
machine_id=machine_id,
|
|
)
|
|
|
|
semaphore = asyncio.Semaphore(10)
|
|
|
|
async def execute_command(machine):
|
|
async with semaphore:
|
|
executed_at = datetime.now(timezone.utc).isoformat()
|
|
cmd_start = time.time()
|
|
|
|
try:
|
|
ssh_username = None
|
|
ssh_password = None
|
|
|
|
if request_data.credentials_mode == "global":
|
|
if not request_data.global_credentials:
|
|
return BulkSSHCommandResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="no_credentials",
|
|
error="Global credentials not provided",
|
|
executed_at=executed_at,
|
|
)
|
|
ssh_username = request_data.global_credentials.username
|
|
ssh_password = request_data.global_credentials.password
|
|
|
|
else: # custom mode
|
|
if not request_data.machine_credentials or str(
|
|
machine.id
|
|
) not in request_data.machine_credentials:
|
|
return BulkSSHCommandResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="no_credentials",
|
|
error="Custom credentials not provided for this machine",
|
|
executed_at=executed_at,
|
|
)
|
|
creds = request_data.machine_credentials[str(machine.id)]
|
|
ssh_username = creds.username
|
|
ssh_password = creds.password
|
|
|
|
if not ssh_username or not ssh_password:
|
|
return BulkSSHCommandResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="no_credentials",
|
|
error="Credentials missing",
|
|
executed_at=executed_at,
|
|
)
|
|
|
|
import paramiko
|
|
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
await asyncio.wait_for(
|
|
asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: ssh.connect(
|
|
machine.ip,
|
|
username=ssh_username,
|
|
password=ssh_password,
|
|
timeout=request_data.timeout,
|
|
),
|
|
),
|
|
timeout=request_data.timeout,
|
|
)
|
|
|
|
stdin, stdout, stderr = ssh.exec_command(request_data.command)
|
|
stdout_text = stdout.read().decode("utf-8", errors="ignore")
|
|
stderr_text = stderr.read().decode("utf-8", errors="ignore")
|
|
exit_code = stdout.channel.recv_exit_status()
|
|
|
|
ssh.close()
|
|
|
|
execution_time = int((time.time() - cmd_start) * 1000)
|
|
|
|
return BulkSSHCommandResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="success" if exit_code == 0 else "failed",
|
|
exit_code=exit_code,
|
|
stdout=stdout_text[:5000],
|
|
stderr=stderr_text[:5000],
|
|
execution_time_ms=execution_time,
|
|
executed_at=executed_at,
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
return BulkSSHCommandResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="timeout",
|
|
error="Command execution timeout",
|
|
executed_at=executed_at,
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"SSH command error",
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
error=str(e),
|
|
error_type=type(e).__name__
|
|
)
|
|
return BulkSSHCommandResult(
|
|
machine_id=str(machine.id),
|
|
machine_name=machine.name,
|
|
hostname=machine.ip,
|
|
status="failed",
|
|
error=str(e)[:500],
|
|
executed_at=executed_at,
|
|
)
|
|
|
|
results = await asyncio.gather(*[execute_command(m) for m in machines])
|
|
|
|
completed_at = datetime.now(timezone.utc)
|
|
execution_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
success_count = sum(1 for r in results if r.status == "success")
|
|
failed_count = len(results) - success_count
|
|
|
|
immutable_audit_logger.log_security_event(
|
|
event_type="bulk_ssh_command",
|
|
client_ip=client_ip,
|
|
user_agent=request.headers.get("user-agent", "unknown"),
|
|
details={
|
|
"machine_count": len(results),
|
|
"command": request_data.command[:100],
|
|
"credentials_mode": request_data.credentials_mode,
|
|
"success": success_count,
|
|
"failed": failed_count,
|
|
"execution_time_ms": execution_time_ms,
|
|
},
|
|
severity="high",
|
|
username=username,
|
|
)
|
|
|
|
logger.info(
|
|
"Bulk SSH command completed",
|
|
username=username,
|
|
total=len(results),
|
|
success=success_count,
|
|
failed=failed_count,
|
|
execution_time_ms=execution_time_ms,
|
|
)
|
|
|
|
return BulkSSHCommandResponse(
|
|
total=len(results),
|
|
success=success_count,
|
|
failed=failed_count,
|
|
results=results,
|
|
execution_time_ms=execution_time_ms,
|
|
command=request_data.command,
|
|
started_at=started_at.isoformat(),
|
|
completed_at=completed_at.isoformat(),
|
|
)
|
|
|