init_guac
This commit is contained in:
18
guacamole_test_11_26/api/Dockerfile
Executable file
18
guacamole_test_11_26/api/Dockerfile
Executable file
@ -0,0 +1,18 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Копируем файл зависимостей
|
||||
COPY requirements.txt .
|
||||
|
||||
# Устанавливаем зависимости
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Копируем код приложения
|
||||
COPY . .
|
||||
|
||||
# Открываем порт
|
||||
EXPOSE 8000
|
||||
|
||||
# Запускаем приложение
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
35
guacamole_test_11_26/api/core/__init__.py
Executable file
35
guacamole_test_11_26/api/core/__init__.py
Executable file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Core module for Remote Access API.
|
||||
|
||||
Provides:
|
||||
- Authentication and authorization (JWT, Guacamole integration)
|
||||
- Security features (CSRF, rate limiting, brute force protection)
|
||||
- Storage and session management (Redis, PostgreSQL)
|
||||
- Audit logging and WebSocket notifications
|
||||
- Role and permission system
|
||||
"""
|
||||
|
||||
from .guacamole_auth import GuacamoleAuthenticator
|
||||
from .models import (
|
||||
ConnectionRequest,
|
||||
ConnectionResponse,
|
||||
LoginRequest,
|
||||
LoginResponse,
|
||||
UserInfo,
|
||||
UserRole,
|
||||
)
|
||||
from .permissions import PermissionChecker
|
||||
from .utils import create_jwt_token, verify_jwt_token
|
||||
|
||||
__all__ = [
|
||||
"ConnectionRequest",
|
||||
"ConnectionResponse",
|
||||
"GuacamoleAuthenticator",
|
||||
"LoginRequest",
|
||||
"LoginResponse",
|
||||
"PermissionChecker",
|
||||
"UserInfo",
|
||||
"UserRole",
|
||||
"create_jwt_token",
|
||||
"verify_jwt_token",
|
||||
]
|
||||
380
guacamole_test_11_26/api/core/audit_logger.py
Executable file
380
guacamole_test_11_26/api/core/audit_logger.py
Executable file
@ -0,0 +1,380 @@
|
||||
"""
|
||||
Immutable audit logging with HMAC signatures.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ImmutableAuditLogger:
|
||||
"""Immutable audit logger with HMAC signatures to prevent log tampering."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the immutable audit logger."""
|
||||
self.hmac_secret = os.getenv(
|
||||
"AUDIT_HMAC_SECRET", "default_audit_secret_change_me"
|
||||
)
|
||||
log_path_str = os.getenv(
|
||||
"AUDIT_LOG_PATH", "/var/log/remote_access_audit.log"
|
||||
)
|
||||
self.audit_log_path = Path(log_path_str)
|
||||
self.audit_log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.audit_logger = structlog.get_logger("audit")
|
||||
|
||||
logger.info(
|
||||
"Immutable audit logger initialized",
|
||||
audit_log_path=str(self.audit_log_path),
|
||||
)
|
||||
|
||||
def _generate_hmac_signature(self, data: str) -> str:
|
||||
"""
|
||||
Generate HMAC signature for data.
|
||||
|
||||
Args:
|
||||
data: Data to sign.
|
||||
|
||||
Returns:
|
||||
HMAC signature in hex format.
|
||||
"""
|
||||
return hmac.new(
|
||||
self.hmac_secret.encode("utf-8"),
|
||||
data.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
def _verify_hmac_signature(self, data: str, signature: str) -> bool:
|
||||
"""
|
||||
Verify HMAC signature.
|
||||
|
||||
Args:
|
||||
data: Data to verify.
|
||||
signature: Signature to verify.
|
||||
|
||||
Returns:
|
||||
True if signature is valid.
|
||||
"""
|
||||
expected_signature = self._generate_hmac_signature(data)
|
||||
return hmac.compare_digest(expected_signature, signature)
|
||||
|
||||
def log_security_event(
|
||||
self,
|
||||
event_type: str,
|
||||
client_ip: str,
|
||||
user_agent: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
severity: str = "info",
|
||||
username: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log security event with immutable record.
|
||||
|
||||
Args:
|
||||
event_type: Event type.
|
||||
client_ip: Client IP address.
|
||||
user_agent: Client user agent.
|
||||
details: Additional details.
|
||||
severity: Severity level.
|
||||
username: Username if applicable.
|
||||
|
||||
Returns:
|
||||
True if logging succeeded.
|
||||
"""
|
||||
try:
|
||||
event_data = {
|
||||
"event_type": "security_event",
|
||||
"security_event_type": event_type,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"severity": severity,
|
||||
"username": username,
|
||||
"details": details or {},
|
||||
}
|
||||
return self._write_immutable_log(event_data)
|
||||
except Exception as e:
|
||||
logger.error("Failed to log security event", error=str(e))
|
||||
return False
|
||||
|
||||
def log_audit_event(
|
||||
self,
|
||||
action: str,
|
||||
resource: str,
|
||||
client_ip: str,
|
||||
user_agent: Optional[str] = None,
|
||||
result: str = "success",
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log audit event with immutable record.
|
||||
|
||||
Args:
|
||||
action: Action performed.
|
||||
resource: Resource affected.
|
||||
client_ip: Client IP address.
|
||||
user_agent: Client user agent.
|
||||
result: Action result.
|
||||
details: Additional details.
|
||||
username: Username.
|
||||
|
||||
Returns:
|
||||
True if logging succeeded.
|
||||
"""
|
||||
try:
|
||||
event_data = {
|
||||
"event_type": "audit_event",
|
||||
"action": action,
|
||||
"resource": resource,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"result": result,
|
||||
"username": username,
|
||||
"details": details or {},
|
||||
}
|
||||
return self._write_immutable_log(event_data)
|
||||
except Exception as e:
|
||||
logger.error("Failed to log audit event", error=str(e))
|
||||
return False
|
||||
|
||||
def log_authentication_event(
|
||||
self,
|
||||
event_type: str,
|
||||
username: str,
|
||||
client_ip: str,
|
||||
success: bool,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log authentication event.
|
||||
|
||||
Args:
|
||||
event_type: Event type (login, logout, failed_login, etc.).
|
||||
username: Username.
|
||||
client_ip: Client IP address.
|
||||
success: Operation success status.
|
||||
details: Additional details.
|
||||
|
||||
Returns:
|
||||
True if logging succeeded.
|
||||
"""
|
||||
try:
|
||||
event_data = {
|
||||
"event_type": "authentication_event",
|
||||
"auth_event_type": event_type,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"username": username,
|
||||
"client_ip": client_ip,
|
||||
"success": success,
|
||||
"details": details or {},
|
||||
}
|
||||
return self._write_immutable_log(event_data)
|
||||
except Exception as e:
|
||||
logger.error("Failed to log authentication event", error=str(e))
|
||||
return False
|
||||
|
||||
def log_connection_event(
|
||||
self,
|
||||
event_type: str,
|
||||
connection_id: str,
|
||||
username: str,
|
||||
client_ip: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log connection event.
|
||||
|
||||
Args:
|
||||
event_type: Event type (created, deleted, expired, etc.).
|
||||
connection_id: Connection ID.
|
||||
username: Username.
|
||||
client_ip: Client IP address.
|
||||
details: Additional details.
|
||||
|
||||
Returns:
|
||||
True if logging succeeded.
|
||||
"""
|
||||
try:
|
||||
event_data = {
|
||||
"event_type": "connection_event",
|
||||
"connection_event_type": event_type,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"connection_id": connection_id,
|
||||
"username": username,
|
||||
"client_ip": client_ip,
|
||||
"details": details or {},
|
||||
}
|
||||
return self._write_immutable_log(event_data)
|
||||
except Exception as e:
|
||||
logger.error("Failed to log connection event", error=str(e))
|
||||
return False
|
||||
|
||||
def _write_immutable_log(self, event_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Write immutable log entry with HMAC signature.
|
||||
|
||||
Args:
|
||||
event_data: Event data.
|
||||
|
||||
Returns:
|
||||
True if write succeeded.
|
||||
"""
|
||||
try:
|
||||
json_data = json.dumps(event_data, ensure_ascii=False, sort_keys=True)
|
||||
signature = self._generate_hmac_signature(json_data)
|
||||
log_entry = {
|
||||
"data": event_data,
|
||||
"signature": signature,
|
||||
"log_timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
with self.audit_log_path.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
||||
f.flush()
|
||||
|
||||
self.audit_logger.info(
|
||||
"Audit event logged",
|
||||
event_type=event_data.get("event_type"),
|
||||
signature=signature[:16] + "...",
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to write immutable log", error=str(e))
|
||||
return False
|
||||
|
||||
def verify_log_integrity(
|
||||
self, log_file_path: Optional[Union[str, Path]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify audit log integrity.
|
||||
|
||||
Args:
|
||||
log_file_path: Path to log file (defaults to main log file).
|
||||
|
||||
Returns:
|
||||
Integrity verification result.
|
||||
"""
|
||||
try:
|
||||
file_path = (
|
||||
Path(log_file_path) if log_file_path else self.audit_log_path
|
||||
)
|
||||
|
||||
if not file_path.exists():
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Log file does not exist",
|
||||
"file_path": str(file_path),
|
||||
}
|
||||
|
||||
valid_entries = 0
|
||||
invalid_entries = 0
|
||||
total_entries = 0
|
||||
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
total_entries += 1
|
||||
|
||||
try:
|
||||
log_entry = json.loads(line)
|
||||
|
||||
if "data" not in log_entry or "signature" not in log_entry:
|
||||
invalid_entries += 1
|
||||
continue
|
||||
|
||||
json_data = json.dumps(
|
||||
log_entry["data"], ensure_ascii=False, sort_keys=True
|
||||
)
|
||||
|
||||
if self._verify_hmac_signature(
|
||||
json_data, log_entry["signature"]
|
||||
):
|
||||
valid_entries += 1
|
||||
else:
|
||||
invalid_entries += 1
|
||||
except (json.JSONDecodeError, KeyError, ValueError):
|
||||
invalid_entries += 1
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_path": str(file_path),
|
||||
"total_entries": total_entries,
|
||||
"valid_entries": valid_entries,
|
||||
"invalid_entries": invalid_entries,
|
||||
"integrity_percentage": (
|
||||
(valid_entries / total_entries * 100) if total_entries > 0 else 0
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify log integrity", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"file_path": str(log_file_path or self.audit_log_path),
|
||||
}
|
||||
|
||||
def get_audit_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get audit log statistics.
|
||||
|
||||
Returns:
|
||||
Audit log statistics.
|
||||
"""
|
||||
try:
|
||||
if not self.audit_log_path.exists():
|
||||
return {
|
||||
"status": "no_log_file",
|
||||
"file_path": str(self.audit_log_path),
|
||||
}
|
||||
|
||||
file_size = self.audit_log_path.stat().st_size
|
||||
event_types: Counter[str] = Counter()
|
||||
total_entries = 0
|
||||
|
||||
with self.audit_log_path.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
log_entry = json.loads(line)
|
||||
if (
|
||||
"data" in log_entry
|
||||
and "event_type" in log_entry["data"]
|
||||
):
|
||||
event_type = log_entry["data"]["event_type"]
|
||||
event_types[event_type] += 1
|
||||
total_entries += 1
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_path": str(self.audit_log_path),
|
||||
"file_size_bytes": file_size,
|
||||
"total_entries": total_entries,
|
||||
"event_types": dict(event_types),
|
||||
"hmac_secret_configured": bool(
|
||||
self.hmac_secret
|
||||
and self.hmac_secret != "default_audit_secret_change_me"
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to get audit stats", error=str(e))
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# Global instance for use in API
|
||||
immutable_audit_logger = ImmutableAuditLogger()
|
||||
327
guacamole_test_11_26/api/core/brute_force_protection.py
Executable file
327
guacamole_test_11_26/api/core/brute_force_protection.py
Executable file
@ -0,0 +1,327 @@
|
||||
"""Brute-force protection for login endpoint."""
|
||||
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import structlog
|
||||
|
||||
from .rate_limiter import redis_rate_limiter
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Backoff constants
|
||||
MAX_BACKOFF_SECONDS = 300
|
||||
MIN_FAILED_ATTEMPTS_FOR_BACKOFF = 2
|
||||
EXPONENTIAL_BACKOFF_BASE = 2
|
||||
|
||||
# Default limits
|
||||
DEFAULT_MAX_LOGIN_ATTEMPTS_PER_IP = 5
|
||||
DEFAULT_MAX_LOGIN_ATTEMPTS_PER_USER = 10
|
||||
DEFAULT_LOGIN_WINDOW_MINUTES = 15
|
||||
DEFAULT_USER_LOCKOUT_MINUTES = 60
|
||||
|
||||
# Block types
|
||||
BLOCK_TYPE_RATE_LIMIT = "rate_limit"
|
||||
BLOCK_TYPE_IP_BLOCKED = "ip_blocked"
|
||||
BLOCK_TYPE_USER_LOCKED = "user_locked"
|
||||
BLOCK_TYPE_EXPONENTIAL_BACKOFF = "exponential_backoff"
|
||||
BLOCK_TYPE_ALLOWED = "allowed"
|
||||
BLOCK_TYPE_ERROR_FALLBACK = "error_fallback"
|
||||
|
||||
# Response messages
|
||||
MSG_RATE_LIMIT_EXCEEDED = "Rate limit exceeded"
|
||||
MSG_IP_BLOCKED = "Too many failed attempts from this IP"
|
||||
MSG_USER_LOCKED = "User account temporarily locked"
|
||||
MSG_LOGIN_ALLOWED = "Login allowed"
|
||||
MSG_LOGIN_ALLOWED_ERROR = "Login allowed (protection error)"
|
||||
|
||||
# Default failure reason
|
||||
DEFAULT_FAILURE_REASON = "invalid_credentials"
|
||||
|
||||
# Empty string for clearing
|
||||
EMPTY_USERNAME = ""
|
||||
EMPTY_IP = ""
|
||||
|
||||
|
||||
class BruteForceProtection:
|
||||
"""Protection against brute-force attacks on login endpoint."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize brute-force protection."""
|
||||
self.max_login_attempts_per_ip = DEFAULT_MAX_LOGIN_ATTEMPTS_PER_IP
|
||||
self.max_login_attempts_per_user = DEFAULT_MAX_LOGIN_ATTEMPTS_PER_USER
|
||||
self.login_window_minutes = DEFAULT_LOGIN_WINDOW_MINUTES
|
||||
self.user_lockout_minutes = DEFAULT_USER_LOCKOUT_MINUTES
|
||||
self.exponential_backoff_base = EXPONENTIAL_BACKOFF_BASE
|
||||
|
||||
def check_login_allowed(
|
||||
self, client_ip: str, username: str
|
||||
) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""
|
||||
Check if login is allowed for given IP and user.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, reason: str, details: Dict[str, Any]).
|
||||
"""
|
||||
try:
|
||||
allowed, headers = redis_rate_limiter.check_login_rate_limit(
|
||||
client_ip, username
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
return (
|
||||
False,
|
||||
MSG_RATE_LIMIT_EXCEEDED,
|
||||
{
|
||||
"type": BLOCK_TYPE_RATE_LIMIT,
|
||||
"client_ip": client_ip,
|
||||
"username": username,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
failed_counts = redis_rate_limiter.get_failed_login_count(
|
||||
client_ip, username, self.user_lockout_minutes
|
||||
)
|
||||
|
||||
if failed_counts["ip_failed_count"] >= self.max_login_attempts_per_ip:
|
||||
return (
|
||||
False,
|
||||
MSG_IP_BLOCKED,
|
||||
{
|
||||
"type": BLOCK_TYPE_IP_BLOCKED,
|
||||
"client_ip": client_ip,
|
||||
"failed_count": failed_counts["ip_failed_count"],
|
||||
"max_attempts": self.max_login_attempts_per_ip,
|
||||
"window_minutes": self.login_window_minutes,
|
||||
},
|
||||
)
|
||||
|
||||
if failed_counts["user_failed_count"] >= self.max_login_attempts_per_user:
|
||||
return (
|
||||
False,
|
||||
MSG_USER_LOCKED,
|
||||
{
|
||||
"type": BLOCK_TYPE_USER_LOCKED,
|
||||
"username": username,
|
||||
"failed_count": failed_counts["user_failed_count"],
|
||||
"max_attempts": self.max_login_attempts_per_user,
|
||||
"lockout_minutes": self.user_lockout_minutes,
|
||||
},
|
||||
)
|
||||
|
||||
backoff_seconds = self._calculate_backoff_time(
|
||||
client_ip, username, failed_counts
|
||||
)
|
||||
if backoff_seconds > 0:
|
||||
return (
|
||||
False,
|
||||
f"Please wait {backoff_seconds} seconds before next attempt",
|
||||
{
|
||||
"type": BLOCK_TYPE_EXPONENTIAL_BACKOFF,
|
||||
"wait_seconds": backoff_seconds,
|
||||
"client_ip": client_ip,
|
||||
"username": username,
|
||||
},
|
||||
)
|
||||
|
||||
return (
|
||||
True,
|
||||
MSG_LOGIN_ALLOWED,
|
||||
{
|
||||
"type": BLOCK_TYPE_ALLOWED,
|
||||
"client_ip": client_ip,
|
||||
"username": username,
|
||||
"failed_counts": failed_counts,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error checking login permission",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
return (
|
||||
True,
|
||||
MSG_LOGIN_ALLOWED_ERROR,
|
||||
{"type": BLOCK_TYPE_ERROR_FALLBACK, "error": str(e)},
|
||||
)
|
||||
|
||||
def record_failed_login(
|
||||
self,
|
||||
client_ip: str,
|
||||
username: str,
|
||||
failure_reason: str = DEFAULT_FAILURE_REASON,
|
||||
) -> None:
|
||||
"""
|
||||
Record failed login attempt.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username.
|
||||
failure_reason: Failure reason.
|
||||
"""
|
||||
try:
|
||||
redis_rate_limiter.record_failed_login(client_ip, username)
|
||||
|
||||
logger.warning(
|
||||
"Failed login attempt recorded",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
failure_reason=failure_reason,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to record failed login attempt",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def record_successful_login(self, client_ip: str, username: str) -> None:
|
||||
"""
|
||||
Record successful login (clear failed attempts).
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username.
|
||||
"""
|
||||
try:
|
||||
redis_rate_limiter.clear_failed_logins(client_ip, username)
|
||||
|
||||
logger.info(
|
||||
"Successful login recorded, failed attempts cleared",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to record successful login",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def _calculate_backoff_time(
|
||||
self, client_ip: str, username: str, failed_counts: Dict[str, int]
|
||||
) -> int:
|
||||
"""
|
||||
Calculate wait time for exponential backoff.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username.
|
||||
failed_counts: Failed attempt counts.
|
||||
|
||||
Returns:
|
||||
Wait time in seconds.
|
||||
"""
|
||||
try:
|
||||
max_failed = max(
|
||||
failed_counts["ip_failed_count"],
|
||||
failed_counts["user_failed_count"],
|
||||
)
|
||||
|
||||
if max_failed <= MIN_FAILED_ATTEMPTS_FOR_BACKOFF:
|
||||
return 0
|
||||
|
||||
backoff_seconds = min(
|
||||
self.exponential_backoff_base
|
||||
** (max_failed - MIN_FAILED_ATTEMPTS_FOR_BACKOFF),
|
||||
MAX_BACKOFF_SECONDS,
|
||||
)
|
||||
|
||||
return backoff_seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error calculating backoff time", error=str(e))
|
||||
return 0
|
||||
|
||||
def get_protection_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get brute-force protection statistics.
|
||||
|
||||
Returns:
|
||||
Protection statistics dictionary.
|
||||
"""
|
||||
try:
|
||||
rate_limit_stats = redis_rate_limiter.get_rate_limit_stats()
|
||||
|
||||
return {
|
||||
"max_login_attempts_per_ip": self.max_login_attempts_per_ip,
|
||||
"max_login_attempts_per_user": self.max_login_attempts_per_user,
|
||||
"login_window_minutes": self.login_window_minutes,
|
||||
"user_lockout_minutes": self.user_lockout_minutes,
|
||||
"exponential_backoff_base": self.exponential_backoff_base,
|
||||
"rate_limit_stats": rate_limit_stats,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get protection stats", error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
def force_unlock_user(self, username: str, unlocked_by: str) -> bool:
|
||||
"""
|
||||
Force unlock user (for administrators).
|
||||
|
||||
Args:
|
||||
username: Username to unlock.
|
||||
unlocked_by: Who unlocked the user.
|
||||
|
||||
Returns:
|
||||
True if unlock successful.
|
||||
"""
|
||||
try:
|
||||
redis_rate_limiter.clear_failed_logins(EMPTY_IP, username)
|
||||
|
||||
logger.info("User force unlocked", username=username, unlocked_by=unlocked_by)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to force unlock user",
|
||||
username=username,
|
||||
unlocked_by=unlocked_by,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def force_unlock_ip(self, client_ip: str, unlocked_by: str) -> bool:
|
||||
"""
|
||||
Force unlock IP (for administrators).
|
||||
|
||||
Args:
|
||||
client_ip: IP address to unlock.
|
||||
unlocked_by: Who unlocked the IP.
|
||||
|
||||
Returns:
|
||||
True if unlock successful.
|
||||
"""
|
||||
try:
|
||||
redis_rate_limiter.clear_failed_logins(client_ip, EMPTY_USERNAME)
|
||||
|
||||
logger.info(
|
||||
"IP force unlocked", client_ip=client_ip, unlocked_by=unlocked_by
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to force unlock IP",
|
||||
client_ip=client_ip,
|
||||
unlocked_by=unlocked_by,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
brute_force_protection = BruteForceProtection()
|
||||
361
guacamole_test_11_26/api/core/csrf_protection.py
Executable file
361
guacamole_test_11_26/api/core/csrf_protection.py
Executable file
@ -0,0 +1,361 @@
|
||||
"""CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, FrozenSet
|
||||
|
||||
import redis
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Redis configuration
|
||||
REDIS_DEFAULT_HOST = "redis"
|
||||
REDIS_DEFAULT_PORT = "6379"
|
||||
REDIS_DEFAULT_DB = 0
|
||||
|
||||
# Token configuration
|
||||
REDIS_KEY_PREFIX = "csrf:token:"
|
||||
CSRF_TOKEN_TTL_SECONDS = 3600
|
||||
TOKEN_SIZE_BYTES = 32
|
||||
SECRET_KEY_SIZE_BYTES = 32
|
||||
TOKEN_PARTS_COUNT = 3
|
||||
TOKEN_PREVIEW_LENGTH = 16
|
||||
SCAN_BATCH_SIZE = 100
|
||||
|
||||
# Redis TTL special values
|
||||
TTL_KEY_NOT_EXISTS = -2
|
||||
TTL_KEY_NO_EXPIRY = -1
|
||||
|
||||
# Protected HTTP methods
|
||||
PROTECTED_METHODS: FrozenSet[str] = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
# Excluded endpoints (no CSRF protection)
|
||||
EXCLUDED_ENDPOINTS: FrozenSet[str] = frozenset({
|
||||
"/auth/login",
|
||||
"/health",
|
||||
"/health/detailed",
|
||||
"/health/ready",
|
||||
"/health/live",
|
||||
"/health/routing",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
})
|
||||
|
||||
|
||||
class CSRFProtection:
|
||||
"""CSRF protection with Double Submit Cookie pattern."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize CSRF protection.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Redis connection fails.
|
||||
"""
|
||||
self._redis_client = redis.Redis(
|
||||
host=os.getenv("REDIS_HOST", REDIS_DEFAULT_HOST),
|
||||
port=int(os.getenv("REDIS_PORT", REDIS_DEFAULT_PORT)),
|
||||
password=os.getenv("REDIS_PASSWORD"),
|
||||
db=REDIS_DEFAULT_DB,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
self._csrf_token_ttl = CSRF_TOKEN_TTL_SECONDS
|
||||
self._token_size = TOKEN_SIZE_BYTES
|
||||
self._secret_key = secrets.token_bytes(SECRET_KEY_SIZE_BYTES)
|
||||
|
||||
try:
|
||||
self._redis_client.ping()
|
||||
logger.info("CSRF Protection connected to Redis successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to Redis for CSRF", error=str(e))
|
||||
raise RuntimeError(f"Redis connection failed: {e}") from e
|
||||
|
||||
self._protected_methods: FrozenSet[str] = PROTECTED_METHODS
|
||||
self._excluded_endpoints: FrozenSet[str] = EXCLUDED_ENDPOINTS
|
||||
|
||||
def generate_csrf_token(self, user_id: str) -> str:
|
||||
"""
|
||||
Generate CSRF token for user.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
CSRF token.
|
||||
"""
|
||||
try:
|
||||
random_bytes = secrets.token_bytes(self._token_size)
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"{user_id}:{timestamp}:{random_bytes.hex()}"
|
||||
signature = hashlib.sha256(
|
||||
f"{data_to_sign}:{self._secret_key.hex()}".encode()
|
||||
).hexdigest()
|
||||
|
||||
csrf_token = f"{random_bytes.hex()}:{timestamp}:{signature}"
|
||||
|
||||
now = datetime.now()
|
||||
expires_at = now + timedelta(seconds=self._csrf_token_ttl)
|
||||
token_data = {
|
||||
"user_id": user_id,
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"used": False,
|
||||
}
|
||||
|
||||
redis_key = f"{REDIS_KEY_PREFIX}{csrf_token}"
|
||||
self._redis_client.setex(
|
||||
redis_key, self._csrf_token_ttl, json.dumps(token_data)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"CSRF token generated in Redis",
|
||||
user_id=user_id,
|
||||
token_preview=csrf_token[:TOKEN_PREVIEW_LENGTH] + "...",
|
||||
expires_at=expires_at.isoformat(),
|
||||
)
|
||||
|
||||
return csrf_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate CSRF token", user_id=user_id, error=str(e))
|
||||
raise
|
||||
|
||||
def validate_csrf_token(self, token: str, user_id: str) -> bool:
|
||||
"""
|
||||
Validate CSRF token.
|
||||
|
||||
Args:
|
||||
token: CSRF token.
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
True if token is valid.
|
||||
"""
|
||||
try:
|
||||
if not token or not user_id:
|
||||
return False
|
||||
|
||||
redis_key = f"{REDIS_KEY_PREFIX}{token}"
|
||||
token_json = self._redis_client.get(redis_key)
|
||||
|
||||
if not token_json:
|
||||
logger.warning(
|
||||
"CSRF token not found in Redis",
|
||||
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
|
||||
user_id=user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
expires_at = datetime.fromisoformat(token_data["expires_at"])
|
||||
if datetime.now() > expires_at:
|
||||
logger.warning(
|
||||
"CSRF token expired",
|
||||
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
|
||||
user_id=user_id,
|
||||
)
|
||||
self._redis_client.delete(redis_key)
|
||||
return False
|
||||
|
||||
if token_data["user_id"] != user_id:
|
||||
logger.warning(
|
||||
"CSRF token user mismatch",
|
||||
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
|
||||
expected_user=user_id,
|
||||
actual_user=token_data["user_id"],
|
||||
)
|
||||
return False
|
||||
|
||||
if not self._verify_token_signature(token, user_id):
|
||||
logger.warning(
|
||||
"CSRF token signature invalid",
|
||||
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
|
||||
user_id=user_id,
|
||||
)
|
||||
self._redis_client.delete(redis_key)
|
||||
return False
|
||||
|
||||
token_data["used"] = True
|
||||
ttl = self._redis_client.ttl(redis_key)
|
||||
if ttl > 0:
|
||||
self._redis_client.setex(redis_key, ttl, json.dumps(token_data))
|
||||
|
||||
logger.debug(
|
||||
"CSRF token validated successfully",
|
||||
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error validating CSRF token",
|
||||
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "..." if token else "none",
|
||||
user_id=user_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def _verify_token_signature(self, token: str, user_id: str) -> bool:
|
||||
"""
|
||||
Verify token signature.
|
||||
|
||||
Args:
|
||||
token: CSRF token.
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
True if signature is valid.
|
||||
"""
|
||||
try:
|
||||
parts = token.split(":")
|
||||
if len(parts) != TOKEN_PARTS_COUNT:
|
||||
return False
|
||||
|
||||
random_hex, timestamp, signature = parts
|
||||
|
||||
data_to_sign = f"{user_id}:{timestamp}:{random_hex}"
|
||||
expected_signature = hashlib.sha256(
|
||||
f"{data_to_sign}:{self._secret_key.hex()}".encode()
|
||||
).hexdigest()
|
||||
|
||||
return signature == expected_signature
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def should_protect_endpoint(self, method: str, path: str) -> bool:
|
||||
"""
|
||||
Check if endpoint needs CSRF protection.
|
||||
|
||||
Args:
|
||||
method: HTTP method.
|
||||
path: Endpoint path.
|
||||
|
||||
Returns:
|
||||
True if CSRF protection is needed.
|
||||
"""
|
||||
if method not in self._protected_methods:
|
||||
return False
|
||||
|
||||
if path in self._excluded_endpoints:
|
||||
return False
|
||||
|
||||
for excluded_path in self._excluded_endpoints:
|
||||
if path.startswith(excluded_path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def cleanup_expired_tokens(self) -> None:
|
||||
"""
|
||||
Clean up expired CSRF tokens from Redis.
|
||||
|
||||
Note: Redis automatically removes keys with expired TTL.
|
||||
"""
|
||||
try:
|
||||
pattern = f"{REDIS_KEY_PREFIX}*"
|
||||
keys = list(self._redis_client.scan_iter(match=pattern, count=SCAN_BATCH_SIZE))
|
||||
|
||||
cleaned_count = 0
|
||||
for key in keys:
|
||||
ttl = self._redis_client.ttl(key)
|
||||
if ttl == TTL_KEY_NOT_EXISTS:
|
||||
cleaned_count += 1
|
||||
elif ttl == TTL_KEY_NO_EXPIRY:
|
||||
self._redis_client.delete(key)
|
||||
cleaned_count += 1
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(
|
||||
"CSRF tokens cleanup completed",
|
||||
cleaned_count=cleaned_count,
|
||||
remaining_count=len(keys) - cleaned_count,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired CSRF tokens", error=str(e))
|
||||
|
||||
def get_csrf_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get CSRF token statistics from Redis.
|
||||
|
||||
Returns:
|
||||
Dictionary with CSRF statistics.
|
||||
"""
|
||||
try:
|
||||
pattern = f"{REDIS_KEY_PREFIX}*"
|
||||
keys = list(self._redis_client.scan_iter(match=pattern, count=SCAN_BATCH_SIZE))
|
||||
|
||||
active_tokens = 0
|
||||
used_tokens = 0
|
||||
|
||||
for key in keys:
|
||||
try:
|
||||
token_json = self._redis_client.get(key)
|
||||
if token_json:
|
||||
token_data = json.loads(token_json)
|
||||
active_tokens += 1
|
||||
if token_data.get("used", False):
|
||||
used_tokens += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {
|
||||
"total_tokens": len(keys),
|
||||
"active_tokens": active_tokens,
|
||||
"used_tokens": used_tokens,
|
||||
"token_ttl_seconds": self._csrf_token_ttl,
|
||||
"protected_methods": sorted(self._protected_methods),
|
||||
"excluded_endpoints": sorted(self._excluded_endpoints),
|
||||
"storage": "Redis",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to get CSRF stats", error=str(e))
|
||||
return {"error": str(e), "storage": "Redis"}
|
||||
|
||||
def revoke_user_tokens(self, user_id: str) -> None:
|
||||
"""
|
||||
Revoke all user tokens from Redis.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
"""
|
||||
try:
|
||||
pattern = f"{REDIS_KEY_PREFIX}*"
|
||||
keys = list(self._redis_client.scan_iter(match=pattern, count=SCAN_BATCH_SIZE))
|
||||
|
||||
revoked_count = 0
|
||||
for key in keys:
|
||||
try:
|
||||
token_json = self._redis_client.get(key)
|
||||
if token_json:
|
||||
token_data = json.loads(token_json)
|
||||
if token_data.get("user_id") == user_id:
|
||||
self._redis_client.delete(key)
|
||||
revoked_count += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if revoked_count > 0:
|
||||
logger.info(
|
||||
"Revoked user CSRF tokens from Redis",
|
||||
user_id=user_id,
|
||||
count=revoked_count,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to revoke user CSRF tokens", user_id=user_id, error=str(e)
|
||||
)
|
||||
|
||||
|
||||
csrf_protection = CSRFProtection()
|
||||
485
guacamole_test_11_26/api/core/guacamole_auth.py
Executable file
485
guacamole_test_11_26/api/core/guacamole_auth.py
Executable file
@ -0,0 +1,485 @@
|
||||
"""Integration with Guacamole API for authentication and user management."""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import structlog
|
||||
|
||||
from .models import UserRole
|
||||
from .permissions import PermissionChecker
|
||||
from .session_storage import session_storage
|
||||
from .utils import create_jwt_token
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class GuacamoleAuthenticator:
|
||||
"""Class for authentication via Guacamole API."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize Guacamole authenticator.
|
||||
|
||||
Raises:
|
||||
ValueError: If system credentials are not set in environment variables.
|
||||
"""
|
||||
self.base_url = os.getenv("GUACAMOLE_URL", "http://guacamole:8080")
|
||||
self.session = requests.Session()
|
||||
|
||||
self._system_token: Optional[str] = None
|
||||
self._system_token_expires: Optional[datetime] = None
|
||||
|
||||
self._system_username = os.getenv("SYSTEM_ADMIN_USERNAME")
|
||||
self._system_password = os.getenv("SYSTEM_ADMIN_PASSWORD")
|
||||
|
||||
if not self._system_username or not self._system_password:
|
||||
raise ValueError(
|
||||
"SYSTEM_ADMIN_USERNAME and SYSTEM_ADMIN_PASSWORD environment "
|
||||
"variables are required. Set these in your .env or "
|
||||
"production.env file for security. Never use default "
|
||||
"credentials in production!"
|
||||
)
|
||||
|
||||
def get_system_token(self) -> str:
|
||||
"""
|
||||
Get system user token for administrative operations.
|
||||
|
||||
Returns:
|
||||
System user token.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If system user authentication fails.
|
||||
"""
|
||||
if (
|
||||
self._system_token is None
|
||||
or self._system_token_expires is None
|
||||
or self._system_token_expires <= datetime.now()
|
||||
):
|
||||
logger.debug("Refreshing system token", username=self._system_username)
|
||||
|
||||
auth_url = f"{self.base_url}/guacamole/api/tokens"
|
||||
auth_data = {
|
||||
"username": self._system_username,
|
||||
"password": self._system_password,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.session.post(auth_url, data=auth_data, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
auth_result = response.json()
|
||||
self._system_token = auth_result.get("authToken")
|
||||
|
||||
if not self._system_token:
|
||||
raise RuntimeError("No authToken in response")
|
||||
|
||||
self._system_token_expires = datetime.now() + timedelta(hours=7)
|
||||
|
||||
logger.info(
|
||||
"System token refreshed successfully",
|
||||
username=self._system_username,
|
||||
expires_at=self._system_token_expires.isoformat(),
|
||||
)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(
|
||||
"Failed to authenticate system user",
|
||||
username=self._system_username,
|
||||
error=str(e),
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to authenticate system user: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error during system authentication",
|
||||
username=self._system_username,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
return self._system_token
|
||||
|
||||
def authenticate_user(
|
||||
self, username: str, password: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Authenticate user via Guacamole API.
|
||||
|
||||
Args:
|
||||
username: Username in Guacamole.
|
||||
password: Password in Guacamole.
|
||||
|
||||
Returns:
|
||||
Dictionary with user information or None if authentication fails.
|
||||
"""
|
||||
auth_url = f"{self.base_url}/guacamole/api/tokens"
|
||||
auth_data = {"username": username, "password": password}
|
||||
|
||||
try:
|
||||
logger.debug("Attempting user authentication", username=username)
|
||||
|
||||
response = self.session.post(auth_url, data=auth_data, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.info(
|
||||
"Authentication failed",
|
||||
username=username,
|
||||
status_code=response.status_code,
|
||||
response=response.text[:200],
|
||||
)
|
||||
return None
|
||||
|
||||
auth_result = response.json()
|
||||
auth_token = auth_result.get("authToken")
|
||||
|
||||
if not auth_token:
|
||||
logger.warning(
|
||||
"No authToken in successful response",
|
||||
username=username,
|
||||
response=auth_result,
|
||||
)
|
||||
return None
|
||||
|
||||
user_info = self.get_user_info(auth_token)
|
||||
if not user_info:
|
||||
logger.warning(
|
||||
"Failed to get user info after authentication", username=username
|
||||
)
|
||||
return None
|
||||
|
||||
system_permissions = user_info.get("systemPermissions", [])
|
||||
user_role = PermissionChecker.determine_role_from_permissions(
|
||||
system_permissions
|
||||
)
|
||||
|
||||
result = {
|
||||
"username": username,
|
||||
"auth_token": auth_token,
|
||||
"role": user_role.value,
|
||||
"permissions": system_permissions,
|
||||
"full_name": user_info.get("fullName"),
|
||||
"email": user_info.get("emailAddress"),
|
||||
"organization": user_info.get("organization"),
|
||||
"organizational_role": user_info.get("organizationalRole"),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"User authenticated successfully",
|
||||
username=username,
|
||||
role=user_role.value,
|
||||
permissions_count=len(system_permissions),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(
|
||||
"Network error during authentication", username=username, error=str(e)
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error during authentication", username=username, error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
def get_user_info(self, auth_token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get user information via Guacamole API.
|
||||
|
||||
Args:
|
||||
auth_token: User authentication token.
|
||||
|
||||
Returns:
|
||||
Dictionary with user information or None.
|
||||
"""
|
||||
user_url = f"{self.base_url}/guacamole/api/session/data/postgresql/self"
|
||||
headers = {"Guacamole-Token": auth_token}
|
||||
|
||||
try:
|
||||
response = self.session.get(user_url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to get user info",
|
||||
status_code=response.status_code,
|
||||
response=response.text[:200],
|
||||
)
|
||||
return None
|
||||
|
||||
user_data = response.json()
|
||||
username = user_data.get("username")
|
||||
|
||||
if not username:
|
||||
logger.warning("No username in user info response")
|
||||
return None
|
||||
|
||||
permissions_url = (
|
||||
f"{self.base_url}/guacamole/api/session/data/postgresql/"
|
||||
f"users/{username}/permissions"
|
||||
)
|
||||
|
||||
try:
|
||||
perm_response = self.session.get(
|
||||
permissions_url, headers=headers, timeout=10
|
||||
)
|
||||
|
||||
if perm_response.status_code == 200:
|
||||
permissions_data = perm_response.json()
|
||||
system_permissions = permissions_data.get("systemPermissions", [])
|
||||
|
||||
logger.info(
|
||||
"System permissions retrieved",
|
||||
username=username,
|
||||
system_permissions=system_permissions,
|
||||
permissions_count=len(system_permissions),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to get user permissions",
|
||||
username=username,
|
||||
status_code=perm_response.status_code,
|
||||
response=perm_response.text[:200],
|
||||
)
|
||||
system_permissions = []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error getting user permissions", username=username, error=str(e)
|
||||
)
|
||||
system_permissions = []
|
||||
|
||||
user_data["systemPermissions"] = system_permissions
|
||||
|
||||
attributes = user_data.get("attributes", {})
|
||||
user_data.update(
|
||||
{
|
||||
"fullName": attributes.get("guac-full-name"),
|
||||
"emailAddress": attributes.get("guac-email-address"),
|
||||
"organization": attributes.get("guac-organization"),
|
||||
"organizationalRole": attributes.get("guac-organizational-role"),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"User info retrieved successfully",
|
||||
username=username,
|
||||
system_permissions=system_permissions,
|
||||
permissions_count=len(system_permissions),
|
||||
)
|
||||
|
||||
return user_data
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("Network error getting user info", error=str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error getting user info", error=str(e))
|
||||
return None
|
||||
|
||||
def create_jwt_for_user(self, user_info: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create JWT token for user with session storage.
|
||||
|
||||
Args:
|
||||
user_info: User information from authenticate_user.
|
||||
|
||||
Returns:
|
||||
JWT token.
|
||||
"""
|
||||
session_id = session_storage.create_session(
|
||||
user_info=user_info,
|
||||
guac_token=user_info["auth_token"],
|
||||
expires_in_minutes=int(
|
||||
os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "60")
|
||||
),
|
||||
)
|
||||
|
||||
return create_jwt_token(user_info, session_id)
|
||||
|
||||
def get_user_connections(self, auth_token: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get list of user connections.
|
||||
|
||||
Args:
|
||||
auth_token: User authentication token.
|
||||
|
||||
Returns:
|
||||
List of connections.
|
||||
"""
|
||||
connections_url = (
|
||||
f"{self.base_url}/guacamole/api/session/data/postgresql/connections"
|
||||
)
|
||||
headers = {"Guacamole-Token": auth_token}
|
||||
|
||||
try:
|
||||
response = self.session.get(connections_url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to get user connections", status_code=response.status_code
|
||||
)
|
||||
return []
|
||||
|
||||
connections_data = response.json()
|
||||
|
||||
if isinstance(connections_data, dict):
|
||||
connections = list(connections_data.values())
|
||||
else:
|
||||
connections = connections_data
|
||||
|
||||
logger.debug("Retrieved user connections", count=len(connections))
|
||||
|
||||
return connections
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("Network error getting connections", error=str(e))
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error getting connections", error=str(e))
|
||||
return []
|
||||
|
||||
def create_connection_with_token(
|
||||
self, connection_config: Dict[str, Any], auth_token: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Create connection using user token.
|
||||
|
||||
Args:
|
||||
connection_config: Connection configuration.
|
||||
auth_token: User authentication token.
|
||||
|
||||
Returns:
|
||||
Information about created connection or None.
|
||||
"""
|
||||
create_url = (
|
||||
f"{self.base_url}/guacamole/api/session/data/postgresql/connections"
|
||||
)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Guacamole-Token": auth_token,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.session.post(
|
||||
create_url, headers=headers, json=connection_config, timeout=30
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
logger.error(
|
||||
"Failed to create connection",
|
||||
status_code=response.status_code,
|
||||
response=response.text[:500],
|
||||
)
|
||||
return None
|
||||
|
||||
created_connection = response.json()
|
||||
connection_id = created_connection.get("identifier")
|
||||
|
||||
if not connection_id:
|
||||
logger.error(
|
||||
"No connection ID in response", response=created_connection
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Connection created successfully",
|
||||
connection_id=connection_id,
|
||||
protocol=connection_config.get("protocol"),
|
||||
hostname=connection_config.get("parameters", {}).get("hostname"),
|
||||
)
|
||||
|
||||
return created_connection
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("Network error creating connection", error=str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error creating connection", error=str(e))
|
||||
return None
|
||||
|
||||
def delete_connection_with_token(
|
||||
self, connection_id: str, auth_token: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete connection using user token.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID to delete.
|
||||
auth_token: User authentication token.
|
||||
|
||||
Returns:
|
||||
True if deletion successful, False otherwise.
|
||||
"""
|
||||
delete_url = (
|
||||
f"{self.base_url}/guacamole/api/session/data/postgresql/"
|
||||
f"connections/{connection_id}"
|
||||
)
|
||||
headers = {"Guacamole-Token": auth_token}
|
||||
|
||||
try:
|
||||
response = self.session.delete(delete_url, headers=headers, timeout=10)
|
||||
|
||||
if response.status_code == 204:
|
||||
logger.info("Connection deleted successfully", connection_id=connection_id)
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Failed to delete connection",
|
||||
connection_id=connection_id,
|
||||
status_code=response.status_code,
|
||||
response=response.text[:200],
|
||||
)
|
||||
return False
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(
|
||||
"Network error deleting connection",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error deleting connection",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def validate_token(self, auth_token: str) -> bool:
|
||||
"""
|
||||
Validate Guacamole token.
|
||||
|
||||
Args:
|
||||
auth_token: Token to validate.
|
||||
|
||||
Returns:
|
||||
True if token is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
user_info = self.get_user_info(auth_token)
|
||||
return user_info is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def refresh_user_token(
|
||||
self, username: str, current_token: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Refresh user token (if supported by Guacamole).
|
||||
|
||||
Args:
|
||||
username: Username.
|
||||
current_token: Current token.
|
||||
|
||||
Returns:
|
||||
New token or None.
|
||||
"""
|
||||
logger.debug(
|
||||
"Token refresh requested but not supported by Guacamole", username=username
|
||||
)
|
||||
return None
|
||||
474
guacamole_test_11_26/api/core/kms_provider.py
Executable file
474
guacamole_test_11_26/api/core/kms_provider.py
Executable file
@ -0,0 +1,474 @@
|
||||
"""Module for working with real KMS/HSM systems."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KMSProvider(ABC):
|
||||
"""Abstract class for KMS providers."""
|
||||
|
||||
@abstractmethod
|
||||
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
|
||||
"""
|
||||
Encrypt data using KMS.
|
||||
|
||||
Args:
|
||||
plaintext: Data to encrypt.
|
||||
key_id: Key identifier.
|
||||
|
||||
Returns:
|
||||
Encrypted data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
|
||||
"""
|
||||
Decrypt data using KMS.
|
||||
|
||||
Args:
|
||||
ciphertext: Encrypted data.
|
||||
key_id: Key identifier.
|
||||
|
||||
Returns:
|
||||
Decrypted data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_data_key(
|
||||
self, key_id: str, key_spec: str = "AES_256"
|
||||
) -> Dict[str, bytes]:
|
||||
"""
|
||||
Generate data encryption key.
|
||||
|
||||
Args:
|
||||
key_id: Key identifier.
|
||||
key_spec: Key specification.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'plaintext' and 'ciphertext' keys.
|
||||
"""
|
||||
pass
|
||||
|
||||
class AWSKMSProvider(KMSProvider):
|
||||
"""AWS KMS provider."""
|
||||
|
||||
def __init__(self, region_name: str = "us-east-1") -> None:
|
||||
"""
|
||||
Initialize AWS KMS provider.
|
||||
|
||||
Args:
|
||||
region_name: AWS region name.
|
||||
"""
|
||||
self.kms_client = boto3.client("kms", region_name=region_name)
|
||||
self.region_name = region_name
|
||||
|
||||
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
|
||||
"""Encrypt data using AWS KMS."""
|
||||
try:
|
||||
response = self.kms_client.encrypt(KeyId=key_id, Plaintext=plaintext)
|
||||
return response["CiphertextBlob"]
|
||||
except ClientError as e:
|
||||
logger.error("AWS KMS encryption failed: %s", e)
|
||||
raise
|
||||
|
||||
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
|
||||
"""Decrypt data using AWS KMS."""
|
||||
try:
|
||||
response = self.kms_client.decrypt(
|
||||
CiphertextBlob=ciphertext, KeyId=key_id
|
||||
)
|
||||
return response["Plaintext"]
|
||||
except ClientError as e:
|
||||
logger.error("AWS KMS decryption failed: %s", e)
|
||||
raise
|
||||
|
||||
def generate_data_key(
|
||||
self, key_id: str, key_spec: str = "AES_256"
|
||||
) -> Dict[str, bytes]:
|
||||
"""Generate data encryption key."""
|
||||
try:
|
||||
response = self.kms_client.generate_data_key(
|
||||
KeyId=key_id, KeySpec=key_spec
|
||||
)
|
||||
return {
|
||||
"plaintext": response["Plaintext"],
|
||||
"ciphertext": response["CiphertextBlob"],
|
||||
}
|
||||
except ClientError as e:
|
||||
logger.error("AWS KMS data key generation failed: %s", e)
|
||||
raise
|
||||
|
||||
class GoogleCloudKMSProvider(KMSProvider):
|
||||
"""Google Cloud KMS provider."""
|
||||
|
||||
def __init__(self, project_id: str, location: str = "global") -> None:
|
||||
"""
|
||||
Initialize Google Cloud KMS provider.
|
||||
|
||||
Args:
|
||||
project_id: Google Cloud project ID.
|
||||
location: Key location.
|
||||
"""
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.base_url = (
|
||||
f"https://cloudkms.googleapis.com/v1/projects/{project_id}"
|
||||
f"/locations/{location}"
|
||||
)
|
||||
|
||||
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
|
||||
"""Encrypt data using Google Cloud KMS."""
|
||||
try:
|
||||
url = (
|
||||
f"{self.base_url}/keyRings/default/cryptoKeys/{key_id}:encrypt"
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"plaintext": base64.b64encode(plaintext).decode()},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._get_access_token()}"
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return base64.b64decode(response.json()["ciphertext"])
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("Google Cloud KMS encryption failed: %s", e)
|
||||
raise RuntimeError(
|
||||
f"Google Cloud KMS encryption failed: {e}"
|
||||
) from e
|
||||
|
||||
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
|
||||
"""Decrypt data using Google Cloud KMS."""
|
||||
try:
|
||||
url = (
|
||||
f"{self.base_url}/keyRings/default/cryptoKeys/{key_id}:decrypt"
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"ciphertext": base64.b64encode(ciphertext).decode()},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._get_access_token()}"
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return base64.b64decode(response.json()["plaintext"])
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("Google Cloud KMS decryption failed: %s", e)
|
||||
raise RuntimeError(
|
||||
f"Google Cloud KMS decryption failed: {e}"
|
||||
) from e
|
||||
|
||||
def generate_data_key(
|
||||
self, key_id: str, key_spec: str = "AES_256"
|
||||
) -> Dict[str, bytes]:
|
||||
"""Generate data encryption key."""
|
||||
try:
|
||||
url = (
|
||||
f"{self.base_url}/keyRings/default/cryptoKeys/{key_id}"
|
||||
":generateDataKey"
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"keySpec": key_spec},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._get_access_token()}"
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return {
|
||||
"plaintext": base64.b64decode(data["plaintext"]),
|
||||
"ciphertext": base64.b64decode(data["ciphertext"]),
|
||||
}
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("Google Cloud KMS data key generation failed: %s", e)
|
||||
raise RuntimeError(
|
||||
f"Google Cloud KMS data key generation failed: {e}"
|
||||
) from e
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
"""
|
||||
Get access token for Google Cloud API.
|
||||
|
||||
Note: In production, use service account or metadata server.
|
||||
"""
|
||||
return os.getenv("GOOGLE_CLOUD_ACCESS_TOKEN", "")
|
||||
|
||||
class YubiHSMProvider(KMSProvider):
|
||||
"""YubiHSM provider (hardware security module)."""
|
||||
|
||||
def __init__(self, hsm_url: str, auth_key_id: int) -> None:
|
||||
"""
|
||||
Initialize YubiHSM provider.
|
||||
|
||||
Args:
|
||||
hsm_url: YubiHSM URL.
|
||||
auth_key_id: Authentication key ID.
|
||||
"""
|
||||
self.hsm_url = hsm_url
|
||||
self.auth_key_id = auth_key_id
|
||||
|
||||
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
|
||||
"""Encrypt data using YubiHSM."""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.hsm_url}/api/v1/encrypt",
|
||||
json={"key_id": key_id, "plaintext": plaintext.hex()},
|
||||
headers={"Authorization": f"Bearer {self._get_hsm_token()}"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return bytes.fromhex(response.json()["ciphertext"])
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("YubiHSM encryption failed: %s", e)
|
||||
raise RuntimeError(f"YubiHSM encryption failed: {e}") from e
|
||||
|
||||
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
|
||||
"""Decrypt data using YubiHSM."""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.hsm_url}/api/v1/decrypt",
|
||||
json={"key_id": key_id, "ciphertext": ciphertext.hex()},
|
||||
headers={"Authorization": f"Bearer {self._get_hsm_token()}"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return bytes.fromhex(response.json()["plaintext"])
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("YubiHSM decryption failed: %s", e)
|
||||
raise RuntimeError(f"YubiHSM decryption failed: {e}") from e
|
||||
|
||||
def generate_data_key(
|
||||
self, key_id: str, key_spec: str = "AES_256"
|
||||
) -> Dict[str, bytes]:
|
||||
"""Generate data encryption key."""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.hsm_url}/api/v1/generate-data-key",
|
||||
json={"key_id": key_id, "key_spec": key_spec},
|
||||
headers={"Authorization": f"Bearer {self._get_hsm_token()}"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return {
|
||||
"plaintext": bytes.fromhex(data["plaintext"]),
|
||||
"ciphertext": bytes.fromhex(data["ciphertext"]),
|
||||
}
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("YubiHSM data key generation failed: %s", e)
|
||||
raise RuntimeError(
|
||||
f"YubiHSM data key generation failed: {e}"
|
||||
) from e
|
||||
|
||||
def _get_hsm_token(self) -> str:
|
||||
"""
|
||||
Get token for YubiHSM.
|
||||
|
||||
Note: In production, use proper YubiHSM authentication.
|
||||
"""
|
||||
return os.getenv("YUBIHSM_TOKEN", "")
|
||||
|
||||
class SecureKeyManager:
|
||||
"""Key manager using real KMS/HSM systems."""
|
||||
|
||||
def __init__(self, kms_provider: KMSProvider, master_key_id: str) -> None:
|
||||
"""
|
||||
Initialize secure key manager.
|
||||
|
||||
Args:
|
||||
kms_provider: KMS provider instance.
|
||||
master_key_id: Master key identifier.
|
||||
"""
|
||||
self.kms_provider = kms_provider
|
||||
self.master_key_id = master_key_id
|
||||
self.key_cache: Dict[str, bytes] = {}
|
||||
|
||||
def encrypt_session_key(self, session_key: bytes, session_id: str) -> bytes:
|
||||
"""
|
||||
Encrypt session key using KMS/HSM.
|
||||
|
||||
Args:
|
||||
session_key: Session key to encrypt.
|
||||
session_id: Session ID for context.
|
||||
|
||||
Returns:
|
||||
Encrypted session key.
|
||||
"""
|
||||
try:
|
||||
encrypted_key = self.kms_provider.encrypt(
|
||||
session_key, self.master_key_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Session key encrypted with KMS/HSM",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"key_length": len(session_key),
|
||||
"encrypted_length": len(encrypted_key),
|
||||
},
|
||||
)
|
||||
|
||||
return encrypted_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to encrypt session key with KMS/HSM",
|
||||
extra={"session_id": session_id, "error": str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
def decrypt_session_key(
|
||||
self, encrypted_session_key: bytes, session_id: str
|
||||
) -> bytes:
|
||||
"""
|
||||
Decrypt session key using KMS/HSM.
|
||||
|
||||
Args:
|
||||
encrypted_session_key: Encrypted session key.
|
||||
session_id: Session ID for context.
|
||||
|
||||
Returns:
|
||||
Decrypted session key.
|
||||
"""
|
||||
try:
|
||||
decrypted_key = self.kms_provider.decrypt(
|
||||
encrypted_session_key, self.master_key_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Session key decrypted with KMS/HSM",
|
||||
extra={"session_id": session_id, "key_length": len(decrypted_key)},
|
||||
)
|
||||
|
||||
return decrypted_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to decrypt session key with KMS/HSM",
|
||||
extra={"session_id": session_id, "error": str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
def generate_encryption_key(self, session_id: str) -> Dict[str, bytes]:
|
||||
"""
|
||||
Generate encryption key using KMS/HSM.
|
||||
|
||||
Args:
|
||||
session_id: Session ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'plaintext' and 'ciphertext' keys.
|
||||
"""
|
||||
try:
|
||||
key_data = self.kms_provider.generate_data_key(self.master_key_id)
|
||||
|
||||
logger.info(
|
||||
"Encryption key generated with KMS/HSM",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"key_length": len(key_data["plaintext"]),
|
||||
},
|
||||
)
|
||||
|
||||
return key_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to generate encryption key with KMS/HSM",
|
||||
extra={"session_id": session_id, "error": str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class KMSProviderFactory:
|
||||
"""Factory for creating KMS providers."""
|
||||
|
||||
@staticmethod
|
||||
def create_provider(provider_type: str, **kwargs: Any) -> KMSProvider:
|
||||
"""
|
||||
Create KMS provider by type.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type ('aws', 'gcp', 'yubihsm').
|
||||
**kwargs: Provider-specific arguments.
|
||||
|
||||
Returns:
|
||||
KMS provider instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If provider type is unsupported.
|
||||
"""
|
||||
if provider_type == "aws":
|
||||
return AWSKMSProvider(**kwargs)
|
||||
if provider_type == "gcp":
|
||||
return GoogleCloudKMSProvider(**kwargs)
|
||||
if provider_type == "yubihsm":
|
||||
return YubiHSMProvider(**kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported KMS provider: {provider_type}")
|
||||
|
||||
|
||||
def get_secure_key_manager() -> SecureKeyManager:
|
||||
"""
|
||||
Get configured secure key manager.
|
||||
|
||||
Returns:
|
||||
Configured SecureKeyManager instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If provider type is unsupported.
|
||||
"""
|
||||
provider_type = os.getenv("KMS_PROVIDER", "aws")
|
||||
master_key_id = os.getenv("KMS_MASTER_KEY_ID", "alias/session-keys")
|
||||
|
||||
if provider_type == "aws":
|
||||
provider = AWSKMSProvider(
|
||||
region_name=os.getenv("AWS_REGION", "us-east-1")
|
||||
)
|
||||
elif provider_type == "gcp":
|
||||
provider = GoogleCloudKMSProvider(
|
||||
project_id=os.getenv("GCP_PROJECT_ID", ""),
|
||||
location=os.getenv("GCP_LOCATION", "global"),
|
||||
)
|
||||
elif provider_type == "yubihsm":
|
||||
provider = YubiHSMProvider(
|
||||
hsm_url=os.getenv("YUBIHSM_URL", ""),
|
||||
auth_key_id=int(os.getenv("YUBIHSM_AUTH_KEY_ID", "0")),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported KMS provider: {provider_type}")
|
||||
|
||||
return SecureKeyManager(provider, master_key_id)
|
||||
|
||||
|
||||
secure_key_manager = get_secure_key_manager()
|
||||
286
guacamole_test_11_26/api/core/log_sanitizer.py
Executable file
286
guacamole_test_11_26/api/core/log_sanitizer.py
Executable file
@ -0,0 +1,286 @@
|
||||
"""Log sanitization for removing sensitive information from logs."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class LogSanitizer:
|
||||
"""Class for cleaning logs from sensitive information."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize LogSanitizer with sensitive fields and patterns."""
|
||||
self.sensitive_fields = {
|
||||
"password",
|
||||
"passwd",
|
||||
"pwd",
|
||||
"secret",
|
||||
"token",
|
||||
"key",
|
||||
"auth_token",
|
||||
"guac_token",
|
||||
"jwt_token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"api_key",
|
||||
"private_key",
|
||||
"encryption_key",
|
||||
"session_id",
|
||||
"cookie",
|
||||
"authorization",
|
||||
"credential",
|
||||
"credentials",
|
||||
"global_credentials",
|
||||
"machine_credentials",
|
||||
"ssh_password",
|
||||
"ssh_username",
|
||||
"credential_hash",
|
||||
"password_hash",
|
||||
"password_salt",
|
||||
"encrypted_password",
|
||||
}
|
||||
|
||||
self.sensitive_patterns = [
|
||||
r'password["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
|
||||
r'token["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
|
||||
r'key["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
|
||||
r'secret["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
|
||||
r'authorization["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
|
||||
]
|
||||
|
||||
self.jwt_pattern = re.compile(
|
||||
r"\b[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b"
|
||||
)
|
||||
self.api_key_pattern = re.compile(r"\b[A-Za-z0-9]{32,}\b")
|
||||
|
||||
def mask_sensitive_value(self, value: str, mask_char: str = "*") -> str:
|
||||
"""
|
||||
Mask sensitive value.
|
||||
|
||||
Args:
|
||||
value: Value to mask.
|
||||
mask_char: Character to use for masking.
|
||||
|
||||
Returns:
|
||||
Masked value.
|
||||
"""
|
||||
if not value or len(value) <= 4:
|
||||
return mask_char * 4
|
||||
|
||||
if len(value) <= 8:
|
||||
return value[:2] + mask_char * (len(value) - 4) + value[-2:]
|
||||
|
||||
return value[:4] + mask_char * (len(value) - 8) + value[-4:]
|
||||
|
||||
def sanitize_string(self, text: str) -> str:
|
||||
"""
|
||||
Clean string from sensitive information.
|
||||
|
||||
Args:
|
||||
text: Text to clean.
|
||||
|
||||
Returns:
|
||||
Cleaned text.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return text
|
||||
|
||||
sanitized = text
|
||||
|
||||
sanitized = self.jwt_pattern.sub(
|
||||
lambda m: self.mask_sensitive_value(m.group(0)), sanitized
|
||||
)
|
||||
|
||||
sanitized = self.api_key_pattern.sub(
|
||||
lambda m: self.mask_sensitive_value(m.group(0)), sanitized
|
||||
)
|
||||
|
||||
for pattern in self.sensitive_patterns:
|
||||
sanitized = re.sub(
|
||||
pattern,
|
||||
lambda m: m.group(0).replace(
|
||||
m.group(1), self.mask_sensitive_value(m.group(1))
|
||||
),
|
||||
sanitized,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
return sanitized
|
||||
|
||||
def sanitize_dict(
|
||||
self, data: Dict[str, Any], max_depth: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursively clean dictionary from sensitive information.
|
||||
|
||||
Args:
|
||||
data: Dictionary to clean.
|
||||
max_depth: Maximum recursion depth.
|
||||
|
||||
Returns:
|
||||
Cleaned dictionary.
|
||||
"""
|
||||
if max_depth <= 0:
|
||||
return {"error": "max_depth_exceeded"}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
sanitized = {}
|
||||
|
||||
for key, value in data.items():
|
||||
key_lower = key.lower()
|
||||
is_sensitive_key = any(
|
||||
sensitive_field in key_lower
|
||||
for sensitive_field in self.sensitive_fields
|
||||
)
|
||||
|
||||
if is_sensitive_key:
|
||||
if isinstance(value, str):
|
||||
sanitized[key] = self.mask_sensitive_value(value)
|
||||
elif isinstance(value, (dict, list)):
|
||||
sanitized[key] = self.sanitize_value(value, max_depth - 1)
|
||||
else:
|
||||
sanitized[key] = "[MASKED]"
|
||||
else:
|
||||
sanitized[key] = self.sanitize_value(value, max_depth - 1)
|
||||
|
||||
return sanitized
|
||||
|
||||
def sanitize_value(self, value: Any, max_depth: int = 10) -> Any:
|
||||
"""
|
||||
Clean value of any type.
|
||||
|
||||
Args:
|
||||
value: Value to clean.
|
||||
max_depth: Maximum recursion depth.
|
||||
|
||||
Returns:
|
||||
Cleaned value.
|
||||
"""
|
||||
if max_depth <= 0:
|
||||
return "[max_depth_exceeded]"
|
||||
|
||||
if isinstance(value, str):
|
||||
return self.sanitize_string(value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return self.sanitize_dict(value, max_depth)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [
|
||||
self.sanitize_value(item, max_depth - 1) for item in value
|
||||
]
|
||||
|
||||
if isinstance(value, (int, float, bool, type(None))):
|
||||
return value
|
||||
|
||||
return self.sanitize_string(str(value))
|
||||
|
||||
def sanitize_log_event(
|
||||
self, event_dict: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Clean log event from sensitive information.
|
||||
|
||||
Args:
|
||||
event_dict: Log event dictionary.
|
||||
|
||||
Returns:
|
||||
Cleaned event dictionary.
|
||||
"""
|
||||
try:
|
||||
sanitized_event = event_dict.copy()
|
||||
sanitized_event = self.sanitize_dict(sanitized_event)
|
||||
|
||||
special_fields = ["request_body", "response_body", "headers"]
|
||||
for field in special_fields:
|
||||
if field in sanitized_event:
|
||||
sanitized_event[field] = self.sanitize_value(
|
||||
sanitized_event[field]
|
||||
)
|
||||
|
||||
return sanitized_event
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sanitizing log event", error=str(e))
|
||||
return {
|
||||
"error": "sanitization_failed",
|
||||
"original_error": str(e),
|
||||
}
|
||||
|
||||
def sanitize_json_string(self, json_string: str) -> str:
|
||||
"""
|
||||
Clean JSON string from sensitive information.
|
||||
|
||||
Args:
|
||||
json_string: JSON string to clean.
|
||||
|
||||
Returns:
|
||||
Cleaned JSON string.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(json_string)
|
||||
sanitized_data = self.sanitize_value(data)
|
||||
return json.dumps(sanitized_data, ensure_ascii=False)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return self.sanitize_string(json_string)
|
||||
except Exception as e:
|
||||
logger.error("Error sanitizing JSON string", error=str(e))
|
||||
return json_string
|
||||
|
||||
def is_sensitive_field(self, field_name: str) -> bool:
|
||||
"""
|
||||
Check if field is sensitive.
|
||||
|
||||
Args:
|
||||
field_name: Field name.
|
||||
|
||||
Returns:
|
||||
True if field is sensitive.
|
||||
"""
|
||||
field_lower = field_name.lower()
|
||||
return any(
|
||||
sensitive_field in field_lower
|
||||
for sensitive_field in self.sensitive_fields
|
||||
)
|
||||
|
||||
def get_sanitization_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get sanitization statistics.
|
||||
|
||||
Returns:
|
||||
Sanitization statistics dictionary.
|
||||
"""
|
||||
return {
|
||||
"sensitive_fields_count": len(self.sensitive_fields),
|
||||
"sensitive_patterns_count": len(self.sensitive_patterns),
|
||||
"sensitive_fields": list(self.sensitive_fields),
|
||||
"jwt_pattern_active": bool(self.jwt_pattern),
|
||||
"api_key_pattern_active": bool(self.api_key_pattern),
|
||||
}
|
||||
|
||||
|
||||
log_sanitizer = LogSanitizer()
|
||||
|
||||
|
||||
def sanitize_log_processor(
|
||||
logger: Any, name: str, event_dict: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Processor for structlog for automatic log sanitization.
|
||||
|
||||
Args:
|
||||
logger: Logger instance.
|
||||
name: Logger name.
|
||||
event_dict: Event dictionary.
|
||||
|
||||
Returns:
|
||||
Sanitized event dictionary.
|
||||
"""
|
||||
return log_sanitizer.sanitize_log_event(event_dict)
|
||||
296
guacamole_test_11_26/api/core/middleware.py
Executable file
296
guacamole_test_11_26/api/core/middleware.py
Executable file
@ -0,0 +1,296 @@
|
||||
"""Authentication and authorization middleware."""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
import structlog
|
||||
|
||||
from .models import UserRole
|
||||
from .permissions import PermissionChecker
|
||||
from .session_storage import session_storage
|
||||
from .utils import extract_token_from_header, verify_jwt_token
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Public endpoints that don't require authentication
|
||||
PUBLIC_PATHS = {
|
||||
"/",
|
||||
"/api/health",
|
||||
"/api/docs",
|
||||
"/api/openapi.json",
|
||||
"/api/redoc",
|
||||
"/favicon.ico",
|
||||
"/api/auth/login",
|
||||
}
|
||||
|
||||
# Static file prefixes for FastAPI (Swagger UI)
|
||||
STATIC_PREFIXES = [
|
||||
"/static/",
|
||||
"/docs/",
|
||||
"/redoc/",
|
||||
"/api/static/",
|
||||
"/api/docs/",
|
||||
"/api/redoc/",
|
||||
]
|
||||
|
||||
|
||||
async def jwt_auth_middleware(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""
|
||||
Middleware for JWT token verification and user authentication.
|
||||
|
||||
Supports JWT token in Authorization header: Bearer <token>
|
||||
"""
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path in PUBLIC_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
if any(path.startswith(prefix) for prefix in STATIC_PREFIXES):
|
||||
return await call_next(request)
|
||||
|
||||
user_token: Optional[str] = None
|
||||
user_info: Optional[Dict[str, Any]] = None
|
||||
auth_method: Optional[str] = None
|
||||
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
jwt_token = extract_token_from_header(auth_header)
|
||||
if jwt_token:
|
||||
jwt_payload = verify_jwt_token(jwt_token)
|
||||
if jwt_payload:
|
||||
session_id = jwt_payload.get("session_id")
|
||||
if session_id:
|
||||
session_data = session_storage.get_session(session_id)
|
||||
if session_data:
|
||||
user_token = session_data.get("guac_token")
|
||||
else:
|
||||
logger.warning(
|
||||
"Session not found in Redis",
|
||||
session_id=session_id,
|
||||
username=jwt_payload.get("username"),
|
||||
)
|
||||
else:
|
||||
user_token = jwt_payload.get("guac_token")
|
||||
|
||||
user_info = {
|
||||
"username": jwt_payload["username"],
|
||||
"role": jwt_payload["role"],
|
||||
"permissions": jwt_payload.get("permissions", []),
|
||||
"full_name": jwt_payload.get("full_name"),
|
||||
"email": jwt_payload.get("email"),
|
||||
"organization": jwt_payload.get("organization"),
|
||||
"organizational_role": jwt_payload.get("organizational_role"),
|
||||
}
|
||||
auth_method = "jwt"
|
||||
|
||||
logger.debug(
|
||||
"JWT authentication successful",
|
||||
username=user_info["username"],
|
||||
role=user_info["role"],
|
||||
has_session=session_id is not None,
|
||||
has_token=user_token is not None,
|
||||
)
|
||||
|
||||
if not user_token or not user_info:
|
||||
logger.info(
|
||||
"Authentication required",
|
||||
path=path,
|
||||
method=request.method,
|
||||
client_ip=request.client.host if request.client else "unknown",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "Authentication required",
|
||||
"message": (
|
||||
"Provide JWT token in Authorization header. "
|
||||
"Get token via /auth/login"
|
||||
),
|
||||
"login_endpoint": "/auth/login",
|
||||
},
|
||||
)
|
||||
|
||||
user_role = UserRole(user_info["role"])
|
||||
allowed, reason = PermissionChecker.check_endpoint_access(
|
||||
user_role, request.method, path
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"Access denied to endpoint",
|
||||
username=user_info["username"],
|
||||
role=user_info["role"],
|
||||
endpoint=f"{request.method} {path}",
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "Access denied",
|
||||
"message": reason,
|
||||
"required_role": "Higher privileges required",
|
||||
},
|
||||
)
|
||||
|
||||
request.state.user_token = user_token
|
||||
request.state.user_info = user_info
|
||||
request.state.auth_method = auth_method
|
||||
|
||||
logger.debug(
|
||||
"Authentication and authorization successful",
|
||||
username=user_info["username"],
|
||||
role=user_info["role"],
|
||||
auth_method=auth_method,
|
||||
endpoint=f"{request.method} {path}",
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if hasattr(request.state, "user_info"):
|
||||
response.headers["X-User"] = request.state.user_info["username"]
|
||||
response.headers["X-User-Role"] = request.state.user_info["role"]
|
||||
response.headers["X-Auth-Method"] = request.state.auth_method
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error in auth middleware",
|
||||
error=str(e),
|
||||
path=path,
|
||||
method=request.method,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "Internal server error",
|
||||
"message": "Authentication system error",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get current user information from request.state.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object.
|
||||
|
||||
Returns:
|
||||
User information dictionary or None.
|
||||
"""
|
||||
return getattr(request.state, "user_info", None)
|
||||
|
||||
|
||||
def get_current_user_token(request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get current user token from request.state.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object.
|
||||
|
||||
Returns:
|
||||
User token string or None.
|
||||
"""
|
||||
return getattr(request.state, "user_token", None)
|
||||
|
||||
|
||||
def require_role(required_role: UserRole) -> Callable:
|
||||
"""
|
||||
Decorator to check user role.
|
||||
|
||||
Args:
|
||||
required_role: Required user role.
|
||||
|
||||
Returns:
|
||||
Function decorator.
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any:
|
||||
user_info = get_current_user(request)
|
||||
if not user_info:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Authentication required"
|
||||
)
|
||||
|
||||
user_role = UserRole(user_info["role"])
|
||||
permission = f"role_{required_role.value}"
|
||||
if not PermissionChecker.check_permission(user_role, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role {required_role.value} required",
|
||||
)
|
||||
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_permission(permission: str) -> Callable:
|
||||
"""
|
||||
Decorator to check specific permission.
|
||||
|
||||
Args:
|
||||
permission: Required permission string.
|
||||
|
||||
Returns:
|
||||
Function decorator.
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any:
|
||||
user_info = get_current_user(request)
|
||||
if not user_info:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Authentication required"
|
||||
)
|
||||
|
||||
user_role = UserRole(user_info["role"])
|
||||
if not PermissionChecker.check_permission(user_role, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Permission '{permission}' required",
|
||||
)
|
||||
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def validate_connection_ownership(
|
||||
request: Request, connection_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check user permissions for connection management.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object.
|
||||
connection_id: Connection ID.
|
||||
|
||||
Returns:
|
||||
True if user can manage the connection, False otherwise.
|
||||
"""
|
||||
user_info = get_current_user(request)
|
||||
if not user_info:
|
||||
return False
|
||||
|
||||
user_role = UserRole(user_info["role"])
|
||||
|
||||
if PermissionChecker.can_delete_any_connection(user_role):
|
||||
return True
|
||||
|
||||
return True
|
||||
343
guacamole_test_11_26/api/core/models.py
Executable file
343
guacamole_test_11_26/api/core/models.py
Executable file
@ -0,0 +1,343 @@
|
||||
"""
|
||||
Pydantic models for authentication system.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""User roles in the system."""
|
||||
|
||||
GUEST = "GUEST"
|
||||
USER = "USER"
|
||||
ADMIN = "ADMIN"
|
||||
SUPER_ADMIN = "SUPER_ADMIN"
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Authentication request."""
|
||||
|
||||
username: str = Field(..., description="Username in Guacamole")
|
||||
password: str = Field(..., description="Password in Guacamole")
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Successful authentication response."""
|
||||
|
||||
access_token: str = Field(..., description="JWT access token")
|
||||
token_type: str = Field(default="bearer", description="Token type")
|
||||
expires_in: int = Field(..., description="Token lifetime in seconds")
|
||||
user_info: Dict[str, Any] = Field(..., description="User information")
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
"""User information."""
|
||||
|
||||
username: str = Field(..., description="Username")
|
||||
role: UserRole = Field(..., description="User role")
|
||||
permissions: List[str] = Field(
|
||||
default_factory=list, description="System permissions"
|
||||
)
|
||||
full_name: Optional[str] = Field(None, description="Full name")
|
||||
email: Optional[str] = Field(None, description="Email address")
|
||||
organization: Optional[str] = Field(None, description="Organization")
|
||||
organizational_role: Optional[str] = Field(None, description="Job title")
|
||||
|
||||
|
||||
class ConnectionRequest(BaseModel):
|
||||
"""Connection creation request.
|
||||
|
||||
Requires JWT token in Authorization header: Bearer <token>
|
||||
Get token via /auth/login
|
||||
"""
|
||||
|
||||
hostname: str = Field(..., description="IP address or hostname")
|
||||
protocol: str = Field(
|
||||
default="rdp", description="Connection protocol (rdp, vnc, ssh)"
|
||||
)
|
||||
username: Optional[str] = Field(
|
||||
None, description="Username for remote machine connection"
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
None,
|
||||
description="Encrypted password for remote machine connection (Base64 AES-256-GCM)",
|
||||
)
|
||||
port: Optional[int] = Field(
|
||||
None, description="Port (default used if not specified)"
|
||||
)
|
||||
ttl_minutes: Optional[int] = Field(
|
||||
default=60, description="Connection lifetime in minutes"
|
||||
)
|
||||
|
||||
enable_sftp: Optional[bool] = Field(
|
||||
default=True, description="Enable SFTP for SSH (file browser with drag'n'drop)"
|
||||
)
|
||||
sftp_root_directory: Optional[str] = Field(
|
||||
default="/", description="Root directory for SFTP (default: /)"
|
||||
)
|
||||
sftp_server_alive_interval: Optional[int] = Field(
|
||||
default=0, description="SFTP keep-alive interval in seconds (0 = disabled)"
|
||||
)
|
||||
|
||||
|
||||
class ConnectionResponse(BaseModel):
|
||||
"""Connection creation response."""
|
||||
|
||||
connection_id: str = Field(..., description="Created connection ID")
|
||||
connection_url: str = Field(..., description="URL to access connection")
|
||||
status: str = Field(..., description="Connection status")
|
||||
expires_at: str = Field(..., description="Connection expiration time")
|
||||
ttl_minutes: int = Field(..., description="TTL in minutes")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Token refresh request."""
|
||||
|
||||
refresh_token: str = Field(..., description="Refresh token")
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Logout request."""
|
||||
|
||||
token: str = Field(..., description="Token to revoke")
|
||||
|
||||
|
||||
class PermissionCheckRequest(BaseModel):
|
||||
"""Permission check request."""
|
||||
|
||||
action: str = Field(..., description="Action to check")
|
||||
resource: Optional[str] = Field(None, description="Resource (optional)")
|
||||
|
||||
|
||||
class PermissionCheckResponse(BaseModel):
|
||||
"""Permission check response."""
|
||||
|
||||
allowed: bool = Field(..., description="Whether action is allowed")
|
||||
reason: Optional[str] = Field(None, description="Denial reason (if applicable)")
|
||||
|
||||
|
||||
class SavedMachineCreate(BaseModel):
|
||||
"""Saved machine create/update request."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Machine name")
|
||||
hostname: str = Field(
|
||||
..., min_length=1, max_length=255, description="IP address or hostname"
|
||||
)
|
||||
port: int = Field(..., gt=0, lt=65536, description="Connection port")
|
||||
protocol: str = Field(
|
||||
..., description="Connection protocol (rdp, ssh, vnc, telnet)"
|
||||
)
|
||||
os: Optional[str] = Field(
|
||||
None,
|
||||
max_length=255,
|
||||
description="Operating system (e.g., Windows Server 2019, Ubuntu 22.04)",
|
||||
)
|
||||
description: Optional[str] = Field(None, description="Machine description")
|
||||
tags: Optional[List[str]] = Field(
|
||||
default_factory=list, description="Tags for grouping"
|
||||
)
|
||||
is_favorite: bool = Field(default=False, description="Favorite machine")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"name": "Production Web Server",
|
||||
"hostname": "192.168.1.100",
|
||||
"port": 3389,
|
||||
"protocol": "rdp",
|
||||
"os": "Windows Server 2019",
|
||||
"description": "Main production web server",
|
||||
"tags": ["production", "web"],
|
||||
"is_favorite": True,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SavedMachineUpdate(BaseModel):
|
||||
"""Saved machine partial update request."""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
hostname: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
port: Optional[int] = Field(None, gt=0, lt=65536)
|
||||
protocol: Optional[str] = None
|
||||
os: Optional[str] = Field(None, max_length=255, description="Operating system")
|
||||
description: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
is_favorite: Optional[bool] = None
|
||||
|
||||
|
||||
class SavedMachineResponse(BaseModel):
|
||||
"""Saved machine information response."""
|
||||
|
||||
id: str = Field(..., description="Machine UUID")
|
||||
user_id: str = Field(..., description="Owner user ID")
|
||||
name: str = Field(..., description="Machine name")
|
||||
hostname: str = Field(..., description="IP address or hostname")
|
||||
port: int = Field(..., description="Connection port")
|
||||
protocol: str = Field(..., description="Connection protocol")
|
||||
os: Optional[str] = Field(None, description="Operating system")
|
||||
description: Optional[str] = Field(None, description="Description")
|
||||
tags: List[str] = Field(default_factory=list, description="Tags")
|
||||
is_favorite: bool = Field(default=False, description="Favorite")
|
||||
created_at: str = Field(..., description="Creation date (ISO 8601)")
|
||||
updated_at: str = Field(..., description="Update date (ISO 8601)")
|
||||
last_connected_at: Optional[str] = Field(
|
||||
None, description="Last connection (ISO 8601)"
|
||||
)
|
||||
connection_stats: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Connection statistics"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SavedMachineList(BaseModel):
|
||||
"""Saved machines list."""
|
||||
|
||||
total: int = Field(..., description="Total number of machines")
|
||||
machines: List[SavedMachineResponse] = Field(..., description="List of machines")
|
||||
|
||||
|
||||
class ConnectionHistoryCreate(BaseModel):
|
||||
"""Connection history record creation request."""
|
||||
|
||||
machine_id: str = Field(..., description="Machine UUID")
|
||||
success: bool = Field(default=True, description="Successful connection")
|
||||
error_message: Optional[str] = Field(None, description="Error message")
|
||||
duration_seconds: Optional[int] = Field(
|
||||
None, description="Connection duration in seconds"
|
||||
)
|
||||
|
||||
|
||||
class ConnectionHistoryResponse(BaseModel):
|
||||
"""Connection history record information response."""
|
||||
|
||||
id: str = Field(..., description="Record UUID")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
machine_id: str = Field(..., description="Machine UUID")
|
||||
connected_at: str = Field(..., description="Connection time (ISO 8601)")
|
||||
disconnected_at: Optional[str] = Field(
|
||||
None, description="Disconnection time (ISO 8601)"
|
||||
)
|
||||
duration_seconds: Optional[int] = Field(None, description="Duration")
|
||||
success: bool = Field(..., description="Successful connection")
|
||||
error_message: Optional[str] = Field(None, description="Error message")
|
||||
client_ip: Optional[str] = Field(None, description="Client IP")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BulkHealthCheckRequest(BaseModel):
|
||||
"""Bulk machine availability check request."""
|
||||
|
||||
machine_ids: List[str] = Field(
|
||||
..., min_items=1, max_items=200, description="List of machine IDs to check"
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=5, ge=1, le=30, description="Timeout for each check in seconds"
|
||||
)
|
||||
check_port: bool = Field(default=True, description="Check connection port")
|
||||
|
||||
|
||||
class BulkHealthCheckResult(BaseModel):
|
||||
"""Single machine check result."""
|
||||
|
||||
machine_id: str = Field(..., description="Machine ID")
|
||||
machine_name: str = Field(..., description="Machine name")
|
||||
hostname: str = Field(..., description="Hostname/IP")
|
||||
status: str = Field(..., description="success, failed, timeout")
|
||||
available: bool = Field(..., description="Machine is available")
|
||||
response_time_ms: Optional[int] = Field(
|
||||
None, description="Response time in milliseconds"
|
||||
)
|
||||
error: Optional[str] = Field(None, description="Error message")
|
||||
checked_at: str = Field(..., description="Check time (ISO 8601)")
|
||||
|
||||
|
||||
class BulkHealthCheckResponse(BaseModel):
|
||||
"""Bulk availability check response."""
|
||||
|
||||
total: int = Field(..., description="Total number of machines")
|
||||
success: int = Field(..., description="Number of successful checks")
|
||||
failed: int = Field(..., description="Number of failed checks")
|
||||
available: int = Field(..., description="Number of available machines")
|
||||
unavailable: int = Field(..., description="Number of unavailable machines")
|
||||
results: List[BulkHealthCheckResult] = Field(..., description="Detailed results")
|
||||
execution_time_ms: int = Field(
|
||||
..., description="Total execution time in milliseconds"
|
||||
)
|
||||
started_at: str = Field(..., description="Start time (ISO 8601)")
|
||||
completed_at: str = Field(..., description="Completion time (ISO 8601)")
|
||||
|
||||
|
||||
class SSHCredentials(BaseModel):
|
||||
"""SSH credentials for machine."""
|
||||
|
||||
username: str = Field(..., min_length=1, max_length=255, description="SSH username")
|
||||
password: str = Field(
|
||||
..., min_length=1, description="SSH password (will be encrypted in transit)"
|
||||
)
|
||||
|
||||
|
||||
class BulkSSHCommandRequest(BaseModel):
|
||||
"""Bulk SSH command execution request."""
|
||||
|
||||
machine_ids: List[str] = Field(
|
||||
..., min_items=1, max_items=100, description="List of machine IDs"
|
||||
)
|
||||
machine_hostnames: Optional[Dict[str, str]] = Field(
|
||||
None,
|
||||
description="Optional hostname/IP for non-saved machines {machine_id: hostname}",
|
||||
)
|
||||
command: str = Field(
|
||||
..., min_length=1, max_length=500, description="SSH command to execute"
|
||||
)
|
||||
credentials_mode: str = Field(
|
||||
..., description="Credentials mode: 'global' (same for all), 'custom' (per-machine)"
|
||||
)
|
||||
global_credentials: Optional[SSHCredentials] = Field(
|
||||
None, description="Shared credentials for all machines (mode 'global')"
|
||||
)
|
||||
machine_credentials: Optional[Dict[str, SSHCredentials]] = Field(
|
||||
None, description="Individual credentials (mode 'custom')"
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=30, ge=5, le=300, description="Command execution timeout (seconds)"
|
||||
)
|
||||
|
||||
|
||||
class BulkSSHCommandResult(BaseModel):
|
||||
"""Single machine SSH command execution result."""
|
||||
|
||||
machine_id: str = Field(..., description="Machine ID")
|
||||
machine_name: str = Field(..., description="Machine name")
|
||||
hostname: str = Field(..., description="Hostname/IP")
|
||||
status: str = Field(..., description="success, failed, timeout, no_credentials")
|
||||
exit_code: Optional[int] = Field(None, description="Command exit code")
|
||||
stdout: Optional[str] = Field(None, description="Stdout output")
|
||||
stderr: Optional[str] = Field(None, description="Stderr output")
|
||||
error: Optional[str] = Field(None, description="Error message")
|
||||
execution_time_ms: Optional[int] = Field(
|
||||
None, description="Execution time in milliseconds"
|
||||
)
|
||||
executed_at: str = Field(..., description="Execution time (ISO 8601)")
|
||||
|
||||
|
||||
class BulkSSHCommandResponse(BaseModel):
|
||||
"""Bulk SSH command execution response."""
|
||||
|
||||
total: int = Field(..., description="Total number of machines")
|
||||
success: int = Field(..., description="Number of successful executions")
|
||||
failed: int = Field(..., description="Number of failed executions")
|
||||
results: List[BulkSSHCommandResult] = Field(..., description="Detailed results")
|
||||
execution_time_ms: int = Field(
|
||||
..., description="Total execution time in milliseconds"
|
||||
)
|
||||
command: str = Field(..., description="Executed command")
|
||||
started_at: str = Field(..., description="Start time (ISO 8601)")
|
||||
completed_at: str = Field(..., description="Completion time (ISO 8601)")
|
||||
292
guacamole_test_11_26/api/core/permissions.py
Executable file
292
guacamole_test_11_26/api/core/permissions.py
Executable file
@ -0,0 +1,292 @@
|
||||
"""Permission and role system for Remote Access API."""
|
||||
|
||||
from typing import Dict, FrozenSet, List, Optional, Tuple
|
||||
|
||||
import structlog
|
||||
|
||||
from .models import UserRole
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class PermissionChecker:
|
||||
"""User permission checker class."""
|
||||
|
||||
ROLE_MAPPING: Dict[str, UserRole] = {
|
||||
"ADMINISTER": UserRole.SUPER_ADMIN,
|
||||
"CREATE_USER": UserRole.ADMIN,
|
||||
"CREATE_CONNECTION": UserRole.USER,
|
||||
}
|
||||
|
||||
ROLE_PERMISSIONS: Dict[UserRole, FrozenSet[str]] = {
|
||||
UserRole.GUEST: frozenset({
|
||||
"view_own_connections",
|
||||
"view_own_profile"
|
||||
}),
|
||||
UserRole.USER: frozenset({
|
||||
"view_own_connections",
|
||||
"view_own_profile",
|
||||
"create_connections",
|
||||
"delete_own_connections",
|
||||
}),
|
||||
UserRole.ADMIN: frozenset({
|
||||
"view_own_connections",
|
||||
"view_own_profile",
|
||||
"create_connections",
|
||||
"delete_own_connections",
|
||||
"view_all_connections",
|
||||
"delete_any_connection",
|
||||
"view_system_stats",
|
||||
"view_system_metrics",
|
||||
}),
|
||||
UserRole.SUPER_ADMIN: frozenset({
|
||||
"view_own_connections",
|
||||
"view_own_profile",
|
||||
"create_connections",
|
||||
"delete_own_connections",
|
||||
"view_all_connections",
|
||||
"delete_any_connection",
|
||||
"view_system_stats",
|
||||
"view_system_metrics",
|
||||
"reset_system_stats",
|
||||
"manage_users",
|
||||
"view_system_logs",
|
||||
"change_system_config",
|
||||
}),
|
||||
}
|
||||
|
||||
ENDPOINT_PERMISSIONS = {
|
||||
"POST /connect": "create_connections",
|
||||
"GET /connections": "view_own_connections",
|
||||
"DELETE /connections": "delete_own_connections",
|
||||
"GET /stats": "view_system_stats",
|
||||
"GET /metrics": "view_system_metrics",
|
||||
"POST /stats/reset": "reset_system_stats",
|
||||
"GET /auth/profile": "view_own_profile",
|
||||
"GET /auth/permissions": "view_own_profile",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def determine_role_from_permissions(cls, guacamole_permissions: List[str]) -> UserRole:
|
||||
"""
|
||||
Determine user role based on Guacamole system permissions.
|
||||
|
||||
Args:
|
||||
guacamole_permissions: List of system permissions from Guacamole.
|
||||
|
||||
Returns:
|
||||
User role.
|
||||
"""
|
||||
for permission, role in cls.ROLE_MAPPING.items():
|
||||
if permission in guacamole_permissions:
|
||||
logger.debug(
|
||||
"Role determined from permission",
|
||||
permission=permission,
|
||||
role=role.value,
|
||||
all_permissions=guacamole_permissions,
|
||||
)
|
||||
return role
|
||||
|
||||
logger.debug(
|
||||
"No system permissions found, assigning GUEST role",
|
||||
permissions=guacamole_permissions,
|
||||
)
|
||||
return UserRole.GUEST
|
||||
|
||||
@classmethod
|
||||
def get_role_permissions(cls, role: UserRole) -> FrozenSet[str]:
|
||||
"""
|
||||
Get all permissions for a role.
|
||||
|
||||
Args:
|
||||
role: User role.
|
||||
|
||||
Returns:
|
||||
Frozen set of permissions.
|
||||
"""
|
||||
return cls.ROLE_PERMISSIONS.get(role, frozenset())
|
||||
|
||||
@classmethod
|
||||
def check_permission(cls, user_role: UserRole, permission: str) -> bool:
|
||||
"""
|
||||
Check if role has specific permission.
|
||||
|
||||
Args:
|
||||
user_role: User role.
|
||||
permission: Permission to check.
|
||||
|
||||
Returns:
|
||||
True if permission exists, False otherwise.
|
||||
"""
|
||||
role_permissions = cls.get_role_permissions(user_role)
|
||||
has_permission = permission in role_permissions
|
||||
|
||||
logger.debug(
|
||||
"Permission check",
|
||||
role=user_role.value,
|
||||
permission=permission,
|
||||
allowed=has_permission,
|
||||
)
|
||||
|
||||
return has_permission
|
||||
|
||||
@classmethod
|
||||
def check_endpoint_access(
|
||||
cls, user_role: UserRole, method: str, path: str
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check endpoint access.
|
||||
|
||||
Args:
|
||||
user_role: User role.
|
||||
method: HTTP method (GET, POST, DELETE, etc.).
|
||||
path: Endpoint path.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, reason: Optional[str]).
|
||||
"""
|
||||
endpoint_key = f"{method} {path}"
|
||||
|
||||
required_permission = cls.ENDPOINT_PERMISSIONS.get(endpoint_key)
|
||||
|
||||
if not required_permission:
|
||||
for pattern, permission in cls.ENDPOINT_PERMISSIONS.items():
|
||||
if cls._match_endpoint_pattern(endpoint_key, pattern):
|
||||
required_permission = permission
|
||||
break
|
||||
|
||||
if not required_permission:
|
||||
return True, None
|
||||
|
||||
has_permission = cls.check_permission(user_role, required_permission)
|
||||
|
||||
if not has_permission:
|
||||
reason = (
|
||||
f"Required permission '{required_permission}' "
|
||||
f"not granted to role '{user_role.value}'"
|
||||
)
|
||||
logger.info(
|
||||
"Endpoint access denied",
|
||||
role=user_role.value,
|
||||
endpoint=endpoint_key,
|
||||
required_permission=required_permission,
|
||||
reason=reason,
|
||||
)
|
||||
return False, reason
|
||||
|
||||
logger.debug(
|
||||
"Endpoint access granted",
|
||||
role=user_role.value,
|
||||
endpoint=endpoint_key,
|
||||
required_permission=required_permission,
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def _match_endpoint_pattern(cls, endpoint: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if endpoint matches pattern.
|
||||
|
||||
Args:
|
||||
endpoint: Endpoint to check (e.g., "DELETE /connections/123").
|
||||
pattern: Pattern (e.g., "DELETE /connections").
|
||||
|
||||
Returns:
|
||||
True if matches.
|
||||
"""
|
||||
if pattern.endswith("/connections"):
|
||||
base_pattern = pattern.replace("/connections", "/connections/")
|
||||
return endpoint.startswith(base_pattern)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_connection_ownership(
|
||||
cls, user_role: UserRole, username: str, connection_owner: str
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check connection management rights.
|
||||
|
||||
Args:
|
||||
user_role: User role.
|
||||
username: Username.
|
||||
connection_owner: Connection owner.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, reason: Optional[str]).
|
||||
"""
|
||||
if user_role in (UserRole.ADMIN, UserRole.SUPER_ADMIN):
|
||||
return True, None
|
||||
|
||||
if username == connection_owner:
|
||||
return True, None
|
||||
|
||||
reason = (
|
||||
f"User '{username}' cannot manage connection owned by '{connection_owner}'"
|
||||
)
|
||||
logger.info(
|
||||
"Connection ownership check failed",
|
||||
user=username,
|
||||
owner=connection_owner,
|
||||
role=user_role.value,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
return False, reason
|
||||
|
||||
@classmethod
|
||||
def can_view_all_connections(cls, user_role: UserRole) -> bool:
|
||||
"""
|
||||
Check if user can view all connections.
|
||||
|
||||
Args:
|
||||
user_role: User role.
|
||||
|
||||
Returns:
|
||||
True if can view all connections.
|
||||
"""
|
||||
return cls.check_permission(user_role, "view_all_connections")
|
||||
|
||||
@classmethod
|
||||
def can_delete_any_connection(cls, user_role: UserRole) -> bool:
|
||||
"""
|
||||
Check if user can delete any connection.
|
||||
|
||||
Args:
|
||||
user_role: User role.
|
||||
|
||||
Returns:
|
||||
True if can delete any connection.
|
||||
"""
|
||||
return cls.check_permission(user_role, "delete_any_connection")
|
||||
|
||||
@classmethod
|
||||
def get_user_permissions_list(cls, user_role: UserRole) -> List[str]:
|
||||
"""
|
||||
Get sorted list of user permissions for API response.
|
||||
|
||||
Args:
|
||||
user_role: User role.
|
||||
|
||||
Returns:
|
||||
Sorted list of permissions.
|
||||
"""
|
||||
permissions = cls.get_role_permissions(user_role)
|
||||
return sorted(permissions)
|
||||
|
||||
@classmethod
|
||||
def validate_role_hierarchy(
|
||||
cls, current_user_role: UserRole, target_user_role: UserRole
|
||||
) -> bool:
|
||||
"""
|
||||
Validate role hierarchy for user management.
|
||||
|
||||
Args:
|
||||
current_user_role: Current user role.
|
||||
target_user_role: Target user role.
|
||||
|
||||
Returns:
|
||||
True if current user can manage target user.
|
||||
"""
|
||||
return current_user_role == UserRole.SUPER_ADMIN
|
||||
259
guacamole_test_11_26/api/core/pki_certificate_verifier.py
Executable file
259
guacamole_test_11_26/api/core/pki_certificate_verifier.py
Executable file
@ -0,0 +1,259 @@
|
||||
"""
|
||||
Module for PKI/CA certificate handling for server key signature verification.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PKICertificateVerifier:
|
||||
"""PKI/CA certificate verifier for server key signature verification."""
|
||||
|
||||
def __init__(self, ca_cert_path: str, crl_urls: Optional[List[str]] = None) -> None:
|
||||
"""Initialize PKI certificate verifier.
|
||||
|
||||
Args:
|
||||
ca_cert_path: Path to CA certificate file.
|
||||
crl_urls: List of CRL URLs (optional).
|
||||
"""
|
||||
self.ca_cert_path = ca_cert_path
|
||||
self.crl_urls = crl_urls or []
|
||||
self.ca_cert = self._load_ca_certificate()
|
||||
self.cert_store = self._build_cert_store()
|
||||
|
||||
def _load_ca_certificate(self) -> x509.Certificate:
|
||||
"""Load CA certificate.
|
||||
|
||||
Returns:
|
||||
Loaded CA certificate.
|
||||
|
||||
Raises:
|
||||
Exception: If certificate cannot be loaded.
|
||||
"""
|
||||
try:
|
||||
with open(self.ca_cert_path, "rb") as f:
|
||||
ca_cert_data = f.read()
|
||||
return x509.load_pem_x509_certificate(ca_cert_data)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load CA certificate", extra={"error": str(e)})
|
||||
raise
|
||||
|
||||
def _build_cert_store(self) -> x509.CertificateStore:
|
||||
"""Build certificate store.
|
||||
|
||||
Returns:
|
||||
Certificate store with CA certificate.
|
||||
"""
|
||||
store = x509.CertificateStore()
|
||||
store.add_cert(self.ca_cert)
|
||||
return store
|
||||
|
||||
def verify_server_certificate(self, server_cert_pem: bytes) -> bool:
|
||||
"""
|
||||
Verify server certificate through PKI/CA.
|
||||
|
||||
Args:
|
||||
server_cert_pem: PEM-encoded server certificate.
|
||||
|
||||
Returns:
|
||||
True if certificate is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
server_cert = x509.load_pem_x509_certificate(server_cert_pem)
|
||||
|
||||
if not self._verify_certificate_chain(server_cert):
|
||||
logger.warning("Certificate chain verification failed")
|
||||
return False
|
||||
|
||||
if not self._check_certificate_revocation(server_cert):
|
||||
logger.warning("Certificate is revoked")
|
||||
return False
|
||||
|
||||
if not self._check_certificate_validity(server_cert):
|
||||
logger.warning("Certificate is expired or not yet valid")
|
||||
return False
|
||||
|
||||
logger.info("Server certificate verified successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Certificate verification error", extra={"error": str(e)})
|
||||
return False
|
||||
|
||||
def _verify_certificate_chain(self, server_cert: x509.Certificate) -> bool:
|
||||
"""Verify certificate chain.
|
||||
|
||||
Args:
|
||||
server_cert: Server certificate to verify.
|
||||
|
||||
Returns:
|
||||
True if chain is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
ca_public_key = self.ca_cert.public_key()
|
||||
ca_public_key.verify(
|
||||
server_cert.signature,
|
||||
server_cert.tbs_certificate_bytes,
|
||||
server_cert.signature_algorithm_oid,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Certificate chain verification failed", extra={"error": str(e)}
|
||||
)
|
||||
return False
|
||||
|
||||
def _check_certificate_revocation(self, server_cert: x509.Certificate) -> bool:
|
||||
"""Check certificate revocation via CRL.
|
||||
|
||||
Args:
|
||||
server_cert: Server certificate to check.
|
||||
|
||||
Returns:
|
||||
True if certificate is not revoked, False otherwise.
|
||||
"""
|
||||
try:
|
||||
crl_dps = server_cert.extensions.get_extension_for_oid(
|
||||
x509.ExtensionOID.CRL_DISTRIBUTION_POINTS
|
||||
).value
|
||||
|
||||
for crl_dp in crl_dps:
|
||||
for crl_url in crl_dp.full_name:
|
||||
if self._check_crl(server_cert, crl_url.value):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("CRL check failed", extra={"error": str(e)})
|
||||
return True
|
||||
|
||||
def _check_crl(self, server_cert: x509.Certificate, crl_url: str) -> bool:
|
||||
"""Check specific CRL for certificate revocation.
|
||||
|
||||
Args:
|
||||
server_cert: Server certificate to check.
|
||||
crl_url: CRL URL.
|
||||
|
||||
Returns:
|
||||
True if certificate is revoked, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(crl_url, timeout=10)
|
||||
if response.status_code == 200:
|
||||
crl_data = response.content
|
||||
crl = x509.load_der_x509_crl(crl_data)
|
||||
return server_cert.serial_number in [revoked.serial_number for revoked in crl]
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Failed to check CRL", extra={"crl_url": crl_url, "error": str(e)})
|
||||
return False
|
||||
|
||||
def _check_certificate_validity(self, server_cert: x509.Certificate) -> bool:
|
||||
"""Check certificate validity period.
|
||||
|
||||
Args:
|
||||
server_cert: Server certificate to check.
|
||||
|
||||
Returns:
|
||||
True if certificate is valid, False otherwise.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
return server_cert.not_valid_before <= now <= server_cert.not_valid_after
|
||||
|
||||
def extract_public_key_from_certificate(
|
||||
self, server_cert_pem: bytes
|
||||
) -> ed25519.Ed25519PublicKey:
|
||||
"""Extract public key from certificate.
|
||||
|
||||
Args:
|
||||
server_cert_pem: PEM-encoded server certificate.
|
||||
|
||||
Returns:
|
||||
Extracted Ed25519 public key.
|
||||
"""
|
||||
server_cert = x509.load_pem_x509_certificate(server_cert_pem)
|
||||
public_key = server_cert.public_key()
|
||||
if not isinstance(public_key, ed25519.Ed25519PublicKey):
|
||||
raise ValueError("Certificate does not contain Ed25519 public key")
|
||||
return public_key
|
||||
|
||||
class ServerCertificateManager:
|
||||
"""Server certificate manager."""
|
||||
|
||||
def __init__(self, pki_verifier: PKICertificateVerifier) -> None:
|
||||
"""Initialize server certificate manager.
|
||||
|
||||
Args:
|
||||
pki_verifier: PKI certificate verifier instance.
|
||||
"""
|
||||
self.pki_verifier = pki_verifier
|
||||
self.server_certificates: Dict[str, bytes] = {}
|
||||
|
||||
def get_server_certificate(self, server_hostname: str) -> Optional[bytes]:
|
||||
"""Get server certificate via TLS handshake.
|
||||
|
||||
Args:
|
||||
server_hostname: Server hostname.
|
||||
|
||||
Returns:
|
||||
PEM-encoded server certificate or None if failed.
|
||||
"""
|
||||
try:
|
||||
context = ssl.create_default_context()
|
||||
context.check_hostname = True
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
|
||||
with ssl.create_connection((server_hostname, 443)) as sock:
|
||||
with context.wrap_socket(sock, server_hostname=server_hostname) as ssock:
|
||||
cert_der = ssock.getpeercert_chain()[0]
|
||||
cert_pem = ssl.DER_cert_to_PEM_cert(cert_der)
|
||||
cert_bytes = cert_pem.encode()
|
||||
|
||||
if self.pki_verifier.verify_server_certificate(cert_bytes):
|
||||
self.server_certificates[server_hostname] = cert_bytes
|
||||
return cert_bytes
|
||||
else:
|
||||
logger.error(
|
||||
"Server certificate verification failed",
|
||||
extra={"server_hostname": server_hostname},
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get server certificate",
|
||||
extra={"server_hostname": server_hostname, "error": str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
def get_trusted_public_key(
|
||||
self, server_hostname: str
|
||||
) -> Optional[ed25519.Ed25519PublicKey]:
|
||||
"""Get trusted public key from server certificate.
|
||||
|
||||
Args:
|
||||
server_hostname: Server hostname.
|
||||
|
||||
Returns:
|
||||
Ed25519 public key or None if failed.
|
||||
"""
|
||||
cert_pem = self.get_server_certificate(server_hostname)
|
||||
if cert_pem:
|
||||
return self.pki_verifier.extract_public_key_from_certificate(cert_pem)
|
||||
return None
|
||||
|
||||
|
||||
pki_verifier = PKICertificateVerifier(
|
||||
ca_cert_path="/etc/ssl/certs/ca-certificates.crt",
|
||||
crl_urls=["http://crl.example.com/crl.pem"],
|
||||
)
|
||||
|
||||
certificate_manager = ServerCertificateManager(pki_verifier)
|
||||
342
guacamole_test_11_26/api/core/rate_limiter.py
Executable file
342
guacamole_test_11_26/api/core/rate_limiter.py
Executable file
@ -0,0 +1,342 @@
|
||||
"""Redis-based thread-safe rate limiting."""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import redis
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Redis connection constants
|
||||
REDIS_DEFAULT_HOST = "localhost"
|
||||
REDIS_DEFAULT_PORT = "6379"
|
||||
REDIS_DEFAULT_DB = "0"
|
||||
REDIS_SOCKET_TIMEOUT = 5
|
||||
|
||||
# Rate limiting constants
|
||||
DEFAULT_RATE_LIMIT_REQUESTS = 10
|
||||
DEFAULT_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
FAILED_LOGIN_RETENTION_SECONDS = 3600 # 1 hour
|
||||
LOGIN_RATE_LIMIT_REQUESTS = 5
|
||||
LOGIN_RATE_LIMIT_WINDOW_SECONDS = 900 # 15 minutes
|
||||
DEFAULT_FAILED_LOGIN_WINDOW_MINUTES = 60
|
||||
SECONDS_PER_MINUTE = 60
|
||||
|
||||
# Redis key prefixes
|
||||
RATE_LIMIT_KEY_PREFIX = "rate_limit:"
|
||||
FAILED_LOGINS_IP_PREFIX = "failed_logins:ip:"
|
||||
FAILED_LOGINS_USER_PREFIX = "failed_logins:user:"
|
||||
LOGIN_LIMIT_PREFIX = "login_limit:"
|
||||
|
||||
# Rate limit headers
|
||||
HEADER_RATE_LIMIT = "X-RateLimit-Limit"
|
||||
HEADER_RATE_LIMIT_WINDOW = "X-RateLimit-Window"
|
||||
HEADER_RATE_LIMIT_USED = "X-RateLimit-Used"
|
||||
HEADER_RATE_LIMIT_REMAINING = "X-RateLimit-Remaining"
|
||||
HEADER_RATE_LIMIT_RESET = "X-RateLimit-Reset"
|
||||
HEADER_RATE_LIMIT_STATUS = "X-RateLimit-Status"
|
||||
|
||||
|
||||
class RedisRateLimiter:
|
||||
"""Thread-safe Redis-based rate limiter with sliding window algorithm."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize Redis rate limiter."""
|
||||
self.redis_client = redis.Redis(
|
||||
host=os.getenv("REDIS_HOST", REDIS_DEFAULT_HOST),
|
||||
port=int(os.getenv("REDIS_PORT", REDIS_DEFAULT_PORT)),
|
||||
password=os.getenv("REDIS_PASSWORD"),
|
||||
db=int(os.getenv("REDIS_DB", REDIS_DEFAULT_DB)),
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client.ping()
|
||||
logger.info("Rate limiter Redis connection established")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error("Failed to connect to Redis for rate limiting", error=str(e))
|
||||
raise
|
||||
|
||||
def check_rate_limit(
|
||||
self,
|
||||
client_ip: str,
|
||||
requests_limit: int = DEFAULT_RATE_LIMIT_REQUESTS,
|
||||
window_seconds: int = DEFAULT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
) -> Tuple[bool, Dict[str, int]]:
|
||||
"""
|
||||
Check rate limit using sliding window algorithm.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
requests_limit: Maximum number of requests.
|
||||
window_seconds: Time window in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, headers: Dict[str, int]).
|
||||
"""
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
window_start = current_time - window_seconds
|
||||
|
||||
key = f"{RATE_LIMIT_KEY_PREFIX}{client_ip}"
|
||||
|
||||
lua_script = """
|
||||
local key = KEYS[1]
|
||||
local window_start = tonumber(ARGV[1])
|
||||
local current_time = tonumber(ARGV[2])
|
||||
local requests_limit = tonumber(ARGV[3])
|
||||
local window_seconds = tonumber(ARGV[4])
|
||||
|
||||
-- Remove old entries (outside window)
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
|
||||
|
||||
-- Count current requests
|
||||
local current_requests = redis.call('ZCARD', key)
|
||||
|
||||
-- Check limit
|
||||
if current_requests >= requests_limit then
|
||||
-- Return blocking information
|
||||
local oldest_request = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
|
||||
local reset_time = oldest_request[2] + window_seconds
|
||||
return {0, current_requests, reset_time}
|
||||
else
|
||||
-- Add current request
|
||||
redis.call('ZADD', key, current_time, current_time)
|
||||
redis.call('EXPIRE', key, window_seconds)
|
||||
|
||||
-- Count updated requests
|
||||
local new_count = redis.call('ZCARD', key)
|
||||
return {1, new_count, 0}
|
||||
end
|
||||
"""
|
||||
|
||||
result = self.redis_client.eval(
|
||||
lua_script, 1, key, window_start, current_time, requests_limit, window_seconds
|
||||
)
|
||||
|
||||
allowed = bool(result[0])
|
||||
current_requests = result[1]
|
||||
reset_time = result[2] if result[2] > 0 else 0
|
||||
|
||||
headers = {
|
||||
HEADER_RATE_LIMIT: requests_limit,
|
||||
HEADER_RATE_LIMIT_WINDOW: window_seconds,
|
||||
HEADER_RATE_LIMIT_USED: current_requests,
|
||||
HEADER_RATE_LIMIT_REMAINING: max(0, requests_limit - current_requests),
|
||||
}
|
||||
|
||||
if reset_time > 0:
|
||||
headers[HEADER_RATE_LIMIT_RESET] = reset_time
|
||||
|
||||
if allowed:
|
||||
logger.debug(
|
||||
"Rate limit check passed",
|
||||
client_ip=client_ip,
|
||||
current_requests=current_requests,
|
||||
limit=requests_limit,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limit exceeded",
|
||||
client_ip=client_ip,
|
||||
current_requests=current_requests,
|
||||
limit=requests_limit,
|
||||
reset_time=reset_time,
|
||||
)
|
||||
|
||||
return allowed, headers
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error(
|
||||
"Redis error during rate limit check", client_ip=client_ip, error=str(e)
|
||||
)
|
||||
return True, {
|
||||
HEADER_RATE_LIMIT: requests_limit,
|
||||
HEADER_RATE_LIMIT_WINDOW: window_seconds,
|
||||
HEADER_RATE_LIMIT_USED: 0,
|
||||
HEADER_RATE_LIMIT_REMAINING: requests_limit,
|
||||
HEADER_RATE_LIMIT_STATUS: "redis_error",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error during rate limit check", client_ip=client_ip, error=str(e)
|
||||
)
|
||||
return True, {
|
||||
HEADER_RATE_LIMIT: requests_limit,
|
||||
HEADER_RATE_LIMIT_WINDOW: window_seconds,
|
||||
HEADER_RATE_LIMIT_USED: 0,
|
||||
HEADER_RATE_LIMIT_REMAINING: requests_limit,
|
||||
HEADER_RATE_LIMIT_STATUS: "error",
|
||||
}
|
||||
|
||||
def check_login_rate_limit(
|
||||
self, client_ip: str, username: Optional[str] = None
|
||||
) -> Tuple[bool, Dict[str, int]]:
|
||||
"""
|
||||
Special rate limit for login endpoint.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username (optional).
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, headers: Dict[str, int]).
|
||||
"""
|
||||
allowed, headers = self.check_rate_limit(
|
||||
client_ip, LOGIN_RATE_LIMIT_REQUESTS, LOGIN_RATE_LIMIT_WINDOW_SECONDS
|
||||
)
|
||||
|
||||
if username and allowed:
|
||||
user_key = f"{LOGIN_LIMIT_PREFIX}{username}"
|
||||
user_allowed, user_headers = self.check_rate_limit(
|
||||
user_key, LOGIN_RATE_LIMIT_REQUESTS, LOGIN_RATE_LIMIT_WINDOW_SECONDS
|
||||
)
|
||||
|
||||
if not user_allowed:
|
||||
logger.warning(
|
||||
"Login rate limit exceeded for user",
|
||||
username=username,
|
||||
client_ip=client_ip,
|
||||
)
|
||||
return False, user_headers
|
||||
|
||||
return allowed, headers
|
||||
|
||||
def record_failed_login(self, client_ip: str, username: str) -> None:
|
||||
"""
|
||||
Record failed login attempt for brute-force attack tracking.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username.
|
||||
"""
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
|
||||
ip_key = f"{FAILED_LOGINS_IP_PREFIX}{client_ip}"
|
||||
self.redis_client.zadd(ip_key, {current_time: current_time})
|
||||
self.redis_client.expire(ip_key, FAILED_LOGIN_RETENTION_SECONDS)
|
||||
|
||||
user_key = f"{FAILED_LOGINS_USER_PREFIX}{username}"
|
||||
self.redis_client.zadd(user_key, {current_time: current_time})
|
||||
self.redis_client.expire(user_key, FAILED_LOGIN_RETENTION_SECONDS)
|
||||
|
||||
logger.debug("Failed login recorded", client_ip=client_ip, username=username)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to record failed login attempt",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def get_failed_login_count(
|
||||
self,
|
||||
client_ip: str,
|
||||
username: Optional[str] = None,
|
||||
window_minutes: int = DEFAULT_FAILED_LOGIN_WINDOW_MINUTES,
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Get count of failed login attempts.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username (optional).
|
||||
window_minutes: Time window in minutes.
|
||||
|
||||
Returns:
|
||||
Dictionary with failed login counts.
|
||||
"""
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
window_start = current_time - (window_minutes * SECONDS_PER_MINUTE)
|
||||
|
||||
result = {"ip_failed_count": 0, "user_failed_count": 0}
|
||||
|
||||
ip_key = f"{FAILED_LOGINS_IP_PREFIX}{client_ip}"
|
||||
ip_count = self.redis_client.zcount(ip_key, window_start, current_time)
|
||||
result["ip_failed_count"] = ip_count
|
||||
|
||||
if username:
|
||||
user_key = f"{FAILED_LOGINS_USER_PREFIX}{username}"
|
||||
user_count = self.redis_client.zcount(user_key, window_start, current_time)
|
||||
result["user_failed_count"] = user_count
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get failed login count",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
return {"ip_failed_count": 0, "user_failed_count": 0}
|
||||
|
||||
def clear_failed_logins(
|
||||
self, client_ip: str, username: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Clear failed login attempt records.
|
||||
|
||||
Args:
|
||||
client_ip: Client IP address.
|
||||
username: Username (optional).
|
||||
"""
|
||||
try:
|
||||
ip_key = f"{FAILED_LOGINS_IP_PREFIX}{client_ip}"
|
||||
self.redis_client.delete(ip_key)
|
||||
|
||||
if username:
|
||||
user_key = f"{FAILED_LOGINS_USER_PREFIX}{username}"
|
||||
self.redis_client.delete(user_key)
|
||||
|
||||
logger.debug(
|
||||
"Failed login records cleared", client_ip=client_ip, username=username
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to clear failed login records",
|
||||
client_ip=client_ip,
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def get_rate_limit_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get rate limiting statistics.
|
||||
|
||||
Returns:
|
||||
Rate limiting statistics dictionary.
|
||||
"""
|
||||
try:
|
||||
rate_limit_keys = self.redis_client.keys(f"{RATE_LIMIT_KEY_PREFIX}*")
|
||||
failed_login_keys = self.redis_client.keys(f"{FAILED_LOGINS_IP_PREFIX}*")
|
||||
|
||||
return {
|
||||
"active_rate_limits": len(rate_limit_keys),
|
||||
"failed_login_trackers": len(failed_login_keys),
|
||||
"redis_memory_usage": (
|
||||
self.redis_client.memory_usage(f"{RATE_LIMIT_KEY_PREFIX}*")
|
||||
+ self.redis_client.memory_usage(f"{FAILED_LOGINS_IP_PREFIX}*")
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get rate limit stats", error=str(e))
|
||||
return {
|
||||
"active_rate_limits": 0,
|
||||
"failed_login_trackers": 0,
|
||||
"redis_memory_usage": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
redis_rate_limiter = RedisRateLimiter()
|
||||
266
guacamole_test_11_26/api/core/redis_storage.py
Executable file
266
guacamole_test_11_26/api/core/redis_storage.py
Executable file
@ -0,0 +1,266 @@
|
||||
"""
|
||||
Redis Storage Helper for storing shared state in cluster.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class RedisConnectionStorage:
|
||||
"""
|
||||
Redis storage for active connections.
|
||||
|
||||
Supports cluster operation with automatic TTL.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize Redis connection storage."""
|
||||
self._redis_client = redis.Redis(
|
||||
host=os.getenv("REDIS_HOST", "redis"),
|
||||
port=int(os.getenv("REDIS_PORT", "6379")),
|
||||
password=os.getenv("REDIS_PASSWORD"),
|
||||
db=0,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
try:
|
||||
self._redis_client.ping()
|
||||
logger.info("Redis Connection Storage initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to Redis for connections", error=str(e))
|
||||
raise RuntimeError(f"Redis connection failed: {e}")
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_data: Dict[str, Any],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add connection to Redis.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID.
|
||||
connection_data: Connection data dictionary.
|
||||
ttl_seconds: TTL in seconds (None = no automatic expiration).
|
||||
"""
|
||||
try:
|
||||
redis_key = f"connection:active:{connection_id}"
|
||||
if ttl_seconds is not None:
|
||||
self._redis_client.setex(
|
||||
redis_key, ttl_seconds, json.dumps(connection_data)
|
||||
)
|
||||
logger.debug(
|
||||
"Connection added to Redis with TTL",
|
||||
connection_id=connection_id,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
else:
|
||||
self._redis_client.set(redis_key, json.dumps(connection_data))
|
||||
logger.debug(
|
||||
"Connection added to Redis without TTL",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to add connection to Redis",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
def get_connection(self, connection_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get connection from Redis.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID.
|
||||
|
||||
Returns:
|
||||
Connection data dictionary or None if not found.
|
||||
"""
|
||||
try:
|
||||
redis_key = f"connection:active:{connection_id}"
|
||||
conn_json = self._redis_client.get(redis_key)
|
||||
|
||||
if not conn_json:
|
||||
return None
|
||||
|
||||
return json.loads(conn_json)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get connection from Redis",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
def update_connection(
|
||||
self, connection_id: str, update_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update connection data.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID.
|
||||
update_data: Data to update (will be merged with existing data).
|
||||
"""
|
||||
try:
|
||||
redis_key = f"connection:active:{connection_id}"
|
||||
|
||||
conn_json = self._redis_client.get(redis_key)
|
||||
if not conn_json:
|
||||
logger.warning(
|
||||
"Cannot update non-existent connection",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
return
|
||||
|
||||
conn_data = json.loads(conn_json)
|
||||
conn_data.update(update_data)
|
||||
|
||||
ttl = self._redis_client.ttl(redis_key)
|
||||
if ttl > 0:
|
||||
self._redis_client.setex(redis_key, ttl, json.dumps(conn_data))
|
||||
logger.debug("Connection updated in Redis", connection_id=connection_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update connection in Redis",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def delete_connection(self, connection_id: str) -> bool:
|
||||
"""
|
||||
Delete connection from Redis.
|
||||
|
||||
Args:
|
||||
connection_id: Connection ID.
|
||||
|
||||
Returns:
|
||||
True if connection was deleted, False otherwise.
|
||||
"""
|
||||
try:
|
||||
redis_key = f"connection:active:{connection_id}"
|
||||
result = self._redis_client.delete(redis_key)
|
||||
|
||||
if result > 0:
|
||||
logger.debug("Connection deleted from Redis", connection_id=connection_id)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to delete connection from Redis",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def get_all_connections(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get all active connections.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping connection_id to connection_data.
|
||||
"""
|
||||
try:
|
||||
pattern = "connection:active:*"
|
||||
keys = list(self._redis_client.scan_iter(match=pattern, count=100))
|
||||
|
||||
connections = {}
|
||||
for key in keys:
|
||||
try:
|
||||
conn_id = key.replace("connection:active:", "")
|
||||
conn_json = self._redis_client.get(key)
|
||||
if conn_json:
|
||||
connections[conn_id] = json.loads(conn_json)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return connections
|
||||
except Exception as e:
|
||||
logger.error("Failed to get all connections from Redis", error=str(e))
|
||||
return {}
|
||||
|
||||
def get_user_connections(self, username: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all user connections.
|
||||
|
||||
Args:
|
||||
username: Username.
|
||||
|
||||
Returns:
|
||||
List of user connections.
|
||||
"""
|
||||
try:
|
||||
all_connections = self.get_all_connections()
|
||||
user_connections = [
|
||||
conn_data
|
||||
for conn_data in all_connections.values()
|
||||
if conn_data.get("owner_username") == username
|
||||
]
|
||||
|
||||
return user_connections
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get user connections from Redis",
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
return []
|
||||
|
||||
def cleanup_expired_connections(self) -> int:
|
||||
"""
|
||||
Cleanup expired connections.
|
||||
|
||||
Returns:
|
||||
Number of removed connections.
|
||||
"""
|
||||
try:
|
||||
pattern = "connection:active:*"
|
||||
keys = list(self._redis_client.scan_iter(match=pattern, count=100))
|
||||
|
||||
cleaned_count = 0
|
||||
for key in keys:
|
||||
ttl = self._redis_client.ttl(key)
|
||||
if ttl == -2:
|
||||
cleaned_count += 1
|
||||
elif ttl == -1:
|
||||
self._redis_client.delete(key)
|
||||
cleaned_count += 1
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(
|
||||
"Connections cleanup completed", cleaned_count=cleaned_count
|
||||
)
|
||||
|
||||
return cleaned_count
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired connections", error=str(e))
|
||||
return 0
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get connection statistics.
|
||||
|
||||
Returns:
|
||||
Connection statistics dictionary.
|
||||
"""
|
||||
try:
|
||||
pattern = "connection:active:*"
|
||||
keys = list(self._redis_client.scan_iter(match=pattern, count=100))
|
||||
|
||||
return {"total_connections": len(keys), "storage": "Redis"}
|
||||
except Exception as e:
|
||||
logger.error("Failed to get connection stats", error=str(e))
|
||||
return {"error": str(e), "storage": "Redis"}
|
||||
|
||||
|
||||
redis_connection_storage = RedisConnectionStorage()
|
||||
|
||||
172
guacamole_test_11_26/api/core/replay_protection.py
Executable file
172
guacamole_test_11_26/api/core/replay_protection.py
Executable file
@ -0,0 +1,172 @@
|
||||
"""
|
||||
Module for nonce management and replay attack prevention.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
NONCE_TTL_SECONDS = 300 # 5 minutes TTL for nonce
|
||||
TIMESTAMP_TOLERANCE_SECONDS = 30 # 30 seconds tolerance for timestamp
|
||||
|
||||
class NonceManager:
|
||||
"""Nonce manager for replay attack prevention."""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis) -> None:
|
||||
"""Initialize nonce manager.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance.
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.nonce_ttl = NONCE_TTL_SECONDS
|
||||
self.timestamp_tolerance = TIMESTAMP_TOLERANCE_SECONDS
|
||||
|
||||
def validate_nonce(
|
||||
self, client_nonce: bytes, timestamp: int, session_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate nonce uniqueness and timestamp validity.
|
||||
|
||||
Args:
|
||||
client_nonce: Nonce from client.
|
||||
timestamp: Timestamp from client.
|
||||
session_id: Session ID.
|
||||
|
||||
Returns:
|
||||
True if nonce is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if not self._validate_timestamp(timestamp):
|
||||
logger.warning(
|
||||
"Invalid timestamp",
|
||||
extra={"timestamp": timestamp, "session_id": session_id},
|
||||
)
|
||||
return False
|
||||
|
||||
nonce_key = self._create_nonce_key(client_nonce, session_id)
|
||||
nonce_hash = hashlib.sha256(client_nonce).hexdigest()[:16]
|
||||
|
||||
if self.redis.exists(nonce_key):
|
||||
logger.warning(
|
||||
"Nonce already used",
|
||||
extra={"session_id": session_id, "nonce_hash": nonce_hash},
|
||||
)
|
||||
return False
|
||||
|
||||
self.redis.setex(nonce_key, self.nonce_ttl, timestamp)
|
||||
|
||||
logger.info(
|
||||
"Nonce validated successfully",
|
||||
extra={"session_id": session_id, "nonce_hash": nonce_hash},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Nonce validation error",
|
||||
extra={"error": str(e), "session_id": session_id},
|
||||
)
|
||||
return False
|
||||
|
||||
def _validate_timestamp(self, timestamp: int) -> bool:
|
||||
"""Validate timestamp.
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp in milliseconds.
|
||||
|
||||
Returns:
|
||||
True if timestamp is within tolerance, False otherwise.
|
||||
"""
|
||||
current_time = int(time.time() * 1000)
|
||||
time_diff = abs(current_time - timestamp)
|
||||
return time_diff <= (self.timestamp_tolerance * 1000)
|
||||
|
||||
def _create_nonce_key(self, client_nonce: bytes, session_id: str) -> str:
|
||||
"""Create unique key for nonce in Redis.
|
||||
|
||||
Args:
|
||||
client_nonce: Nonce from client.
|
||||
session_id: Session ID.
|
||||
|
||||
Returns:
|
||||
Redis key string.
|
||||
"""
|
||||
nonce_hash = hashlib.sha256(client_nonce).hexdigest()
|
||||
return f"nonce:{session_id}:{nonce_hash}"
|
||||
|
||||
def cleanup_expired_nonces(self) -> int:
|
||||
"""
|
||||
Cleanup expired nonces.
|
||||
|
||||
Redis automatically removes keys by TTL, but this method provides
|
||||
additional cleanup for keys without TTL.
|
||||
|
||||
Returns:
|
||||
Number of expired nonces removed.
|
||||
"""
|
||||
try:
|
||||
pattern = "nonce:*"
|
||||
keys = self.redis.keys(pattern)
|
||||
|
||||
expired_count = 0
|
||||
for key in keys:
|
||||
ttl = self.redis.ttl(key)
|
||||
if ttl == -1:
|
||||
self.redis.delete(key)
|
||||
expired_count += 1
|
||||
|
||||
logger.info(
|
||||
"Nonce cleanup completed",
|
||||
extra={"expired_count": expired_count, "total_keys": len(keys)},
|
||||
)
|
||||
|
||||
return expired_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Nonce cleanup error", extra={"error": str(e)})
|
||||
return 0
|
||||
|
||||
class ReplayProtection:
|
||||
"""Replay attack protection."""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis) -> None:
|
||||
"""Initialize replay protection.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance.
|
||||
"""
|
||||
self.nonce_manager = NonceManager(redis_client)
|
||||
|
||||
def validate_request(
|
||||
self, client_nonce: bytes, timestamp: int, session_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate request for replay attacks.
|
||||
|
||||
Args:
|
||||
client_nonce: Nonce from client.
|
||||
timestamp: Timestamp from client.
|
||||
session_id: Session ID.
|
||||
|
||||
Returns:
|
||||
True if request is valid, False otherwise.
|
||||
"""
|
||||
return self.nonce_manager.validate_nonce(
|
||||
client_nonce, timestamp, session_id
|
||||
)
|
||||
|
||||
def cleanup(self) -> int:
|
||||
"""
|
||||
Cleanup expired nonces.
|
||||
|
||||
Returns:
|
||||
Number of expired nonces removed.
|
||||
"""
|
||||
return self.nonce_manager.cleanup_expired_nonces()
|
||||
401
guacamole_test_11_26/api/core/saved_machines_db.py
Executable file
401
guacamole_test_11_26/api/core/saved_machines_db.py
Executable file
@ -0,0 +1,401 @@
|
||||
"""
|
||||
Database operations for saved user machines.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import psycopg2
|
||||
import structlog
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from psycopg2.extensions import connection as Connection
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class SavedMachinesDB:
|
||||
"""PostgreSQL operations for saved machines."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize database configuration."""
|
||||
self.db_config = {
|
||||
"host": os.getenv("POSTGRES_HOST", "postgres"),
|
||||
"port": int(os.getenv("POSTGRES_PORT", "5432")),
|
||||
"database": os.getenv("POSTGRES_DB", "guacamole_db"),
|
||||
"user": os.getenv("POSTGRES_USER", "guacamole_user"),
|
||||
"password": os.getenv("POSTGRES_PASSWORD"),
|
||||
}
|
||||
|
||||
def _get_connection(self) -> Connection:
|
||||
"""Get database connection."""
|
||||
try:
|
||||
return psycopg2.connect(**self.db_config)
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to database", error=str(e))
|
||||
raise
|
||||
|
||||
def create_machine(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
hostname: str,
|
||||
port: int,
|
||||
protocol: str,
|
||||
os: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
is_favorite: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create new saved machine.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
name: Machine name.
|
||||
hostname: Machine hostname.
|
||||
port: Connection port.
|
||||
protocol: Connection protocol.
|
||||
os: Operating system (optional).
|
||||
description: Description (optional).
|
||||
tags: Tags list (optional).
|
||||
is_favorite: Whether machine is favorite.
|
||||
|
||||
Returns:
|
||||
Dictionary with created machine data including ID.
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
try:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
query = """
|
||||
INSERT INTO api.user_saved_machines
|
||||
(user_id, name, hostname, port, protocol, os,
|
||||
description, tags, is_favorite)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id, user_id, name, hostname, port, protocol, os,
|
||||
description, tags, is_favorite, created_at, updated_at,
|
||||
last_connected_at
|
||||
"""
|
||||
|
||||
cur.execute(
|
||||
query,
|
||||
(
|
||||
user_id,
|
||||
name,
|
||||
hostname,
|
||||
port,
|
||||
protocol,
|
||||
os,
|
||||
description,
|
||||
tags or [],
|
||||
is_favorite,
|
||||
),
|
||||
)
|
||||
|
||||
result = dict(cur.fetchone())
|
||||
conn.commit()
|
||||
|
||||
logger.info(
|
||||
"Saved machine created",
|
||||
machine_id=result["id"],
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(
|
||||
"Failed to create saved machine", error=str(e), user_id=user_id
|
||||
)
|
||||
raise
|
||||
|
||||
def get_user_machines(
|
||||
self, user_id: str, include_stats: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all user machines.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
include_stats: Include connection statistics.
|
||||
|
||||
Returns:
|
||||
List of machine dictionaries.
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
if include_stats:
|
||||
query = """
|
||||
SELECT
|
||||
m.*,
|
||||
json_build_object(
|
||||
'total_connections', COALESCE(COUNT(h.id), 0),
|
||||
'last_connection', MAX(h.connected_at),
|
||||
'successful_connections',
|
||||
COALESCE(SUM(CASE WHEN h.success = TRUE THEN 1 ELSE 0 END), 0),
|
||||
'failed_connections',
|
||||
COALESCE(SUM(CASE WHEN h.success = FALSE THEN 1 ELSE 0 END), 0)
|
||||
) as connection_stats
|
||||
FROM api.user_saved_machines m
|
||||
LEFT JOIN api.connection_history h ON m.id = h.machine_id
|
||||
WHERE m.user_id = %s
|
||||
GROUP BY m.id
|
||||
ORDER BY m.is_favorite DESC, m.updated_at DESC
|
||||
"""
|
||||
else:
|
||||
query = """
|
||||
SELECT * FROM api.user_saved_machines
|
||||
WHERE user_id = %s
|
||||
ORDER BY is_favorite DESC, updated_at DESC
|
||||
"""
|
||||
|
||||
cur.execute(query, (user_id,))
|
||||
results = [dict(row) for row in cur.fetchall()]
|
||||
|
||||
logger.debug(
|
||||
"Retrieved user machines", user_id=user_id, count=len(results)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def get_machine_by_id(
|
||||
self, machine_id: str, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get machine by ID with owner verification.
|
||||
|
||||
Args:
|
||||
machine_id: Machine UUID.
|
||||
user_id: User ID for permission check.
|
||||
|
||||
Returns:
|
||||
Machine dictionary or None if not found.
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
query = """
|
||||
SELECT * FROM api.user_saved_machines
|
||||
WHERE id = %s AND user_id = %s
|
||||
"""
|
||||
|
||||
cur.execute(query, (machine_id, user_id))
|
||||
result = cur.fetchone()
|
||||
|
||||
return dict(result) if result else None
|
||||
|
||||
def update_machine(
|
||||
self, machine_id: str, user_id: str, **updates: Any
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Update machine.
|
||||
|
||||
Args:
|
||||
machine_id: Machine UUID.
|
||||
user_id: User ID for permission check.
|
||||
**updates: Fields to update.
|
||||
|
||||
Returns:
|
||||
Updated machine dictionary or None if not found.
|
||||
"""
|
||||
allowed_fields = {
|
||||
"name",
|
||||
"hostname",
|
||||
"port",
|
||||
"protocol",
|
||||
"os",
|
||||
"description",
|
||||
"tags",
|
||||
"is_favorite",
|
||||
}
|
||||
|
||||
updates_filtered = {
|
||||
k: v for k, v in updates.items() if k in allowed_fields and v is not None
|
||||
}
|
||||
|
||||
if not updates_filtered:
|
||||
return self.get_machine_by_id(machine_id, user_id)
|
||||
|
||||
set_clause = ", ".join([f"{k} = %s" for k in updates_filtered.keys()])
|
||||
values = list(updates_filtered.values()) + [machine_id, user_id]
|
||||
|
||||
query = f"""
|
||||
UPDATE api.user_saved_machines
|
||||
SET {set_clause}
|
||||
WHERE id = %s AND user_id = %s
|
||||
RETURNING id, user_id, name, hostname, port, protocol, os,
|
||||
description, tags, is_favorite, created_at, updated_at,
|
||||
last_connected_at
|
||||
"""
|
||||
|
||||
with self._get_connection() as conn:
|
||||
try:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
cur.execute(query, values)
|
||||
result = cur.fetchone()
|
||||
conn.commit()
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"Saved machine updated",
|
||||
machine_id=machine_id,
|
||||
user_id=user_id,
|
||||
updated_fields=list(updates_filtered.keys()),
|
||||
)
|
||||
return dict(result)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(
|
||||
"Failed to update machine",
|
||||
error=str(e),
|
||||
machine_id=machine_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
def delete_machine(self, machine_id: str, user_id: str) -> bool:
|
||||
"""
|
||||
Delete machine.
|
||||
|
||||
Args:
|
||||
machine_id: Machine UUID.
|
||||
user_id: User ID for permission check.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
query = """
|
||||
DELETE FROM api.user_saved_machines
|
||||
WHERE id = %s AND user_id = %s
|
||||
"""
|
||||
|
||||
cur.execute(query, (machine_id, user_id))
|
||||
deleted_count = cur.rowcount
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(
|
||||
"Saved machine deleted",
|
||||
machine_id=machine_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Machine not found for deletion",
|
||||
machine_id=machine_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error(
|
||||
"Failed to delete machine",
|
||||
error=str(e),
|
||||
machine_id=machine_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
def update_last_connected(self, machine_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Update last connection time.
|
||||
|
||||
Args:
|
||||
machine_id: Machine UUID.
|
||||
user_id: User ID.
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
query = """
|
||||
UPDATE api.user_saved_machines
|
||||
SET last_connected_at = NOW()
|
||||
WHERE id = %s AND user_id = %s
|
||||
"""
|
||||
|
||||
cur.execute(query, (machine_id, user_id))
|
||||
conn.commit()
|
||||
|
||||
logger.debug(
|
||||
"Updated last_connected_at",
|
||||
machine_id=machine_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error("Failed to update last_connected", error=str(e))
|
||||
|
||||
def add_connection_history(
|
||||
self,
|
||||
user_id: str,
|
||||
machine_id: str,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
duration_seconds: Optional[int] = None,
|
||||
client_ip: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Add connection history record.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
machine_id: Machine ID.
|
||||
success: Whether connection was successful.
|
||||
error_message: Error message if failed (optional).
|
||||
duration_seconds: Connection duration in seconds (optional).
|
||||
client_ip: Client IP address (optional).
|
||||
|
||||
Returns:
|
||||
UUID of created record.
|
||||
"""
|
||||
with self._get_connection() as conn:
|
||||
try:
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
query = """
|
||||
INSERT INTO api.connection_history
|
||||
(user_id, machine_id, success, error_message, duration_seconds, client_ip)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
cur.execute(
|
||||
query,
|
||||
(
|
||||
user_id,
|
||||
machine_id,
|
||||
success,
|
||||
error_message,
|
||||
duration_seconds,
|
||||
client_ip,
|
||||
),
|
||||
)
|
||||
|
||||
result = cur.fetchone()
|
||||
conn.commit()
|
||||
|
||||
logger.info(
|
||||
"Connection history record created",
|
||||
machine_id=machine_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
)
|
||||
|
||||
return str(result["id"])
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logger.error("Failed to add connection history", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
saved_machines_db = SavedMachinesDB()
|
||||
|
||||
339
guacamole_test_11_26/api/core/session_storage.py
Executable file
339
guacamole_test_11_26/api/core/session_storage.py
Executable file
@ -0,0 +1,339 @@
|
||||
"""
|
||||
Redis-based session storage for Guacamole tokens.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class SessionStorage:
|
||||
"""Redis-based session storage for secure Guacamole token storage."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize Redis client and verify connection."""
|
||||
self.redis_client = redis.Redis(
|
||||
host=os.getenv("REDIS_HOST", "localhost"),
|
||||
port=int(os.getenv("REDIS_PORT", "6379")),
|
||||
password=os.getenv("REDIS_PASSWORD"),
|
||||
db=int(os.getenv("REDIS_DB", "0")),
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client.ping()
|
||||
logger.info("Redis connection established successfully")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error("Failed to connect to Redis", error=str(e))
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
user_info: Dict[str, Any],
|
||||
guac_token: str,
|
||||
expires_in_minutes: int = 60,
|
||||
) -> str:
|
||||
"""
|
||||
Create new session.
|
||||
|
||||
Args:
|
||||
user_info: User information dictionary.
|
||||
guac_token: Guacamole authentication token.
|
||||
expires_in_minutes: Session lifetime in minutes.
|
||||
|
||||
Returns:
|
||||
Unique session ID.
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
session_data = {
|
||||
"session_id": session_id,
|
||||
"user_info": user_info,
|
||||
"guac_token": guac_token,
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": (now + timedelta(minutes=expires_in_minutes)).isoformat(),
|
||||
"last_accessed": now.isoformat(),
|
||||
}
|
||||
|
||||
try:
|
||||
ttl_seconds = expires_in_minutes * 60
|
||||
self.redis_client.setex(
|
||||
f"session:{session_id}",
|
||||
ttl_seconds,
|
||||
json.dumps(session_data),
|
||||
)
|
||||
|
||||
self.redis_client.setex(
|
||||
f"user_session:{user_info['username']}",
|
||||
ttl_seconds,
|
||||
session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Session created successfully",
|
||||
session_id=session_id,
|
||||
username=user_info["username"],
|
||||
expires_in_minutes=expires_in_minutes,
|
||||
redis_key=f"session:{session_id}",
|
||||
has_guac_token=bool(guac_token),
|
||||
guac_token_length=len(guac_token) if guac_token else 0,
|
||||
)
|
||||
|
||||
return session_id
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error("Failed to create session", error=str(e))
|
||||
raise
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get session data.
|
||||
|
||||
Args:
|
||||
session_id: Session ID.
|
||||
|
||||
Returns:
|
||||
Session data or None if not found/expired.
|
||||
"""
|
||||
try:
|
||||
session_data = self.redis_client.get(f"session:{session_id}")
|
||||
|
||||
if not session_data:
|
||||
logger.debug("Session not found", session_id=session_id)
|
||||
return None
|
||||
|
||||
session = json.loads(session_data)
|
||||
session["last_accessed"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
ttl = self.redis_client.ttl(f"session:{session_id}")
|
||||
if ttl > 0:
|
||||
self.redis_client.setex(
|
||||
f"session:{session_id}",
|
||||
ttl,
|
||||
json.dumps(session),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Session retrieved successfully",
|
||||
session_id=session_id,
|
||||
username=session["user_info"]["username"],
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error("Failed to get session", session_id=session_id, error=str(e))
|
||||
return None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to decode session data", session_id=session_id, error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
def get_session_by_username(self, username: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get session by username.
|
||||
|
||||
Args:
|
||||
username: Username.
|
||||
|
||||
Returns:
|
||||
Session data or None.
|
||||
"""
|
||||
try:
|
||||
session_id = self.redis_client.get(f"user_session:{username}")
|
||||
|
||||
if not session_id:
|
||||
logger.debug("No active session for user", username=username)
|
||||
return None
|
||||
|
||||
return self.get_session(session_id)
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error(
|
||||
"Failed to get session by username", username=username, error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
def update_session(self, session_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update session data.
|
||||
|
||||
Args:
|
||||
session_id: Session ID.
|
||||
updates: Updates to apply.
|
||||
|
||||
Returns:
|
||||
True if update successful.
|
||||
"""
|
||||
try:
|
||||
session_data = self.redis_client.get(f"session:{session_id}")
|
||||
|
||||
if not session_data:
|
||||
logger.warning("Session not found for update", session_id=session_id)
|
||||
return False
|
||||
|
||||
session = json.loads(session_data)
|
||||
session.update(updates)
|
||||
session["last_accessed"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
ttl = self.redis_client.ttl(f"session:{session_id}")
|
||||
if ttl > 0:
|
||||
self.redis_client.setex(
|
||||
f"session:{session_id}",
|
||||
ttl,
|
||||
json.dumps(session),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Session updated successfully",
|
||||
session_id=session_id,
|
||||
updates=list(updates.keys()),
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning("Session expired during update", session_id=session_id)
|
||||
return False
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error("Failed to update session", session_id=session_id, error=str(e))
|
||||
return False
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to decode session data for update",
|
||||
session_id=session_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
Delete session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID.
|
||||
|
||||
Returns:
|
||||
True if deletion successful.
|
||||
"""
|
||||
try:
|
||||
session_data = self.redis_client.get(f"session:{session_id}")
|
||||
|
||||
if session_data:
|
||||
session = json.loads(session_data)
|
||||
username = session["user_info"]["username"]
|
||||
|
||||
self.redis_client.delete(f"session:{session_id}")
|
||||
self.redis_client.delete(f"user_session:{username}")
|
||||
|
||||
logger.info(
|
||||
"Session deleted successfully",
|
||||
session_id=session_id,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.debug("Session not found for deletion", session_id=session_id)
|
||||
return False
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error("Failed to delete session", session_id=session_id, error=str(e))
|
||||
return False
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to decode session data for deletion",
|
||||
session_id=session_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
def delete_user_sessions(self, username: str) -> int:
|
||||
"""
|
||||
Delete all user sessions.
|
||||
|
||||
Args:
|
||||
username: Username.
|
||||
|
||||
Returns:
|
||||
Number of deleted sessions.
|
||||
"""
|
||||
try:
|
||||
pattern = f"user_session:{username}"
|
||||
session_keys = self.redis_client.keys(pattern)
|
||||
|
||||
deleted_count = 0
|
||||
for key in session_keys:
|
||||
session_id = self.redis_client.get(key)
|
||||
if session_id and self.delete_session(session_id):
|
||||
deleted_count += 1
|
||||
|
||||
logger.info(
|
||||
"User sessions deleted", username=username, deleted_count=deleted_count
|
||||
)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error(
|
||||
"Failed to delete user sessions", username=username, error=str(e)
|
||||
)
|
||||
return 0
|
||||
|
||||
def cleanup_expired_sessions(self) -> int:
|
||||
"""
|
||||
Cleanup expired sessions.
|
||||
|
||||
Redis automatically removes keys by TTL, so this method is mainly
|
||||
for compatibility and potential logic extension.
|
||||
|
||||
Returns:
|
||||
Number of cleaned sessions (always 0, as Redis handles this automatically).
|
||||
"""
|
||||
logger.debug(
|
||||
"Expired sessions cleanup completed (Redis TTL handles this automatically)"
|
||||
)
|
||||
return 0
|
||||
|
||||
def get_session_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get session statistics.
|
||||
|
||||
Returns:
|
||||
Session statistics dictionary.
|
||||
"""
|
||||
try:
|
||||
session_keys = self.redis_client.keys("session:*")
|
||||
user_keys = self.redis_client.keys("user_session:*")
|
||||
|
||||
memory_usage = (
|
||||
self.redis_client.memory_usage("session:*") if session_keys else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"active_sessions": len(session_keys),
|
||||
"active_users": len(user_keys),
|
||||
"redis_memory_usage": memory_usage,
|
||||
}
|
||||
|
||||
except redis.RedisError as e:
|
||||
logger.error("Failed to get session stats", error=str(e))
|
||||
return {
|
||||
"active_sessions": 0,
|
||||
"active_users": 0,
|
||||
"redis_memory_usage": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
session_storage = SessionStorage()
|
||||
154
guacamole_test_11_26/api/core/signature_verifier.py
Executable file
154
guacamole_test_11_26/api/core/signature_verifier.py
Executable file
@ -0,0 +1,154 @@
|
||||
"""Module for verifying server key signatures with constant-time comparison."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ed25519 constants
|
||||
ED25519_SIGNATURE_LENGTH = 64
|
||||
ED25519_PUBLIC_KEY_LENGTH = 32
|
||||
DEFAULT_KEY_ID = "default"
|
||||
|
||||
|
||||
class SignatureVerifier:
|
||||
"""Signature verifier with constant-time comparison."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the signature verifier."""
|
||||
self.trusted_public_keys = self._load_trusted_keys()
|
||||
|
||||
def _load_trusted_keys(self) -> Dict[str, Optional[ed25519.Ed25519PublicKey]]:
|
||||
"""
|
||||
Load trusted public keys.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping key IDs to public keys.
|
||||
"""
|
||||
return {DEFAULT_KEY_ID: None}
|
||||
|
||||
def verify_server_key_signature(
|
||||
self,
|
||||
public_key_pem: bytes,
|
||||
signature: bytes,
|
||||
kid: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify server public key signature with constant-time comparison.
|
||||
|
||||
Args:
|
||||
public_key_pem: PEM-encoded public key.
|
||||
signature: Signature bytes.
|
||||
kid: Key ID for key selection (optional).
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if len(signature) != ED25519_SIGNATURE_LENGTH:
|
||||
logger.warning(
|
||||
"Invalid signature length",
|
||||
extra={
|
||||
"expected": ED25519_SIGNATURE_LENGTH,
|
||||
"actual": len(signature),
|
||||
"kid": kid,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
public_key = serialization.load_pem_public_key(public_key_pem)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to load PEM public key",
|
||||
extra={"error": str(e), "kid": kid},
|
||||
)
|
||||
return False
|
||||
|
||||
if not isinstance(public_key, ed25519.Ed25519PublicKey):
|
||||
logger.warning(
|
||||
"Public key is not Ed25519",
|
||||
extra={"kid": kid},
|
||||
)
|
||||
return False
|
||||
|
||||
raw_public_key = public_key.public_bytes_raw()
|
||||
if len(raw_public_key) != ED25519_PUBLIC_KEY_LENGTH:
|
||||
logger.warning(
|
||||
"Invalid public key length",
|
||||
extra={
|
||||
"expected": ED25519_PUBLIC_KEY_LENGTH,
|
||||
"actual": len(raw_public_key),
|
||||
"kid": kid,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
trusted_key = self._get_trusted_key(kid)
|
||||
if not trusted_key:
|
||||
logger.error("No trusted key found", extra={"kid": kid})
|
||||
return False
|
||||
|
||||
try:
|
||||
trusted_key.verify(signature, public_key_pem)
|
||||
logger.info("Signature verification successful", extra={"kid": kid})
|
||||
return True
|
||||
except InvalidSignature:
|
||||
logger.warning("Signature verification failed", extra={"kid": kid})
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Signature verification error",
|
||||
extra={"error": str(e), "kid": kid},
|
||||
)
|
||||
return False
|
||||
|
||||
def _get_trusted_key(
|
||||
self, kid: Optional[str] = None
|
||||
) -> Optional[ed25519.Ed25519PublicKey]:
|
||||
"""
|
||||
Get trusted public key by kid.
|
||||
|
||||
Args:
|
||||
kid: Key ID (optional).
|
||||
|
||||
Returns:
|
||||
Trusted public key or None if not found.
|
||||
"""
|
||||
key_id = kid if kid else DEFAULT_KEY_ID
|
||||
return self.trusted_public_keys.get(key_id)
|
||||
|
||||
def add_trusted_key(self, kid: str, public_key_pem: bytes) -> bool:
|
||||
"""
|
||||
Add trusted public key.
|
||||
|
||||
Args:
|
||||
kid: Key ID.
|
||||
public_key_pem: PEM-encoded public key.
|
||||
|
||||
Returns:
|
||||
True if key was added successfully.
|
||||
"""
|
||||
try:
|
||||
public_key = serialization.load_pem_public_key(public_key_pem)
|
||||
if not isinstance(public_key, ed25519.Ed25519PublicKey):
|
||||
logger.error("Public key is not Ed25519", extra={"kid": kid})
|
||||
return False
|
||||
|
||||
self.trusted_public_keys[kid] = public_key
|
||||
logger.info("Trusted key added", extra={"kid": kid})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to add trusted key",
|
||||
extra={"error": str(e), "kid": kid},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
signature_verifier = SignatureVerifier()
|
||||
327
guacamole_test_11_26/api/core/ssrf_protection.py
Executable file
327
guacamole_test_11_26/api/core/ssrf_protection.py
Executable file
@ -0,0 +1,327 @@
|
||||
"""Enhanced SSRF attack protection with DNS pinning and rebinding prevention."""
|
||||
|
||||
# Standard library imports
|
||||
import ipaddress
|
||||
import socket
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
# Third-party imports
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class SSRFProtection:
|
||||
"""Enhanced SSRF attack protection with DNS pinning."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize SSRF protection with blocked IPs and networks."""
|
||||
self._dns_cache: Dict[str, Tuple[str, float, int]] = {}
|
||||
self._dns_cache_ttl = 300
|
||||
|
||||
self._blocked_ips: Set[str] = {
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
"0.0.0.0",
|
||||
"169.254.169.254",
|
||||
"10.0.0.1",
|
||||
"10.255.255.255",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
"192.168.0.1",
|
||||
"192.168.255.255",
|
||||
}
|
||||
|
||||
self._blocked_networks = [
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16",
|
||||
"224.0.0.0/4",
|
||||
"240.0.0.0/4",
|
||||
"172.17.0.0/16",
|
||||
"172.18.0.0/16",
|
||||
"172.19.0.0/16",
|
||||
"172.20.0.0/16",
|
||||
"172.21.0.0/16",
|
||||
"172.22.0.0/16",
|
||||
"172.23.0.0/16",
|
||||
"172.24.0.0/16",
|
||||
"172.25.0.0/16",
|
||||
"172.26.0.0/16",
|
||||
"172.27.0.0/16",
|
||||
"172.28.0.0/16",
|
||||
"172.29.0.0/16",
|
||||
"172.30.0.0/16",
|
||||
"172.31.0.0/16",
|
||||
]
|
||||
|
||||
self._allowed_networks: Dict[str, List[str]] = {
|
||||
"USER": ["10.0.0.0/8", "172.16.0.0/16", "192.168.1.0/24"],
|
||||
"ADMIN": [
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/16",
|
||||
"192.168.0.0/16",
|
||||
"203.0.113.0/24",
|
||||
],
|
||||
"SUPER_ADMIN": ["0.0.0.0/0"],
|
||||
}
|
||||
|
||||
def validate_host(
|
||||
self, hostname: str, user_role: str
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate host with enhanced SSRF protection.
|
||||
|
||||
Args:
|
||||
hostname: Hostname or IP address
|
||||
user_role: User role
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, reason: str)
|
||||
"""
|
||||
try:
|
||||
if not hostname or len(hostname) > 253:
|
||||
return False, f"Invalid hostname length: {hostname}"
|
||||
|
||||
suspicious_chars = [
|
||||
"..",
|
||||
"//",
|
||||
"\\",
|
||||
"<",
|
||||
">",
|
||||
'"',
|
||||
"'",
|
||||
"`",
|
||||
"\x00",
|
||||
]
|
||||
if any(char in hostname for char in suspicious_chars):
|
||||
return False, f"Suspicious characters in hostname: {hostname}"
|
||||
|
||||
if hostname.lower() in ("localhost", "127.0.0.1", "::1"):
|
||||
return False, f"Host {hostname} is blocked (localhost)"
|
||||
|
||||
resolved_ip = self._resolve_hostname_with_pinning(hostname)
|
||||
if not resolved_ip:
|
||||
return False, f"Cannot resolve hostname: {hostname}"
|
||||
|
||||
if resolved_ip in self._blocked_ips:
|
||||
return False, f"IP {resolved_ip} is in blocked list"
|
||||
|
||||
ip_addr = ipaddress.ip_address(resolved_ip)
|
||||
for blocked_network in self._blocked_networks:
|
||||
if ip_addr in ipaddress.ip_network(blocked_network):
|
||||
return (
|
||||
False,
|
||||
f"IP {resolved_ip} is in blocked network {blocked_network}",
|
||||
)
|
||||
|
||||
allowed_networks = self._allowed_networks.get(user_role, [])
|
||||
if not allowed_networks:
|
||||
return False, f"Role {user_role} has no allowed networks"
|
||||
|
||||
if user_role == "SUPER_ADMIN":
|
||||
return True, f"IP {resolved_ip} allowed for SUPER_ADMIN"
|
||||
|
||||
for allowed_network in allowed_networks:
|
||||
if ip_addr in ipaddress.ip_network(allowed_network):
|
||||
return (
|
||||
True,
|
||||
f"IP {resolved_ip} allowed in network {allowed_network}",
|
||||
)
|
||||
|
||||
return (
|
||||
False,
|
||||
f"IP {resolved_ip} not in any allowed network for role {user_role}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("SSRF validation error", hostname=hostname, error=str(e))
|
||||
return False, f"Error validating host: {str(e)}"
|
||||
|
||||
def _resolve_hostname_with_pinning(self, hostname: str) -> Optional[str]:
|
||||
"""DNS resolution with pinning to prevent rebinding attacks.
|
||||
|
||||
Args:
|
||||
hostname: Hostname to resolve
|
||||
|
||||
Returns:
|
||||
IP address or None if resolution failed
|
||||
"""
|
||||
try:
|
||||
cache_key = hostname.lower()
|
||||
if cache_key in self._dns_cache:
|
||||
cached_ip, timestamp, ttl = self._dns_cache[cache_key]
|
||||
|
||||
if time.time() - timestamp < ttl:
|
||||
logger.debug(
|
||||
"Using cached DNS resolution",
|
||||
hostname=hostname,
|
||||
ip=cached_ip,
|
||||
age_seconds=int(time.time() - timestamp),
|
||||
)
|
||||
return cached_ip
|
||||
del self._dns_cache[cache_key]
|
||||
|
||||
original_timeout = socket.getdefaulttimeout()
|
||||
socket.setdefaulttimeout(5)
|
||||
|
||||
try:
|
||||
ip1 = socket.gethostbyname(hostname)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
ip2 = socket.gethostbyname(hostname)
|
||||
|
||||
if ip1 != ip2:
|
||||
logger.warning(
|
||||
"DNS rebinding detected",
|
||||
hostname=hostname,
|
||||
ip1=ip1,
|
||||
ip2=ip2,
|
||||
)
|
||||
return None
|
||||
|
||||
if ip1 in ("127.0.0.1", "::1"):
|
||||
logger.warning(
|
||||
"DNS resolution returned localhost", hostname=hostname, ip=ip1
|
||||
)
|
||||
return None
|
||||
|
||||
self._dns_cache[cache_key] = (
|
||||
ip1,
|
||||
time.time(),
|
||||
self._dns_cache_ttl,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"DNS resolution successful", hostname=hostname, ip=ip1, cached=True
|
||||
)
|
||||
|
||||
return ip1
|
||||
|
||||
finally:
|
||||
socket.setdefaulttimeout(original_timeout)
|
||||
|
||||
except socket.gaierror as e:
|
||||
logger.warning("DNS resolution failed", hostname=hostname, error=str(e))
|
||||
return None
|
||||
except socket.timeout:
|
||||
logger.warning("DNS resolution timeout", hostname=hostname)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected DNS resolution error", hostname=hostname, error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
def validate_port(self, port: int) -> Tuple[bool, str]:
|
||||
"""Validate port number.
|
||||
|
||||
Args:
|
||||
port: Port number
|
||||
|
||||
Returns:
|
||||
Tuple of (valid: bool, reason: str)
|
||||
"""
|
||||
if not isinstance(port, int) or port < 1 or port > 65535:
|
||||
return False, f"Invalid port: {port}"
|
||||
|
||||
blocked_ports = {
|
||||
22,
|
||||
23,
|
||||
25,
|
||||
53,
|
||||
80,
|
||||
110,
|
||||
143,
|
||||
443,
|
||||
993,
|
||||
995,
|
||||
135,
|
||||
139,
|
||||
445,
|
||||
1433,
|
||||
1521,
|
||||
3306,
|
||||
5432,
|
||||
6379,
|
||||
3389,
|
||||
5900,
|
||||
5901,
|
||||
5902,
|
||||
5903,
|
||||
5904,
|
||||
5905,
|
||||
8080,
|
||||
8443,
|
||||
9090,
|
||||
9091,
|
||||
}
|
||||
|
||||
if port in blocked_ports:
|
||||
return False, f"Port {port} is blocked (system port)"
|
||||
|
||||
return True, f"Port {port} is valid"
|
||||
|
||||
def cleanup_expired_cache(self) -> None:
|
||||
"""Clean up expired DNS cache entries."""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key
|
||||
for key, (_, timestamp, ttl) in self._dns_cache.items()
|
||||
if current_time - timestamp > ttl
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self._dns_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(
|
||||
"Cleaned up expired DNS cache entries", count=len(expired_keys)
|
||||
)
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get DNS cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
active_entries = 0
|
||||
expired_entries = 0
|
||||
|
||||
for _, timestamp, ttl in self._dns_cache.values():
|
||||
if current_time - timestamp < ttl:
|
||||
active_entries += 1
|
||||
else:
|
||||
expired_entries += 1
|
||||
|
||||
return {
|
||||
"total_entries": len(self._dns_cache),
|
||||
"active_entries": active_entries,
|
||||
"expired_entries": expired_entries,
|
||||
"cache_ttl_seconds": self._dns_cache_ttl,
|
||||
"blocked_ips_count": len(self._blocked_ips),
|
||||
"blocked_networks_count": len(self._blocked_networks),
|
||||
}
|
||||
|
||||
def add_blocked_ip(self, ip: str) -> None:
|
||||
"""Add IP to blocked list.
|
||||
|
||||
Args:
|
||||
ip: IP address to block
|
||||
"""
|
||||
self._blocked_ips.add(ip)
|
||||
logger.info("Added IP to blocked list", ip=ip)
|
||||
|
||||
def remove_blocked_ip(self, ip: str) -> None:
|
||||
"""Remove IP from blocked list.
|
||||
|
||||
Args:
|
||||
ip: IP address to unblock
|
||||
"""
|
||||
self._blocked_ips.discard(ip)
|
||||
logger.info("Removed IP from blocked list", ip=ip)
|
||||
|
||||
|
||||
# Global instance for use in API
|
||||
ssrf_protection = SSRFProtection()
|
||||
263
guacamole_test_11_26/api/core/token_blacklist.py
Executable file
263
guacamole_test_11_26/api/core/token_blacklist.py
Executable file
@ -0,0 +1,263 @@
|
||||
"""Redis-based token blacklist for JWT token revocation."""
|
||||
|
||||
# Standard library imports
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# Third-party imports
|
||||
import redis
|
||||
import structlog
|
||||
|
||||
# Local imports
|
||||
from .session_storage import session_storage
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Redis configuration constants
|
||||
REDIS_SOCKET_TIMEOUT = 5
|
||||
REDIS_DEFAULT_HOST = "localhost"
|
||||
REDIS_DEFAULT_PORT = "6379"
|
||||
REDIS_DEFAULT_DB = "0"
|
||||
|
||||
# Blacklist constants
|
||||
BLACKLIST_KEY_PREFIX = "blacklist:"
|
||||
TOKEN_HASH_PREVIEW_LENGTH = 16
|
||||
DEFAULT_REVOCATION_REASON = "logout"
|
||||
DEFAULT_FORCE_LOGOUT_REASON = "force_logout"
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""Redis-based blacklist for JWT token revocation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize token blacklist with Redis connection."""
|
||||
self.redis_client = redis.Redis(
|
||||
host=os.getenv("REDIS_HOST", REDIS_DEFAULT_HOST),
|
||||
port=int(os.getenv("REDIS_PORT", REDIS_DEFAULT_PORT)),
|
||||
password=os.getenv("REDIS_PASSWORD"),
|
||||
db=int(os.getenv("REDIS_DB", REDIS_DEFAULT_DB)),
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client.ping()
|
||||
logger.info("Token blacklist Redis connection established")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(
|
||||
"Failed to connect to Redis for token blacklist", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_token_hash(self, token: str) -> str:
|
||||
"""Get token hash for use as Redis key.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of token
|
||||
"""
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
def revoke_token(
|
||||
self,
|
||||
token: str,
|
||||
reason: str = DEFAULT_REVOCATION_REASON,
|
||||
revoked_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Revoke token (add to blacklist).
|
||||
|
||||
Args:
|
||||
token: JWT token to revoke
|
||||
reason: Revocation reason
|
||||
revoked_by: Username who revoked the token
|
||||
|
||||
Returns:
|
||||
True if token successfully revoked
|
||||
"""
|
||||
try:
|
||||
from .utils import get_token_expiry_info
|
||||
|
||||
expiry_info = get_token_expiry_info(token)
|
||||
|
||||
if not expiry_info:
|
||||
logger.warning(
|
||||
"Cannot revoke token: invalid or expired", reason=reason
|
||||
)
|
||||
return False
|
||||
|
||||
token_hash = self._get_token_hash(token)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
blacklist_data = {
|
||||
"token_hash": token_hash,
|
||||
"reason": reason,
|
||||
"revoked_at": now.isoformat(),
|
||||
"revoked_by": revoked_by,
|
||||
"expires_at": expiry_info["expires_at"],
|
||||
"username": expiry_info.get("username"),
|
||||
"token_type": expiry_info.get("token_type", "access"),
|
||||
}
|
||||
|
||||
expires_at = datetime.fromisoformat(expiry_info["expires_at"])
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
ttl_seconds = int((expires_at - now).total_seconds())
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
logger.debug(
|
||||
"Token already expired, no need to blacklist",
|
||||
username=expiry_info.get("username"),
|
||||
)
|
||||
return True
|
||||
|
||||
self.redis_client.setex(
|
||||
f"{BLACKLIST_KEY_PREFIX}{token_hash}",
|
||||
ttl_seconds,
|
||||
json.dumps(blacklist_data),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Token revoked successfully",
|
||||
token_hash=token_hash[:TOKEN_HASH_PREVIEW_LENGTH] + "...",
|
||||
username=expiry_info.get("username"),
|
||||
reason=reason,
|
||||
revoked_by=revoked_by,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke token", error=str(e), reason=reason)
|
||||
return False
|
||||
|
||||
def is_token_revoked(self, token: str) -> bool:
|
||||
"""Check if token is revoked.
|
||||
|
||||
Args:
|
||||
token: JWT token to check
|
||||
|
||||
Returns:
|
||||
True if token is revoked
|
||||
"""
|
||||
try:
|
||||
token_hash = self._get_token_hash(token)
|
||||
blacklist_data = self.redis_client.get(f"{BLACKLIST_KEY_PREFIX}{token_hash}")
|
||||
|
||||
if blacklist_data:
|
||||
data = json.loads(blacklist_data)
|
||||
logger.debug(
|
||||
"Token is revoked",
|
||||
token_hash=token_hash[:TOKEN_HASH_PREVIEW_LENGTH] + "...",
|
||||
reason=data.get("reason"),
|
||||
revoked_at=data.get("revoked_at"),
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check token revocation status", error=str(e)
|
||||
)
|
||||
return False
|
||||
|
||||
def revoke_user_tokens(
|
||||
self,
|
||||
username: str,
|
||||
reason: str = DEFAULT_FORCE_LOGOUT_REASON,
|
||||
revoked_by: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Revoke all user tokens.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
reason: Revocation reason
|
||||
revoked_by: Who revoked the tokens
|
||||
|
||||
Returns:
|
||||
Number of revoked tokens
|
||||
"""
|
||||
try:
|
||||
session = session_storage.get_session_by_username(username)
|
||||
|
||||
if not session:
|
||||
logger.debug(
|
||||
"No active session found for user", username=username
|
||||
)
|
||||
return 0
|
||||
|
||||
session_storage.delete_user_sessions(username)
|
||||
|
||||
logger.info(
|
||||
"User tokens revoked",
|
||||
username=username,
|
||||
reason=reason,
|
||||
revoked_by=revoked_by,
|
||||
)
|
||||
|
||||
return 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to revoke user tokens", username=username, error=str(e)
|
||||
)
|
||||
return 0
|
||||
|
||||
def get_blacklist_stats(self) -> Dict[str, Any]:
|
||||
"""Get blacklist statistics.
|
||||
|
||||
Returns:
|
||||
Statistics about revoked tokens
|
||||
"""
|
||||
try:
|
||||
blacklist_keys = self.redis_client.keys(f"{BLACKLIST_KEY_PREFIX}*")
|
||||
|
||||
reasons_count: Dict[str, int] = {}
|
||||
for key in blacklist_keys:
|
||||
data = self.redis_client.get(key)
|
||||
if data:
|
||||
blacklist_data = json.loads(data)
|
||||
reason = blacklist_data.get("reason", "unknown")
|
||||
reasons_count[reason] = reasons_count.get(reason, 0) + 1
|
||||
|
||||
return {
|
||||
"revoked_tokens": len(blacklist_keys),
|
||||
"reasons": reasons_count,
|
||||
"redis_memory_usage": (
|
||||
self.redis_client.memory_usage(f"{BLACKLIST_KEY_PREFIX}*")
|
||||
if blacklist_keys
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get blacklist stats", error=str(e))
|
||||
return {
|
||||
"revoked_tokens": 0,
|
||||
"reasons": {},
|
||||
"redis_memory_usage": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def cleanup_expired_blacklist(self) -> int:
|
||||
"""Clean up expired blacklist entries.
|
||||
|
||||
Returns:
|
||||
Number of cleaned entries (always 0, Redis handles this automatically)
|
||||
"""
|
||||
logger.debug(
|
||||
"Expired blacklist cleanup completed (Redis TTL handles this automatically)"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Global instance for use in API
|
||||
token_blacklist = TokenBlacklist()
|
||||
362
guacamole_test_11_26/api/core/utils.py
Executable file
362
guacamole_test_11_26/api/core/utils.py
Executable file
@ -0,0 +1,362 @@
|
||||
"""Utilities for JWT token and session storage operations."""
|
||||
|
||||
# Standard library imports
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# Third-party imports
|
||||
import jwt
|
||||
import structlog
|
||||
|
||||
# Local imports
|
||||
from .session_storage import session_storage
|
||||
from .token_blacklist import token_blacklist
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# JWT configuration from environment variables
|
||||
JWT_SECRET_KEY = os.getenv(
|
||||
"JWT_SECRET_KEY",
|
||||
"your_super_secret_jwt_key_minimum_32_characters_long",
|
||||
)
|
||||
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES = int(
|
||||
os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "60")
|
||||
)
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS = int(
|
||||
os.getenv("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "7")
|
||||
)
|
||||
|
||||
|
||||
def create_jwt_token(
|
||||
user_info: Dict[str, Any], session_id: str, token_type: str = "access"
|
||||
) -> str:
|
||||
"""Create JWT token with session_id instead of Guacamole token.
|
||||
|
||||
Args:
|
||||
user_info: User information dictionary
|
||||
session_id: Session ID in Redis
|
||||
token_type: Token type ("access" or "refresh")
|
||||
|
||||
Returns:
|
||||
JWT token as string
|
||||
|
||||
Raises:
|
||||
Exception: If token creation fails
|
||||
"""
|
||||
try:
|
||||
if token_type == "refresh":
|
||||
expire_delta = timedelta(days=JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
else:
|
||||
expire_delta = timedelta(
|
||||
minutes=JWT_ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"username": user_info["username"],
|
||||
"role": user_info["role"],
|
||||
"permissions": user_info.get("permissions", []),
|
||||
"session_id": session_id,
|
||||
"token_type": token_type,
|
||||
"exp": now + expire_delta,
|
||||
"iat": now,
|
||||
"iss": "remote-access-api",
|
||||
}
|
||||
|
||||
optional_fields = [
|
||||
"full_name",
|
||||
"email",
|
||||
"organization",
|
||||
"organizational_role",
|
||||
]
|
||||
for field in optional_fields:
|
||||
if field in user_info:
|
||||
payload[field] = user_info[field]
|
||||
|
||||
token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
logger.info(
|
||||
"JWT token created successfully",
|
||||
username=user_info["username"],
|
||||
token_type=token_type,
|
||||
session_id=session_id,
|
||||
expires_in_minutes=expire_delta.total_seconds() / 60,
|
||||
payload_keys=list(payload.keys()),
|
||||
token_prefix=token[:30] + "...",
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create JWT token",
|
||||
username=user_info.get("username", "unknown"),
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def verify_jwt_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify and decode JWT token with blacklist check.
|
||||
|
||||
Args:
|
||||
token: JWT token to verify
|
||||
|
||||
Returns:
|
||||
Decoded payload or None if token is invalid
|
||||
"""
|
||||
try:
|
||||
logger.debug("Starting JWT verification", token_prefix=token[:30] + "...")
|
||||
|
||||
if token_blacklist.is_token_revoked(token):
|
||||
logger.info("JWT token is revoked", token_prefix=token[:20] + "...")
|
||||
return None
|
||||
|
||||
logger.debug("Token not in blacklist, attempting decode")
|
||||
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
logger.info(
|
||||
"JWT decode successful",
|
||||
username=payload.get("username"),
|
||||
payload_keys=list(payload.keys()),
|
||||
has_session_id="session_id" in payload,
|
||||
session_id=payload.get("session_id", "NOT_FOUND"),
|
||||
)
|
||||
|
||||
required_fields = ["username", "role", "session_id", "exp", "iat"]
|
||||
for field in required_fields:
|
||||
if field not in payload:
|
||||
logger.warning(
|
||||
"JWT token missing required field",
|
||||
field=field,
|
||||
username=payload.get("username", "unknown"),
|
||||
available_fields=list(payload.keys()),
|
||||
)
|
||||
return None
|
||||
|
||||
logger.debug("All required fields present")
|
||||
|
||||
exp_timestamp = payload["exp"]
|
||||
current_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
if current_timestamp > exp_timestamp:
|
||||
logger.info(
|
||||
"JWT token expired",
|
||||
username=payload["username"],
|
||||
expired_at=datetime.fromtimestamp(
|
||||
exp_timestamp, tz=timezone.utc
|
||||
).isoformat(),
|
||||
current_time=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"Token not expired, checking Redis session",
|
||||
session_id=payload["session_id"],
|
||||
)
|
||||
|
||||
session_data = session_storage.get_session(payload["session_id"])
|
||||
if not session_data:
|
||||
logger.warning(
|
||||
"Session not found for JWT token",
|
||||
username=payload["username"],
|
||||
session_id=payload["session_id"],
|
||||
possible_reasons=[
|
||||
"session expired in Redis",
|
||||
"session never created",
|
||||
"Redis connection issue",
|
||||
],
|
||||
)
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"Session found in Redis",
|
||||
username=payload["username"],
|
||||
session_id=payload["session_id"],
|
||||
session_keys=list(session_data.keys()),
|
||||
)
|
||||
|
||||
if "guac_token" not in session_data:
|
||||
logger.error(
|
||||
"Session exists but missing guac_token",
|
||||
username=payload["username"],
|
||||
session_id=payload["session_id"],
|
||||
session_keys=list(session_data.keys()),
|
||||
)
|
||||
return None
|
||||
|
||||
payload["guac_token"] = session_data["guac_token"]
|
||||
|
||||
logger.info(
|
||||
"JWT token verified successfully",
|
||||
username=payload["username"],
|
||||
role=payload["role"],
|
||||
token_type=payload.get("token_type", "access"),
|
||||
session_id=payload["session_id"],
|
||||
guac_token_length=len(session_data["guac_token"]),
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.info(
|
||||
"JWT token expired (ExpiredSignatureError)",
|
||||
token_prefix=token[:20] + "...",
|
||||
)
|
||||
return None
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.warning(
|
||||
"Invalid JWT token (InvalidTokenError)",
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
token_prefix=token[:20] + "...",
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error verifying JWT token",
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
user_info: Dict[str, Any], session_id: str
|
||||
) -> str:
|
||||
"""Create refresh token.
|
||||
|
||||
Args:
|
||||
user_info: User information dictionary
|
||||
session_id: Session ID in Redis
|
||||
|
||||
Returns:
|
||||
Refresh token
|
||||
"""
|
||||
return create_jwt_token(user_info, session_id, token_type="refresh")
|
||||
|
||||
|
||||
def extract_token_from_header(
|
||||
authorization_header: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""Extract token from Authorization header.
|
||||
|
||||
Args:
|
||||
authorization_header: Authorization header value
|
||||
|
||||
Returns:
|
||||
JWT token or None
|
||||
"""
|
||||
if not authorization_header:
|
||||
return None
|
||||
|
||||
if not authorization_header.startswith("Bearer "):
|
||||
return None
|
||||
|
||||
return authorization_header.split(" ", 1)[1]
|
||||
|
||||
|
||||
def get_token_expiry_info(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get token expiration information.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Expiration information or None
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
exp_timestamp = payload.get("exp")
|
||||
iat_timestamp = payload.get("iat")
|
||||
|
||||
if not exp_timestamp:
|
||||
return None
|
||||
|
||||
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
||||
iat_datetime = (
|
||||
datetime.fromtimestamp(iat_timestamp, tz=timezone.utc)
|
||||
if iat_timestamp
|
||||
else None
|
||||
)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"expires_at": exp_datetime.isoformat(),
|
||||
"issued_at": iat_datetime.isoformat() if iat_datetime else None,
|
||||
"expires_in_seconds": max(
|
||||
0, int((exp_datetime - current_time).total_seconds())
|
||||
),
|
||||
"is_expired": current_time > exp_datetime,
|
||||
"username": payload.get("username"),
|
||||
"token_type": payload.get("token_type", "access"),
|
||||
"session_id": payload.get("session_id"),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token expiry info", error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
def is_token_expired(token: str) -> bool:
|
||||
"""Check if token is expired.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
True if token is expired, False if valid
|
||||
"""
|
||||
expiry_info = get_token_expiry_info(token)
|
||||
return expiry_info["is_expired"] if expiry_info else True
|
||||
|
||||
|
||||
def revoke_token(
|
||||
token: str,
|
||||
reason: str = "logout",
|
||||
revoked_by: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Revoke token (add to blacklist).
|
||||
|
||||
Args:
|
||||
token: JWT token to revoke
|
||||
reason: Revocation reason
|
||||
revoked_by: Who revoked the token
|
||||
|
||||
Returns:
|
||||
True if token successfully revoked
|
||||
"""
|
||||
try:
|
||||
return token_blacklist.revoke_token(token, reason, revoked_by)
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke token", error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def revoke_user_tokens(
|
||||
username: str,
|
||||
reason: str = "force_logout",
|
||||
revoked_by: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Revoke all user tokens.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
reason: Revocation reason
|
||||
revoked_by: Who revoked the tokens
|
||||
|
||||
Returns:
|
||||
Number of revoked tokens
|
||||
"""
|
||||
try:
|
||||
return token_blacklist.revoke_user_tokens(
|
||||
username, reason, revoked_by
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to revoke user tokens", username=username, error=str(e)
|
||||
)
|
||||
return 0
|
||||
326
guacamole_test_11_26/api/core/websocket_manager.py
Executable file
326
guacamole_test_11_26/api/core/websocket_manager.py
Executable file
@ -0,0 +1,326 @@
|
||||
"""WebSocket Manager for real-time client notifications."""
|
||||
|
||||
# Standard library imports
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
# Third-party imports
|
||||
from fastapi import WebSocket
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""WebSocket connection manager for sending notifications to clients.
|
||||
|
||||
Supported events:
|
||||
- connection_expired: Connection expired
|
||||
- connection_deleted: Connection deleted manually
|
||||
- connection_will_expire: Connection will expire soon (5 min warning)
|
||||
- jwt_will_expire: JWT token will expire soon (5 min warning)
|
||||
- jwt_expired: JWT token expired
|
||||
- connection_extended: Connection TTL extended
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize WebSocket manager."""
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket, username: str) -> None:
|
||||
"""Connect a new client.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection (already accepted)
|
||||
username: Username
|
||||
"""
|
||||
async with self._lock:
|
||||
if username not in self.active_connections:
|
||||
self.active_connections[username] = set()
|
||||
self.active_connections[username].add(websocket)
|
||||
|
||||
logger.info(
|
||||
"WebSocket client connected",
|
||||
username=username,
|
||||
total_connections=len(
|
||||
self.active_connections.get(username, set())
|
||||
),
|
||||
)
|
||||
|
||||
async def disconnect(self, websocket: WebSocket, username: str) -> None:
|
||||
"""Disconnect a client.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
username: Username
|
||||
"""
|
||||
async with self._lock:
|
||||
if username in self.active_connections:
|
||||
self.active_connections[username].discard(websocket)
|
||||
|
||||
if not self.active_connections[username]:
|
||||
del self.active_connections[username]
|
||||
|
||||
logger.info(
|
||||
"WebSocket client disconnected",
|
||||
username=username,
|
||||
remaining_connections=len(
|
||||
self.active_connections.get(username, set())
|
||||
),
|
||||
)
|
||||
|
||||
async def send_to_user(
|
||||
self, username: str, message: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Send message to all WebSocket connections of a user.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
message: Dictionary with data to send
|
||||
"""
|
||||
if username not in self.active_connections:
|
||||
logger.debug(
|
||||
"No active WebSocket connections for user", username=username
|
||||
)
|
||||
return
|
||||
|
||||
connections = self.active_connections[username].copy()
|
||||
|
||||
disconnected = []
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
logger.debug(
|
||||
"Message sent via WebSocket",
|
||||
username=username,
|
||||
event_type=message.get("type"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to send WebSocket message",
|
||||
username=username,
|
||||
error=str(e),
|
||||
)
|
||||
disconnected.append(websocket)
|
||||
|
||||
if disconnected:
|
||||
async with self._lock:
|
||||
for ws in disconnected:
|
||||
self.active_connections[username].discard(ws)
|
||||
|
||||
if not self.active_connections[username]:
|
||||
del self.active_connections[username]
|
||||
|
||||
async def send_connection_expired(
|
||||
self,
|
||||
username: str,
|
||||
connection_id: str,
|
||||
hostname: str,
|
||||
protocol: str,
|
||||
) -> None:
|
||||
"""Notify about connection expiration.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
connection_id: Connection ID
|
||||
hostname: Machine hostname
|
||||
protocol: Connection protocol
|
||||
"""
|
||||
message = {
|
||||
"type": "connection_expired",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"connection_id": connection_id,
|
||||
"hostname": hostname,
|
||||
"protocol": protocol,
|
||||
"reason": "TTL expired",
|
||||
},
|
||||
}
|
||||
await self.send_to_user(username, message)
|
||||
|
||||
logger.info(
|
||||
"Connection expired notification sent",
|
||||
username=username,
|
||||
connection_id=connection_id,
|
||||
hostname=hostname,
|
||||
)
|
||||
|
||||
async def send_connection_deleted(
|
||||
self,
|
||||
username: str,
|
||||
connection_id: str,
|
||||
hostname: str,
|
||||
protocol: str,
|
||||
reason: str = "manual",
|
||||
) -> None:
|
||||
"""Notify about connection deletion.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
connection_id: Connection ID
|
||||
hostname: Machine hostname
|
||||
protocol: Connection protocol
|
||||
reason: Deletion reason (manual, expired, error)
|
||||
"""
|
||||
message = {
|
||||
"type": "connection_deleted",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"connection_id": connection_id,
|
||||
"hostname": hostname,
|
||||
"protocol": protocol,
|
||||
"reason": reason,
|
||||
},
|
||||
}
|
||||
await self.send_to_user(username, message)
|
||||
|
||||
logger.info(
|
||||
"Connection deleted notification sent",
|
||||
username=username,
|
||||
connection_id=connection_id,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
async def send_connection_will_expire(
|
||||
self,
|
||||
username: str,
|
||||
connection_id: str,
|
||||
hostname: str,
|
||||
protocol: str,
|
||||
minutes_remaining: int,
|
||||
) -> None:
|
||||
"""Warn about upcoming connection expiration.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
connection_id: Connection ID
|
||||
hostname: Machine hostname
|
||||
protocol: Connection protocol
|
||||
minutes_remaining: Minutes until expiration
|
||||
"""
|
||||
message = {
|
||||
"type": "connection_will_expire",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"connection_id": connection_id,
|
||||
"hostname": hostname,
|
||||
"protocol": protocol,
|
||||
"minutes_remaining": minutes_remaining,
|
||||
},
|
||||
}
|
||||
await self.send_to_user(username, message)
|
||||
|
||||
logger.info(
|
||||
"Connection expiration warning sent",
|
||||
username=username,
|
||||
connection_id=connection_id,
|
||||
minutes_remaining=minutes_remaining,
|
||||
)
|
||||
|
||||
async def send_jwt_will_expire(
|
||||
self, username: str, minutes_remaining: int
|
||||
) -> None:
|
||||
"""Warn about upcoming JWT token expiration.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
minutes_remaining: Minutes until expiration
|
||||
"""
|
||||
message = {
|
||||
"type": "jwt_will_expire",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"minutes_remaining": minutes_remaining,
|
||||
"action_required": "Please refresh your token or re-login",
|
||||
},
|
||||
}
|
||||
await self.send_to_user(username, message)
|
||||
|
||||
logger.info(
|
||||
"JWT expiration warning sent",
|
||||
username=username,
|
||||
minutes_remaining=minutes_remaining,
|
||||
)
|
||||
|
||||
async def send_jwt_expired(self, username: str) -> None:
|
||||
"""Notify about JWT token expiration.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
"""
|
||||
message = {
|
||||
"type": "jwt_expired",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {"action_required": "Please re-login"},
|
||||
}
|
||||
await self.send_to_user(username, message)
|
||||
|
||||
logger.info("JWT expired notification sent", username=username)
|
||||
|
||||
async def send_connection_extended(
|
||||
self,
|
||||
username: str,
|
||||
connection_id: str,
|
||||
hostname: str,
|
||||
new_expires_at: datetime,
|
||||
additional_minutes: int,
|
||||
) -> None:
|
||||
"""Notify about connection extension.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
connection_id: Connection ID
|
||||
hostname: Machine hostname
|
||||
new_expires_at: New expiration time
|
||||
additional_minutes: Minutes added
|
||||
"""
|
||||
message = {
|
||||
"type": "connection_extended",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"connection_id": connection_id,
|
||||
"hostname": hostname,
|
||||
"new_expires_at": new_expires_at.isoformat(),
|
||||
"additional_minutes": additional_minutes,
|
||||
},
|
||||
}
|
||||
await self.send_to_user(username, message)
|
||||
|
||||
logger.info(
|
||||
"Connection extension notification sent",
|
||||
username=username,
|
||||
connection_id=connection_id,
|
||||
additional_minutes=additional_minutes,
|
||||
)
|
||||
|
||||
def get_active_users(self) -> List[str]:
|
||||
"""Get list of users with active WebSocket connections.
|
||||
|
||||
Returns:
|
||||
List of usernames
|
||||
"""
|
||||
return list(self.active_connections.keys())
|
||||
|
||||
def get_connection_count(self, username: Optional[str] = None) -> int:
|
||||
"""Get count of active WebSocket connections.
|
||||
|
||||
Args:
|
||||
username: Username (if None, returns total count)
|
||||
|
||||
Returns:
|
||||
Number of connections
|
||||
"""
|
||||
if username:
|
||||
return len(self.active_connections.get(username, []))
|
||||
|
||||
return sum(
|
||||
len(connections)
|
||||
for connections in self.active_connections.values()
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
websocket_manager = WebSocketManager()
|
||||
|
||||
129
guacamole_test_11_26/api/get_signing_key.py
Executable file
129
guacamole_test_11_26/api/get_signing_key.py
Executable file
@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Utility to retrieve Ed25519 signing public key for client configuration.
|
||||
|
||||
This script outputs the public key in base64 format for adding to
|
||||
SignatureVerificationService.ts on the client side.
|
||||
|
||||
Usage:
|
||||
python get_signing_key.py
|
||||
"""
|
||||
|
||||
# Standard library imports
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
# Third-party imports
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
|
||||
def get_signing_public_key() -> Tuple[str, str]:
|
||||
"""Read signing public key from file.
|
||||
|
||||
Returns:
|
||||
Tuple of (PEM format string, base64 encoded string).
|
||||
|
||||
Raises:
|
||||
SystemExit: If key file not found or failed to load.
|
||||
"""
|
||||
key_file = os.getenv(
|
||||
"ED25519_SIGNING_KEY_PATH", "/app/secrets/ed25519_signing_key.pem"
|
||||
)
|
||||
|
||||
if not os.path.exists(key_file):
|
||||
print(
|
||||
f"ERROR: Signing key file not found: {key_file}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print("", file=sys.stderr)
|
||||
print("SOLUTION:", file=sys.stderr)
|
||||
print(
|
||||
"1. Start the API server first to generate the key:",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
" docker-compose up remote_access_api",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
"2. Or run this script inside the container:",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
" docker-compose exec remote_access_api python get_signing_key.py",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
with open(key_file, "rb") as f:
|
||||
private_key_pem = f.read()
|
||||
|
||||
private_key = serialization.load_pem_private_key(
|
||||
private_key_pem, password=None, backend=default_backend()
|
||||
)
|
||||
|
||||
public_key = private_key.public_key()
|
||||
|
||||
public_key_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
public_key_b64 = base64.b64encode(public_key_pem).decode("utf-8")
|
||||
|
||||
return public_key_pem.decode("utf-8"), public_key_b64
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"ERROR: Failed to load signing key: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main function to display signing public key."""
|
||||
print("=" * 80)
|
||||
print("Ed25519 Signing Public Key for Client Configuration")
|
||||
print("=" * 80)
|
||||
print("")
|
||||
|
||||
pem, base64_encoded = get_signing_public_key()
|
||||
|
||||
print("PEM Format:")
|
||||
print(pem)
|
||||
|
||||
print("Base64 Encoded (for client configuration):")
|
||||
print(base64_encoded)
|
||||
print("")
|
||||
|
||||
print("=" * 80)
|
||||
print("How to use:")
|
||||
print("=" * 80)
|
||||
print("")
|
||||
print("1. Copy the Base64 encoded key above")
|
||||
print("")
|
||||
print(
|
||||
"2. Update MachineControlCenter/src/renderer/services/SignatureVerificationService.ts:"
|
||||
)
|
||||
print("")
|
||||
print(" const TRUSTED_SIGNING_KEYS: Record<Environment, string> = {")
|
||||
print(f" production: '{base64_encoded}',")
|
||||
print(f" development: '{base64_encoded}',")
|
||||
print(f" local: '{base64_encoded}'")
|
||||
print(" };")
|
||||
print("")
|
||||
print("3. Rebuild the client application:")
|
||||
print(" cd MachineControlCenter")
|
||||
print(" npm run build")
|
||||
print("")
|
||||
print("=" * 80)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
2903
guacamole_test_11_26/api/main.py
Executable file
2903
guacamole_test_11_26/api/main.py
Executable file
File diff suppressed because it is too large
Load Diff
13
guacamole_test_11_26/api/requirements.txt
Executable file
13
guacamole_test_11_26/api/requirements.txt
Executable file
@ -0,0 +1,13 @@
|
||||
fastapi==0.115.12
|
||||
uvicorn[standard]==0.32.1
|
||||
requests==2.32.3
|
||||
pydantic==2.5.0
|
||||
python-multipart==0.0.6
|
||||
structlog==23.2.0
|
||||
psutil==5.9.6
|
||||
python-dotenv==1.0.0
|
||||
PyJWT==2.8.0
|
||||
cryptography==43.0.3
|
||||
redis==5.0.1
|
||||
psycopg2-binary==2.9.9
|
||||
paramiko==3.4.0
|
||||
477
guacamole_test_11_26/api/routers.py
Executable file
477
guacamole_test_11_26/api/routers.py
Executable file
@ -0,0 +1,477 @@
|
||||
"""Bulk operations router for mass machine operations."""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from core.middleware import get_current_user
|
||||
from core.models import (
|
||||
BulkHealthCheckRequest,
|
||||
BulkHealthCheckResponse,
|
||||
BulkHealthCheckResult,
|
||||
BulkSSHCommandRequest,
|
||||
BulkSSHCommandResponse,
|
||||
BulkSSHCommandResult,
|
||||
UserRole,
|
||||
)
|
||||
from core.permissions import PermissionChecker
|
||||
from core.saved_machines_db import saved_machines_db
|
||||
from core.audit_logger import immutable_audit_logger
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
security = HTTPBearer()
|
||||
|
||||
bulk_router = APIRouter(prefix="/api/bulk", tags=["Bulk Operations"])
|
||||
|
||||
|
||||
ROLE_HEALTH_CHECK_LIMITS = {
|
||||
UserRole.GUEST: 10,
|
||||
UserRole.USER: 50,
|
||||
UserRole.ADMIN: 200,
|
||||
UserRole.SUPER_ADMIN: 200,
|
||||
}
|
||||
|
||||
ROLE_SSH_COMMAND_LIMITS = {
|
||||
UserRole.GUEST: 0,
|
||||
UserRole.USER: 20,
|
||||
UserRole.ADMIN: 100,
|
||||
UserRole.SUPER_ADMIN: 100,
|
||||
}
|
||||
|
||||
|
||||
async def check_host_availability(
|
||||
hostname: str, port: int = 22, timeout: int = 5
|
||||
) -> tuple[bool, float | None, str | None]:
|
||||
"""Check if host is available via TCP connection."""
|
||||
start_time = time.time()
|
||||
try:
|
||||
reader, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(hostname, port), timeout=timeout
|
||||
)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
return True, response_time, None
|
||||
except asyncio.TimeoutError:
|
||||
return False, None, "Connection timeout"
|
||||
except socket.gaierror:
|
||||
return False, None, "DNS resolution failed"
|
||||
except ConnectionRefusedError:
|
||||
return False, None, "Connection refused"
|
||||
except Exception as e:
|
||||
return False, None, f"Connection error: {str(e)}"
|
||||
|
||||
|
||||
@bulk_router.post(
|
||||
"/health-check",
|
||||
response_model=BulkHealthCheckResponse,
|
||||
summary="Bulk health check",
|
||||
description="Check availability of multiple machines in parallel"
|
||||
)
|
||||
async def bulk_health_check(
|
||||
request_data: BulkHealthCheckRequest,
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
):
|
||||
"""Bulk machine availability check with role-based limits."""
|
||||
user_info = get_current_user(request)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
username = user_info["username"]
|
||||
user_role = UserRole(user_info["role"])
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
max_machines = ROLE_HEALTH_CHECK_LIMITS.get(user_role, 10)
|
||||
machine_count = len(request_data.machine_ids)
|
||||
|
||||
if machine_count > max_machines:
|
||||
logger.warning(
|
||||
"Bulk health check limit exceeded",
|
||||
username=username,
|
||||
role=user_role.value,
|
||||
requested=machine_count,
|
||||
limit=max_machines,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role {user_role.value} can check max {max_machines} machines at once",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bulk health check started",
|
||||
username=username,
|
||||
machine_count=machine_count,
|
||||
timeout=request_data.timeout,
|
||||
)
|
||||
|
||||
started_at = datetime.now(timezone.utc)
|
||||
start_time = time.time()
|
||||
|
||||
machines = []
|
||||
for machine_id in request_data.machine_ids:
|
||||
# Try to get from saved machines first (UUID format)
|
||||
try:
|
||||
UUID(machine_id)
|
||||
machine_dict = saved_machines_db.get_machine_by_id(machine_id, username)
|
||||
if machine_dict:
|
||||
# Convert dict to object with attributes for uniform access
|
||||
machine = SimpleNamespace(
|
||||
id=machine_dict['id'],
|
||||
name=machine_dict['name'],
|
||||
ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')),
|
||||
hostname=machine_dict.get('hostname', 'unknown'),
|
||||
)
|
||||
machines.append(machine)
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
# Not a UUID
|
||||
pass
|
||||
|
||||
logger.warning(
|
||||
"Machine not found or invalid UUID",
|
||||
username=username,
|
||||
machine_id=machine_id,
|
||||
)
|
||||
|
||||
async def check_machine(machine):
|
||||
checked_at = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
available, response_time, error = await check_host_availability(
|
||||
machine.ip, timeout=request_data.timeout
|
||||
)
|
||||
|
||||
return BulkHealthCheckResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="success" if available else "failed",
|
||||
available=available,
|
||||
response_time_ms=int(response_time) if response_time else None,
|
||||
error=error,
|
||||
checked_at=checked_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Health check error", machine_id=str(machine.id), error=str(e)
|
||||
)
|
||||
return BulkHealthCheckResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="failed",
|
||||
available=False,
|
||||
error=str(e),
|
||||
checked_at=checked_at,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*[check_machine(m) for m in machines])
|
||||
|
||||
completed_at = datetime.now(timezone.utc)
|
||||
execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
success_count = sum(1 for r in results if r.status == "success")
|
||||
failed_count = len(results) - success_count
|
||||
available_count = sum(1 for r in results if r.available)
|
||||
unavailable_count = len(results) - available_count
|
||||
|
||||
immutable_audit_logger.log_security_event(
|
||||
event_type="bulk_health_check",
|
||||
client_ip=client_ip,
|
||||
user_agent=request.headers.get("user-agent", "unknown"),
|
||||
details={
|
||||
"machine_count": len(results),
|
||||
"available": available_count,
|
||||
"unavailable": unavailable_count,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
},
|
||||
severity="info",
|
||||
username=username,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bulk health check completed",
|
||||
username=username,
|
||||
total=len(results),
|
||||
available=available_count,
|
||||
execution_time_ms=execution_time_ms,
|
||||
)
|
||||
|
||||
return BulkHealthCheckResponse(
|
||||
total=len(results),
|
||||
success=success_count,
|
||||
failed=failed_count,
|
||||
available=available_count,
|
||||
unavailable=unavailable_count,
|
||||
results=results,
|
||||
execution_time_ms=execution_time_ms,
|
||||
started_at=started_at.isoformat(),
|
||||
completed_at=completed_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@bulk_router.post(
|
||||
"/ssh-command",
|
||||
response_model=BulkSSHCommandResponse,
|
||||
summary="Bulk SSH command",
|
||||
description="Execute SSH commands on multiple machines in parallel"
|
||||
)
|
||||
async def bulk_ssh_command(
|
||||
request_data: BulkSSHCommandRequest,
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
):
|
||||
"""Bulk SSH command execution with role-based limits."""
|
||||
user_info = get_current_user(request)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
username = user_info["username"]
|
||||
user_role = UserRole(user_info["role"])
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
if user_role == UserRole.GUEST:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="GUEST role cannot execute SSH commands"
|
||||
)
|
||||
|
||||
max_machines = ROLE_SSH_COMMAND_LIMITS.get(user_role, 0)
|
||||
machine_count = len(request_data.machine_ids)
|
||||
|
||||
if machine_count > max_machines:
|
||||
logger.warning(
|
||||
"Bulk SSH command limit exceeded",
|
||||
username=username,
|
||||
role=user_role.value,
|
||||
requested=machine_count,
|
||||
limit=max_machines,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role {user_role.value} can execute commands on max {max_machines} machines at once",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bulk SSH command started",
|
||||
username=username,
|
||||
machine_count=machine_count,
|
||||
command=request_data.command[:50],
|
||||
mode=request_data.credentials_mode,
|
||||
)
|
||||
|
||||
started_at = datetime.now(timezone.utc)
|
||||
start_time = time.time()
|
||||
|
||||
machines = []
|
||||
for machine_id in request_data.machine_ids:
|
||||
# Try to get from saved machines first (UUID format)
|
||||
try:
|
||||
UUID(machine_id)
|
||||
machine_dict = saved_machines_db.get_machine_by_id(machine_id, username)
|
||||
if machine_dict:
|
||||
# Convert dict to object with attributes for uniform access
|
||||
machine = SimpleNamespace(
|
||||
id=machine_dict['id'],
|
||||
name=machine_dict['name'],
|
||||
ip=machine_dict.get('hostname', machine_dict.get('ip', 'unknown')),
|
||||
hostname=machine_dict.get('hostname', 'unknown'),
|
||||
)
|
||||
machines.append(machine)
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
# Not a UUID, check if hostname provided
|
||||
pass
|
||||
|
||||
# Check if hostname provided for non-saved machine (mock machines)
|
||||
if request_data.machine_hostnames and machine_id in request_data.machine_hostnames:
|
||||
hostname = request_data.machine_hostnames[machine_id]
|
||||
# Create mock machine object for non-saved machines
|
||||
mock_machine = SimpleNamespace(
|
||||
id=machine_id,
|
||||
name=f'Mock-{machine_id}',
|
||||
ip=hostname,
|
||||
hostname=hostname,
|
||||
)
|
||||
machines.append(mock_machine)
|
||||
logger.info(
|
||||
"Using non-saved machine (mock)",
|
||||
username=username,
|
||||
machine_id=machine_id,
|
||||
hostname=hostname,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.warning(
|
||||
"Machine not found and no hostname provided",
|
||||
username=username,
|
||||
machine_id=machine_id,
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
|
||||
async def execute_command(machine):
|
||||
async with semaphore:
|
||||
executed_at = datetime.now(timezone.utc).isoformat()
|
||||
cmd_start = time.time()
|
||||
|
||||
try:
|
||||
ssh_username = None
|
||||
ssh_password = None
|
||||
|
||||
if request_data.credentials_mode == "global":
|
||||
if not request_data.global_credentials:
|
||||
return BulkSSHCommandResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="no_credentials",
|
||||
error="Global credentials not provided",
|
||||
executed_at=executed_at,
|
||||
)
|
||||
ssh_username = request_data.global_credentials.username
|
||||
ssh_password = request_data.global_credentials.password
|
||||
|
||||
else: # custom mode
|
||||
if not request_data.machine_credentials or str(
|
||||
machine.id
|
||||
) not in request_data.machine_credentials:
|
||||
return BulkSSHCommandResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="no_credentials",
|
||||
error="Custom credentials not provided for this machine",
|
||||
executed_at=executed_at,
|
||||
)
|
||||
creds = request_data.machine_credentials[str(machine.id)]
|
||||
ssh_username = creds.username
|
||||
ssh_password = creds.password
|
||||
|
||||
if not ssh_username or not ssh_password:
|
||||
return BulkSSHCommandResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="no_credentials",
|
||||
error="Credentials missing",
|
||||
executed_at=executed_at,
|
||||
)
|
||||
|
||||
import paramiko
|
||||
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: ssh.connect(
|
||||
machine.ip,
|
||||
username=ssh_username,
|
||||
password=ssh_password,
|
||||
timeout=request_data.timeout,
|
||||
),
|
||||
),
|
||||
timeout=request_data.timeout,
|
||||
)
|
||||
|
||||
stdin, stdout, stderr = ssh.exec_command(request_data.command)
|
||||
stdout_text = stdout.read().decode("utf-8", errors="ignore")
|
||||
stderr_text = stderr.read().decode("utf-8", errors="ignore")
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
|
||||
ssh.close()
|
||||
|
||||
execution_time = int((time.time() - cmd_start) * 1000)
|
||||
|
||||
return BulkSSHCommandResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="success" if exit_code == 0 else "failed",
|
||||
exit_code=exit_code,
|
||||
stdout=stdout_text[:5000],
|
||||
stderr=stderr_text[:5000],
|
||||
execution_time_ms=execution_time,
|
||||
executed_at=executed_at,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return BulkSSHCommandResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="timeout",
|
||||
error="Command execution timeout",
|
||||
executed_at=executed_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"SSH command error",
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
return BulkSSHCommandResult(
|
||||
machine_id=str(machine.id),
|
||||
machine_name=machine.name,
|
||||
hostname=machine.ip,
|
||||
status="failed",
|
||||
error=str(e)[:500],
|
||||
executed_at=executed_at,
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*[execute_command(m) for m in machines])
|
||||
|
||||
completed_at = datetime.now(timezone.utc)
|
||||
execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
success_count = sum(1 for r in results if r.status == "success")
|
||||
failed_count = len(results) - success_count
|
||||
|
||||
immutable_audit_logger.log_security_event(
|
||||
event_type="bulk_ssh_command",
|
||||
client_ip=client_ip,
|
||||
user_agent=request.headers.get("user-agent", "unknown"),
|
||||
details={
|
||||
"machine_count": len(results),
|
||||
"command": request_data.command[:100],
|
||||
"credentials_mode": request_data.credentials_mode,
|
||||
"success": success_count,
|
||||
"failed": failed_count,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
},
|
||||
severity="high",
|
||||
username=username,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bulk SSH command completed",
|
||||
username=username,
|
||||
total=len(results),
|
||||
success=success_count,
|
||||
failed=failed_count,
|
||||
execution_time_ms=execution_time_ms,
|
||||
)
|
||||
|
||||
return BulkSSHCommandResponse(
|
||||
total=len(results),
|
||||
success=success_count,
|
||||
failed=failed_count,
|
||||
results=results,
|
||||
execution_time_ms=execution_time_ms,
|
||||
command=request_data.command,
|
||||
started_at=started_at.isoformat(),
|
||||
completed_at=completed_at.isoformat(),
|
||||
)
|
||||
|
||||
143
guacamole_test_11_26/api/security_config.py
Executable file
143
guacamole_test_11_26/api/security_config.py
Executable file
@ -0,0 +1,143 @@
|
||||
"""
|
||||
Security configuration for Remote Access API.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from core.models import UserRole
|
||||
from core.ssrf_protection import ssrf_protection
|
||||
|
||||
|
||||
class SecurityConfig:
|
||||
"""Security configuration for the system."""
|
||||
|
||||
MAX_TTL_MINUTES = int(os.getenv("MAX_TTL_MINUTES", "480"))
|
||||
|
||||
MAX_CONNECTIONS_PER_USER = int(os.getenv("MAX_CONNECTIONS_PER_USER", "5"))
|
||||
|
||||
BLOCKED_HOSTS = {
|
||||
"127.0.0.1",
|
||||
"localhost",
|
||||
"0.0.0.0",
|
||||
"::1",
|
||||
"169.254.169.254",
|
||||
"metadata.google.internal",
|
||||
}
|
||||
|
||||
BLOCKED_NETWORKS = [
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16",
|
||||
"224.0.0.0/4",
|
||||
"240.0.0.0/4",
|
||||
"172.17.0.0/16",
|
||||
"172.18.0.0/16",
|
||||
"172.19.0.0/16",
|
||||
"172.20.0.0/16",
|
||||
"172.21.0.0/16",
|
||||
"172.22.0.0/16",
|
||||
"172.23.0.0/16",
|
||||
"172.24.0.0/16",
|
||||
"172.25.0.0/16",
|
||||
"172.26.0.0/16",
|
||||
"172.27.0.0/16",
|
||||
"172.28.0.0/16",
|
||||
"172.29.0.0/16",
|
||||
"172.30.0.0/16",
|
||||
"172.31.0.0/16",
|
||||
]
|
||||
|
||||
ROLE_ALLOWED_NETWORKS = {
|
||||
UserRole.GUEST: [],
|
||||
UserRole.USER: [
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/16",
|
||||
"192.168.1.0/24",
|
||||
],
|
||||
UserRole.ADMIN: [
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/16",
|
||||
"192.168.0.0/16",
|
||||
"203.0.113.0/24",
|
||||
],
|
||||
UserRole.SUPER_ADMIN: [
|
||||
"0.0.0.0/0",
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def is_host_allowed(
|
||||
cls, hostname: str, user_role: UserRole
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if host is allowed for the given role with enhanced SSRF protection.
|
||||
|
||||
Args:
|
||||
hostname: IP address or hostname.
|
||||
user_role: User role.
|
||||
|
||||
Returns:
|
||||
Tuple (allowed: bool, reason: str).
|
||||
"""
|
||||
return ssrf_protection.validate_host(hostname, user_role.value)
|
||||
|
||||
@classmethod
|
||||
def validate_ttl(cls, ttl_minutes: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate connection TTL.
|
||||
|
||||
Args:
|
||||
ttl_minutes: Requested time-to-live in minutes.
|
||||
|
||||
Returns:
|
||||
Tuple (valid: bool, reason: str).
|
||||
"""
|
||||
if ttl_minutes <= 0:
|
||||
return False, "TTL must be positive"
|
||||
|
||||
if ttl_minutes > cls.MAX_TTL_MINUTES:
|
||||
return False, f"TTL cannot exceed {cls.MAX_TTL_MINUTES} minutes"
|
||||
|
||||
return True, "TTL is valid"
|
||||
|
||||
@classmethod
|
||||
def get_role_limits(cls, user_role: UserRole) -> Dict[str, Any]:
|
||||
"""
|
||||
Get limits for a role.
|
||||
|
||||
Args:
|
||||
user_role: User role.
|
||||
|
||||
Returns:
|
||||
Dictionary with limits.
|
||||
"""
|
||||
base_limits = {
|
||||
"max_ttl_minutes": cls.MAX_TTL_MINUTES,
|
||||
"max_connections": cls.MAX_CONNECTIONS_PER_USER,
|
||||
"allowed_networks": cls.ROLE_ALLOWED_NETWORKS.get(user_role, []),
|
||||
"can_create_connections": user_role != UserRole.GUEST,
|
||||
}
|
||||
|
||||
if user_role == UserRole.GUEST:
|
||||
base_limits.update(
|
||||
{
|
||||
"max_connections": 0,
|
||||
"max_ttl_minutes": 0,
|
||||
}
|
||||
)
|
||||
elif user_role == UserRole.USER:
|
||||
base_limits.update(
|
||||
{
|
||||
"max_connections": 3,
|
||||
"max_ttl_minutes": 240,
|
||||
}
|
||||
)
|
||||
elif user_role == UserRole.ADMIN:
|
||||
base_limits.update(
|
||||
{
|
||||
"max_connections": 10,
|
||||
"max_ttl_minutes": 480,
|
||||
}
|
||||
)
|
||||
|
||||
return base_limits
|
||||
5
guacamole_test_11_26/api/services/__init__.py
Executable file
5
guacamole_test_11_26/api/services/__init__.py
Executable file
@ -0,0 +1,5 @@
|
||||
"""Services package for system operations"""
|
||||
from .system_service import SystemService
|
||||
|
||||
__all__ = ['SystemService']
|
||||
|
||||
225
guacamole_test_11_26/api/services/system_service.py
Executable file
225
guacamole_test_11_26/api/services/system_service.py
Executable file
@ -0,0 +1,225 @@
|
||||
"""
|
||||
System Service Module
|
||||
|
||||
Provides system monitoring and health check functionality for the Remote Access API.
|
||||
Includes checks for database connectivity, daemon status, and system resources.
|
||||
"""
|
||||
|
||||
import socket
|
||||
import psutil
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class SystemService:
|
||||
"""Service for system health checks and monitoring"""
|
||||
|
||||
def __init__(self, service_start_time: Optional[datetime] = None):
|
||||
"""
|
||||
Initialize SystemService
|
||||
|
||||
Args:
|
||||
service_start_time: Service startup time for uptime calculation
|
||||
"""
|
||||
self.service_start_time = service_start_time or datetime.now()
|
||||
|
||||
@staticmethod
|
||||
def check_database_connection(guacamole_client: Any, guacamole_url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check Guacamole database connectivity
|
||||
|
||||
Args:
|
||||
guacamole_client: Guacamole client instance
|
||||
guacamole_url: Guacamole base URL
|
||||
|
||||
Returns:
|
||||
Status dictionary with connection state
|
||||
"""
|
||||
try:
|
||||
# Try to get system token (requires database access)
|
||||
token = guacamole_client.get_system_token()
|
||||
|
||||
if token:
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": "Database connection healthy"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Failed to obtain system token"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database connection check failed", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Database connection failed"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def check_guacd_daemon() -> Dict[str, Any]:
|
||||
"""
|
||||
Check if guacd daemon is running
|
||||
|
||||
Returns:
|
||||
Status dictionary with daemon state
|
||||
"""
|
||||
try:
|
||||
# Check if guacd is listening on default port 4822
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(2)
|
||||
result = sock.connect_ex(('localhost', 4822))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": "guacd daemon is running",
|
||||
"port": 4822
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "guacd daemon is not accessible",
|
||||
"port": 4822
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("guacd daemon check failed", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to check guacd daemon"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def check_system_resources() -> Dict[str, Any]:
|
||||
"""
|
||||
Check system resources (CPU, RAM, Disk)
|
||||
|
||||
Returns:
|
||||
Status dictionary with resource usage
|
||||
"""
|
||||
try:
|
||||
# CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
|
||||
# Memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
memory_percent = memory.percent
|
||||
|
||||
# Disk usage
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_percent = disk.percent
|
||||
|
||||
# Determine overall status based on thresholds
|
||||
status = "ok"
|
||||
warnings = []
|
||||
|
||||
if cpu_percent > 90:
|
||||
status = "critical"
|
||||
warnings.append(f"CPU usage critical: {cpu_percent}%")
|
||||
elif cpu_percent > 80:
|
||||
status = "warning"
|
||||
warnings.append(f"CPU usage high: {cpu_percent}%")
|
||||
|
||||
if memory_percent > 90:
|
||||
status = "critical"
|
||||
warnings.append(f"Memory usage critical: {memory_percent}%")
|
||||
elif memory_percent > 80:
|
||||
if status == "ok":
|
||||
status = "warning"
|
||||
warnings.append(f"Memory usage high: {memory_percent}%")
|
||||
|
||||
if disk_percent > 90:
|
||||
status = "critical"
|
||||
warnings.append(f"Disk usage critical: {disk_percent}%")
|
||||
elif disk_percent > 80:
|
||||
if status == "ok":
|
||||
status = "warning"
|
||||
warnings.append(f"Disk usage high: {disk_percent}%")
|
||||
|
||||
result = {
|
||||
"status": status,
|
||||
"cpu_percent": round(cpu_percent, 2),
|
||||
"memory_percent": round(memory_percent, 2),
|
||||
"disk_percent": round(disk_percent, 2),
|
||||
"memory_available_gb": round(memory.available / (1024**3), 2),
|
||||
"disk_free_gb": round(disk.free / (1024**3), 2)
|
||||
}
|
||||
|
||||
if warnings:
|
||||
result["warnings"] = warnings
|
||||
|
||||
if status == "ok":
|
||||
result["message"] = "System resources healthy"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("System resources check failed", error=str(e))
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": "Failed to check system resources"
|
||||
}
|
||||
|
||||
def get_system_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get system information (uptime, version, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary with system information
|
||||
"""
|
||||
try:
|
||||
uptime_seconds = int((datetime.now() - self.service_start_time).total_seconds())
|
||||
|
||||
return {
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"uptime_formatted": self._format_uptime(uptime_seconds),
|
||||
"python_version": f"{psutil.PROCFS_PATH if hasattr(psutil, 'PROCFS_PATH') else 'N/A'}",
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"boot_time": datetime.fromtimestamp(psutil.boot_time()).isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get system info", error=str(e))
|
||||
return {
|
||||
"error": str(e),
|
||||
"message": "Failed to retrieve system information"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_uptime(seconds: int) -> str:
|
||||
"""
|
||||
Format uptime seconds to human-readable string
|
||||
|
||||
Args:
|
||||
seconds: Uptime in seconds
|
||||
|
||||
Returns:
|
||||
Formatted uptime string
|
||||
"""
|
||||
days = seconds // 86400
|
||||
hours = (seconds % 86400) // 3600
|
||||
minutes = (seconds % 3600) // 60
|
||||
secs = seconds % 60
|
||||
|
||||
parts = []
|
||||
if days > 0:
|
||||
parts.append(f"{days}d")
|
||||
if hours > 0:
|
||||
parts.append(f"{hours}h")
|
||||
if minutes > 0:
|
||||
parts.append(f"{minutes}m")
|
||||
if secs > 0 or not parts:
|
||||
parts.append(f"{secs}s")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user