This commit is contained in:
root
2025-11-25 10:11:32 +03:00
parent 48b1934def
commit 60792735ad
38 changed files with 12695 additions and 0 deletions

View File

@ -0,0 +1,477 @@
"""Bulk operations router for mass machine operations."""
import asyncio
import socket
import time
from datetime import datetime, timezone
from types import SimpleNamespace
from typing import Dict, List
from uuid import UUID
import structlog
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from core.middleware import get_current_user
from core.models import (
BulkHealthCheckRequest,
BulkHealthCheckResponse,
BulkHealthCheckResult,
BulkSSHCommandRequest,
BulkSSHCommandResponse,
BulkSSHCommandResult,
UserRole,
)
from core.permissions import PermissionChecker
from core.saved_machines_db import saved_machines_db
from core.audit_logger import immutable_audit_logger
logger = structlog.get_logger(__name__)
security = HTTPBearer()
bulk_router = APIRouter(prefix="/api/bulk", tags=["Bulk Operations"])
ROLE_HEALTH_CHECK_LIMITS = {
UserRole.GUEST: 10,
UserRole.USER: 50,
UserRole.ADMIN: 200,
UserRole.SUPER_ADMIN: 200,
}
ROLE_SSH_COMMAND_LIMITS = {
UserRole.GUEST: 0,
UserRole.USER: 20,
UserRole.ADMIN: 100,
UserRole.SUPER_ADMIN: 100,
}
async def check_host_availability(
hostname: str, port: int = 22, timeout: int = 5
) -> tuple[bool, float | None, str | None]:
"""Check if host is available via TCP connection."""
start_time = time.time()
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(hostname, port), timeout=timeout
)
writer.close()
await writer.wait_closed()
response_time = (time.time() - start_time) * 1000
return True, response_time, None
except asyncio.TimeoutError:
return False, None, "Connection timeout"
except socket.gaierror:
return False, None, "DNS resolution failed"
except ConnectionRefusedError:
return False, None, "Connection refused"
except Exception as e:
return False, None, f"Connection error: {str(e)}"
@bulk_router.post(
"/health-check",
response_model=BulkHealthCheckResponse,
summary="Bulk health check",
description="Check availability of multiple machines in parallel"
)
async def bulk_health_check(
request_data: BulkHealthCheckRequest,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security),
):
"""Bulk machine availability check with role-based limits."""
user_info = get_current_user(request)
if not user_info:
raise HTTPException(status_code=401, detail="Authentication required")
username = user_info["username"]
user_role = UserRole(user_info["role"])
client_ip = request.client.host if request.client else "unknown"
max_machines = ROLE_HEALTH_CHECK_LIMITS.get(user_role, 10)
machine_count = len(request_data.machine_ids)
if machine_count > max_machines:
logger.warning(
"Bulk health check limit exceeded",
username=username,
role=user_role.value,
requested=machine_count,
limit=max_machines,
)
raise HTTPException(
status_code=403,
detail=f"Role {user_role.value} can check max {max_machines} machines at once",
)
logger.info(
"Bulk health check started",
username=username,
machine_count=machine_count,
timeout=request_data.timeout,
)
started_at = datetime.now(timezone.utc)
start_time = time.time()
machines = []
for machine_id in request_data.machine_ids:
# Try to get from saved machines first (UUID format)
try:
UUID(machine_id)
machine_dict = saved_machines_db.get_machine_by_id(machine_id, username)
if machine_dict:
# Convert dict to object with attributes for uniform access
machine = SimpleNamespace(
id=machine_dict['id'],
name=machine_dict['name'],
ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')),
hostname=machine_dict.get('hostname', 'unknown'),
)
machines.append(machine)
continue
except (ValueError, AttributeError):
# Not a UUID
pass
logger.warning(
"Machine not found or invalid UUID",
username=username,
machine_id=machine_id,
)
async def check_machine(machine):
checked_at = datetime.now(timezone.utc).isoformat()
try:
available, response_time, error = await check_host_availability(
machine.ip, timeout=request_data.timeout
)
return BulkHealthCheckResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="success" if available else "failed",
available=available,
response_time_ms=int(response_time) if response_time else None,
error=error,
checked_at=checked_at,
)
except Exception as e:
logger.error(
"Health check error", machine_id=str(machine.id), error=str(e)
)
return BulkHealthCheckResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="failed",
available=False,
error=str(e),
checked_at=checked_at,
)
results = await asyncio.gather(*[check_machine(m) for m in machines])
completed_at = datetime.now(timezone.utc)
execution_time_ms = int((time.time() - start_time) * 1000)
success_count = sum(1 for r in results if r.status == "success")
failed_count = len(results) - success_count
available_count = sum(1 for r in results if r.available)
unavailable_count = len(results) - available_count
immutable_audit_logger.log_security_event(
event_type="bulk_health_check",
client_ip=client_ip,
user_agent=request.headers.get("user-agent", "unknown"),
details={
"machine_count": len(results),
"available": available_count,
"unavailable": unavailable_count,
"execution_time_ms": execution_time_ms,
},
severity="info",
username=username,
)
logger.info(
"Bulk health check completed",
username=username,
total=len(results),
available=available_count,
execution_time_ms=execution_time_ms,
)
return BulkHealthCheckResponse(
total=len(results),
success=success_count,
failed=failed_count,
available=available_count,
unavailable=unavailable_count,
results=results,
execution_time_ms=execution_time_ms,
started_at=started_at.isoformat(),
completed_at=completed_at.isoformat(),
)
@bulk_router.post(
"/ssh-command",
response_model=BulkSSHCommandResponse,
summary="Bulk SSH command",
description="Execute SSH commands on multiple machines in parallel"
)
async def bulk_ssh_command(
request_data: BulkSSHCommandRequest,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security),
):
"""Bulk SSH command execution with role-based limits."""
user_info = get_current_user(request)
if not user_info:
raise HTTPException(status_code=401, detail="Authentication required")
username = user_info["username"]
user_role = UserRole(user_info["role"])
client_ip = request.client.host if request.client else "unknown"
if user_role == UserRole.GUEST:
raise HTTPException(
status_code=403, detail="GUEST role cannot execute SSH commands"
)
max_machines = ROLE_SSH_COMMAND_LIMITS.get(user_role, 0)
machine_count = len(request_data.machine_ids)
if machine_count > max_machines:
logger.warning(
"Bulk SSH command limit exceeded",
username=username,
role=user_role.value,
requested=machine_count,
limit=max_machines,
)
raise HTTPException(
status_code=403,
detail=f"Role {user_role.value} can execute commands on max {max_machines} machines at once",
)
logger.info(
"Bulk SSH command started",
username=username,
machine_count=machine_count,
command=request_data.command[:50],
mode=request_data.credentials_mode,
)
started_at = datetime.now(timezone.utc)
start_time = time.time()
machines = []
for machine_id in request_data.machine_ids:
# Try to get from saved machines first (UUID format)
try:
UUID(machine_id)
machine_dict = saved_machines_db.get_machine_by_id(machine_id, username)
if machine_dict:
# Convert dict to object with attributes for uniform access
machine = SimpleNamespace(
id=machine_dict['id'],
name=machine_dict['name'],
ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')),
hostname=machine_dict.get('hostname', 'unknown'),
)
machines.append(machine)
continue
except (ValueError, AttributeError):
# Not a UUID, check if hostname provided
pass
# Check if hostname provided for non-saved machine (mock machines)
if request_data.machine_hostnames and machine_id in request_data.machine_hostnames:
hostname = request_data.machine_hostnames[machine_id]
# Create mock machine object for non-saved machines
mock_machine = SimpleNamespace(
id=machine_id,
name=f'Mock-{machine_id}',
ip=hostname,
hostname=hostname,
)
machines.append(mock_machine)
logger.info(
"Using non-saved machine (mock)",
username=username,
machine_id=machine_id,
hostname=hostname,
)
continue
logger.warning(
"Machine not found and no hostname provided",
username=username,
machine_id=machine_id,
)
semaphore = asyncio.Semaphore(10)
async def execute_command(machine):
async with semaphore:
executed_at = datetime.now(timezone.utc).isoformat()
cmd_start = time.time()
try:
ssh_username = None
ssh_password = None
if request_data.credentials_mode == "global":
if not request_data.global_credentials:
return BulkSSHCommandResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="no_credentials",
error="Global credentials not provided",
executed_at=executed_at,
)
ssh_username = request_data.global_credentials.username
ssh_password = request_data.global_credentials.password
else: # custom mode
if not request_data.machine_credentials or str(
machine.id
) not in request_data.machine_credentials:
return BulkSSHCommandResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="no_credentials",
error="Custom credentials not provided for this machine",
executed_at=executed_at,
)
creds = request_data.machine_credentials[str(machine.id)]
ssh_username = creds.username
ssh_password = creds.password
if not ssh_username or not ssh_password:
return BulkSSHCommandResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="no_credentials",
error="Credentials missing",
executed_at=executed_at,
)
import paramiko
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(
None,
lambda: ssh.connect(
machine.ip,
username=ssh_username,
password=ssh_password,
timeout=request_data.timeout,
),
),
timeout=request_data.timeout,
)
stdin, stdout, stderr = ssh.exec_command(request_data.command)
stdout_text = stdout.read().decode("utf-8", errors="ignore")
stderr_text = stderr.read().decode("utf-8", errors="ignore")
exit_code = stdout.channel.recv_exit_status()
ssh.close()
execution_time = int((time.time() - cmd_start) * 1000)
return BulkSSHCommandResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="success" if exit_code == 0 else "failed",
exit_code=exit_code,
stdout=stdout_text[:5000],
stderr=stderr_text[:5000],
execution_time_ms=execution_time,
executed_at=executed_at,
)
except asyncio.TimeoutError:
return BulkSSHCommandResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="timeout",
error="Command execution timeout",
executed_at=executed_at,
)
except Exception as e:
logger.error(
"SSH command error",
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
error=str(e),
error_type=type(e).__name__
)
return BulkSSHCommandResult(
machine_id=str(machine.id),
machine_name=machine.name,
hostname=machine.ip,
status="failed",
error=str(e)[:500],
executed_at=executed_at,
)
results = await asyncio.gather(*[execute_command(m) for m in machines])
completed_at = datetime.now(timezone.utc)
execution_time_ms = int((time.time() - start_time) * 1000)
success_count = sum(1 for r in results if r.status == "success")
failed_count = len(results) - success_count
immutable_audit_logger.log_security_event(
event_type="bulk_ssh_command",
client_ip=client_ip,
user_agent=request.headers.get("user-agent", "unknown"),
details={
"machine_count": len(results),
"command": request_data.command[:100],
"credentials_mode": request_data.credentials_mode,
"success": success_count,
"failed": failed_count,
"execution_time_ms": execution_time_ms,
},
severity="high",
username=username,
)
logger.info(
"Bulk SSH command completed",
username=username,
total=len(results),
success=success_count,
failed=failed_count,
execution_time_ms=execution_time_ms,
)
return BulkSSHCommandResponse(
total=len(results),
success=success_count,
failed=failed_count,
results=results,
execution_time_ms=execution_time_ms,
command=request_data.command,
started_at=started_at.isoformat(),
completed_at=completed_at.isoformat(),
)