2903 lines
99 KiB
Python
Executable File
2903 lines
99 KiB
Python
Executable File
# Standard library imports
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import platform
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from collections import defaultdict, deque
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
# Third-party imports
|
|
import jwt
|
|
import psutil
|
|
import requests
|
|
import structlog
|
|
from cryptography.hazmat.primitives import serialization
|
|
from dotenv import load_dotenv
|
|
from fastapi import (
|
|
BackgroundTasks,
|
|
Depends,
|
|
FastAPI,
|
|
HTTPException,
|
|
Request,
|
|
Response,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from pydantic import BaseModel
|
|
|
|
# Local imports
|
|
from core import (
|
|
ConnectionRequest,
|
|
ConnectionResponse,
|
|
GuacamoleAuthenticator,
|
|
LoginRequest,
|
|
LoginResponse,
|
|
PermissionChecker,
|
|
UserInfo,
|
|
UserRole,
|
|
create_jwt_token,
|
|
verify_jwt_token,
|
|
)
|
|
from core.audit_logger import immutable_audit_logger
|
|
from core.brute_force_protection import brute_force_protection
|
|
from core.csrf_protection import csrf_protection
|
|
from core.log_sanitizer import log_sanitizer, sanitize_log_processor
|
|
from core.middleware import get_current_user, get_current_user_token, jwt_auth_middleware
|
|
from core.models import (
|
|
BulkHealthCheckRequest,
|
|
BulkHealthCheckResponse,
|
|
BulkHealthCheckResult,
|
|
BulkSSHCommandRequest,
|
|
BulkSSHCommandResponse,
|
|
BulkSSHCommandResult,
|
|
ConnectionHistoryCreate,
|
|
ConnectionHistoryResponse,
|
|
SavedMachineCreate,
|
|
SavedMachineList,
|
|
SavedMachineResponse,
|
|
SavedMachineUpdate,
|
|
SSHCredentials,
|
|
)
|
|
from core.rate_limiter import redis_rate_limiter
|
|
from core.redis_storage import redis_connection_storage
|
|
from core.saved_machines_db import saved_machines_db
|
|
from core.session_storage import session_storage
|
|
from core.ssrf_protection import ssrf_protection
|
|
from core.token_blacklist import token_blacklist
|
|
from core.websocket_manager import websocket_manager
|
|
from security_config import SecurityConfig
|
|
from services.system_service import SystemService
|
|
from routers import bulk_router
|
|
|
|
load_dotenv()
|
|
|
|
enable_docs = os.getenv("ENABLE_DOCS", "true").lower() == "true"
|
|
|
|
tags_metadata = [
|
|
{
|
|
"name": "System",
|
|
"description": "System health checks and service status"
|
|
},
|
|
{
|
|
"name": "Authentication",
|
|
"description": "User authentication and authorization (login, logout, profile, permissions)"
|
|
},
|
|
{
|
|
"name": "Connections",
|
|
"description": "Remote desktop connection management (create, list, delete, extend TTL)"
|
|
},
|
|
{
|
|
"name": "Machines",
|
|
"description": "Machine management and saved machines CRUD operations"
|
|
},
|
|
{
|
|
"name": "Bulk Operations",
|
|
"description": "Bulk operations on multiple machines (health checks, SSH commands)"
|
|
}
|
|
]
|
|
|
|
app = FastAPI(
|
|
title="Remote Access API",
|
|
description="Remote desktop management API via Apache Guacamole. Supports RDP, VNC, SSH protocols with JWT authentication.",
|
|
version="1.0.0",
|
|
docs_url="/api/docs" if enable_docs else None,
|
|
redoc_url="/api/redoc" if enable_docs else None,
|
|
openapi_url="/api/openapi.json" if enable_docs else None,
|
|
openapi_tags=tags_metadata
|
|
)
|
|
|
|
app.include_router(bulk_router)
|
|
|
|
security = HTTPBearer()
|
|
|
|
# Structured logging configuration
|
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
LOG_FORMAT = os.getenv("LOG_FORMAT", "json")
|
|
|
|
def add_caller_info(
|
|
logger: Any, name: str, event_dict: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Add caller information to log entry."""
|
|
frame = sys._getframe()
|
|
try:
|
|
while frame and frame.f_back:
|
|
filename = frame.f_code.co_filename
|
|
if "structlog" not in filename and "logging" not in filename:
|
|
event_dict["caller"] = {
|
|
"file": filename.split("/")[-1],
|
|
"function": frame.f_code.co_name,
|
|
"line": frame.f_lineno,
|
|
}
|
|
break
|
|
frame = frame.f_back
|
|
finally:
|
|
del frame
|
|
return event_dict
|
|
|
|
|
|
def add_service_context(
|
|
logger: Any, name: str, event_dict: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Add service context to log entry."""
|
|
event_dict["service"] = "remote-access-api"
|
|
event_dict["version"] = "1.0.0"
|
|
return event_dict
|
|
|
|
json_processors = [
|
|
structlog.processors.TimeStamper(fmt="iso"),
|
|
add_service_context,
|
|
add_caller_info,
|
|
sanitize_log_processor,
|
|
structlog.processors.add_log_level,
|
|
structlog.processors.JSONRenderer()
|
|
]
|
|
|
|
text_processors = [
|
|
structlog.processors.TimeStamper(fmt="iso"),
|
|
add_service_context,
|
|
sanitize_log_processor,
|
|
structlog.processors.add_log_level,
|
|
structlog.dev.ConsoleRenderer(colors=True)
|
|
]
|
|
|
|
structlog.configure(
|
|
processors=json_processors if LOG_FORMAT == "json" else text_processors,
|
|
wrapper_class=structlog.stdlib.BoundLogger,
|
|
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
cache_logger_on_first_use=True,
|
|
)
|
|
|
|
logging.basicConfig(
|
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
|
format='%(message)s'
|
|
)
|
|
|
|
logger = structlog.get_logger()
|
|
security_logger = structlog.get_logger("security")
|
|
audit_logger = structlog.get_logger("audit")
|
|
performance_logger = structlog.get_logger("performance")
|
|
error_logger = structlog.get_logger("error")
|
|
|
|
|
|
def record_request_metric(
|
|
endpoint: str,
|
|
method: str,
|
|
status_code: int,
|
|
response_time_ms: float,
|
|
client_ip: str,
|
|
) -> None:
|
|
"""Record request metrics."""
|
|
metrics_storage["requests"]["total"] += 1
|
|
metrics_storage["requests"]["by_endpoint"][endpoint] += 1
|
|
metrics_storage["requests"]["by_method"][method] += 1
|
|
metrics_storage["requests"]["by_status"][status_code] += 1
|
|
metrics_storage["requests"]["by_ip"][client_ip] += 1
|
|
|
|
times = metrics_storage["performance"]["response_times"][endpoint]
|
|
times.append(response_time_ms)
|
|
if len(times) > 100:
|
|
times.pop(0)
|
|
|
|
def record_connection_metric(
|
|
protocol: str,
|
|
client_ip: str,
|
|
creation_time_ms: float,
|
|
success: bool = True,
|
|
) -> None:
|
|
"""Record connection creation metrics."""
|
|
if success:
|
|
metrics_storage["connections"]["total_created"] += 1
|
|
metrics_storage["connections"]["by_protocol"][protocol] += 1
|
|
metrics_storage["connections"]["by_ip"][client_ip] += 1
|
|
metrics_storage["connections"]["active_count"] = len(
|
|
redis_connection_storage.get_all_connections()
|
|
)
|
|
|
|
times = metrics_storage["performance"]["connection_creation_times"]
|
|
times.append(creation_time_ms)
|
|
if len(times) > 100:
|
|
times.pop(0)
|
|
else:
|
|
metrics_storage["errors"]["connection_failures"] += 1
|
|
|
|
def record_host_check_metric(check_time_ms: float, success: bool = True) -> None:
|
|
"""Record host check metrics."""
|
|
if success:
|
|
times = metrics_storage["performance"]["host_check_times"]
|
|
times.append(check_time_ms)
|
|
if len(times) > 100:
|
|
times.pop(0)
|
|
else:
|
|
metrics_storage["connections"]["failed_host_checks"] += 1
|
|
metrics_storage["errors"]["host_unreachable"] += 1
|
|
|
|
|
|
def record_error_metric(error_type: str) -> None:
|
|
"""Record error metrics."""
|
|
if error_type in metrics_storage["errors"]:
|
|
metrics_storage["errors"][error_type] += 1
|
|
|
|
def calculate_percentiles(values: List[float]) -> Dict[str, float]:
|
|
"""Calculate percentiles for value list"""
|
|
if not values:
|
|
return {"p50": 0, "p90": 0, "p95": 0, "p99": 0}
|
|
|
|
sorted_values = sorted(values)
|
|
n = len(sorted_values)
|
|
|
|
return {
|
|
"p50": sorted_values[int(n * 0.5)],
|
|
"p90": sorted_values[int(n * 0.9)],
|
|
"p95": sorted_values[int(n * 0.95)],
|
|
"p99": sorted_values[int(n * 0.99)]
|
|
}
|
|
|
|
def get_metrics_summary() -> Dict[str, Any]:
|
|
"""Get metrics summary."""
|
|
uptime_seconds = int((datetime.now() - service_start_time).total_seconds())
|
|
|
|
endpoint_stats = {
|
|
endpoint: {
|
|
"request_count": len(times),
|
|
"avg_response_time_ms": round(sum(times) / len(times), 2),
|
|
"percentiles": calculate_percentiles(times),
|
|
}
|
|
for endpoint, times in metrics_storage["performance"]["response_times"].items()
|
|
if times
|
|
}
|
|
|
|
connection_times = metrics_storage["performance"]["connection_creation_times"]
|
|
connection_stats = (
|
|
{
|
|
"avg_creation_time_ms": round(sum(connection_times) / len(connection_times), 2),
|
|
"percentiles": calculate_percentiles(connection_times),
|
|
}
|
|
if connection_times
|
|
else {}
|
|
)
|
|
|
|
host_check_times = metrics_storage["performance"]["host_check_times"]
|
|
host_check_stats = (
|
|
{
|
|
"avg_check_time_ms": round(sum(host_check_times) / len(host_check_times), 2),
|
|
"percentiles": calculate_percentiles(host_check_times),
|
|
}
|
|
if host_check_times
|
|
else {}
|
|
)
|
|
|
|
top_ips = dict(
|
|
sorted(
|
|
metrics_storage["requests"]["by_ip"].items(),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)[:10]
|
|
)
|
|
|
|
top_protocols = dict(
|
|
sorted(
|
|
metrics_storage["connections"]["by_protocol"].items(),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)
|
|
)
|
|
|
|
return {
|
|
"uptime_seconds": uptime_seconds,
|
|
"requests": {
|
|
"total": metrics_storage["requests"]["total"],
|
|
"requests_per_second": round(
|
|
metrics_storage["requests"]["total"] / max(uptime_seconds, 1), 2
|
|
),
|
|
"by_status": dict(metrics_storage["requests"]["by_status"]),
|
|
"by_method": dict(metrics_storage["requests"]["by_method"]),
|
|
"top_ips": top_ips,
|
|
},
|
|
"connections": {
|
|
"total_created": metrics_storage["connections"]["total_created"],
|
|
"currently_active": len(redis_connection_storage.get_all_connections()),
|
|
"by_protocol": top_protocols,
|
|
"failed_host_checks": metrics_storage["connections"]["failed_host_checks"],
|
|
},
|
|
"performance": {
|
|
"endpoints": endpoint_stats,
|
|
"connection_creation": connection_stats,
|
|
"host_checks": host_check_stats,
|
|
},
|
|
"errors": dict(metrics_storage["errors"]),
|
|
}
|
|
|
|
def log_security_event(
|
|
event_type: str,
|
|
client_ip: str,
|
|
user_agent: Optional[str] = None,
|
|
details: Optional[Dict[str, Any]] = None,
|
|
severity: str = "info",
|
|
username: Optional[str] = None,
|
|
) -> None:
|
|
"""Log security events to immutable audit log."""
|
|
immutable_audit_logger.log_security_event(
|
|
event_type=event_type,
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details=details,
|
|
severity=severity,
|
|
username=username,
|
|
)
|
|
|
|
event = {
|
|
"event_type": event_type,
|
|
"client_ip": client_ip,
|
|
"user_agent": user_agent or "unknown",
|
|
"severity": severity,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
if details:
|
|
event.update(details)
|
|
|
|
severity_loggers = {
|
|
"critical": security_logger.critical,
|
|
"high": security_logger.error,
|
|
"medium": security_logger.warning,
|
|
}
|
|
logger_func = severity_loggers.get(severity, security_logger.info)
|
|
logger_func("Security event", **event)
|
|
|
|
def log_audit_event(
|
|
action: str,
|
|
resource: str,
|
|
client_ip: str,
|
|
user_agent: Optional[str] = None,
|
|
result: str = "success",
|
|
details: Optional[Dict[str, Any]] = None,
|
|
username: Optional[str] = None,
|
|
) -> None:
|
|
"""Log audit events to immutable audit log."""
|
|
immutable_audit_logger.log_audit_event(
|
|
action=action,
|
|
resource=resource,
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result=result,
|
|
details=details,
|
|
username=username,
|
|
)
|
|
|
|
event = {
|
|
"action": action,
|
|
"resource": resource,
|
|
"client_ip": client_ip,
|
|
"user_agent": user_agent or "unknown",
|
|
"result": result,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
if details:
|
|
event.update(details)
|
|
|
|
if result == "failure":
|
|
audit_logger.warning("Audit event", **event)
|
|
else:
|
|
audit_logger.info("Audit event", **event)
|
|
|
|
def log_performance_event(
|
|
operation: str,
|
|
duration_ms: float,
|
|
details: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Log performance events."""
|
|
event = {
|
|
"operation": operation,
|
|
"duration_ms": round(duration_ms, 2),
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
if details:
|
|
event.update(details)
|
|
|
|
if duration_ms > 5000:
|
|
performance_logger.warning("Slow operation", **event)
|
|
elif duration_ms > 1000:
|
|
performance_logger.info("Performance event", **event)
|
|
else:
|
|
performance_logger.debug("Performance event", **event)
|
|
|
|
def log_connection_lifecycle(
|
|
connection_id: str,
|
|
action: str,
|
|
client_ip: str,
|
|
hostname: str,
|
|
protocol: str,
|
|
details: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Log connection lifecycle events."""
|
|
event = {
|
|
"connection_id": connection_id,
|
|
"action": action,
|
|
"client_ip": client_ip,
|
|
"hostname": hostname,
|
|
"protocol": protocol,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
if details:
|
|
event.update(details)
|
|
|
|
if action == "failed":
|
|
audit_logger.error("Connection lifecycle", **event)
|
|
else:
|
|
audit_logger.info("Connection lifecycle", **event)
|
|
|
|
def log_error_with_context(
|
|
error: Exception,
|
|
operation: str,
|
|
context: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Log errors with context."""
|
|
event = {
|
|
"operation": operation,
|
|
"error_type": type(error).__name__,
|
|
"error_message": str(error),
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
if context:
|
|
event.update(context)
|
|
|
|
error_logger.error("Application error", **event)
|
|
|
|
|
|
|
|
# CORS middleware - configured via .env
|
|
allowed_origins_str = os.getenv("ALLOWED_ORIGINS")
|
|
|
|
if not allowed_origins_str:
|
|
logger.error("ALLOWED_ORIGINS environment variable is not set!")
|
|
raise RuntimeError(
|
|
"ALLOWED_ORIGINS must be set in .env file or docker-compose.yml. "
|
|
"Example: ALLOWED_ORIGINS=https://mc.exbytestudios.com,http://localhost:5173"
|
|
)
|
|
|
|
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",") if origin.strip()]
|
|
|
|
if not allowed_origins:
|
|
logger.error("ALLOWED_ORIGINS is empty after parsing!")
|
|
raise RuntimeError("ALLOWED_ORIGINS must contain at least one valid origin")
|
|
|
|
logger.info("CORS configured", allowed_origins=allowed_origins)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allowed_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Electron desktop app CORS middleware
|
|
@app.middleware("http")
|
|
async def electron_cors_middleware(
|
|
request: Request, call_next: Any
|
|
) -> Response:
|
|
"""Handle CORS for Electron desktop app (missing/custom Origin headers)."""
|
|
origin = request.headers.get("origin")
|
|
response = await call_next(request)
|
|
|
|
cors_headers = {
|
|
"Access-Control-Allow-Credentials": "true",
|
|
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH",
|
|
"Access-Control-Allow-Headers": "*",
|
|
}
|
|
|
|
if not origin or origin == "null":
|
|
logger.debug(
|
|
"Request without Origin header (Electron or API client)",
|
|
path=request.url.path,
|
|
method=request.method,
|
|
user_agent=request.headers.get("user-agent", "unknown")[:50],
|
|
)
|
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
response.headers.update(cors_headers)
|
|
return response
|
|
|
|
if origin.startswith(("file://", "app://")):
|
|
logger.debug(
|
|
"Request from Electron with custom protocol",
|
|
origin=origin,
|
|
path=request.url.path,
|
|
method=request.method,
|
|
)
|
|
response.headers["Access-Control-Allow-Origin"] = origin
|
|
response.headers.update(cors_headers)
|
|
return response
|
|
|
|
return response
|
|
|
|
@app.middleware("http")
|
|
async def auth_middleware(request: Request, call_next: Any) -> Response:
|
|
"""JWT authentication middleware."""
|
|
return await jwt_auth_middleware(request, call_next)
|
|
|
|
@app.middleware("http")
|
|
async def logging_and_metrics_middleware(
|
|
request: Request, call_next: Any
|
|
) -> Response:
|
|
"""Request logging and metrics collection middleware."""
|
|
start_time = time.time()
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
request_id = str(uuid.uuid4())[:8]
|
|
logger.debug(
|
|
"Request started",
|
|
request_id=request_id,
|
|
method=request.method,
|
|
path=request.url.path,
|
|
query=str(request.query_params) if request.query_params else None,
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
)
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
response_time_ms = (time.time() - start_time) * 1000
|
|
|
|
record_request_metric(
|
|
endpoint=request.url.path,
|
|
method=request.method,
|
|
status_code=response.status_code,
|
|
response_time_ms=response_time_ms,
|
|
client_ip=client_ip,
|
|
)
|
|
|
|
log_level = logging.WARNING if response.status_code >= 400 else logging.INFO
|
|
logger.log(
|
|
log_level,
|
|
"Request completed",
|
|
request_id=request_id,
|
|
method=request.method,
|
|
path=request.url.path,
|
|
status_code=response.status_code,
|
|
response_time_ms=round(response_time_ms, 2),
|
|
client_ip=client_ip,
|
|
)
|
|
|
|
if response_time_ms > 1000:
|
|
log_performance_event(
|
|
operation=f"{request.method} {request.url.path}",
|
|
duration_ms=response_time_ms,
|
|
details={
|
|
"request_id": request_id,
|
|
"client_ip": client_ip,
|
|
"status_code": response.status_code,
|
|
},
|
|
)
|
|
|
|
response.headers["X-Request-ID"] = request_id
|
|
return response
|
|
|
|
except Exception as e:
|
|
response_time_ms = (time.time() - start_time) * 1000
|
|
|
|
log_error_with_context(
|
|
error=e,
|
|
operation=f"{request.method} {request.url.path}",
|
|
context={
|
|
"request_id": request_id,
|
|
"client_ip": client_ip,
|
|
"user_agent": user_agent,
|
|
"response_time_ms": round(response_time_ms, 2),
|
|
},
|
|
)
|
|
|
|
raise
|
|
|
|
@app.middleware("http")
|
|
async def csrf_middleware(request: Request, call_next: Any) -> Response:
|
|
"""CSRF protection middleware."""
|
|
if not csrf_protection.should_protect_endpoint(
|
|
request.method, request.url.path
|
|
):
|
|
return await call_next(request)
|
|
|
|
try:
|
|
user_info = None
|
|
user_id = None
|
|
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
token = auth_header.split(" ")[1]
|
|
try:
|
|
jwt_payload = verify_jwt_token(token)
|
|
if jwt_payload:
|
|
user_info = jwt_payload
|
|
user_id = jwt_payload.get("username")
|
|
except Exception as e:
|
|
logger.debug(
|
|
"JWT token validation failed in CSRF middleware",
|
|
error=str(e),
|
|
)
|
|
|
|
if not user_info or not user_id:
|
|
return await call_next(request)
|
|
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
logger.debug(
|
|
"JWT authentication detected, skipping CSRF validation",
|
|
user_id=user_id,
|
|
method=request.method,
|
|
path=request.url.path,
|
|
)
|
|
return await call_next(request)
|
|
|
|
csrf_token = request.headers.get("X-CSRF-Token")
|
|
|
|
if not csrf_token:
|
|
logger.warning(
|
|
"CSRF token missing for non-JWT auth",
|
|
user_id=user_id,
|
|
method=request.method,
|
|
path=request.url.path,
|
|
)
|
|
|
|
return Response(
|
|
content=json.dumps({
|
|
"error": "CSRF token required",
|
|
"message": "X-CSRF-Token header is required for cookie-based authentication",
|
|
}),
|
|
status_code=403,
|
|
media_type="application/json",
|
|
)
|
|
|
|
if not csrf_protection.validate_csrf_token(csrf_token, user_id):
|
|
logger.warning(
|
|
"CSRF token validation failed",
|
|
user_id=user_id,
|
|
method=request.method,
|
|
path=request.url.path,
|
|
token_preview=csrf_token[:16] + "...",
|
|
)
|
|
|
|
return Response(
|
|
content=json.dumps({
|
|
"error": "Invalid CSRF token",
|
|
"message": "CSRF token validation failed",
|
|
}),
|
|
status_code=403,
|
|
media_type="application/json",
|
|
)
|
|
|
|
response = await call_next(request)
|
|
new_csrf_token = csrf_protection.generate_csrf_token(user_id)
|
|
response.headers["X-CSRF-Token"] = new_csrf_token
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error("CSRF middleware error", error=str(e))
|
|
return Response(
|
|
content=json.dumps({
|
|
"error": "CSRF protection error",
|
|
"message": "Internal server error",
|
|
}),
|
|
status_code=500,
|
|
media_type="application/json",
|
|
)
|
|
|
|
@app.middleware("http")
|
|
async def rate_limit_middleware(request: Request, call_next: Any) -> Response:
|
|
"""Rate limiting middleware with Redis."""
|
|
excluded_paths = {
|
|
"/health",
|
|
"/health/detailed",
|
|
"/health/ready",
|
|
"/health/live",
|
|
"/",
|
|
"/docs",
|
|
"/openapi.json",
|
|
"/rate-limit/status",
|
|
"/metrics",
|
|
"/stats",
|
|
}
|
|
if request.url.path in excluded_paths:
|
|
return await call_next(request)
|
|
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
allowed, headers = redis_rate_limiter.check_rate_limit(
|
|
client_ip=client_ip,
|
|
requests_limit=RATE_LIMIT_REQUESTS,
|
|
window_seconds=RATE_LIMIT_WINDOW,
|
|
)
|
|
|
|
if not allowed:
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
log_security_event(
|
|
event_type="rate_limit_exceeded",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={
|
|
"path": request.url.path,
|
|
"method": request.method,
|
|
"limit": RATE_LIMIT_REQUESTS,
|
|
"window_seconds": RATE_LIMIT_WINDOW,
|
|
},
|
|
severity="medium",
|
|
)
|
|
|
|
record_error_metric("rate_limit_blocks")
|
|
|
|
response = Response(
|
|
content=json.dumps({
|
|
"error": "Rate limit exceeded",
|
|
"message": f"Too many requests. Limit: {RATE_LIMIT_REQUESTS} per {RATE_LIMIT_WINDOW} seconds",
|
|
"retry_after": headers.get(
|
|
"X-RateLimit-Reset", int(time.time() + RATE_LIMIT_WINDOW)
|
|
),
|
|
}),
|
|
status_code=429,
|
|
media_type="application/json",
|
|
)
|
|
else:
|
|
response = await call_next(request)
|
|
|
|
for header_name, header_value in headers.items():
|
|
response.headers[header_name] = str(header_value)
|
|
|
|
return response
|
|
|
|
# Guacamole configuration
|
|
GUACAMOLE_URL = os.getenv("GUACAMOLE_URL", "http://localhost:8080")
|
|
GUACAMOLE_PUBLIC_URL = os.getenv("GUACAMOLE_PUBLIC_URL", GUACAMOLE_URL)
|
|
|
|
guacamole_authenticator = GuacamoleAuthenticator()
|
|
|
|
# Rate limiting configuration
|
|
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "10"))
|
|
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
|
|
RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
|
|
|
|
|
|
service_start_time = datetime.now()
|
|
|
|
# Metrics storage
|
|
metrics_storage = {
|
|
"requests": {
|
|
"total": 0,
|
|
"by_endpoint": defaultdict(int),
|
|
"by_method": defaultdict(int),
|
|
"by_status": defaultdict(int),
|
|
"by_ip": defaultdict(int),
|
|
},
|
|
"connections": {
|
|
"total_created": 0,
|
|
"by_protocol": defaultdict(int),
|
|
"by_ip": defaultdict(int),
|
|
"active_count": 0,
|
|
"failed_host_checks": 0,
|
|
},
|
|
"performance": {
|
|
"response_times": defaultdict(list),
|
|
"connection_creation_times": [],
|
|
"host_check_times": [],
|
|
},
|
|
"errors": {
|
|
"rate_limit_blocks": 0,
|
|
"connection_failures": 0,
|
|
"host_unreachable": 0,
|
|
"authentication_failures": 0,
|
|
},
|
|
}
|
|
|
|
|
|
def decrypt_password_from_request(
|
|
encrypted_password: str,
|
|
request: Request,
|
|
context: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""Return password as-is (protected by HTTPS).
|
|
|
|
Args:
|
|
encrypted_password: Password string from client
|
|
request: FastAPI request object
|
|
context: Optional context dict for logging
|
|
|
|
Returns:
|
|
Password as provided by client
|
|
"""
|
|
return encrypted_password
|
|
|
|
|
|
def generate_connection_url(connection_id: str, guacamole_token: str) -> str:
|
|
"""Generate Guacamole connection URL.
|
|
|
|
Args:
|
|
connection_id: Guacamole connection ID
|
|
guacamole_token: Guacamole auth token (NOT JWT)
|
|
|
|
Returns:
|
|
Full URL for Guacamole client connection
|
|
"""
|
|
encoded_connection_id = base64.b64encode(
|
|
f"{connection_id}\0c\0postgresql".encode()
|
|
).decode()
|
|
|
|
connection_url = (
|
|
f"{GUACAMOLE_PUBLIC_URL}/guacamole/?token={guacamole_token}"
|
|
f"#/client/{encoded_connection_id}"
|
|
)
|
|
|
|
logger.debug(
|
|
"Connection URL generated",
|
|
connection_id=connection_id,
|
|
encoded_connection_id=encoded_connection_id,
|
|
url_length=len(connection_url),
|
|
)
|
|
|
|
return connection_url
|
|
|
|
class GuacamoleClient:
|
|
"""Client for interacting with Guacamole API."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize Guacamole client."""
|
|
self.base_url = GUACAMOLE_URL
|
|
self.session = requests.Session()
|
|
self.authenticator = guacamole_authenticator
|
|
|
|
def get_system_token(self) -> str:
|
|
"""Get system token for service operations."""
|
|
return self.authenticator.get_system_token()
|
|
|
|
def create_connection_with_user_token(
|
|
self, connection_request: ConnectionRequest, guacamole_token: str
|
|
) -> Dict[str, Any]:
|
|
"""Create new Guacamole connection with Guacamole token.
|
|
|
|
Note: guacamole_token is GUACAMOLE auth token, NOT JWT
|
|
"""
|
|
if not connection_request.port:
|
|
port_map = {"rdp": 3389, "vnc": 5900, "ssh": 22}
|
|
connection_request.port = port_map.get(
|
|
connection_request.protocol, 3389
|
|
)
|
|
|
|
original_hostname = connection_request.hostname
|
|
resolved_ip = original_hostname
|
|
|
|
try:
|
|
resolved_info = socket.getaddrinfo(
|
|
original_hostname, None, socket.AF_INET
|
|
)
|
|
if resolved_info:
|
|
resolved_ip = resolved_info[0][4][0]
|
|
logger.info(
|
|
"Hostname resolved for Guacamole connection",
|
|
original_hostname=original_hostname,
|
|
resolved_ip=resolved_ip,
|
|
protocol=connection_request.protocol,
|
|
port=connection_request.port,
|
|
)
|
|
except (socket.gaierror, socket.herror, OSError) as e:
|
|
logger.warning(
|
|
"Failed to resolve hostname, using as-is",
|
|
hostname=original_hostname,
|
|
error=str(e),
|
|
message="Guacamole will receive the original hostname",
|
|
)
|
|
resolved_ip = original_hostname
|
|
|
|
connection_config = {
|
|
"name": f"Auto-{original_hostname}-{int(time.time())}",
|
|
"protocol": connection_request.protocol,
|
|
"parameters": {
|
|
"hostname": resolved_ip,
|
|
"port": str(connection_request.port),
|
|
},
|
|
"attributes": {},
|
|
}
|
|
|
|
if connection_request.protocol == "rdp":
|
|
connection_config["parameters"].update({
|
|
"security": "any",
|
|
"ignore-cert": "true",
|
|
"enable-wallpaper": "false",
|
|
})
|
|
if connection_request.username:
|
|
connection_config["parameters"][
|
|
"username"
|
|
] = connection_request.username
|
|
if connection_request.password:
|
|
connection_config["parameters"][
|
|
"password"
|
|
] = connection_request.password
|
|
|
|
elif connection_request.protocol == "vnc":
|
|
if connection_request.password:
|
|
connection_config["parameters"][
|
|
"password"
|
|
] = connection_request.password
|
|
|
|
elif connection_request.protocol == "ssh":
|
|
if connection_request.username:
|
|
connection_config["parameters"][
|
|
"username"
|
|
] = connection_request.username
|
|
if connection_request.password:
|
|
connection_config["parameters"][
|
|
"password"
|
|
] = connection_request.password
|
|
|
|
if connection_request.enable_sftp is not None:
|
|
connection_config["parameters"]["enable-sftp"] = (
|
|
"true" if connection_request.enable_sftp else "false"
|
|
)
|
|
|
|
if (
|
|
connection_request.enable_sftp
|
|
and connection_request.sftp_root_directory
|
|
):
|
|
connection_config["parameters"][
|
|
"sftp-root-directory"
|
|
] = connection_request.sftp_root_directory
|
|
|
|
if (
|
|
connection_request.enable_sftp
|
|
and connection_request.sftp_server_alive_interval
|
|
and connection_request.sftp_server_alive_interval > 0
|
|
):
|
|
connection_config["parameters"][
|
|
"server-alive-interval"
|
|
] = str(connection_request.sftp_server_alive_interval)
|
|
else:
|
|
connection_config["parameters"]["enable-sftp"] = "true"
|
|
|
|
created_connection = self.authenticator.create_connection_with_token(
|
|
connection_config, guacamole_token
|
|
)
|
|
|
|
if not created_connection:
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to create connection in Guacamole"
|
|
)
|
|
|
|
connection_id = created_connection.get("identifier")
|
|
connection_url = generate_connection_url(connection_id, guacamole_token)
|
|
|
|
logger.info(
|
|
"Connection created",
|
|
connection_id=connection_id,
|
|
token_type="guacamole",
|
|
)
|
|
|
|
return {
|
|
"connection_id": connection_id,
|
|
"connection_url": connection_url,
|
|
"status": "created",
|
|
"auth_token": guacamole_token,
|
|
}
|
|
|
|
def delete_connection_with_user_token(
|
|
self, connection_id: str, auth_token: str
|
|
) -> bool:
|
|
"""Delete Guacamole connection using user token."""
|
|
return self.authenticator.delete_connection_with_token(
|
|
connection_id, auth_token
|
|
)
|
|
|
|
def get_user_connections(self, auth_token: str) -> List[Dict[str, Any]]:
|
|
"""Get user connections list."""
|
|
return self.authenticator.get_user_connections(auth_token)
|
|
|
|
def get_all_connections_with_system_token(self) -> List[Dict[str, Any]]:
|
|
"""Get all connections using system token."""
|
|
system_token = self.get_system_token()
|
|
return self.authenticator.get_user_connections(system_token)
|
|
|
|
def delete_connection_with_system_token(self, connection_id: str) -> bool:
|
|
"""Delete connection using system token."""
|
|
system_token = self.get_system_token()
|
|
return self.authenticator.delete_connection_with_token(
|
|
connection_id, system_token
|
|
)
|
|
|
|
guacamole_client = GuacamoleClient()
|
|
|
|
async def wait_for_guacamole(
|
|
timeout_seconds: int = 30, check_interval: float = 1.0
|
|
) -> bool:
|
|
"""Wait for Guacamole to become available.
|
|
|
|
Args:
|
|
timeout_seconds: Maximum wait time in seconds
|
|
check_interval: Check interval in seconds
|
|
|
|
Returns:
|
|
True if Guacamole is available, False on timeout
|
|
"""
|
|
start_time = time.time()
|
|
attempt = 0
|
|
|
|
logger.info(
|
|
"Waiting for Guacamole to become available...",
|
|
timeout_seconds=timeout_seconds,
|
|
guacamole_url=GUACAMOLE_URL,
|
|
)
|
|
|
|
while (time.time() - start_time) < timeout_seconds:
|
|
attempt += 1
|
|
try:
|
|
response = await asyncio.to_thread(
|
|
requests.get, f"{GUACAMOLE_URL}/guacamole/", timeout=2
|
|
)
|
|
|
|
if response.status_code in (200, 401, 403, 404):
|
|
elapsed = time.time() - start_time
|
|
logger.info(
|
|
"Guacamole is available",
|
|
attempt=attempt,
|
|
elapsed_seconds=round(elapsed, 2),
|
|
status_code=response.status_code,
|
|
)
|
|
return True
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logger.debug(
|
|
"Guacamole not ready yet",
|
|
attempt=attempt,
|
|
elapsed_seconds=round(time.time() - start_time, 2),
|
|
error=str(e)[:100],
|
|
)
|
|
|
|
await asyncio.sleep(check_interval)
|
|
|
|
logger.warning(
|
|
"Guacamole did not become available within timeout",
|
|
timeout_seconds=timeout_seconds,
|
|
total_attempts=attempt,
|
|
)
|
|
return False
|
|
|
|
async def cleanup_orphaned_guacamole_connections():
|
|
"""Clean up orphaned Guacamole connections if Redis is empty
|
|
|
|
Needed after FLUSHDB or first startup after crash when Guacamole
|
|
may have orphaned connections without Redis records.
|
|
|
|
Returns:
|
|
Number of deleted connections
|
|
"""
|
|
try:
|
|
all_connections = redis_connection_storage.get_all_connections()
|
|
|
|
if len(all_connections) > 0:
|
|
logger.info("Redis has active connections, skipping orphaned cleanup",
|
|
redis_connections_count=len(all_connections))
|
|
return 0
|
|
|
|
logger.warning("Redis is empty, checking for orphaned Guacamole connections",
|
|
message="This usually happens after FLUSHDB or service restart")
|
|
|
|
guac_connections = guacamole_client.get_all_connections_with_system_token()
|
|
|
|
if not guac_connections or len(guac_connections) == 0:
|
|
logger.info("No Guacamole connections found, nothing to clean up")
|
|
return 0
|
|
|
|
logger.warning("Found orphaned Guacamole connections",
|
|
guacamole_connections_count=len(guac_connections),
|
|
message="Deleting all orphaned connections")
|
|
|
|
deleted_count = 0
|
|
for conn in guac_connections:
|
|
conn_id = conn.get('identifier')
|
|
if conn_id:
|
|
try:
|
|
if guacamole_client.delete_connection_with_system_token(conn_id):
|
|
deleted_count += 1
|
|
logger.debug("Deleted orphaned connection",
|
|
connection_id=conn_id,
|
|
connection_name=conn.get('name', 'unknown'))
|
|
except Exception as e:
|
|
logger.error("Failed to delete orphaned connection",
|
|
connection_id=conn_id,
|
|
error=str(e))
|
|
|
|
if deleted_count > 0:
|
|
logger.info("Orphaned connections cleanup completed",
|
|
deleted_count=deleted_count,
|
|
total_found=len(guac_connections))
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Error during orphaned connections cleanup", error=str(e))
|
|
return 0
|
|
|
|
async def cleanup_expired_connections_once(log_action: str = "expired"):
|
|
"""Execute one iteration of expired connections cleanup
|
|
|
|
Args:
|
|
log_action: Action for logging (expired, startup_cleanup, etc.)
|
|
|
|
Returns:
|
|
Number of deleted connections
|
|
"""
|
|
try:
|
|
current_time = datetime.now(timezone.utc)
|
|
expired_connections = []
|
|
|
|
all_connections = redis_connection_storage.get_all_connections()
|
|
for conn_id, conn_data in all_connections.items():
|
|
expires_at = datetime.fromisoformat(conn_data['expires_at'])
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
|
|
if expires_at <= current_time:
|
|
expired_connections.append(conn_id)
|
|
|
|
deleted_count = 0
|
|
for conn_id in expired_connections:
|
|
conn_data = redis_connection_storage.get_connection(conn_id)
|
|
if conn_data:
|
|
if guacamole_client.delete_connection_with_system_token(conn_id):
|
|
deleted_count += 1
|
|
log_connection_lifecycle(
|
|
connection_id=conn_id,
|
|
action=log_action,
|
|
client_ip="system",
|
|
hostname=conn_data.get('hostname', 'unknown'),
|
|
protocol=conn_data.get('protocol', 'unknown'),
|
|
details={
|
|
"ttl_minutes": conn_data.get('ttl_minutes'),
|
|
"created_at": conn_data.get('created_at'),
|
|
"expires_at": conn_data.get('expires_at')
|
|
}
|
|
)
|
|
|
|
owner_username = conn_data.get('owner_username')
|
|
if owner_username:
|
|
try:
|
|
await websocket_manager.send_connection_expired(
|
|
username=owner_username,
|
|
connection_id=conn_id,
|
|
hostname=conn_data.get('hostname', 'unknown'),
|
|
protocol=conn_data.get('protocol', 'unknown')
|
|
)
|
|
except Exception as ws_error:
|
|
logger.warning("Failed to send WebSocket notification",
|
|
connection_id=conn_id,
|
|
error=str(ws_error))
|
|
|
|
redis_connection_storage.delete_connection(conn_id)
|
|
|
|
if deleted_count > 0:
|
|
logger.info("Cleanup completed",
|
|
action=log_action,
|
|
expired_count=len(expired_connections),
|
|
deleted_count=deleted_count)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Error during cleanup", error=str(e))
|
|
return 0
|
|
|
|
async def cleanup_expired_connections():
|
|
"""Background task to remove expired connections"""
|
|
while True:
|
|
try:
|
|
await cleanup_expired_connections_once("expired")
|
|
except Exception as e:
|
|
logger.error("Error during cleanup task", error=str(e))
|
|
|
|
await asyncio.sleep(60)
|
|
|
|
async def check_expiring_connections():
|
|
"""Background task to check connections expiring soon (warns 5 min before)"""
|
|
warned_connections = set()
|
|
|
|
while True:
|
|
try:
|
|
current_time = datetime.now(timezone.utc)
|
|
warning_threshold = current_time + timedelta(minutes=5)
|
|
|
|
all_connections = redis_connection_storage.get_all_connections()
|
|
|
|
for conn_id, conn_data in all_connections.items():
|
|
expires_at = datetime.fromisoformat(conn_data['expires_at'])
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
|
|
if current_time < expires_at <= warning_threshold and conn_id not in warned_connections:
|
|
owner_username = conn_data.get('owner_username')
|
|
if owner_username:
|
|
minutes_remaining = max(1, int((expires_at - current_time).total_seconds() / 60))
|
|
|
|
try:
|
|
await websocket_manager.send_connection_will_expire(
|
|
username=owner_username,
|
|
connection_id=conn_id,
|
|
hostname=conn_data.get('hostname', 'unknown'),
|
|
protocol=conn_data.get('protocol', 'unknown'),
|
|
minutes_remaining=minutes_remaining
|
|
)
|
|
warned_connections.add(conn_id)
|
|
|
|
logger.info("Connection expiration warning sent",
|
|
connection_id=conn_id,
|
|
username=owner_username,
|
|
minutes_remaining=minutes_remaining)
|
|
except Exception as ws_error:
|
|
logger.warning("Failed to send expiration warning",
|
|
connection_id=conn_id,
|
|
error=str(ws_error))
|
|
|
|
elif expires_at <= current_time and conn_id in warned_connections:
|
|
warned_connections.discard(conn_id)
|
|
|
|
current_conn_ids = set(all_connections.keys())
|
|
warned_connections &= current_conn_ids
|
|
|
|
except Exception as e:
|
|
logger.error("Error during expiring connections check", error=str(e))
|
|
|
|
await asyncio.sleep(30)
|
|
|
|
|
|
async def cleanup_ssrf_cache():
|
|
"""Background task to clean up SSRF cache"""
|
|
while True:
|
|
try:
|
|
ssrf_protection.cleanup_expired_cache()
|
|
except Exception as e:
|
|
logger.error("Error during SSRF cache cleanup", error=str(e))
|
|
|
|
await asyncio.sleep(180)
|
|
|
|
async def cleanup_csrf_tokens():
|
|
"""Background task to clean up expired CSRF tokens"""
|
|
while True:
|
|
try:
|
|
csrf_protection.cleanup_expired_tokens()
|
|
except Exception as e:
|
|
logger.error("Error during CSRF token cleanup", error=str(e))
|
|
|
|
await asyncio.sleep(600)
|
|
|
|
def schedule_connection_deletion(
|
|
connection_id: str,
|
|
ttl_minutes: int,
|
|
auth_token: str,
|
|
guacamole_username: str,
|
|
hostname: str = "unknown",
|
|
protocol: str = "unknown",
|
|
owner_username: str = "unknown",
|
|
owner_role: str = "unknown"
|
|
):
|
|
"""Schedule connection deletion via TTL using Redis"""
|
|
|
|
expires_at = datetime.now(timezone.utc) + timedelta(minutes=ttl_minutes)
|
|
created_at = datetime.now(timezone.utc)
|
|
|
|
|
|
connection_data = {
|
|
'connection_id': connection_id,
|
|
'created_at': created_at.isoformat(),
|
|
'expires_at': expires_at.isoformat(),
|
|
'ttl_minutes': ttl_minutes,
|
|
'auth_token': auth_token,
|
|
'guacamole_username': guacamole_username,
|
|
'hostname': hostname,
|
|
'protocol': protocol,
|
|
'owner_username': owner_username,
|
|
'owner_role': owner_role
|
|
}
|
|
|
|
redis_connection_storage.add_connection(
|
|
connection_id,
|
|
connection_data,
|
|
ttl_seconds=None
|
|
)
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Application startup initialization"""
|
|
startup_info = {
|
|
"guacamole_url": GUACAMOLE_URL,
|
|
"guacamole_public_url": GUACAMOLE_PUBLIC_URL,
|
|
"rate_limiting_enabled": RATE_LIMIT_ENABLED,
|
|
"rate_limit_config": f"{RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds" if RATE_LIMIT_ENABLED else None,
|
|
"log_level": LOG_LEVEL,
|
|
"log_format": LOG_FORMAT,
|
|
"python_version": sys.version.split()[0],
|
|
"platform": platform.system()
|
|
}
|
|
|
|
print("Starting Remote Access API...")
|
|
print(f"Guacamole URL (internal): {GUACAMOLE_URL}")
|
|
print(f"Guacamole Public URL (client): {GUACAMOLE_PUBLIC_URL}")
|
|
print(f"Rate Limiting: {'Enabled' if RATE_LIMIT_ENABLED else 'Disabled'}")
|
|
if RATE_LIMIT_ENABLED:
|
|
print(f"Rate Limit: {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds")
|
|
print(f"Log Level: {LOG_LEVEL}, Format: {LOG_FORMAT}")
|
|
|
|
logger.info("Application startup", **startup_info)
|
|
|
|
log_audit_event(
|
|
action="application_started",
|
|
resource="system",
|
|
client_ip="system",
|
|
details=startup_info
|
|
)
|
|
|
|
# Cleanup expired connections on startup
|
|
guacamole_ready = await wait_for_guacamole(timeout_seconds=30, check_interval=1.0)
|
|
|
|
if not guacamole_ready:
|
|
logger.warning("Guacamole not available, skipping startup cleanup",
|
|
message="Cleanup will be performed by background task when Guacamole becomes available")
|
|
else:
|
|
logger.info("Checking for orphaned Guacamole connections...")
|
|
orphaned_count = await cleanup_orphaned_guacamole_connections()
|
|
|
|
if orphaned_count > 0:
|
|
logger.warning(
|
|
"Orphaned cleanup completed",
|
|
deleted_connections=orphaned_count,
|
|
message="Removed orphaned Guacamole connections (no Redis records)"
|
|
)
|
|
|
|
logger.info("Starting cleanup of expired connections from previous runs...")
|
|
deleted_count = await cleanup_expired_connections_once("startup_cleanup")
|
|
|
|
if deleted_count > 0:
|
|
logger.info(
|
|
"Startup cleanup completed",
|
|
deleted_connections=deleted_count,
|
|
message="Removed expired connections from previous application runs"
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Startup cleanup completed",
|
|
deleted_connections=0,
|
|
message="No expired connections found"
|
|
)
|
|
|
|
asyncio.create_task(cleanup_expired_connections())
|
|
asyncio.create_task(check_expiring_connections())
|
|
asyncio.create_task(cleanup_ssrf_cache())
|
|
asyncio.create_task(cleanup_csrf_tokens())
|
|
logger.info(
|
|
"Background tasks started",
|
|
ttl_cleanup=True,
|
|
expiring_connections_check=True,
|
|
rate_limit_cleanup=RATE_LIMIT_ENABLED,
|
|
ssrf_cache_cleanup=True,
|
|
csrf_token_cleanup=True
|
|
)
|
|
|
|
@app.get("/", tags=["System"])
|
|
async def root():
|
|
return {"message": "Remote Access API is running"}
|
|
|
|
@app.get("/api/health", tags=["System"])
|
|
async def health_check():
|
|
"""Health check with component status"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
response = await asyncio.to_thread(
|
|
requests.get,
|
|
f"{GUACAMOLE_URL}/guacamole",
|
|
timeout=5
|
|
)
|
|
guacamole_web = {
|
|
"status": "ok" if response.status_code == 200 else "error",
|
|
"response_time_ms": round(response.elapsed.total_seconds() * 1000, 2),
|
|
"status_code": response.status_code
|
|
}
|
|
except Exception as e:
|
|
guacamole_web = {"status": "error", "error": str(e)}
|
|
|
|
database = SystemService.check_database_connection(guacamole_client, GUACAMOLE_URL)
|
|
guacd = SystemService.check_guacd_daemon()
|
|
system = SystemService.check_system_resources()
|
|
system_info = SystemService(service_start_time).get_system_info()
|
|
|
|
components = [guacamole_web, database, guacd, system]
|
|
overall_status = "ok"
|
|
|
|
for component in components:
|
|
if component.get("status") == "error":
|
|
overall_status = "error"
|
|
break
|
|
elif component.get("status") == "critical":
|
|
overall_status = "critical"
|
|
elif component.get("status") == "warning" and overall_status == "ok":
|
|
overall_status = "warning"
|
|
|
|
check_duration = round((time.time() - start_time) * 1000, 2)
|
|
|
|
return {
|
|
"overall_status": overall_status,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"check_duration_ms": check_duration,
|
|
"system_info": system_info,
|
|
"components": {
|
|
"guacamole_web": guacamole_web,
|
|
"database": database,
|
|
"guacd_daemon": guacd,
|
|
"system_resources": system
|
|
},
|
|
"statistics": {
|
|
"active_connections": len(redis_connection_storage.get_all_connections()),
|
|
"rate_limiting": {
|
|
"enabled": RATE_LIMIT_ENABLED,
|
|
"limit": RATE_LIMIT_REQUESTS,
|
|
"window_seconds": RATE_LIMIT_WINDOW,
|
|
"active_clients": redis_rate_limiter.get_stats().get("active_rate_limits", 0) if RATE_LIMIT_ENABLED else 0
|
|
}
|
|
}
|
|
}
|
|
|
|
@app.websocket("/ws/notifications")
|
|
async def websocket_notifications(websocket: WebSocket):
|
|
"""
|
|
WebSocket endpoint for real-time notifications
|
|
|
|
Events:
|
|
- connection_expired: Connection expired
|
|
- connection_deleted: Connection deleted
|
|
- connection_will_expire: Connection will expire soon (5 min warning)
|
|
- jwt_will_expire: JWT will expire soon (5 min warning)
|
|
- jwt_expired: JWT expired
|
|
- connection_extended: Connection extended
|
|
|
|
Connection protocol:
|
|
1. Client sends JWT token on connection
|
|
2. Server validates token
|
|
3. Server sends confirmation
|
|
4. Server starts sending notifications
|
|
"""
|
|
username = None
|
|
|
|
try:
|
|
await websocket.accept()
|
|
logger.info("WebSocket connection accepted, waiting for auth")
|
|
|
|
auth_message = await asyncio.wait_for(
|
|
websocket.receive_json(),
|
|
timeout=5.0
|
|
)
|
|
|
|
if auth_message.get("type") != "auth" or not auth_message.get("token"):
|
|
await websocket.close(code=4001, reason="Authentication required")
|
|
logger.warning("WebSocket connection rejected: no auth")
|
|
return
|
|
|
|
token = auth_message["token"]
|
|
payload = verify_jwt_token(token)
|
|
|
|
if not payload:
|
|
await websocket.close(code=4001, reason="Invalid token")
|
|
logger.warning("WebSocket connection rejected: invalid token")
|
|
return
|
|
|
|
username = payload.get("username")
|
|
if not username:
|
|
await websocket.close(code=4001, reason="Invalid token payload")
|
|
return
|
|
|
|
await websocket_manager.connect(websocket, username)
|
|
|
|
await websocket.send_json({
|
|
"type": "connected",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"data": {
|
|
"username": username,
|
|
"message": "Successfully connected to notifications stream"
|
|
}
|
|
})
|
|
|
|
logger.info("WebSocket client authenticated and connected",
|
|
username=username)
|
|
|
|
while True:
|
|
try:
|
|
message = await asyncio.wait_for(
|
|
websocket.receive_json(),
|
|
timeout=30.0
|
|
)
|
|
|
|
if message.get("type") == "ping":
|
|
await websocket.send_json({
|
|
"type": "pong",
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
})
|
|
|
|
except asyncio.TimeoutError:
|
|
# Timeout - send ping from server
|
|
try:
|
|
await websocket.send_json({
|
|
"type": "ping",
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
})
|
|
except (WebSocketDisconnect, ConnectionError, RuntimeError):
|
|
break
|
|
except WebSocketDisconnect:
|
|
break
|
|
except Exception as e:
|
|
logger.error("Error in WebSocket loop",
|
|
username=username,
|
|
error=str(e))
|
|
break
|
|
|
|
except asyncio.TimeoutError:
|
|
try:
|
|
await websocket.close(code=4408, reason="Authentication timeout")
|
|
except (WebSocketDisconnect, ConnectionError, RuntimeError):
|
|
pass # WebSocket may not be accepted yet
|
|
logger.warning("WebSocket connection timeout during auth")
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("WebSocket client disconnected",
|
|
username=username)
|
|
|
|
except Exception as e:
|
|
logger.error("WebSocket error",
|
|
username=username,
|
|
error=str(e))
|
|
try:
|
|
await websocket.close(code=1011, reason="Internal error")
|
|
except (WebSocketDisconnect, ConnectionError, RuntimeError):
|
|
pass
|
|
|
|
finally:
|
|
if username:
|
|
await websocket_manager.disconnect(websocket, username)
|
|
|
|
@app.post(
|
|
"/api/auth/login",
|
|
tags=["Authentication"],
|
|
response_model=LoginResponse,
|
|
summary="Authenticate user",
|
|
description="Login with username and password to receive JWT token"
|
|
)
|
|
async def login(login_request: LoginRequest, request: Request):
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
try:
|
|
# Check brute-force protection
|
|
allowed, reason, protection_details = brute_force_protection.check_login_allowed(
|
|
client_ip, login_request.username
|
|
)
|
|
|
|
if not allowed:
|
|
# Log security event
|
|
immutable_audit_logger.log_security_event(
|
|
event_type="login_blocked",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={
|
|
"username": login_request.username,
|
|
"reason": reason,
|
|
"protection_details": protection_details
|
|
},
|
|
severity="high",
|
|
username=login_request.username
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=f"Login blocked: {reason}"
|
|
)
|
|
|
|
user_info = guacamole_authenticator.authenticate_user(
|
|
login_request.username,
|
|
login_request.password
|
|
)
|
|
|
|
if not user_info:
|
|
brute_force_protection.record_failed_login(
|
|
client_ip, login_request.username, "invalid_credentials"
|
|
)
|
|
|
|
immutable_audit_logger.log_security_event(
|
|
event_type="login_failed",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={"username": login_request.username},
|
|
severity="medium",
|
|
username=login_request.username
|
|
)
|
|
|
|
record_error_metric("authentication_failures")
|
|
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid username or password"
|
|
)
|
|
|
|
brute_force_protection.record_successful_login(client_ip, login_request.username)
|
|
jwt_token = guacamole_authenticator.create_jwt_for_user(user_info)
|
|
|
|
immutable_audit_logger.log_audit_event(
|
|
action="user_login",
|
|
resource=f"user/{login_request.username}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="success",
|
|
details={
|
|
"role": user_info["role"],
|
|
"permissions_count": len(user_info.get("permissions", []))
|
|
},
|
|
username=login_request.username
|
|
)
|
|
|
|
expires_in = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "60")) * 60
|
|
|
|
return LoginResponse(
|
|
access_token=jwt_token,
|
|
token_type="bearer",
|
|
expires_in=expires_in,
|
|
user_info={
|
|
"username": user_info["username"],
|
|
"role": user_info["role"],
|
|
"permissions": user_info.get("permissions", []),
|
|
"full_name": user_info.get("full_name"),
|
|
"email": user_info.get("email"),
|
|
"organization": user_info.get("organization"),
|
|
"organizational_role": user_info.get("organizational_role")
|
|
}
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Unexpected error during login",
|
|
username=login_request.username,
|
|
client_ip=client_ip,
|
|
error=str(e))
|
|
|
|
record_error_metric("authentication_failures")
|
|
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Internal server error during authentication"
|
|
)
|
|
|
|
@app.get(
|
|
"/api/auth/profile",
|
|
tags=["Authentication"],
|
|
summary="Get user profile",
|
|
description="Retrieve current user profile information"
|
|
)
|
|
async def get_user_profile(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""Get current user information"""
|
|
user_info = get_current_user(request)
|
|
|
|
return {
|
|
"username": user_info["username"],
|
|
"role": user_info["role"],
|
|
"permissions": user_info.get("permissions", []),
|
|
"full_name": user_info.get("full_name"),
|
|
"email": user_info.get("email"),
|
|
"organization": user_info.get("organization"),
|
|
"organizational_role": user_info.get("organizational_role")
|
|
}
|
|
|
|
@app.get(
|
|
"/api/auth/permissions",
|
|
tags=["Authentication"],
|
|
summary="Get user permissions",
|
|
description="List all permissions for current user role"
|
|
)
|
|
async def get_user_permissions(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
user_info = get_current_user(request)
|
|
user_role = UserRole(user_info["role"])
|
|
|
|
permissions = PermissionChecker.get_user_permissions_list(user_role)
|
|
|
|
return {
|
|
"username": user_info["username"],
|
|
"role": user_info["role"],
|
|
"permissions": permissions,
|
|
"system_permissions": user_info.get("permissions", [])
|
|
}
|
|
|
|
@app.post(
|
|
"/api/auth/logout",
|
|
tags=["Authentication"],
|
|
summary="Logout user",
|
|
description="Revoke current JWT token and end session"
|
|
)
|
|
async def logout(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
user_info = get_current_user(request)
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
token = auth_header.split(" ", 1)[1]
|
|
token_blacklist.revoke_token(token, "logout", user_info["username"])
|
|
|
|
immutable_audit_logger.log_audit_event(
|
|
action="user_logout",
|
|
resource=f"user/{user_info['username']}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="success",
|
|
username=user_info["username"]
|
|
)
|
|
|
|
return {
|
|
"message": "Successfully logged out",
|
|
"note": "JWT token has been revoked and added to blacklist"
|
|
}
|
|
|
|
@app.get(
|
|
"/api/auth/limits",
|
|
tags=["Authentication"],
|
|
summary="Get user limits",
|
|
description="Retrieve role-based limits and allowed networks"
|
|
)
|
|
async def get_user_limits(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""Get limits and restrictions for current user"""
|
|
user_info = get_current_user(request)
|
|
|
|
if not user_info:
|
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
|
|
user_role = UserRole(user_info["role"])
|
|
role_limits = SecurityConfig.get_role_limits(user_role)
|
|
|
|
return {
|
|
"username": user_info["username"],
|
|
"role": user_role.value,
|
|
"limits": role_limits,
|
|
"security_info": {
|
|
"blocked_hosts": list(SecurityConfig.BLOCKED_HOSTS),
|
|
"blocked_networks": SecurityConfig.BLOCKED_NETWORKS
|
|
}
|
|
}
|
|
|
|
@app.post(
|
|
"/api/auth/revoke",
|
|
tags=["Authentication"],
|
|
summary="Revoke token",
|
|
description="Revoke JWT token and invalidate session"
|
|
)
|
|
async def revoke_token(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""Revoke current JWT token"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
user_info = get_current_user(request)
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header or not auth_header.startswith("Bearer "):
|
|
raise HTTPException(status_code=400, detail="No valid token provided")
|
|
|
|
token = auth_header.split(" ", 1)[1]
|
|
|
|
success = token_blacklist.revoke_token(token, "logout", user_info["username"])
|
|
|
|
if success:
|
|
immutable_audit_logger.log_audit_event(
|
|
action="token_revoked",
|
|
resource=f"token/{token[:20]}...",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="success",
|
|
username=user_info["username"]
|
|
)
|
|
|
|
return {"message": "Token revoked successfully"}
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Failed to revoke token")
|
|
|
|
class MachineAvailabilityRequest(BaseModel):
|
|
"""Machine availability check request"""
|
|
hostname: str
|
|
port: Optional[int] = None
|
|
|
|
class MachineAvailabilityResponse(BaseModel):
|
|
"""Machine availability check response"""
|
|
available: bool
|
|
hostname: str
|
|
port: int
|
|
response_time_ms: Optional[float] = None
|
|
checked_at: str
|
|
|
|
@app.post(
|
|
"/api/machines/check-availability",
|
|
tags=["Machines"],
|
|
response_model=MachineAvailabilityResponse,
|
|
summary="Check machine availability",
|
|
description="Test if machine is reachable via TCP connection"
|
|
)
|
|
async def check_machine_availability(
|
|
availability_request: MachineAvailabilityRequest,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"""Check machine availability (quick ping)
|
|
|
|
Args:
|
|
hostname: DNS hostname of machine
|
|
port: Port to check (optional, default 3389 for RDP)
|
|
|
|
Returns:
|
|
Availability check result with response time
|
|
"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_info = get_current_user(request)
|
|
username = user_info.get("username", "unknown")
|
|
|
|
hostname = availability_request.hostname
|
|
port = availability_request.port if availability_request.port else 3389
|
|
|
|
logger.debug("Machine availability check requested",
|
|
hostname=hostname,
|
|
port=port,
|
|
username=username,
|
|
client_ip=client_ip)
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
with socket.create_connection((hostname, port), timeout=2):
|
|
response_time_ms = (time.time() - start_time) * 1000
|
|
available = True
|
|
|
|
logger.info("Machine is available",
|
|
hostname=hostname,
|
|
port=port,
|
|
response_time_ms=round(response_time_ms, 2),
|
|
username=username)
|
|
except (socket.timeout, socket.error, ConnectionRefusedError, OSError) as e:
|
|
response_time_ms = (time.time() - start_time) * 1000
|
|
available = False
|
|
|
|
logger.info("Machine is not available",
|
|
hostname=hostname,
|
|
port=port,
|
|
response_time_ms=round(response_time_ms, 2),
|
|
error=str(e),
|
|
username=username)
|
|
|
|
return MachineAvailabilityResponse(
|
|
available=available,
|
|
hostname=hostname,
|
|
port=port,
|
|
response_time_ms=round(response_time_ms, 2) if response_time_ms else None,
|
|
checked_at=datetime.now().isoformat()
|
|
)
|
|
|
|
@app.post(
|
|
"/api/connections",
|
|
tags=["Connections"],
|
|
response_model=ConnectionResponse,
|
|
summary="Create connection",
|
|
description="Create new remote desktop connection (RDP/VNC/SSH)"
|
|
)
|
|
async def create_remote_connection(
|
|
connection_request: ConnectionRequest,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
start_time = time.time()
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
user_info = get_current_user(request)
|
|
guacamole_token = get_current_user_token(request)
|
|
|
|
if not user_info or not guacamole_token:
|
|
logger.error("Missing user info or token from middleware",
|
|
has_user_info=bool(user_info),
|
|
has_token=bool(guacamole_token))
|
|
raise HTTPException(status_code=401, detail="Authentication required")
|
|
|
|
username = user_info["username"]
|
|
user_role = UserRole(user_info["role"])
|
|
|
|
logger.info("Creating connection for authenticated user",
|
|
username=username,
|
|
role=user_role.value,
|
|
guac_token_length=len(guacamole_token))
|
|
|
|
if connection_request.password:
|
|
connection_request.password = decrypt_password_from_request(
|
|
connection_request.password,
|
|
request,
|
|
context={
|
|
"username": username,
|
|
"hostname": connection_request.hostname,
|
|
"protocol": connection_request.protocol
|
|
}
|
|
)
|
|
|
|
role_limits = SecurityConfig.get_role_limits(user_role)
|
|
|
|
if not role_limits["can_create_connections"]:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Role {user_role.value} cannot create connections"
|
|
)
|
|
|
|
ttl_valid, ttl_reason = SecurityConfig.validate_ttl(connection_request.ttl_minutes)
|
|
if not ttl_valid:
|
|
raise HTTPException(status_code=400, detail=f"Invalid TTL: {ttl_reason}")
|
|
|
|
if connection_request.ttl_minutes > role_limits["max_ttl_minutes"]:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"TTL {connection_request.ttl_minutes} exceeds role limit {role_limits['max_ttl_minutes']} minutes"
|
|
)
|
|
|
|
host_allowed, host_reason = SecurityConfig.is_host_allowed(connection_request.hostname, user_role)
|
|
if not host_allowed:
|
|
logger.warning("Host access denied",
|
|
username=username,
|
|
role=user_role.value,
|
|
hostname=connection_request.hostname,
|
|
reason=host_reason,
|
|
client_ip=client_ip)
|
|
|
|
log_security_event(
|
|
event_type="forbidden_host_access_attempt",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={
|
|
"username": username,
|
|
"role": user_role.value,
|
|
"hostname": connection_request.hostname,
|
|
"reason": host_reason
|
|
},
|
|
severity="high"
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"Access to host denied: {host_reason}"
|
|
)
|
|
|
|
log_audit_event(
|
|
action="connection_creation_started",
|
|
resource=f"{connection_request.protocol}://{connection_request.hostname}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={
|
|
"protocol": connection_request.protocol,
|
|
"hostname": connection_request.hostname,
|
|
"port": connection_request.port,
|
|
"ttl_minutes": connection_request.ttl_minutes,
|
|
"username": username,
|
|
"role": user_role.value
|
|
}
|
|
)
|
|
|
|
if not connection_request.port:
|
|
port_map = {"rdp": 3389, "vnc": 5900, "ssh": 22}
|
|
check_port = port_map.get(connection_request.protocol, 3389)
|
|
else:
|
|
check_port = connection_request.port
|
|
|
|
logger.debug("Starting host connectivity check",
|
|
target_host=connection_request.hostname,
|
|
port=check_port,
|
|
client_ip=client_ip)
|
|
|
|
# Check host availability before creating connection
|
|
try:
|
|
with socket.create_connection((connection_request.hostname, check_port), timeout=3):
|
|
pass # Connection successful
|
|
except (socket.timeout, socket.error, ConnectionRefusedError, OSError) as e:
|
|
log_security_event(
|
|
event_type="unreachable_host_attempt",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={
|
|
"hostname": connection_request.hostname,
|
|
"port": check_port,
|
|
"protocol": connection_request.protocol,
|
|
"error": str(e)
|
|
},
|
|
severity="low"
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Host {connection_request.hostname}:{check_port} is not accessible"
|
|
)
|
|
|
|
logger.debug("Host connectivity check passed",
|
|
target_host=connection_request.hostname,
|
|
port=check_port,
|
|
client_ip=client_ip)
|
|
|
|
try:
|
|
|
|
result = guacamole_client.create_connection_with_user_token(connection_request, guacamole_token)
|
|
connection_id = result.get("connection_id")
|
|
|
|
expires_at = datetime.now(timezone.utc) + timedelta(minutes=connection_request.ttl_minutes)
|
|
schedule_connection_deletion(
|
|
connection_id=connection_id,
|
|
ttl_minutes=connection_request.ttl_minutes,
|
|
auth_token=result['auth_token'],
|
|
guacamole_username=user_info["username"],
|
|
hostname=connection_request.hostname,
|
|
protocol=connection_request.protocol,
|
|
owner_username=user_info["username"],
|
|
owner_role=user_info["role"]
|
|
)
|
|
|
|
result['expires_at'] = expires_at.isoformat()
|
|
result['ttl_minutes'] = connection_request.ttl_minutes
|
|
|
|
public_result = {
|
|
"connection_id": result["connection_id"],
|
|
"connection_url": result["connection_url"],
|
|
"status": result["status"],
|
|
"expires_at": result["expires_at"],
|
|
"ttl_minutes": result["ttl_minutes"]
|
|
}
|
|
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
|
|
record_connection_metric(
|
|
protocol=connection_request.protocol,
|
|
client_ip=client_ip,
|
|
creation_time_ms=duration_ms,
|
|
success=True
|
|
)
|
|
|
|
log_connection_lifecycle(
|
|
connection_id=connection_id,
|
|
action="created",
|
|
client_ip=client_ip,
|
|
hostname=connection_request.hostname,
|
|
protocol=connection_request.protocol,
|
|
details={
|
|
"ttl_minutes": connection_request.ttl_minutes,
|
|
"expires_at": expires_at.isoformat(),
|
|
"creation_duration_ms": round(duration_ms, 2),
|
|
"user_agent": user_agent,
|
|
"username": user_info["username"],
|
|
"role": user_info["role"]
|
|
}
|
|
)
|
|
|
|
log_audit_event(
|
|
action="connection_created",
|
|
resource=f"{connection_request.protocol}://{connection_request.hostname}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="success",
|
|
details={
|
|
"connection_id": connection_id,
|
|
"duration_ms": round(duration_ms, 2),
|
|
"username": user_info["username"],
|
|
"role": user_info["role"]
|
|
}
|
|
)
|
|
|
|
return ConnectionResponse(**public_result)
|
|
except Exception as e:
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
|
|
record_connection_metric(
|
|
protocol=connection_request.protocol,
|
|
client_ip=client_ip,
|
|
creation_time_ms=duration_ms,
|
|
success=False
|
|
)
|
|
|
|
log_connection_lifecycle(
|
|
connection_id="failed",
|
|
action="failed",
|
|
client_ip=client_ip,
|
|
hostname=connection_request.hostname,
|
|
protocol=connection_request.protocol,
|
|
details={
|
|
"error": str(e),
|
|
"duration_ms": round(duration_ms, 2),
|
|
"user_agent": user_agent,
|
|
"username": user_info["username"],
|
|
"role": user_info["role"]
|
|
}
|
|
)
|
|
|
|
log_audit_event(
|
|
action="connection_creation_failed",
|
|
resource=f"{connection_request.protocol}://{connection_request.hostname}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="failure",
|
|
details={
|
|
"error": str(e),
|
|
"duration_ms": round(duration_ms, 2),
|
|
"username": user_info["username"],
|
|
"role": user_info["role"]
|
|
}
|
|
)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get(
|
|
"/api/connections",
|
|
tags=["Connections"],
|
|
summary="List connections",
|
|
description="Retrieve active connections based on user role"
|
|
)
|
|
async def list_connections(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""
|
|
Returns active connections with connection_url for session restoration.
|
|
Used to restore connections after user re-login.
|
|
|
|
Note: URLs are generated with current user token, not old token from Redis.
|
|
This allows restoring connections after logout/login without 403 errors.
|
|
"""
|
|
user_info = get_current_user(request)
|
|
user_role = UserRole(user_info["role"])
|
|
username = user_info["username"]
|
|
|
|
current_guac_token = get_current_user_token(request)
|
|
|
|
if not current_guac_token:
|
|
logger.error("No Guacamole token available for user",
|
|
username=username)
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Authentication token not available"
|
|
)
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
rate_limit_key = f"get_connections:{username}:{client_ip}"
|
|
|
|
allowed, rate_limit_headers = redis_rate_limiter.check_rate_limit(
|
|
rate_limit_key,
|
|
requests_limit=60,
|
|
window_seconds=60
|
|
)
|
|
|
|
if not allowed:
|
|
logger.warning("Rate limit exceeded for get connections",
|
|
username=username,
|
|
client_ip=client_ip,
|
|
rate_limit_headers=rate_limit_headers)
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="Too many requests. Please try again later."
|
|
)
|
|
current_time = datetime.now(timezone.utc)
|
|
connections_with_ttl = []
|
|
all_connections = redis_connection_storage.get_all_connections()
|
|
|
|
for conn_id, conn_data in all_connections.items():
|
|
if not PermissionChecker.can_view_all_connections(user_role):
|
|
if conn_data.get('owner_username') != username:
|
|
continue
|
|
|
|
expires_at = datetime.fromisoformat(conn_data['expires_at'])
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
|
|
remaining_minutes = max(0, int((expires_at - current_time).total_seconds() / 60))
|
|
|
|
connection_url = None
|
|
|
|
if remaining_minutes > 0:
|
|
try:
|
|
connection_url = generate_connection_url(conn_id, current_guac_token)
|
|
logger.debug("Connection URL generated with current user token",
|
|
connection_id=conn_id,
|
|
username=username,
|
|
remaining_minutes=remaining_minutes)
|
|
except Exception as e:
|
|
logger.error("Failed to generate connection URL",
|
|
connection_id=conn_id,
|
|
error=str(e))
|
|
|
|
connections_with_ttl.append({
|
|
"connection_id": conn_id,
|
|
"hostname": conn_data.get('hostname', 'unknown'),
|
|
"protocol": conn_data.get('protocol', 'unknown'),
|
|
"owner_username": conn_data.get('owner_username', 'unknown'),
|
|
"owner_role": conn_data.get('owner_role', 'unknown'),
|
|
"created_at": conn_data['created_at'],
|
|
"expires_at": conn_data['expires_at'],
|
|
"ttl_minutes": conn_data['ttl_minutes'],
|
|
"remaining_minutes": remaining_minutes,
|
|
"status": "active" if remaining_minutes > 0 else "expired",
|
|
"connection_url": connection_url
|
|
})
|
|
|
|
logger.info(
|
|
"User retrieved connections list with refreshed tokens",
|
|
username=username,
|
|
total_connections=len(connections_with_ttl),
|
|
active_connections=len([c for c in connections_with_ttl if c['status'] == 'active']),
|
|
using_current_token=True
|
|
)
|
|
|
|
return {
|
|
"total_connections": len(connections_with_ttl),
|
|
"active_connections": len([c for c in connections_with_ttl if c['status'] == 'active']),
|
|
"connections": connections_with_ttl
|
|
}
|
|
|
|
import asyncio
|
|
import inspect
|
|
from typing import Any
|
|
|
|
async def _maybe_call(func: Any, *args, **kwargs):
|
|
if inspect.iscoroutinefunction(func):
|
|
return await func(*args, **kwargs)
|
|
return await asyncio.to_thread(func, *args, **kwargs)
|
|
|
|
|
|
@app.delete("/api/connections/{connection_id}", tags=["Connections"])
|
|
async def delete_connection(
|
|
connection_id: str,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
):
|
|
"""Force delete connection before TTL expiration"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
user_info = await _maybe_call(get_current_user, request)
|
|
user_role = UserRole(user_info["role"])
|
|
current_user_token = await _maybe_call(get_current_user_token, request)
|
|
|
|
if not current_user_token:
|
|
logger.error(
|
|
"No Guacamole token available for user",
|
|
username=user_info.get("username"),
|
|
connection_id=connection_id,
|
|
)
|
|
raise HTTPException(status_code=401, detail="Authentication token not available")
|
|
|
|
conn_data = await _maybe_call(redis_connection_storage.get_connection, connection_id)
|
|
|
|
if not conn_data:
|
|
log_security_event(
|
|
event_type="delete_nonexistent_connection",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={"connection_id": connection_id, "username": user_info.get("username")},
|
|
severity="low",
|
|
)
|
|
raise HTTPException(status_code=404, detail="Connection not found")
|
|
|
|
# Permission check
|
|
allowed, reason = PermissionChecker.check_connection_ownership(
|
|
user_role, user_info.get("username"), conn_data.get("owner_username", "")
|
|
)
|
|
|
|
if not allowed:
|
|
log_security_event(
|
|
event_type="unauthorized_connection_deletion",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
details={
|
|
"connection_id": connection_id,
|
|
"username": user_info.get("username"),
|
|
"owner": conn_data.get("owner_username", ""),
|
|
"reason": reason,
|
|
},
|
|
severity="medium",
|
|
)
|
|
raise HTTPException(status_code=403, detail=reason)
|
|
|
|
try:
|
|
|
|
deletion_success = await _maybe_call(
|
|
guacamole_client.delete_connection_with_user_token, connection_id, current_user_token
|
|
)
|
|
|
|
if not deletion_success:
|
|
logger.warning(
|
|
"Failed to delete connection from Guacamole",
|
|
connection_id=connection_id,
|
|
username=user_info.get("username"),
|
|
)
|
|
log_audit_event(
|
|
action="connection_deletion_failed",
|
|
resource=f"connection/{connection_id}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="failure",
|
|
details={"error": "Failed to delete from Guacamole"},
|
|
)
|
|
raise HTTPException(status_code=500, detail="Failed to delete connection from Guacamole")
|
|
|
|
await _maybe_call(redis_connection_storage.delete_connection, connection_id)
|
|
|
|
log_connection_lifecycle(
|
|
connection_id=connection_id,
|
|
action="deleted",
|
|
client_ip=client_ip,
|
|
hostname=conn_data.get("hostname", "unknown"),
|
|
protocol=conn_data.get("protocol", "unknown"),
|
|
details={
|
|
"user_agent": user_agent,
|
|
"remaining_ttl_minutes": conn_data.get("ttl_minutes", 0),
|
|
"deleted_manually": True,
|
|
"deleted_by": user_info.get("username"),
|
|
"deleter_role": user_info.get("role"),
|
|
},
|
|
)
|
|
|
|
log_audit_event(
|
|
action="connection_deleted",
|
|
resource=f"connection/{connection_id}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="success",
|
|
details={
|
|
"deleted_by": user_info.get("username"),
|
|
"deleter_role": user_info.get("role"),
|
|
"owner": conn_data.get("owner_username", ""),
|
|
},
|
|
)
|
|
|
|
return {"status": "deleted", "connection_id": connection_id}
|
|
|
|
except HTTPException:
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Exception during connection deletion",
|
|
connection_id=connection_id,
|
|
username=user_info.get("username"),
|
|
error=str(e),
|
|
error_type=type(e).__name__,
|
|
)
|
|
|
|
log_audit_event(
|
|
action="connection_deletion_error",
|
|
resource=f"connection/{connection_id}",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
result="error",
|
|
details={"error": str(e), "error_type": type(e).__name__},
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Internal error during connection deletion: {str(e)}"
|
|
)
|
|
|
|
|
|
|
|
@app.post("/api/connections/{connection_id}/extend", tags=["Connections"])
|
|
async def extend_connection_ttl(
|
|
connection_id: str,
|
|
request: Request,
|
|
additional_minutes: int = 30,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"""Extend active connection TTL
|
|
|
|
Args:
|
|
connection_id: Connection ID
|
|
additional_minutes: Minutes to add to TTL (default 30)
|
|
|
|
Returns:
|
|
Updated connection information
|
|
"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_info = get_current_user(request)
|
|
user_role = UserRole(user_info["role"])
|
|
|
|
if additional_minutes < 1 or additional_minutes > 120:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Additional minutes must be between 1 and 120"
|
|
)
|
|
|
|
conn_data = redis_connection_storage.get_connection(connection_id)
|
|
|
|
if not conn_data:
|
|
raise HTTPException(status_code=404, detail="Connection not found")
|
|
|
|
allowed, reason = PermissionChecker.check_connection_ownership(
|
|
user_role, user_info["username"], conn_data.get('owner_username', '')
|
|
)
|
|
|
|
if not allowed:
|
|
log_security_event(
|
|
event_type="unauthorized_connection_extension",
|
|
client_ip=client_ip,
|
|
user_agent=request.headers.get("user-agent", "unknown"),
|
|
details={
|
|
"connection_id": connection_id,
|
|
"username": user_info["username"],
|
|
"owner": conn_data.get('owner_username', ''),
|
|
"reason": reason
|
|
},
|
|
severity="medium"
|
|
)
|
|
raise HTTPException(status_code=403, detail=reason)
|
|
|
|
current_expires_at = datetime.fromisoformat(conn_data['expires_at'])
|
|
if current_expires_at.tzinfo is None:
|
|
current_expires_at = current_expires_at.replace(tzinfo=timezone.utc)
|
|
|
|
current_time = datetime.now(timezone.utc)
|
|
if current_expires_at <= current_time:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Connection has already expired and cannot be extended"
|
|
)
|
|
|
|
new_expires_at = current_expires_at + timedelta(minutes=additional_minutes)
|
|
new_ttl_minutes = int(conn_data.get('ttl_minutes', 60)) + additional_minutes
|
|
|
|
conn_data['expires_at'] = new_expires_at.isoformat()
|
|
conn_data['ttl_minutes'] = new_ttl_minutes
|
|
|
|
redis_connection_storage.add_connection(
|
|
connection_id,
|
|
conn_data,
|
|
ttl_seconds=None
|
|
)
|
|
|
|
log_connection_lifecycle(
|
|
connection_id=connection_id,
|
|
action="extended",
|
|
client_ip=client_ip,
|
|
hostname=conn_data.get('hostname', 'unknown'),
|
|
protocol=conn_data.get('protocol', 'unknown'),
|
|
details={
|
|
"extended_by": user_info["username"],
|
|
"additional_minutes": additional_minutes,
|
|
"new_ttl_minutes": new_ttl_minutes,
|
|
"new_expires_at": new_expires_at.isoformat()
|
|
}
|
|
)
|
|
|
|
owner_username = conn_data.get('owner_username')
|
|
if owner_username:
|
|
try:
|
|
await websocket_manager.send_connection_extended(
|
|
username=owner_username,
|
|
connection_id=connection_id,
|
|
hostname=conn_data.get('hostname', 'unknown'),
|
|
new_expires_at=new_expires_at,
|
|
additional_minutes=additional_minutes
|
|
)
|
|
except Exception as ws_error:
|
|
logger.warning("Failed to send WebSocket notification",
|
|
connection_id=connection_id,
|
|
error=str(ws_error))
|
|
|
|
logger.info("Connection TTL extended",
|
|
connection_id=connection_id,
|
|
username=user_info["username"],
|
|
additional_minutes=additional_minutes,
|
|
new_expires_at=new_expires_at.isoformat())
|
|
|
|
return {
|
|
"status": "extended",
|
|
"connection_id": connection_id,
|
|
"additional_minutes": additional_minutes,
|
|
"new_ttl_minutes": new_ttl_minutes,
|
|
"new_expires_at": new_expires_at.isoformat(),
|
|
"remaining_minutes": int((new_expires_at - current_time).total_seconds() / 60)
|
|
}
|
|
|
|
@app.get(
|
|
"/api/machines/saved",
|
|
tags=["Machines"],
|
|
response_model=SavedMachineList,
|
|
summary="List saved machines",
|
|
description="Retrieve all saved machines for current user"
|
|
)
|
|
async def get_saved_machines(
|
|
request: Request,
|
|
include_stats: bool = False,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"""Get user saved machines list
|
|
|
|
Args:
|
|
include_stats: Include connection statistics (optional)
|
|
"""
|
|
try:
|
|
user_info = get_current_user(request)
|
|
user_id = user_info["username"]
|
|
|
|
machines = saved_machines_db.get_user_machines(user_id, include_stats=include_stats)
|
|
|
|
machines_response = []
|
|
for machine in machines:
|
|
machine_dict = {
|
|
"id": str(machine['id']),
|
|
"user_id": machine['user_id'],
|
|
"name": machine['name'],
|
|
"hostname": machine['hostname'],
|
|
"port": machine['port'],
|
|
"protocol": machine['protocol'],
|
|
"os": machine.get('os'),
|
|
"username": machine.get('username'),
|
|
"description": machine.get('description'),
|
|
"tags": machine.get('tags') or [],
|
|
"is_favorite": machine.get('is_favorite', False),
|
|
"created_at": machine['created_at'].isoformat(),
|
|
"updated_at": machine['updated_at'].isoformat(),
|
|
"last_connected_at": machine['last_connected_at'].isoformat() if machine.get('last_connected_at') else None,
|
|
}
|
|
|
|
if include_stats and 'connection_stats' in machine:
|
|
machine_dict['connection_stats'] = machine['connection_stats']
|
|
|
|
machines_response.append(SavedMachineResponse(**machine_dict))
|
|
|
|
logger.info(
|
|
"Retrieved saved machines",
|
|
user_id=user_id,
|
|
count=len(machines_response)
|
|
)
|
|
|
|
return SavedMachineList(
|
|
total=len(machines_response),
|
|
machines=machines_response
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get saved machines", error=str(e))
|
|
raise HTTPException(status_code=500, detail=f"Failed to retrieve saved machines: {str(e)}")
|
|
|
|
|
|
@app.post(
|
|
"/api/machines/saved",
|
|
tags=["Machines"],
|
|
response_model=SavedMachineResponse,
|
|
summary="Save machine",
|
|
description="Create new saved machine entry with credentials"
|
|
)
|
|
async def create_saved_machine(
|
|
machine: SavedMachineCreate,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"""
|
|
Save a new machine in user profile
|
|
|
|
Security: Password transmitted over HTTPS and encrypted in DB (AES-256).
|
|
"""
|
|
try:
|
|
user_info = get_current_user(request)
|
|
user_id = user_info["username"]
|
|
user_token = get_current_user_token(request)
|
|
|
|
valid_protocols = ['rdp', 'ssh', 'vnc', 'telnet']
|
|
if machine.protocol.lower() not in valid_protocols:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid protocol. Must be one of: {', '.join(valid_protocols)}"
|
|
)
|
|
|
|
# Create machine in DB (passwords are NOT stored)
|
|
created_machine = saved_machines_db.create_machine(
|
|
user_id=user_id,
|
|
name=machine.name,
|
|
hostname=machine.hostname,
|
|
port=machine.port,
|
|
protocol=machine.protocol.lower(),
|
|
os=machine.os,
|
|
description=machine.description,
|
|
tags=machine.tags or [],
|
|
is_favorite=machine.is_favorite
|
|
)
|
|
|
|
logger.info(
|
|
"Saved machine created",
|
|
machine_id=created_machine['id'],
|
|
user_id=user_id,
|
|
name=machine.name
|
|
)
|
|
|
|
return SavedMachineResponse(
|
|
id=str(created_machine['id']),
|
|
user_id=created_machine['user_id'],
|
|
name=created_machine['name'],
|
|
hostname=created_machine['hostname'],
|
|
port=created_machine['port'],
|
|
protocol=created_machine['protocol'],
|
|
os=created_machine.get('os'),
|
|
username=created_machine.get('username'),
|
|
description=created_machine.get('description'),
|
|
tags=created_machine.get('tags') or [],
|
|
is_favorite=created_machine.get('is_favorite', False),
|
|
created_at=created_machine['created_at'].isoformat(),
|
|
updated_at=created_machine['updated_at'].isoformat(),
|
|
last_connected_at=None
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to create saved machine",
|
|
error=str(e),
|
|
user_id=user_info.get("username") if 'user_info' in locals() else "unknown")
|
|
raise HTTPException(status_code=500, detail=f"Failed to create saved machine: {str(e)}")
|
|
|
|
|
|
@app.get(
|
|
"/api/machines/saved/{machine_id}",
|
|
tags=["Machines"],
|
|
response_model=SavedMachineResponse,
|
|
summary="Get saved machine",
|
|
description="Retrieve specific saved machine details"
|
|
)
|
|
async def get_saved_machine(
|
|
machine_id: str,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"""Get saved machine information
|
|
|
|
Only owner can access the machine.
|
|
"""
|
|
try:
|
|
user_info = get_current_user(request)
|
|
user_id = user_info["username"]
|
|
|
|
machine = saved_machines_db.get_machine_by_id(machine_id, user_id)
|
|
|
|
if not machine:
|
|
raise HTTPException(status_code=404, detail="Machine not found")
|
|
|
|
return SavedMachineResponse(
|
|
id=str(machine['id']),
|
|
user_id=machine['user_id'],
|
|
name=machine['name'],
|
|
hostname=machine['hostname'],
|
|
port=machine['port'],
|
|
protocol=machine['protocol'],
|
|
os=machine.get('os'),
|
|
description=machine.get('description'),
|
|
tags=machine.get('tags') or [],
|
|
is_favorite=machine.get('is_favorite', False),
|
|
created_at=machine['created_at'].isoformat(),
|
|
updated_at=machine['updated_at'].isoformat(),
|
|
last_connected_at=machine['last_connected_at'].isoformat() if machine.get('last_connected_at') else None
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to get saved machine",
|
|
error=str(e),
|
|
machine_id=machine_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to retrieve machine: {str(e)}")
|
|
|
|
|
|
@app.put(
|
|
"/api/machines/saved/{machine_id}",
|
|
tags=["Machines"],
|
|
response_model=SavedMachineResponse,
|
|
summary="Update saved machine",
|
|
description="Modify machine configuration and credentials"
|
|
)
|
|
async def update_saved_machine(
|
|
machine_id: str,
|
|
machine: SavedMachineUpdate,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"""
|
|
Update saved machine
|
|
|
|
Security: Password transmitted over HTTPS and encrypted in DB (AES-256).
|
|
"""
|
|
try:
|
|
user_info = get_current_user(request)
|
|
user_id = user_info["username"]
|
|
user_token = get_current_user_token(request)
|
|
|
|
existing = saved_machines_db.get_machine_by_id(machine_id, user_id)
|
|
if not existing:
|
|
raise HTTPException(status_code=404, detail="Machine not found")
|
|
|
|
update_data = {}
|
|
|
|
if machine.name is not None:
|
|
update_data['name'] = machine.name
|
|
if machine.hostname is not None:
|
|
update_data['hostname'] = machine.hostname
|
|
if machine.port is not None:
|
|
update_data['port'] = machine.port
|
|
if machine.protocol is not None:
|
|
valid_protocols = ['rdp', 'ssh', 'vnc', 'telnet']
|
|
if machine.protocol.lower() not in valid_protocols:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid protocol. Must be one of: {', '.join(valid_protocols)}"
|
|
)
|
|
update_data['protocol'] = machine.protocol.lower()
|
|
if machine.os is not None:
|
|
update_data['os'] = machine.os
|
|
# Note: credentials are NOT stored, they are requested at connection time
|
|
if machine.description is not None:
|
|
update_data['description'] = machine.description
|
|
if machine.tags is not None:
|
|
update_data['tags'] = machine.tags
|
|
if machine.is_favorite is not None:
|
|
update_data['is_favorite'] = machine.is_favorite
|
|
|
|
# Update in DB
|
|
updated_machine = saved_machines_db.update_machine(machine_id, user_id, **update_data)
|
|
|
|
if not updated_machine:
|
|
raise HTTPException(status_code=404, detail="Machine not found after update")
|
|
|
|
logger.info(
|
|
"Saved machine updated",
|
|
machine_id=machine_id,
|
|
user_id=user_id,
|
|
updated_fields=list(update_data.keys())
|
|
)
|
|
|
|
return SavedMachineResponse(
|
|
id=str(updated_machine['id']),
|
|
user_id=updated_machine['user_id'],
|
|
name=updated_machine['name'],
|
|
hostname=updated_machine['hostname'],
|
|
port=updated_machine['port'],
|
|
protocol=updated_machine['protocol'],
|
|
os=updated_machine.get('os'),
|
|
username=updated_machine.get('username'),
|
|
description=updated_machine.get('description'),
|
|
tags=updated_machine.get('tags') or [],
|
|
is_favorite=updated_machine.get('is_favorite', False),
|
|
created_at=updated_machine['created_at'].isoformat(),
|
|
updated_at=updated_machine['updated_at'].isoformat(),
|
|
last_connected_at=updated_machine['last_connected_at'].isoformat() if updated_machine.get('last_connected_at') else None
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to update saved machine",
|
|
error=str(e),
|
|
machine_id=machine_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to update machine: {str(e)}")
|
|
|
|
|
|
@app.delete(
|
|
"/api/machines/saved/{machine_id}",
|
|
tags=["Machines"],
|
|
summary="Delete saved machine",
|
|
description="Remove saved machine from user profile"
|
|
)
|
|
async def delete_saved_machine(
|
|
machine_id: str,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"""Delete saved machine. Only owner can delete the machine. Connection history will also be deleted (CASCADE)."""
|
|
try:
|
|
user_info = get_current_user(request)
|
|
user_id = user_info["username"]
|
|
|
|
deleted = saved_machines_db.delete_machine(machine_id, user_id)
|
|
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail="Machine not found")
|
|
|
|
logger.info(
|
|
"Saved machine deleted",
|
|
machine_id=machine_id,
|
|
user_id=user_id
|
|
)
|
|
|
|
return {"success": True, "message": "Machine deleted successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to delete saved machine",
|
|
error=str(e),
|
|
machine_id=machine_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to delete machine: {str(e)}")
|
|
|
|
|
|
@app.post(
|
|
"/api/machines/saved/{machine_id}/connect",
|
|
tags=["Machines"],
|
|
summary="Connect to saved machine",
|
|
description="Record connection attempt and update timestamp"
|
|
)
|
|
async def connect_to_saved_machine(
|
|
machine_id: str,
|
|
request: Request,
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
):
|
|
"Mark connection to saved machine. Updates last_connected_at and adds history entry."
|
|
try:
|
|
user_info = get_current_user(request)
|
|
user_id = user_info["username"]
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
machine = saved_machines_db.get_machine_by_id(machine_id, user_id)
|
|
if not machine:
|
|
raise HTTPException(status_code=404, detail="Machine not found")
|
|
|
|
# Update last_connected_at
|
|
saved_machines_db.update_last_connected(machine_id, user_id)
|
|
|
|
# Add history entry
|
|
history_id = saved_machines_db.add_connection_history(
|
|
user_id=user_id,
|
|
machine_id=machine_id,
|
|
success=True,
|
|
client_ip=client_ip
|
|
)
|
|
|
|
logger.info(
|
|
"Connection to saved machine recorded",
|
|
machine_id=machine_id,
|
|
user_id=user_id,
|
|
history_id=history_id
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Connection recorded",
|
|
"history_id": history_id
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to record connection",
|
|
error=str(e),
|
|
machine_id=machine_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to record connection: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |