Files
Remote-Control-Center/guacamole_test_11_26/api/main.py
2025-11-25 09:58:37 +03:00

2903 lines
99 KiB
Python
Executable File

# Standard library imports
import asyncio
import base64
import json
import logging
import os
import platform
import socket
import subprocess
import sys
import time
import uuid
from collections import defaultdict, deque
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
# Third-party imports
import jwt
import psutil
import requests
import structlog
from cryptography.hazmat.primitives import serialization
from dotenv import load_dotenv
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
HTTPException,
Request,
Response,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
# Local imports
from core import (
ConnectionRequest,
ConnectionResponse,
GuacamoleAuthenticator,
LoginRequest,
LoginResponse,
PermissionChecker,
UserInfo,
UserRole,
create_jwt_token,
verify_jwt_token,
)
from core.audit_logger import immutable_audit_logger
from core.brute_force_protection import brute_force_protection
from core.csrf_protection import csrf_protection
from core.log_sanitizer import log_sanitizer, sanitize_log_processor
from core.middleware import get_current_user, get_current_user_token, jwt_auth_middleware
from core.models import (
BulkHealthCheckRequest,
BulkHealthCheckResponse,
BulkHealthCheckResult,
BulkSSHCommandRequest,
BulkSSHCommandResponse,
BulkSSHCommandResult,
ConnectionHistoryCreate,
ConnectionHistoryResponse,
SavedMachineCreate,
SavedMachineList,
SavedMachineResponse,
SavedMachineUpdate,
SSHCredentials,
)
from core.rate_limiter import redis_rate_limiter
from core.redis_storage import redis_connection_storage
from core.saved_machines_db import saved_machines_db
from core.session_storage import session_storage
from core.ssrf_protection import ssrf_protection
from core.token_blacklist import token_blacklist
from core.websocket_manager import websocket_manager
from security_config import SecurityConfig
from services.system_service import SystemService
from routers import bulk_router
load_dotenv()
enable_docs = os.getenv("ENABLE_DOCS", "true").lower() == "true"
tags_metadata = [
{
"name": "System",
"description": "System health checks and service status"
},
{
"name": "Authentication",
"description": "User authentication and authorization (login, logout, profile, permissions)"
},
{
"name": "Connections",
"description": "Remote desktop connection management (create, list, delete, extend TTL)"
},
{
"name": "Machines",
"description": "Machine management and saved machines CRUD operations"
},
{
"name": "Bulk Operations",
"description": "Bulk operations on multiple machines (health checks, SSH commands)"
}
]
app = FastAPI(
title="Remote Access API",
description="Remote desktop management API via Apache Guacamole. Supports RDP, VNC, SSH protocols with JWT authentication.",
version="1.0.0",
docs_url="/api/docs" if enable_docs else None,
redoc_url="/api/redoc" if enable_docs else None,
openapi_url="/api/openapi.json" if enable_docs else None,
openapi_tags=tags_metadata
)
app.include_router(bulk_router)
security = HTTPBearer()
# Structured logging configuration
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
LOG_FORMAT = os.getenv("LOG_FORMAT", "json")
def add_caller_info(
logger: Any, name: str, event_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""Add caller information to log entry."""
frame = sys._getframe()
try:
while frame and frame.f_back:
filename = frame.f_code.co_filename
if "structlog" not in filename and "logging" not in filename:
event_dict["caller"] = {
"file": filename.split("/")[-1],
"function": frame.f_code.co_name,
"line": frame.f_lineno,
}
break
frame = frame.f_back
finally:
del frame
return event_dict
def add_service_context(
logger: Any, name: str, event_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""Add service context to log entry."""
event_dict["service"] = "remote-access-api"
event_dict["version"] = "1.0.0"
return event_dict
json_processors = [
structlog.processors.TimeStamper(fmt="iso"),
add_service_context,
add_caller_info,
sanitize_log_processor,
structlog.processors.add_log_level,
structlog.processors.JSONRenderer()
]
text_processors = [
structlog.processors.TimeStamper(fmt="iso"),
add_service_context,
sanitize_log_processor,
structlog.processors.add_log_level,
structlog.dev.ConsoleRenderer(colors=True)
]
structlog.configure(
processors=json_processors if LOG_FORMAT == "json" else text_processors,
wrapper_class=structlog.stdlib.BoundLogger,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(message)s'
)
logger = structlog.get_logger()
security_logger = structlog.get_logger("security")
audit_logger = structlog.get_logger("audit")
performance_logger = structlog.get_logger("performance")
error_logger = structlog.get_logger("error")
def record_request_metric(
endpoint: str,
method: str,
status_code: int,
response_time_ms: float,
client_ip: str,
) -> None:
"""Record request metrics."""
metrics_storage["requests"]["total"] += 1
metrics_storage["requests"]["by_endpoint"][endpoint] += 1
metrics_storage["requests"]["by_method"][method] += 1
metrics_storage["requests"]["by_status"][status_code] += 1
metrics_storage["requests"]["by_ip"][client_ip] += 1
times = metrics_storage["performance"]["response_times"][endpoint]
times.append(response_time_ms)
if len(times) > 100:
times.pop(0)
def record_connection_metric(
protocol: str,
client_ip: str,
creation_time_ms: float,
success: bool = True,
) -> None:
"""Record connection creation metrics."""
if success:
metrics_storage["connections"]["total_created"] += 1
metrics_storage["connections"]["by_protocol"][protocol] += 1
metrics_storage["connections"]["by_ip"][client_ip] += 1
metrics_storage["connections"]["active_count"] = len(
redis_connection_storage.get_all_connections()
)
times = metrics_storage["performance"]["connection_creation_times"]
times.append(creation_time_ms)
if len(times) > 100:
times.pop(0)
else:
metrics_storage["errors"]["connection_failures"] += 1
def record_host_check_metric(check_time_ms: float, success: bool = True) -> None:
"""Record host check metrics."""
if success:
times = metrics_storage["performance"]["host_check_times"]
times.append(check_time_ms)
if len(times) > 100:
times.pop(0)
else:
metrics_storage["connections"]["failed_host_checks"] += 1
metrics_storage["errors"]["host_unreachable"] += 1
def record_error_metric(error_type: str) -> None:
"""Record error metrics."""
if error_type in metrics_storage["errors"]:
metrics_storage["errors"][error_type] += 1
def calculate_percentiles(values: List[float]) -> Dict[str, float]:
"""Calculate percentiles for value list"""
if not values:
return {"p50": 0, "p90": 0, "p95": 0, "p99": 0}
sorted_values = sorted(values)
n = len(sorted_values)
return {
"p50": sorted_values[int(n * 0.5)],
"p90": sorted_values[int(n * 0.9)],
"p95": sorted_values[int(n * 0.95)],
"p99": sorted_values[int(n * 0.99)]
}
def get_metrics_summary() -> Dict[str, Any]:
"""Get metrics summary."""
uptime_seconds = int((datetime.now() - service_start_time).total_seconds())
endpoint_stats = {
endpoint: {
"request_count": len(times),
"avg_response_time_ms": round(sum(times) / len(times), 2),
"percentiles": calculate_percentiles(times),
}
for endpoint, times in metrics_storage["performance"]["response_times"].items()
if times
}
connection_times = metrics_storage["performance"]["connection_creation_times"]
connection_stats = (
{
"avg_creation_time_ms": round(sum(connection_times) / len(connection_times), 2),
"percentiles": calculate_percentiles(connection_times),
}
if connection_times
else {}
)
host_check_times = metrics_storage["performance"]["host_check_times"]
host_check_stats = (
{
"avg_check_time_ms": round(sum(host_check_times) / len(host_check_times), 2),
"percentiles": calculate_percentiles(host_check_times),
}
if host_check_times
else {}
)
top_ips = dict(
sorted(
metrics_storage["requests"]["by_ip"].items(),
key=lambda x: x[1],
reverse=True,
)[:10]
)
top_protocols = dict(
sorted(
metrics_storage["connections"]["by_protocol"].items(),
key=lambda x: x[1],
reverse=True,
)
)
return {
"uptime_seconds": uptime_seconds,
"requests": {
"total": metrics_storage["requests"]["total"],
"requests_per_second": round(
metrics_storage["requests"]["total"] / max(uptime_seconds, 1), 2
),
"by_status": dict(metrics_storage["requests"]["by_status"]),
"by_method": dict(metrics_storage["requests"]["by_method"]),
"top_ips": top_ips,
},
"connections": {
"total_created": metrics_storage["connections"]["total_created"],
"currently_active": len(redis_connection_storage.get_all_connections()),
"by_protocol": top_protocols,
"failed_host_checks": metrics_storage["connections"]["failed_host_checks"],
},
"performance": {
"endpoints": endpoint_stats,
"connection_creation": connection_stats,
"host_checks": host_check_stats,
},
"errors": dict(metrics_storage["errors"]),
}
def log_security_event(
event_type: str,
client_ip: str,
user_agent: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
severity: str = "info",
username: Optional[str] = None,
) -> None:
"""Log security events to immutable audit log."""
immutable_audit_logger.log_security_event(
event_type=event_type,
client_ip=client_ip,
user_agent=user_agent,
details=details,
severity=severity,
username=username,
)
event = {
"event_type": event_type,
"client_ip": client_ip,
"user_agent": user_agent or "unknown",
"severity": severity,
"timestamp": datetime.now().isoformat(),
}
if details:
event.update(details)
severity_loggers = {
"critical": security_logger.critical,
"high": security_logger.error,
"medium": security_logger.warning,
}
logger_func = severity_loggers.get(severity, security_logger.info)
logger_func("Security event", **event)
def log_audit_event(
action: str,
resource: str,
client_ip: str,
user_agent: Optional[str] = None,
result: str = "success",
details: Optional[Dict[str, Any]] = None,
username: Optional[str] = None,
) -> None:
"""Log audit events to immutable audit log."""
immutable_audit_logger.log_audit_event(
action=action,
resource=resource,
client_ip=client_ip,
user_agent=user_agent,
result=result,
details=details,
username=username,
)
event = {
"action": action,
"resource": resource,
"client_ip": client_ip,
"user_agent": user_agent or "unknown",
"result": result,
"timestamp": datetime.now().isoformat(),
}
if details:
event.update(details)
if result == "failure":
audit_logger.warning("Audit event", **event)
else:
audit_logger.info("Audit event", **event)
def log_performance_event(
operation: str,
duration_ms: float,
details: Optional[Dict[str, Any]] = None,
) -> None:
"""Log performance events."""
event = {
"operation": operation,
"duration_ms": round(duration_ms, 2),
"timestamp": datetime.now().isoformat(),
}
if details:
event.update(details)
if duration_ms > 5000:
performance_logger.warning("Slow operation", **event)
elif duration_ms > 1000:
performance_logger.info("Performance event", **event)
else:
performance_logger.debug("Performance event", **event)
def log_connection_lifecycle(
connection_id: str,
action: str,
client_ip: str,
hostname: str,
protocol: str,
details: Optional[Dict[str, Any]] = None,
) -> None:
"""Log connection lifecycle events."""
event = {
"connection_id": connection_id,
"action": action,
"client_ip": client_ip,
"hostname": hostname,
"protocol": protocol,
"timestamp": datetime.now().isoformat(),
}
if details:
event.update(details)
if action == "failed":
audit_logger.error("Connection lifecycle", **event)
else:
audit_logger.info("Connection lifecycle", **event)
def log_error_with_context(
error: Exception,
operation: str,
context: Optional[Dict[str, Any]] = None,
) -> None:
"""Log errors with context."""
event = {
"operation": operation,
"error_type": type(error).__name__,
"error_message": str(error),
"timestamp": datetime.now().isoformat(),
}
if context:
event.update(context)
error_logger.error("Application error", **event)
# CORS middleware - configured via .env
allowed_origins_str = os.getenv("ALLOWED_ORIGINS")
if not allowed_origins_str:
logger.error("ALLOWED_ORIGINS environment variable is not set!")
raise RuntimeError(
"ALLOWED_ORIGINS must be set in .env file or docker-compose.yml. "
"Example: ALLOWED_ORIGINS=https://mc.exbytestudios.com,http://localhost:5173"
)
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",") if origin.strip()]
if not allowed_origins:
logger.error("ALLOWED_ORIGINS is empty after parsing!")
raise RuntimeError("ALLOWED_ORIGINS must contain at least one valid origin")
logger.info("CORS configured", allowed_origins=allowed_origins)
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
allow_headers=["*"],
)
# Electron desktop app CORS middleware
@app.middleware("http")
async def electron_cors_middleware(
request: Request, call_next: Any
) -> Response:
"""Handle CORS for Electron desktop app (missing/custom Origin headers)."""
origin = request.headers.get("origin")
response = await call_next(request)
cors_headers = {
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH",
"Access-Control-Allow-Headers": "*",
}
if not origin or origin == "null":
logger.debug(
"Request without Origin header (Electron or API client)",
path=request.url.path,
method=request.method,
user_agent=request.headers.get("user-agent", "unknown")[:50],
)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers.update(cors_headers)
return response
if origin.startswith(("file://", "app://")):
logger.debug(
"Request from Electron with custom protocol",
origin=origin,
path=request.url.path,
method=request.method,
)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers.update(cors_headers)
return response
return response
@app.middleware("http")
async def auth_middleware(request: Request, call_next: Any) -> Response:
"""JWT authentication middleware."""
return await jwt_auth_middleware(request, call_next)
@app.middleware("http")
async def logging_and_metrics_middleware(
request: Request, call_next: Any
) -> Response:
"""Request logging and metrics collection middleware."""
start_time = time.time()
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
request_id = str(uuid.uuid4())[:8]
logger.debug(
"Request started",
request_id=request_id,
method=request.method,
path=request.url.path,
query=str(request.query_params) if request.query_params else None,
client_ip=client_ip,
user_agent=user_agent,
)
try:
response = await call_next(request)
response_time_ms = (time.time() - start_time) * 1000
record_request_metric(
endpoint=request.url.path,
method=request.method,
status_code=response.status_code,
response_time_ms=response_time_ms,
client_ip=client_ip,
)
log_level = logging.WARNING if response.status_code >= 400 else logging.INFO
logger.log(
log_level,
"Request completed",
request_id=request_id,
method=request.method,
path=request.url.path,
status_code=response.status_code,
response_time_ms=round(response_time_ms, 2),
client_ip=client_ip,
)
if response_time_ms > 1000:
log_performance_event(
operation=f"{request.method} {request.url.path}",
duration_ms=response_time_ms,
details={
"request_id": request_id,
"client_ip": client_ip,
"status_code": response.status_code,
},
)
response.headers["X-Request-ID"] = request_id
return response
except Exception as e:
response_time_ms = (time.time() - start_time) * 1000
log_error_with_context(
error=e,
operation=f"{request.method} {request.url.path}",
context={
"request_id": request_id,
"client_ip": client_ip,
"user_agent": user_agent,
"response_time_ms": round(response_time_ms, 2),
},
)
raise
@app.middleware("http")
async def csrf_middleware(request: Request, call_next: Any) -> Response:
"""CSRF protection middleware."""
if not csrf_protection.should_protect_endpoint(
request.method, request.url.path
):
return await call_next(request)
try:
user_info = None
user_id = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
try:
jwt_payload = verify_jwt_token(token)
if jwt_payload:
user_info = jwt_payload
user_id = jwt_payload.get("username")
except Exception as e:
logger.debug(
"JWT token validation failed in CSRF middleware",
error=str(e),
)
if not user_info or not user_id:
return await call_next(request)
if auth_header and auth_header.startswith("Bearer "):
logger.debug(
"JWT authentication detected, skipping CSRF validation",
user_id=user_id,
method=request.method,
path=request.url.path,
)
return await call_next(request)
csrf_token = request.headers.get("X-CSRF-Token")
if not csrf_token:
logger.warning(
"CSRF token missing for non-JWT auth",
user_id=user_id,
method=request.method,
path=request.url.path,
)
return Response(
content=json.dumps({
"error": "CSRF token required",
"message": "X-CSRF-Token header is required for cookie-based authentication",
}),
status_code=403,
media_type="application/json",
)
if not csrf_protection.validate_csrf_token(csrf_token, user_id):
logger.warning(
"CSRF token validation failed",
user_id=user_id,
method=request.method,
path=request.url.path,
token_preview=csrf_token[:16] + "...",
)
return Response(
content=json.dumps({
"error": "Invalid CSRF token",
"message": "CSRF token validation failed",
}),
status_code=403,
media_type="application/json",
)
response = await call_next(request)
new_csrf_token = csrf_protection.generate_csrf_token(user_id)
response.headers["X-CSRF-Token"] = new_csrf_token
return response
except Exception as e:
logger.error("CSRF middleware error", error=str(e))
return Response(
content=json.dumps({
"error": "CSRF protection error",
"message": "Internal server error",
}),
status_code=500,
media_type="application/json",
)
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next: Any) -> Response:
"""Rate limiting middleware with Redis."""
excluded_paths = {
"/health",
"/health/detailed",
"/health/ready",
"/health/live",
"/",
"/docs",
"/openapi.json",
"/rate-limit/status",
"/metrics",
"/stats",
}
if request.url.path in excluded_paths:
return await call_next(request)
client_ip = request.client.host if request.client else "unknown"
allowed, headers = redis_rate_limiter.check_rate_limit(
client_ip=client_ip,
requests_limit=RATE_LIMIT_REQUESTS,
window_seconds=RATE_LIMIT_WINDOW,
)
if not allowed:
user_agent = request.headers.get("user-agent", "unknown")
log_security_event(
event_type="rate_limit_exceeded",
client_ip=client_ip,
user_agent=user_agent,
details={
"path": request.url.path,
"method": request.method,
"limit": RATE_LIMIT_REQUESTS,
"window_seconds": RATE_LIMIT_WINDOW,
},
severity="medium",
)
record_error_metric("rate_limit_blocks")
response = Response(
content=json.dumps({
"error": "Rate limit exceeded",
"message": f"Too many requests. Limit: {RATE_LIMIT_REQUESTS} per {RATE_LIMIT_WINDOW} seconds",
"retry_after": headers.get(
"X-RateLimit-Reset", int(time.time() + RATE_LIMIT_WINDOW)
),
}),
status_code=429,
media_type="application/json",
)
else:
response = await call_next(request)
for header_name, header_value in headers.items():
response.headers[header_name] = str(header_value)
return response
# Guacamole configuration
GUACAMOLE_URL = os.getenv("GUACAMOLE_URL", "http://localhost:8080")
GUACAMOLE_PUBLIC_URL = os.getenv("GUACAMOLE_PUBLIC_URL", GUACAMOLE_URL)
guacamole_authenticator = GuacamoleAuthenticator()
# Rate limiting configuration
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "10"))
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
RATE_LIMIT_ENABLED = os.getenv("RATE_LIMIT_ENABLED", "true").lower() == "true"
service_start_time = datetime.now()
# Metrics storage
metrics_storage = {
"requests": {
"total": 0,
"by_endpoint": defaultdict(int),
"by_method": defaultdict(int),
"by_status": defaultdict(int),
"by_ip": defaultdict(int),
},
"connections": {
"total_created": 0,
"by_protocol": defaultdict(int),
"by_ip": defaultdict(int),
"active_count": 0,
"failed_host_checks": 0,
},
"performance": {
"response_times": defaultdict(list),
"connection_creation_times": [],
"host_check_times": [],
},
"errors": {
"rate_limit_blocks": 0,
"connection_failures": 0,
"host_unreachable": 0,
"authentication_failures": 0,
},
}
def decrypt_password_from_request(
encrypted_password: str,
request: Request,
context: Optional[Dict[str, Any]] = None,
) -> str:
"""Return password as-is (protected by HTTPS).
Args:
encrypted_password: Password string from client
request: FastAPI request object
context: Optional context dict for logging
Returns:
Password as provided by client
"""
return encrypted_password
def generate_connection_url(connection_id: str, guacamole_token: str) -> str:
"""Generate Guacamole connection URL.
Args:
connection_id: Guacamole connection ID
guacamole_token: Guacamole auth token (NOT JWT)
Returns:
Full URL for Guacamole client connection
"""
encoded_connection_id = base64.b64encode(
f"{connection_id}\0c\0postgresql".encode()
).decode()
connection_url = (
f"{GUACAMOLE_PUBLIC_URL}/guacamole/?token={guacamole_token}"
f"#/client/{encoded_connection_id}"
)
logger.debug(
"Connection URL generated",
connection_id=connection_id,
encoded_connection_id=encoded_connection_id,
url_length=len(connection_url),
)
return connection_url
class GuacamoleClient:
"""Client for interacting with Guacamole API."""
def __init__(self) -> None:
"""Initialize Guacamole client."""
self.base_url = GUACAMOLE_URL
self.session = requests.Session()
self.authenticator = guacamole_authenticator
def get_system_token(self) -> str:
"""Get system token for service operations."""
return self.authenticator.get_system_token()
def create_connection_with_user_token(
self, connection_request: ConnectionRequest, guacamole_token: str
) -> Dict[str, Any]:
"""Create new Guacamole connection with Guacamole token.
Note: guacamole_token is GUACAMOLE auth token, NOT JWT
"""
if not connection_request.port:
port_map = {"rdp": 3389, "vnc": 5900, "ssh": 22}
connection_request.port = port_map.get(
connection_request.protocol, 3389
)
original_hostname = connection_request.hostname
resolved_ip = original_hostname
try:
resolved_info = socket.getaddrinfo(
original_hostname, None, socket.AF_INET
)
if resolved_info:
resolved_ip = resolved_info[0][4][0]
logger.info(
"Hostname resolved for Guacamole connection",
original_hostname=original_hostname,
resolved_ip=resolved_ip,
protocol=connection_request.protocol,
port=connection_request.port,
)
except (socket.gaierror, socket.herror, OSError) as e:
logger.warning(
"Failed to resolve hostname, using as-is",
hostname=original_hostname,
error=str(e),
message="Guacamole will receive the original hostname",
)
resolved_ip = original_hostname
connection_config = {
"name": f"Auto-{original_hostname}-{int(time.time())}",
"protocol": connection_request.protocol,
"parameters": {
"hostname": resolved_ip,
"port": str(connection_request.port),
},
"attributes": {},
}
if connection_request.protocol == "rdp":
connection_config["parameters"].update({
"security": "any",
"ignore-cert": "true",
"enable-wallpaper": "false",
})
if connection_request.username:
connection_config["parameters"][
"username"
] = connection_request.username
if connection_request.password:
connection_config["parameters"][
"password"
] = connection_request.password
elif connection_request.protocol == "vnc":
if connection_request.password:
connection_config["parameters"][
"password"
] = connection_request.password
elif connection_request.protocol == "ssh":
if connection_request.username:
connection_config["parameters"][
"username"
] = connection_request.username
if connection_request.password:
connection_config["parameters"][
"password"
] = connection_request.password
if connection_request.enable_sftp is not None:
connection_config["parameters"]["enable-sftp"] = (
"true" if connection_request.enable_sftp else "false"
)
if (
connection_request.enable_sftp
and connection_request.sftp_root_directory
):
connection_config["parameters"][
"sftp-root-directory"
] = connection_request.sftp_root_directory
if (
connection_request.enable_sftp
and connection_request.sftp_server_alive_interval
and connection_request.sftp_server_alive_interval > 0
):
connection_config["parameters"][
"server-alive-interval"
] = str(connection_request.sftp_server_alive_interval)
else:
connection_config["parameters"]["enable-sftp"] = "true"
created_connection = self.authenticator.create_connection_with_token(
connection_config, guacamole_token
)
if not created_connection:
raise HTTPException(
status_code=500, detail="Failed to create connection in Guacamole"
)
connection_id = created_connection.get("identifier")
connection_url = generate_connection_url(connection_id, guacamole_token)
logger.info(
"Connection created",
connection_id=connection_id,
token_type="guacamole",
)
return {
"connection_id": connection_id,
"connection_url": connection_url,
"status": "created",
"auth_token": guacamole_token,
}
def delete_connection_with_user_token(
self, connection_id: str, auth_token: str
) -> bool:
"""Delete Guacamole connection using user token."""
return self.authenticator.delete_connection_with_token(
connection_id, auth_token
)
def get_user_connections(self, auth_token: str) -> List[Dict[str, Any]]:
"""Get user connections list."""
return self.authenticator.get_user_connections(auth_token)
def get_all_connections_with_system_token(self) -> List[Dict[str, Any]]:
"""Get all connections using system token."""
system_token = self.get_system_token()
return self.authenticator.get_user_connections(system_token)
def delete_connection_with_system_token(self, connection_id: str) -> bool:
"""Delete connection using system token."""
system_token = self.get_system_token()
return self.authenticator.delete_connection_with_token(
connection_id, system_token
)
guacamole_client = GuacamoleClient()
async def wait_for_guacamole(
timeout_seconds: int = 30, check_interval: float = 1.0
) -> bool:
"""Wait for Guacamole to become available.
Args:
timeout_seconds: Maximum wait time in seconds
check_interval: Check interval in seconds
Returns:
True if Guacamole is available, False on timeout
"""
start_time = time.time()
attempt = 0
logger.info(
"Waiting for Guacamole to become available...",
timeout_seconds=timeout_seconds,
guacamole_url=GUACAMOLE_URL,
)
while (time.time() - start_time) < timeout_seconds:
attempt += 1
try:
response = await asyncio.to_thread(
requests.get, f"{GUACAMOLE_URL}/guacamole/", timeout=2
)
if response.status_code in (200, 401, 403, 404):
elapsed = time.time() - start_time
logger.info(
"Guacamole is available",
attempt=attempt,
elapsed_seconds=round(elapsed, 2),
status_code=response.status_code,
)
return True
except requests.exceptions.RequestException as e:
logger.debug(
"Guacamole not ready yet",
attempt=attempt,
elapsed_seconds=round(time.time() - start_time, 2),
error=str(e)[:100],
)
await asyncio.sleep(check_interval)
logger.warning(
"Guacamole did not become available within timeout",
timeout_seconds=timeout_seconds,
total_attempts=attempt,
)
return False
async def cleanup_orphaned_guacamole_connections():
"""Clean up orphaned Guacamole connections if Redis is empty
Needed after FLUSHDB or first startup after crash when Guacamole
may have orphaned connections without Redis records.
Returns:
Number of deleted connections
"""
try:
all_connections = redis_connection_storage.get_all_connections()
if len(all_connections) > 0:
logger.info("Redis has active connections, skipping orphaned cleanup",
redis_connections_count=len(all_connections))
return 0
logger.warning("Redis is empty, checking for orphaned Guacamole connections",
message="This usually happens after FLUSHDB or service restart")
guac_connections = guacamole_client.get_all_connections_with_system_token()
if not guac_connections or len(guac_connections) == 0:
logger.info("No Guacamole connections found, nothing to clean up")
return 0
logger.warning("Found orphaned Guacamole connections",
guacamole_connections_count=len(guac_connections),
message="Deleting all orphaned connections")
deleted_count = 0
for conn in guac_connections:
conn_id = conn.get('identifier')
if conn_id:
try:
if guacamole_client.delete_connection_with_system_token(conn_id):
deleted_count += 1
logger.debug("Deleted orphaned connection",
connection_id=conn_id,
connection_name=conn.get('name', 'unknown'))
except Exception as e:
logger.error("Failed to delete orphaned connection",
connection_id=conn_id,
error=str(e))
if deleted_count > 0:
logger.info("Orphaned connections cleanup completed",
deleted_count=deleted_count,
total_found=len(guac_connections))
return deleted_count
except Exception as e:
logger.error("Error during orphaned connections cleanup", error=str(e))
return 0
async def cleanup_expired_connections_once(log_action: str = "expired"):
"""Execute one iteration of expired connections cleanup
Args:
log_action: Action for logging (expired, startup_cleanup, etc.)
Returns:
Number of deleted connections
"""
try:
current_time = datetime.now(timezone.utc)
expired_connections = []
all_connections = redis_connection_storage.get_all_connections()
for conn_id, conn_data in all_connections.items():
expires_at = datetime.fromisoformat(conn_data['expires_at'])
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
if expires_at <= current_time:
expired_connections.append(conn_id)
deleted_count = 0
for conn_id in expired_connections:
conn_data = redis_connection_storage.get_connection(conn_id)
if conn_data:
if guacamole_client.delete_connection_with_system_token(conn_id):
deleted_count += 1
log_connection_lifecycle(
connection_id=conn_id,
action=log_action,
client_ip="system",
hostname=conn_data.get('hostname', 'unknown'),
protocol=conn_data.get('protocol', 'unknown'),
details={
"ttl_minutes": conn_data.get('ttl_minutes'),
"created_at": conn_data.get('created_at'),
"expires_at": conn_data.get('expires_at')
}
)
owner_username = conn_data.get('owner_username')
if owner_username:
try:
await websocket_manager.send_connection_expired(
username=owner_username,
connection_id=conn_id,
hostname=conn_data.get('hostname', 'unknown'),
protocol=conn_data.get('protocol', 'unknown')
)
except Exception as ws_error:
logger.warning("Failed to send WebSocket notification",
connection_id=conn_id,
error=str(ws_error))
redis_connection_storage.delete_connection(conn_id)
if deleted_count > 0:
logger.info("Cleanup completed",
action=log_action,
expired_count=len(expired_connections),
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Error during cleanup", error=str(e))
return 0
async def cleanup_expired_connections():
"""Background task to remove expired connections"""
while True:
try:
await cleanup_expired_connections_once("expired")
except Exception as e:
logger.error("Error during cleanup task", error=str(e))
await asyncio.sleep(60)
async def check_expiring_connections():
"""Background task to check connections expiring soon (warns 5 min before)"""
warned_connections = set()
while True:
try:
current_time = datetime.now(timezone.utc)
warning_threshold = current_time + timedelta(minutes=5)
all_connections = redis_connection_storage.get_all_connections()
for conn_id, conn_data in all_connections.items():
expires_at = datetime.fromisoformat(conn_data['expires_at'])
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
if current_time < expires_at <= warning_threshold and conn_id not in warned_connections:
owner_username = conn_data.get('owner_username')
if owner_username:
minutes_remaining = max(1, int((expires_at - current_time).total_seconds() / 60))
try:
await websocket_manager.send_connection_will_expire(
username=owner_username,
connection_id=conn_id,
hostname=conn_data.get('hostname', 'unknown'),
protocol=conn_data.get('protocol', 'unknown'),
minutes_remaining=minutes_remaining
)
warned_connections.add(conn_id)
logger.info("Connection expiration warning sent",
connection_id=conn_id,
username=owner_username,
minutes_remaining=minutes_remaining)
except Exception as ws_error:
logger.warning("Failed to send expiration warning",
connection_id=conn_id,
error=str(ws_error))
elif expires_at <= current_time and conn_id in warned_connections:
warned_connections.discard(conn_id)
current_conn_ids = set(all_connections.keys())
warned_connections &= current_conn_ids
except Exception as e:
logger.error("Error during expiring connections check", error=str(e))
await asyncio.sleep(30)
async def cleanup_ssrf_cache():
"""Background task to clean up SSRF cache"""
while True:
try:
ssrf_protection.cleanup_expired_cache()
except Exception as e:
logger.error("Error during SSRF cache cleanup", error=str(e))
await asyncio.sleep(180)
async def cleanup_csrf_tokens():
"""Background task to clean up expired CSRF tokens"""
while True:
try:
csrf_protection.cleanup_expired_tokens()
except Exception as e:
logger.error("Error during CSRF token cleanup", error=str(e))
await asyncio.sleep(600)
def schedule_connection_deletion(
connection_id: str,
ttl_minutes: int,
auth_token: str,
guacamole_username: str,
hostname: str = "unknown",
protocol: str = "unknown",
owner_username: str = "unknown",
owner_role: str = "unknown"
):
"""Schedule connection deletion via TTL using Redis"""
expires_at = datetime.now(timezone.utc) + timedelta(minutes=ttl_minutes)
created_at = datetime.now(timezone.utc)
connection_data = {
'connection_id': connection_id,
'created_at': created_at.isoformat(),
'expires_at': expires_at.isoformat(),
'ttl_minutes': ttl_minutes,
'auth_token': auth_token,
'guacamole_username': guacamole_username,
'hostname': hostname,
'protocol': protocol,
'owner_username': owner_username,
'owner_role': owner_role
}
redis_connection_storage.add_connection(
connection_id,
connection_data,
ttl_seconds=None
)
@app.on_event("startup")
async def startup_event():
"""Application startup initialization"""
startup_info = {
"guacamole_url": GUACAMOLE_URL,
"guacamole_public_url": GUACAMOLE_PUBLIC_URL,
"rate_limiting_enabled": RATE_LIMIT_ENABLED,
"rate_limit_config": f"{RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds" if RATE_LIMIT_ENABLED else None,
"log_level": LOG_LEVEL,
"log_format": LOG_FORMAT,
"python_version": sys.version.split()[0],
"platform": platform.system()
}
print("Starting Remote Access API...")
print(f"Guacamole URL (internal): {GUACAMOLE_URL}")
print(f"Guacamole Public URL (client): {GUACAMOLE_PUBLIC_URL}")
print(f"Rate Limiting: {'Enabled' if RATE_LIMIT_ENABLED else 'Disabled'}")
if RATE_LIMIT_ENABLED:
print(f"Rate Limit: {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds")
print(f"Log Level: {LOG_LEVEL}, Format: {LOG_FORMAT}")
logger.info("Application startup", **startup_info)
log_audit_event(
action="application_started",
resource="system",
client_ip="system",
details=startup_info
)
# Cleanup expired connections on startup
guacamole_ready = await wait_for_guacamole(timeout_seconds=30, check_interval=1.0)
if not guacamole_ready:
logger.warning("Guacamole not available, skipping startup cleanup",
message="Cleanup will be performed by background task when Guacamole becomes available")
else:
logger.info("Checking for orphaned Guacamole connections...")
orphaned_count = await cleanup_orphaned_guacamole_connections()
if orphaned_count > 0:
logger.warning(
"Orphaned cleanup completed",
deleted_connections=orphaned_count,
message="Removed orphaned Guacamole connections (no Redis records)"
)
logger.info("Starting cleanup of expired connections from previous runs...")
deleted_count = await cleanup_expired_connections_once("startup_cleanup")
if deleted_count > 0:
logger.info(
"Startup cleanup completed",
deleted_connections=deleted_count,
message="Removed expired connections from previous application runs"
)
else:
logger.info(
"Startup cleanup completed",
deleted_connections=0,
message="No expired connections found"
)
asyncio.create_task(cleanup_expired_connections())
asyncio.create_task(check_expiring_connections())
asyncio.create_task(cleanup_ssrf_cache())
asyncio.create_task(cleanup_csrf_tokens())
logger.info(
"Background tasks started",
ttl_cleanup=True,
expiring_connections_check=True,
rate_limit_cleanup=RATE_LIMIT_ENABLED,
ssrf_cache_cleanup=True,
csrf_token_cleanup=True
)
@app.get("/", tags=["System"])
async def root():
return {"message": "Remote Access API is running"}
@app.get("/api/health", tags=["System"])
async def health_check():
"""Health check with component status"""
start_time = time.time()
try:
response = await asyncio.to_thread(
requests.get,
f"{GUACAMOLE_URL}/guacamole",
timeout=5
)
guacamole_web = {
"status": "ok" if response.status_code == 200 else "error",
"response_time_ms": round(response.elapsed.total_seconds() * 1000, 2),
"status_code": response.status_code
}
except Exception as e:
guacamole_web = {"status": "error", "error": str(e)}
database = SystemService.check_database_connection(guacamole_client, GUACAMOLE_URL)
guacd = SystemService.check_guacd_daemon()
system = SystemService.check_system_resources()
system_info = SystemService(service_start_time).get_system_info()
components = [guacamole_web, database, guacd, system]
overall_status = "ok"
for component in components:
if component.get("status") == "error":
overall_status = "error"
break
elif component.get("status") == "critical":
overall_status = "critical"
elif component.get("status") == "warning" and overall_status == "ok":
overall_status = "warning"
check_duration = round((time.time() - start_time) * 1000, 2)
return {
"overall_status": overall_status,
"timestamp": datetime.now().isoformat(),
"check_duration_ms": check_duration,
"system_info": system_info,
"components": {
"guacamole_web": guacamole_web,
"database": database,
"guacd_daemon": guacd,
"system_resources": system
},
"statistics": {
"active_connections": len(redis_connection_storage.get_all_connections()),
"rate_limiting": {
"enabled": RATE_LIMIT_ENABLED,
"limit": RATE_LIMIT_REQUESTS,
"window_seconds": RATE_LIMIT_WINDOW,
"active_clients": redis_rate_limiter.get_stats().get("active_rate_limits", 0) if RATE_LIMIT_ENABLED else 0
}
}
}
@app.websocket("/ws/notifications")
async def websocket_notifications(websocket: WebSocket):
"""
WebSocket endpoint for real-time notifications
Events:
- connection_expired: Connection expired
- connection_deleted: Connection deleted
- connection_will_expire: Connection will expire soon (5 min warning)
- jwt_will_expire: JWT will expire soon (5 min warning)
- jwt_expired: JWT expired
- connection_extended: Connection extended
Connection protocol:
1. Client sends JWT token on connection
2. Server validates token
3. Server sends confirmation
4. Server starts sending notifications
"""
username = None
try:
await websocket.accept()
logger.info("WebSocket connection accepted, waiting for auth")
auth_message = await asyncio.wait_for(
websocket.receive_json(),
timeout=5.0
)
if auth_message.get("type") != "auth" or not auth_message.get("token"):
await websocket.close(code=4001, reason="Authentication required")
logger.warning("WebSocket connection rejected: no auth")
return
token = auth_message["token"]
payload = verify_jwt_token(token)
if not payload:
await websocket.close(code=4001, reason="Invalid token")
logger.warning("WebSocket connection rejected: invalid token")
return
username = payload.get("username")
if not username:
await websocket.close(code=4001, reason="Invalid token payload")
return
await websocket_manager.connect(websocket, username)
await websocket.send_json({
"type": "connected",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"username": username,
"message": "Successfully connected to notifications stream"
}
})
logger.info("WebSocket client authenticated and connected",
username=username)
while True:
try:
message = await asyncio.wait_for(
websocket.receive_json(),
timeout=30.0
)
if message.get("type") == "ping":
await websocket.send_json({
"type": "pong",
"timestamp": datetime.now(timezone.utc).isoformat()
})
except asyncio.TimeoutError:
# Timeout - send ping from server
try:
await websocket.send_json({
"type": "ping",
"timestamp": datetime.now(timezone.utc).isoformat()
})
except (WebSocketDisconnect, ConnectionError, RuntimeError):
break
except WebSocketDisconnect:
break
except Exception as e:
logger.error("Error in WebSocket loop",
username=username,
error=str(e))
break
except asyncio.TimeoutError:
try:
await websocket.close(code=4408, reason="Authentication timeout")
except (WebSocketDisconnect, ConnectionError, RuntimeError):
pass # WebSocket may not be accepted yet
logger.warning("WebSocket connection timeout during auth")
except WebSocketDisconnect:
logger.info("WebSocket client disconnected",
username=username)
except Exception as e:
logger.error("WebSocket error",
username=username,
error=str(e))
try:
await websocket.close(code=1011, reason="Internal error")
except (WebSocketDisconnect, ConnectionError, RuntimeError):
pass
finally:
if username:
await websocket_manager.disconnect(websocket, username)
@app.post(
"/api/auth/login",
tags=["Authentication"],
response_model=LoginResponse,
summary="Authenticate user",
description="Login with username and password to receive JWT token"
)
async def login(login_request: LoginRequest, request: Request):
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
try:
# Check brute-force protection
allowed, reason, protection_details = brute_force_protection.check_login_allowed(
client_ip, login_request.username
)
if not allowed:
# Log security event
immutable_audit_logger.log_security_event(
event_type="login_blocked",
client_ip=client_ip,
user_agent=user_agent,
details={
"username": login_request.username,
"reason": reason,
"protection_details": protection_details
},
severity="high",
username=login_request.username
)
raise HTTPException(
status_code=429,
detail=f"Login blocked: {reason}"
)
user_info = guacamole_authenticator.authenticate_user(
login_request.username,
login_request.password
)
if not user_info:
brute_force_protection.record_failed_login(
client_ip, login_request.username, "invalid_credentials"
)
immutable_audit_logger.log_security_event(
event_type="login_failed",
client_ip=client_ip,
user_agent=user_agent,
details={"username": login_request.username},
severity="medium",
username=login_request.username
)
record_error_metric("authentication_failures")
raise HTTPException(
status_code=401,
detail="Invalid username or password"
)
brute_force_protection.record_successful_login(client_ip, login_request.username)
jwt_token = guacamole_authenticator.create_jwt_for_user(user_info)
immutable_audit_logger.log_audit_event(
action="user_login",
resource=f"user/{login_request.username}",
client_ip=client_ip,
user_agent=user_agent,
result="success",
details={
"role": user_info["role"],
"permissions_count": len(user_info.get("permissions", []))
},
username=login_request.username
)
expires_in = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "60")) * 60
return LoginResponse(
access_token=jwt_token,
token_type="bearer",
expires_in=expires_in,
user_info={
"username": user_info["username"],
"role": user_info["role"],
"permissions": user_info.get("permissions", []),
"full_name": user_info.get("full_name"),
"email": user_info.get("email"),
"organization": user_info.get("organization"),
"organizational_role": user_info.get("organizational_role")
}
)
except HTTPException:
raise
except Exception as e:
logger.error("Unexpected error during login",
username=login_request.username,
client_ip=client_ip,
error=str(e))
record_error_metric("authentication_failures")
raise HTTPException(
status_code=500,
detail="Internal server error during authentication"
)
@app.get(
"/api/auth/profile",
tags=["Authentication"],
summary="Get user profile",
description="Retrieve current user profile information"
)
async def get_user_profile(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Get current user information"""
user_info = get_current_user(request)
return {
"username": user_info["username"],
"role": user_info["role"],
"permissions": user_info.get("permissions", []),
"full_name": user_info.get("full_name"),
"email": user_info.get("email"),
"organization": user_info.get("organization"),
"organizational_role": user_info.get("organizational_role")
}
@app.get(
"/api/auth/permissions",
tags=["Authentication"],
summary="Get user permissions",
description="List all permissions for current user role"
)
async def get_user_permissions(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
user_info = get_current_user(request)
user_role = UserRole(user_info["role"])
permissions = PermissionChecker.get_user_permissions_list(user_role)
return {
"username": user_info["username"],
"role": user_info["role"],
"permissions": permissions,
"system_permissions": user_info.get("permissions", [])
}
@app.post(
"/api/auth/logout",
tags=["Authentication"],
summary="Logout user",
description="Revoke current JWT token and end session"
)
async def logout(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
user_info = get_current_user(request)
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ", 1)[1]
token_blacklist.revoke_token(token, "logout", user_info["username"])
immutable_audit_logger.log_audit_event(
action="user_logout",
resource=f"user/{user_info['username']}",
client_ip=client_ip,
user_agent=user_agent,
result="success",
username=user_info["username"]
)
return {
"message": "Successfully logged out",
"note": "JWT token has been revoked and added to blacklist"
}
@app.get(
"/api/auth/limits",
tags=["Authentication"],
summary="Get user limits",
description="Retrieve role-based limits and allowed networks"
)
async def get_user_limits(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Get limits and restrictions for current user"""
user_info = get_current_user(request)
if not user_info:
raise HTTPException(status_code=401, detail="Not authenticated")
user_role = UserRole(user_info["role"])
role_limits = SecurityConfig.get_role_limits(user_role)
return {
"username": user_info["username"],
"role": user_role.value,
"limits": role_limits,
"security_info": {
"blocked_hosts": list(SecurityConfig.BLOCKED_HOSTS),
"blocked_networks": SecurityConfig.BLOCKED_NETWORKS
}
}
@app.post(
"/api/auth/revoke",
tags=["Authentication"],
summary="Revoke token",
description="Revoke JWT token and invalidate session"
)
async def revoke_token(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Revoke current JWT token"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
user_info = get_current_user(request)
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(status_code=400, detail="No valid token provided")
token = auth_header.split(" ", 1)[1]
success = token_blacklist.revoke_token(token, "logout", user_info["username"])
if success:
immutable_audit_logger.log_audit_event(
action="token_revoked",
resource=f"token/{token[:20]}...",
client_ip=client_ip,
user_agent=user_agent,
result="success",
username=user_info["username"]
)
return {"message": "Token revoked successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to revoke token")
class MachineAvailabilityRequest(BaseModel):
"""Machine availability check request"""
hostname: str
port: Optional[int] = None
class MachineAvailabilityResponse(BaseModel):
"""Machine availability check response"""
available: bool
hostname: str
port: int
response_time_ms: Optional[float] = None
checked_at: str
@app.post(
"/api/machines/check-availability",
tags=["Machines"],
response_model=MachineAvailabilityResponse,
summary="Check machine availability",
description="Test if machine is reachable via TCP connection"
)
async def check_machine_availability(
availability_request: MachineAvailabilityRequest,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""Check machine availability (quick ping)
Args:
hostname: DNS hostname of machine
port: Port to check (optional, default 3389 for RDP)
Returns:
Availability check result with response time
"""
client_ip = request.client.host if request.client else "unknown"
user_info = get_current_user(request)
username = user_info.get("username", "unknown")
hostname = availability_request.hostname
port = availability_request.port if availability_request.port else 3389
logger.debug("Machine availability check requested",
hostname=hostname,
port=port,
username=username,
client_ip=client_ip)
start_time = time.time()
try:
with socket.create_connection((hostname, port), timeout=2):
response_time_ms = (time.time() - start_time) * 1000
available = True
logger.info("Machine is available",
hostname=hostname,
port=port,
response_time_ms=round(response_time_ms, 2),
username=username)
except (socket.timeout, socket.error, ConnectionRefusedError, OSError) as e:
response_time_ms = (time.time() - start_time) * 1000
available = False
logger.info("Machine is not available",
hostname=hostname,
port=port,
response_time_ms=round(response_time_ms, 2),
error=str(e),
username=username)
return MachineAvailabilityResponse(
available=available,
hostname=hostname,
port=port,
response_time_ms=round(response_time_ms, 2) if response_time_ms else None,
checked_at=datetime.now().isoformat()
)
@app.post(
"/api/connections",
tags=["Connections"],
response_model=ConnectionResponse,
summary="Create connection",
description="Create new remote desktop connection (RDP/VNC/SSH)"
)
async def create_remote_connection(
connection_request: ConnectionRequest,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
start_time = time.time()
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
user_info = get_current_user(request)
guacamole_token = get_current_user_token(request)
if not user_info or not guacamole_token:
logger.error("Missing user info or token from middleware",
has_user_info=bool(user_info),
has_token=bool(guacamole_token))
raise HTTPException(status_code=401, detail="Authentication required")
username = user_info["username"]
user_role = UserRole(user_info["role"])
logger.info("Creating connection for authenticated user",
username=username,
role=user_role.value,
guac_token_length=len(guacamole_token))
if connection_request.password:
connection_request.password = decrypt_password_from_request(
connection_request.password,
request,
context={
"username": username,
"hostname": connection_request.hostname,
"protocol": connection_request.protocol
}
)
role_limits = SecurityConfig.get_role_limits(user_role)
if not role_limits["can_create_connections"]:
raise HTTPException(
status_code=403,
detail=f"Role {user_role.value} cannot create connections"
)
ttl_valid, ttl_reason = SecurityConfig.validate_ttl(connection_request.ttl_minutes)
if not ttl_valid:
raise HTTPException(status_code=400, detail=f"Invalid TTL: {ttl_reason}")
if connection_request.ttl_minutes > role_limits["max_ttl_minutes"]:
raise HTTPException(
status_code=403,
detail=f"TTL {connection_request.ttl_minutes} exceeds role limit {role_limits['max_ttl_minutes']} minutes"
)
host_allowed, host_reason = SecurityConfig.is_host_allowed(connection_request.hostname, user_role)
if not host_allowed:
logger.warning("Host access denied",
username=username,
role=user_role.value,
hostname=connection_request.hostname,
reason=host_reason,
client_ip=client_ip)
log_security_event(
event_type="forbidden_host_access_attempt",
client_ip=client_ip,
user_agent=user_agent,
details={
"username": username,
"role": user_role.value,
"hostname": connection_request.hostname,
"reason": host_reason
},
severity="high"
)
raise HTTPException(
status_code=403,
detail=f"Access to host denied: {host_reason}"
)
log_audit_event(
action="connection_creation_started",
resource=f"{connection_request.protocol}://{connection_request.hostname}",
client_ip=client_ip,
user_agent=user_agent,
details={
"protocol": connection_request.protocol,
"hostname": connection_request.hostname,
"port": connection_request.port,
"ttl_minutes": connection_request.ttl_minutes,
"username": username,
"role": user_role.value
}
)
if not connection_request.port:
port_map = {"rdp": 3389, "vnc": 5900, "ssh": 22}
check_port = port_map.get(connection_request.protocol, 3389)
else:
check_port = connection_request.port
logger.debug("Starting host connectivity check",
target_host=connection_request.hostname,
port=check_port,
client_ip=client_ip)
# Check host availability before creating connection
try:
with socket.create_connection((connection_request.hostname, check_port), timeout=3):
pass # Connection successful
except (socket.timeout, socket.error, ConnectionRefusedError, OSError) as e:
log_security_event(
event_type="unreachable_host_attempt",
client_ip=client_ip,
user_agent=user_agent,
details={
"hostname": connection_request.hostname,
"port": check_port,
"protocol": connection_request.protocol,
"error": str(e)
},
severity="low"
)
raise HTTPException(
status_code=400,
detail=f"Host {connection_request.hostname}:{check_port} is not accessible"
)
logger.debug("Host connectivity check passed",
target_host=connection_request.hostname,
port=check_port,
client_ip=client_ip)
try:
result = guacamole_client.create_connection_with_user_token(connection_request, guacamole_token)
connection_id = result.get("connection_id")
expires_at = datetime.now(timezone.utc) + timedelta(minutes=connection_request.ttl_minutes)
schedule_connection_deletion(
connection_id=connection_id,
ttl_minutes=connection_request.ttl_minutes,
auth_token=result['auth_token'],
guacamole_username=user_info["username"],
hostname=connection_request.hostname,
protocol=connection_request.protocol,
owner_username=user_info["username"],
owner_role=user_info["role"]
)
result['expires_at'] = expires_at.isoformat()
result['ttl_minutes'] = connection_request.ttl_minutes
public_result = {
"connection_id": result["connection_id"],
"connection_url": result["connection_url"],
"status": result["status"],
"expires_at": result["expires_at"],
"ttl_minutes": result["ttl_minutes"]
}
duration_ms = (time.time() - start_time) * 1000
record_connection_metric(
protocol=connection_request.protocol,
client_ip=client_ip,
creation_time_ms=duration_ms,
success=True
)
log_connection_lifecycle(
connection_id=connection_id,
action="created",
client_ip=client_ip,
hostname=connection_request.hostname,
protocol=connection_request.protocol,
details={
"ttl_minutes": connection_request.ttl_minutes,
"expires_at": expires_at.isoformat(),
"creation_duration_ms": round(duration_ms, 2),
"user_agent": user_agent,
"username": user_info["username"],
"role": user_info["role"]
}
)
log_audit_event(
action="connection_created",
resource=f"{connection_request.protocol}://{connection_request.hostname}",
client_ip=client_ip,
user_agent=user_agent,
result="success",
details={
"connection_id": connection_id,
"duration_ms": round(duration_ms, 2),
"username": user_info["username"],
"role": user_info["role"]
}
)
return ConnectionResponse(**public_result)
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
record_connection_metric(
protocol=connection_request.protocol,
client_ip=client_ip,
creation_time_ms=duration_ms,
success=False
)
log_connection_lifecycle(
connection_id="failed",
action="failed",
client_ip=client_ip,
hostname=connection_request.hostname,
protocol=connection_request.protocol,
details={
"error": str(e),
"duration_ms": round(duration_ms, 2),
"user_agent": user_agent,
"username": user_info["username"],
"role": user_info["role"]
}
)
log_audit_event(
action="connection_creation_failed",
resource=f"{connection_request.protocol}://{connection_request.hostname}",
client_ip=client_ip,
user_agent=user_agent,
result="failure",
details={
"error": str(e),
"duration_ms": round(duration_ms, 2),
"username": user_info["username"],
"role": user_info["role"]
}
)
raise HTTPException(status_code=500, detail=str(e))
@app.get(
"/api/connections",
tags=["Connections"],
summary="List connections",
description="Retrieve active connections based on user role"
)
async def list_connections(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""
Returns active connections with connection_url for session restoration.
Used to restore connections after user re-login.
Note: URLs are generated with current user token, not old token from Redis.
This allows restoring connections after logout/login without 403 errors.
"""
user_info = get_current_user(request)
user_role = UserRole(user_info["role"])
username = user_info["username"]
current_guac_token = get_current_user_token(request)
if not current_guac_token:
logger.error("No Guacamole token available for user",
username=username)
raise HTTPException(
status_code=401,
detail="Authentication token not available"
)
client_ip = request.client.host if request.client else "unknown"
rate_limit_key = f"get_connections:{username}:{client_ip}"
allowed, rate_limit_headers = redis_rate_limiter.check_rate_limit(
rate_limit_key,
requests_limit=60,
window_seconds=60
)
if not allowed:
logger.warning("Rate limit exceeded for get connections",
username=username,
client_ip=client_ip,
rate_limit_headers=rate_limit_headers)
raise HTTPException(
status_code=429,
detail="Too many requests. Please try again later."
)
current_time = datetime.now(timezone.utc)
connections_with_ttl = []
all_connections = redis_connection_storage.get_all_connections()
for conn_id, conn_data in all_connections.items():
if not PermissionChecker.can_view_all_connections(user_role):
if conn_data.get('owner_username') != username:
continue
expires_at = datetime.fromisoformat(conn_data['expires_at'])
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
remaining_minutes = max(0, int((expires_at - current_time).total_seconds() / 60))
connection_url = None
if remaining_minutes > 0:
try:
connection_url = generate_connection_url(conn_id, current_guac_token)
logger.debug("Connection URL generated with current user token",
connection_id=conn_id,
username=username,
remaining_minutes=remaining_minutes)
except Exception as e:
logger.error("Failed to generate connection URL",
connection_id=conn_id,
error=str(e))
connections_with_ttl.append({
"connection_id": conn_id,
"hostname": conn_data.get('hostname', 'unknown'),
"protocol": conn_data.get('protocol', 'unknown'),
"owner_username": conn_data.get('owner_username', 'unknown'),
"owner_role": conn_data.get('owner_role', 'unknown'),
"created_at": conn_data['created_at'],
"expires_at": conn_data['expires_at'],
"ttl_minutes": conn_data['ttl_minutes'],
"remaining_minutes": remaining_minutes,
"status": "active" if remaining_minutes > 0 else "expired",
"connection_url": connection_url
})
logger.info(
"User retrieved connections list with refreshed tokens",
username=username,
total_connections=len(connections_with_ttl),
active_connections=len([c for c in connections_with_ttl if c['status'] == 'active']),
using_current_token=True
)
return {
"total_connections": len(connections_with_ttl),
"active_connections": len([c for c in connections_with_ttl if c['status'] == 'active']),
"connections": connections_with_ttl
}
import asyncio
import inspect
from typing import Any
async def _maybe_call(func: Any, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
return await asyncio.to_thread(func, *args, **kwargs)
@app.delete("/api/connections/{connection_id}", tags=["Connections"])
async def delete_connection(
connection_id: str,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security),
):
"""Force delete connection before TTL expiration"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
user_info = await _maybe_call(get_current_user, request)
user_role = UserRole(user_info["role"])
current_user_token = await _maybe_call(get_current_user_token, request)
if not current_user_token:
logger.error(
"No Guacamole token available for user",
username=user_info.get("username"),
connection_id=connection_id,
)
raise HTTPException(status_code=401, detail="Authentication token not available")
conn_data = await _maybe_call(redis_connection_storage.get_connection, connection_id)
if not conn_data:
log_security_event(
event_type="delete_nonexistent_connection",
client_ip=client_ip,
user_agent=user_agent,
details={"connection_id": connection_id, "username": user_info.get("username")},
severity="low",
)
raise HTTPException(status_code=404, detail="Connection not found")
# Permission check
allowed, reason = PermissionChecker.check_connection_ownership(
user_role, user_info.get("username"), conn_data.get("owner_username", "")
)
if not allowed:
log_security_event(
event_type="unauthorized_connection_deletion",
client_ip=client_ip,
user_agent=user_agent,
details={
"connection_id": connection_id,
"username": user_info.get("username"),
"owner": conn_data.get("owner_username", ""),
"reason": reason,
},
severity="medium",
)
raise HTTPException(status_code=403, detail=reason)
try:
deletion_success = await _maybe_call(
guacamole_client.delete_connection_with_user_token, connection_id, current_user_token
)
if not deletion_success:
logger.warning(
"Failed to delete connection from Guacamole",
connection_id=connection_id,
username=user_info.get("username"),
)
log_audit_event(
action="connection_deletion_failed",
resource=f"connection/{connection_id}",
client_ip=client_ip,
user_agent=user_agent,
result="failure",
details={"error": "Failed to delete from Guacamole"},
)
raise HTTPException(status_code=500, detail="Failed to delete connection from Guacamole")
await _maybe_call(redis_connection_storage.delete_connection, connection_id)
log_connection_lifecycle(
connection_id=connection_id,
action="deleted",
client_ip=client_ip,
hostname=conn_data.get("hostname", "unknown"),
protocol=conn_data.get("protocol", "unknown"),
details={
"user_agent": user_agent,
"remaining_ttl_minutes": conn_data.get("ttl_minutes", 0),
"deleted_manually": True,
"deleted_by": user_info.get("username"),
"deleter_role": user_info.get("role"),
},
)
log_audit_event(
action="connection_deleted",
resource=f"connection/{connection_id}",
client_ip=client_ip,
user_agent=user_agent,
result="success",
details={
"deleted_by": user_info.get("username"),
"deleter_role": user_info.get("role"),
"owner": conn_data.get("owner_username", ""),
},
)
return {"status": "deleted", "connection_id": connection_id}
except HTTPException:
raise
except Exception as e:
logger.error(
"Exception during connection deletion",
connection_id=connection_id,
username=user_info.get("username"),
error=str(e),
error_type=type(e).__name__,
)
log_audit_event(
action="connection_deletion_error",
resource=f"connection/{connection_id}",
client_ip=client_ip,
user_agent=user_agent,
result="error",
details={"error": str(e), "error_type": type(e).__name__},
)
raise HTTPException(
status_code=500, detail=f"Internal error during connection deletion: {str(e)}"
)
@app.post("/api/connections/{connection_id}/extend", tags=["Connections"])
async def extend_connection_ttl(
connection_id: str,
request: Request,
additional_minutes: int = 30,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""Extend active connection TTL
Args:
connection_id: Connection ID
additional_minutes: Minutes to add to TTL (default 30)
Returns:
Updated connection information
"""
client_ip = request.client.host if request.client else "unknown"
user_info = get_current_user(request)
user_role = UserRole(user_info["role"])
if additional_minutes < 1 or additional_minutes > 120:
raise HTTPException(
status_code=400,
detail="Additional minutes must be between 1 and 120"
)
conn_data = redis_connection_storage.get_connection(connection_id)
if not conn_data:
raise HTTPException(status_code=404, detail="Connection not found")
allowed, reason = PermissionChecker.check_connection_ownership(
user_role, user_info["username"], conn_data.get('owner_username', '')
)
if not allowed:
log_security_event(
event_type="unauthorized_connection_extension",
client_ip=client_ip,
user_agent=request.headers.get("user-agent", "unknown"),
details={
"connection_id": connection_id,
"username": user_info["username"],
"owner": conn_data.get('owner_username', ''),
"reason": reason
},
severity="medium"
)
raise HTTPException(status_code=403, detail=reason)
current_expires_at = datetime.fromisoformat(conn_data['expires_at'])
if current_expires_at.tzinfo is None:
current_expires_at = current_expires_at.replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
if current_expires_at <= current_time:
raise HTTPException(
status_code=400,
detail="Connection has already expired and cannot be extended"
)
new_expires_at = current_expires_at + timedelta(minutes=additional_minutes)
new_ttl_minutes = int(conn_data.get('ttl_minutes', 60)) + additional_minutes
conn_data['expires_at'] = new_expires_at.isoformat()
conn_data['ttl_minutes'] = new_ttl_minutes
redis_connection_storage.add_connection(
connection_id,
conn_data,
ttl_seconds=None
)
log_connection_lifecycle(
connection_id=connection_id,
action="extended",
client_ip=client_ip,
hostname=conn_data.get('hostname', 'unknown'),
protocol=conn_data.get('protocol', 'unknown'),
details={
"extended_by": user_info["username"],
"additional_minutes": additional_minutes,
"new_ttl_minutes": new_ttl_minutes,
"new_expires_at": new_expires_at.isoformat()
}
)
owner_username = conn_data.get('owner_username')
if owner_username:
try:
await websocket_manager.send_connection_extended(
username=owner_username,
connection_id=connection_id,
hostname=conn_data.get('hostname', 'unknown'),
new_expires_at=new_expires_at,
additional_minutes=additional_minutes
)
except Exception as ws_error:
logger.warning("Failed to send WebSocket notification",
connection_id=connection_id,
error=str(ws_error))
logger.info("Connection TTL extended",
connection_id=connection_id,
username=user_info["username"],
additional_minutes=additional_minutes,
new_expires_at=new_expires_at.isoformat())
return {
"status": "extended",
"connection_id": connection_id,
"additional_minutes": additional_minutes,
"new_ttl_minutes": new_ttl_minutes,
"new_expires_at": new_expires_at.isoformat(),
"remaining_minutes": int((new_expires_at - current_time).total_seconds() / 60)
}
@app.get(
"/api/machines/saved",
tags=["Machines"],
response_model=SavedMachineList,
summary="List saved machines",
description="Retrieve all saved machines for current user"
)
async def get_saved_machines(
request: Request,
include_stats: bool = False,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""Get user saved machines list
Args:
include_stats: Include connection statistics (optional)
"""
try:
user_info = get_current_user(request)
user_id = user_info["username"]
machines = saved_machines_db.get_user_machines(user_id, include_stats=include_stats)
machines_response = []
for machine in machines:
machine_dict = {
"id": str(machine['id']),
"user_id": machine['user_id'],
"name": machine['name'],
"hostname": machine['hostname'],
"port": machine['port'],
"protocol": machine['protocol'],
"os": machine.get('os'),
"username": machine.get('username'),
"description": machine.get('description'),
"tags": machine.get('tags') or [],
"is_favorite": machine.get('is_favorite', False),
"created_at": machine['created_at'].isoformat(),
"updated_at": machine['updated_at'].isoformat(),
"last_connected_at": machine['last_connected_at'].isoformat() if machine.get('last_connected_at') else None,
}
if include_stats and 'connection_stats' in machine:
machine_dict['connection_stats'] = machine['connection_stats']
machines_response.append(SavedMachineResponse(**machine_dict))
logger.info(
"Retrieved saved machines",
user_id=user_id,
count=len(machines_response)
)
return SavedMachineList(
total=len(machines_response),
machines=machines_response
)
except Exception as e:
logger.error("Failed to get saved machines", error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to retrieve saved machines: {str(e)}")
@app.post(
"/api/machines/saved",
tags=["Machines"],
response_model=SavedMachineResponse,
summary="Save machine",
description="Create new saved machine entry with credentials"
)
async def create_saved_machine(
machine: SavedMachineCreate,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""
Save a new machine in user profile
Security: Password transmitted over HTTPS and encrypted in DB (AES-256).
"""
try:
user_info = get_current_user(request)
user_id = user_info["username"]
user_token = get_current_user_token(request)
valid_protocols = ['rdp', 'ssh', 'vnc', 'telnet']
if machine.protocol.lower() not in valid_protocols:
raise HTTPException(
status_code=400,
detail=f"Invalid protocol. Must be one of: {', '.join(valid_protocols)}"
)
# Create machine in DB (passwords are NOT stored)
created_machine = saved_machines_db.create_machine(
user_id=user_id,
name=machine.name,
hostname=machine.hostname,
port=machine.port,
protocol=machine.protocol.lower(),
os=machine.os,
description=machine.description,
tags=machine.tags or [],
is_favorite=machine.is_favorite
)
logger.info(
"Saved machine created",
machine_id=created_machine['id'],
user_id=user_id,
name=machine.name
)
return SavedMachineResponse(
id=str(created_machine['id']),
user_id=created_machine['user_id'],
name=created_machine['name'],
hostname=created_machine['hostname'],
port=created_machine['port'],
protocol=created_machine['protocol'],
os=created_machine.get('os'),
username=created_machine.get('username'),
description=created_machine.get('description'),
tags=created_machine.get('tags') or [],
is_favorite=created_machine.get('is_favorite', False),
created_at=created_machine['created_at'].isoformat(),
updated_at=created_machine['updated_at'].isoformat(),
last_connected_at=None
)
except HTTPException:
raise
except Exception as e:
logger.error("Failed to create saved machine",
error=str(e),
user_id=user_info.get("username") if 'user_info' in locals() else "unknown")
raise HTTPException(status_code=500, detail=f"Failed to create saved machine: {str(e)}")
@app.get(
"/api/machines/saved/{machine_id}",
tags=["Machines"],
response_model=SavedMachineResponse,
summary="Get saved machine",
description="Retrieve specific saved machine details"
)
async def get_saved_machine(
machine_id: str,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""Get saved machine information
Only owner can access the machine.
"""
try:
user_info = get_current_user(request)
user_id = user_info["username"]
machine = saved_machines_db.get_machine_by_id(machine_id, user_id)
if not machine:
raise HTTPException(status_code=404, detail="Machine not found")
return SavedMachineResponse(
id=str(machine['id']),
user_id=machine['user_id'],
name=machine['name'],
hostname=machine['hostname'],
port=machine['port'],
protocol=machine['protocol'],
os=machine.get('os'),
description=machine.get('description'),
tags=machine.get('tags') or [],
is_favorite=machine.get('is_favorite', False),
created_at=machine['created_at'].isoformat(),
updated_at=machine['updated_at'].isoformat(),
last_connected_at=machine['last_connected_at'].isoformat() if machine.get('last_connected_at') else None
)
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get saved machine",
error=str(e),
machine_id=machine_id)
raise HTTPException(status_code=500, detail=f"Failed to retrieve machine: {str(e)}")
@app.put(
"/api/machines/saved/{machine_id}",
tags=["Machines"],
response_model=SavedMachineResponse,
summary="Update saved machine",
description="Modify machine configuration and credentials"
)
async def update_saved_machine(
machine_id: str,
machine: SavedMachineUpdate,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""
Update saved machine
Security: Password transmitted over HTTPS and encrypted in DB (AES-256).
"""
try:
user_info = get_current_user(request)
user_id = user_info["username"]
user_token = get_current_user_token(request)
existing = saved_machines_db.get_machine_by_id(machine_id, user_id)
if not existing:
raise HTTPException(status_code=404, detail="Machine not found")
update_data = {}
if machine.name is not None:
update_data['name'] = machine.name
if machine.hostname is not None:
update_data['hostname'] = machine.hostname
if machine.port is not None:
update_data['port'] = machine.port
if machine.protocol is not None:
valid_protocols = ['rdp', 'ssh', 'vnc', 'telnet']
if machine.protocol.lower() not in valid_protocols:
raise HTTPException(
status_code=400,
detail=f"Invalid protocol. Must be one of: {', '.join(valid_protocols)}"
)
update_data['protocol'] = machine.protocol.lower()
if machine.os is not None:
update_data['os'] = machine.os
# Note: credentials are NOT stored, they are requested at connection time
if machine.description is not None:
update_data['description'] = machine.description
if machine.tags is not None:
update_data['tags'] = machine.tags
if machine.is_favorite is not None:
update_data['is_favorite'] = machine.is_favorite
# Update in DB
updated_machine = saved_machines_db.update_machine(machine_id, user_id, **update_data)
if not updated_machine:
raise HTTPException(status_code=404, detail="Machine not found after update")
logger.info(
"Saved machine updated",
machine_id=machine_id,
user_id=user_id,
updated_fields=list(update_data.keys())
)
return SavedMachineResponse(
id=str(updated_machine['id']),
user_id=updated_machine['user_id'],
name=updated_machine['name'],
hostname=updated_machine['hostname'],
port=updated_machine['port'],
protocol=updated_machine['protocol'],
os=updated_machine.get('os'),
username=updated_machine.get('username'),
description=updated_machine.get('description'),
tags=updated_machine.get('tags') or [],
is_favorite=updated_machine.get('is_favorite', False),
created_at=updated_machine['created_at'].isoformat(),
updated_at=updated_machine['updated_at'].isoformat(),
last_connected_at=updated_machine['last_connected_at'].isoformat() if updated_machine.get('last_connected_at') else None
)
except HTTPException:
raise
except Exception as e:
logger.error("Failed to update saved machine",
error=str(e),
machine_id=machine_id)
raise HTTPException(status_code=500, detail=f"Failed to update machine: {str(e)}")
@app.delete(
"/api/machines/saved/{machine_id}",
tags=["Machines"],
summary="Delete saved machine",
description="Remove saved machine from user profile"
)
async def delete_saved_machine(
machine_id: str,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""Delete saved machine. Only owner can delete the machine. Connection history will also be deleted (CASCADE)."""
try:
user_info = get_current_user(request)
user_id = user_info["username"]
deleted = saved_machines_db.delete_machine(machine_id, user_id)
if not deleted:
raise HTTPException(status_code=404, detail="Machine not found")
logger.info(
"Saved machine deleted",
machine_id=machine_id,
user_id=user_id
)
return {"success": True, "message": "Machine deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to delete saved machine",
error=str(e),
machine_id=machine_id)
raise HTTPException(status_code=500, detail=f"Failed to delete machine: {str(e)}")
@app.post(
"/api/machines/saved/{machine_id}/connect",
tags=["Machines"],
summary="Connect to saved machine",
description="Record connection attempt and update timestamp"
)
async def connect_to_saved_machine(
machine_id: str,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"Mark connection to saved machine. Updates last_connected_at and adds history entry."
try:
user_info = get_current_user(request)
user_id = user_info["username"]
client_ip = request.client.host if request.client else "unknown"
machine = saved_machines_db.get_machine_by_id(machine_id, user_id)
if not machine:
raise HTTPException(status_code=404, detail="Machine not found")
# Update last_connected_at
saved_machines_db.update_last_connected(machine_id, user_id)
# Add history entry
history_id = saved_machines_db.add_connection_history(
user_id=user_id,
machine_id=machine_id,
success=True,
client_ip=client_ip
)
logger.info(
"Connection to saved machine recorded",
machine_id=machine_id,
user_id=user_id,
history_id=history_id
)
return {
"success": True,
"message": "Connection recorded",
"history_id": history_id
}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to record connection",
error=str(e),
machine_id=machine_id)
raise HTTPException(status_code=500, detail=f"Failed to record connection: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)