init_guac

This commit is contained in:
root
2025-11-25 09:58:37 +03:00
parent 68c8f0e80d
commit 9d5bdd57a7
57 changed files with 18272 additions and 0 deletions

View File

@ -0,0 +1,35 @@
"""
Core module for Remote Access API.
Provides:
- Authentication and authorization (JWT, Guacamole integration)
- Security features (CSRF, rate limiting, brute force protection)
- Storage and session management (Redis, PostgreSQL)
- Audit logging and WebSocket notifications
- Role and permission system
"""
from .guacamole_auth import GuacamoleAuthenticator
from .models import (
ConnectionRequest,
ConnectionResponse,
LoginRequest,
LoginResponse,
UserInfo,
UserRole,
)
from .permissions import PermissionChecker
from .utils import create_jwt_token, verify_jwt_token
__all__ = [
"ConnectionRequest",
"ConnectionResponse",
"GuacamoleAuthenticator",
"LoginRequest",
"LoginResponse",
"PermissionChecker",
"UserInfo",
"UserRole",
"create_jwt_token",
"verify_jwt_token",
]

View File

@ -0,0 +1,380 @@
"""
Immutable audit logging with HMAC signatures.
"""
import hashlib
import hmac
import json
import os
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Union
import structlog
logger = structlog.get_logger(__name__)
class ImmutableAuditLogger:
"""Immutable audit logger with HMAC signatures to prevent log tampering."""
def __init__(self) -> None:
"""Initialize the immutable audit logger."""
self.hmac_secret = os.getenv(
"AUDIT_HMAC_SECRET", "default_audit_secret_change_me"
)
log_path_str = os.getenv(
"AUDIT_LOG_PATH", "/var/log/remote_access_audit.log"
)
self.audit_log_path = Path(log_path_str)
self.audit_log_path.parent.mkdir(parents=True, exist_ok=True)
self.audit_logger = structlog.get_logger("audit")
logger.info(
"Immutable audit logger initialized",
audit_log_path=str(self.audit_log_path),
)
def _generate_hmac_signature(self, data: str) -> str:
"""
Generate HMAC signature for data.
Args:
data: Data to sign.
Returns:
HMAC signature in hex format.
"""
return hmac.new(
self.hmac_secret.encode("utf-8"),
data.encode("utf-8"),
hashlib.sha256,
).hexdigest()
def _verify_hmac_signature(self, data: str, signature: str) -> bool:
"""
Verify HMAC signature.
Args:
data: Data to verify.
signature: Signature to verify.
Returns:
True if signature is valid.
"""
expected_signature = self._generate_hmac_signature(data)
return hmac.compare_digest(expected_signature, signature)
def log_security_event(
self,
event_type: str,
client_ip: str,
user_agent: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
severity: str = "info",
username: Optional[str] = None,
) -> bool:
"""
Log security event with immutable record.
Args:
event_type: Event type.
client_ip: Client IP address.
user_agent: Client user agent.
details: Additional details.
severity: Severity level.
username: Username if applicable.
Returns:
True if logging succeeded.
"""
try:
event_data = {
"event_type": "security_event",
"security_event_type": event_type,
"timestamp": datetime.now(timezone.utc).isoformat(),
"client_ip": client_ip,
"user_agent": user_agent or "unknown",
"severity": severity,
"username": username,
"details": details or {},
}
return self._write_immutable_log(event_data)
except Exception as e:
logger.error("Failed to log security event", error=str(e))
return False
def log_audit_event(
self,
action: str,
resource: str,
client_ip: str,
user_agent: Optional[str] = None,
result: str = "success",
details: Optional[Dict[str, Any]] = None,
username: Optional[str] = None,
) -> bool:
"""
Log audit event with immutable record.
Args:
action: Action performed.
resource: Resource affected.
client_ip: Client IP address.
user_agent: Client user agent.
result: Action result.
details: Additional details.
username: Username.
Returns:
True if logging succeeded.
"""
try:
event_data = {
"event_type": "audit_event",
"action": action,
"resource": resource,
"timestamp": datetime.now(timezone.utc).isoformat(),
"client_ip": client_ip,
"user_agent": user_agent or "unknown",
"result": result,
"username": username,
"details": details or {},
}
return self._write_immutable_log(event_data)
except Exception as e:
logger.error("Failed to log audit event", error=str(e))
return False
def log_authentication_event(
self,
event_type: str,
username: str,
client_ip: str,
success: bool,
details: Optional[Dict[str, Any]] = None,
) -> bool:
"""
Log authentication event.
Args:
event_type: Event type (login, logout, failed_login, etc.).
username: Username.
client_ip: Client IP address.
success: Operation success status.
details: Additional details.
Returns:
True if logging succeeded.
"""
try:
event_data = {
"event_type": "authentication_event",
"auth_event_type": event_type,
"timestamp": datetime.now(timezone.utc).isoformat(),
"username": username,
"client_ip": client_ip,
"success": success,
"details": details or {},
}
return self._write_immutable_log(event_data)
except Exception as e:
logger.error("Failed to log authentication event", error=str(e))
return False
def log_connection_event(
self,
event_type: str,
connection_id: str,
username: str,
client_ip: str,
details: Optional[Dict[str, Any]] = None,
) -> bool:
"""
Log connection event.
Args:
event_type: Event type (created, deleted, expired, etc.).
connection_id: Connection ID.
username: Username.
client_ip: Client IP address.
details: Additional details.
Returns:
True if logging succeeded.
"""
try:
event_data = {
"event_type": "connection_event",
"connection_event_type": event_type,
"timestamp": datetime.now(timezone.utc).isoformat(),
"connection_id": connection_id,
"username": username,
"client_ip": client_ip,
"details": details or {},
}
return self._write_immutable_log(event_data)
except Exception as e:
logger.error("Failed to log connection event", error=str(e))
return False
def _write_immutable_log(self, event_data: Dict[str, Any]) -> bool:
"""
Write immutable log entry with HMAC signature.
Args:
event_data: Event data.
Returns:
True if write succeeded.
"""
try:
json_data = json.dumps(event_data, ensure_ascii=False, sort_keys=True)
signature = self._generate_hmac_signature(json_data)
log_entry = {
"data": event_data,
"signature": signature,
"log_timestamp": datetime.now(timezone.utc).isoformat(),
}
with self.audit_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
f.flush()
self.audit_logger.info(
"Audit event logged",
event_type=event_data.get("event_type"),
signature=signature[:16] + "...",
)
return True
except Exception as e:
logger.error("Failed to write immutable log", error=str(e))
return False
def verify_log_integrity(
self, log_file_path: Optional[Union[str, Path]] = None
) -> Dict[str, Any]:
"""
Verify audit log integrity.
Args:
log_file_path: Path to log file (defaults to main log file).
Returns:
Integrity verification result.
"""
try:
file_path = (
Path(log_file_path) if log_file_path else self.audit_log_path
)
if not file_path.exists():
return {
"status": "error",
"message": "Log file does not exist",
"file_path": str(file_path),
}
valid_entries = 0
invalid_entries = 0
total_entries = 0
with file_path.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
total_entries += 1
try:
log_entry = json.loads(line)
if "data" not in log_entry or "signature" not in log_entry:
invalid_entries += 1
continue
json_data = json.dumps(
log_entry["data"], ensure_ascii=False, sort_keys=True
)
if self._verify_hmac_signature(
json_data, log_entry["signature"]
):
valid_entries += 1
else:
invalid_entries += 1
except (json.JSONDecodeError, KeyError, ValueError):
invalid_entries += 1
return {
"status": "success",
"file_path": str(file_path),
"total_entries": total_entries,
"valid_entries": valid_entries,
"invalid_entries": invalid_entries,
"integrity_percentage": (
(valid_entries / total_entries * 100) if total_entries > 0 else 0
),
}
except Exception as e:
logger.error("Failed to verify log integrity", error=str(e))
return {
"status": "error",
"message": str(e),
"file_path": str(log_file_path or self.audit_log_path),
}
def get_audit_stats(self) -> Dict[str, Any]:
"""
Get audit log statistics.
Returns:
Audit log statistics.
"""
try:
if not self.audit_log_path.exists():
return {
"status": "no_log_file",
"file_path": str(self.audit_log_path),
}
file_size = self.audit_log_path.stat().st_size
event_types: Counter[str] = Counter()
total_entries = 0
with self.audit_log_path.open("r", encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
try:
log_entry = json.loads(line)
if (
"data" in log_entry
and "event_type" in log_entry["data"]
):
event_type = log_entry["data"]["event_type"]
event_types[event_type] += 1
total_entries += 1
except (json.JSONDecodeError, KeyError):
continue
return {
"status": "success",
"file_path": str(self.audit_log_path),
"file_size_bytes": file_size,
"total_entries": total_entries,
"event_types": dict(event_types),
"hmac_secret_configured": bool(
self.hmac_secret
and self.hmac_secret != "default_audit_secret_change_me"
),
}
except Exception as e:
logger.error("Failed to get audit stats", error=str(e))
return {"status": "error", "message": str(e)}
# Global instance for use in API
immutable_audit_logger = ImmutableAuditLogger()

View File

@ -0,0 +1,327 @@
"""Brute-force protection for login endpoint."""
from typing import Any, Dict, Tuple
import structlog
from .rate_limiter import redis_rate_limiter
logger = structlog.get_logger(__name__)
# Backoff constants
MAX_BACKOFF_SECONDS = 300
MIN_FAILED_ATTEMPTS_FOR_BACKOFF = 2
EXPONENTIAL_BACKOFF_BASE = 2
# Default limits
DEFAULT_MAX_LOGIN_ATTEMPTS_PER_IP = 5
DEFAULT_MAX_LOGIN_ATTEMPTS_PER_USER = 10
DEFAULT_LOGIN_WINDOW_MINUTES = 15
DEFAULT_USER_LOCKOUT_MINUTES = 60
# Block types
BLOCK_TYPE_RATE_LIMIT = "rate_limit"
BLOCK_TYPE_IP_BLOCKED = "ip_blocked"
BLOCK_TYPE_USER_LOCKED = "user_locked"
BLOCK_TYPE_EXPONENTIAL_BACKOFF = "exponential_backoff"
BLOCK_TYPE_ALLOWED = "allowed"
BLOCK_TYPE_ERROR_FALLBACK = "error_fallback"
# Response messages
MSG_RATE_LIMIT_EXCEEDED = "Rate limit exceeded"
MSG_IP_BLOCKED = "Too many failed attempts from this IP"
MSG_USER_LOCKED = "User account temporarily locked"
MSG_LOGIN_ALLOWED = "Login allowed"
MSG_LOGIN_ALLOWED_ERROR = "Login allowed (protection error)"
# Default failure reason
DEFAULT_FAILURE_REASON = "invalid_credentials"
# Empty string for clearing
EMPTY_USERNAME = ""
EMPTY_IP = ""
class BruteForceProtection:
"""Protection against brute-force attacks on login endpoint."""
def __init__(self) -> None:
"""Initialize brute-force protection."""
self.max_login_attempts_per_ip = DEFAULT_MAX_LOGIN_ATTEMPTS_PER_IP
self.max_login_attempts_per_user = DEFAULT_MAX_LOGIN_ATTEMPTS_PER_USER
self.login_window_minutes = DEFAULT_LOGIN_WINDOW_MINUTES
self.user_lockout_minutes = DEFAULT_USER_LOCKOUT_MINUTES
self.exponential_backoff_base = EXPONENTIAL_BACKOFF_BASE
def check_login_allowed(
self, client_ip: str, username: str
) -> Tuple[bool, str, Dict[str, Any]]:
"""
Check if login is allowed for given IP and user.
Args:
client_ip: Client IP address.
username: Username.
Returns:
Tuple of (allowed: bool, reason: str, details: Dict[str, Any]).
"""
try:
allowed, headers = redis_rate_limiter.check_login_rate_limit(
client_ip, username
)
if not allowed:
return (
False,
MSG_RATE_LIMIT_EXCEEDED,
{
"type": BLOCK_TYPE_RATE_LIMIT,
"client_ip": client_ip,
"username": username,
"headers": headers,
},
)
failed_counts = redis_rate_limiter.get_failed_login_count(
client_ip, username, self.user_lockout_minutes
)
if failed_counts["ip_failed_count"] >= self.max_login_attempts_per_ip:
return (
False,
MSG_IP_BLOCKED,
{
"type": BLOCK_TYPE_IP_BLOCKED,
"client_ip": client_ip,
"failed_count": failed_counts["ip_failed_count"],
"max_attempts": self.max_login_attempts_per_ip,
"window_minutes": self.login_window_minutes,
},
)
if failed_counts["user_failed_count"] >= self.max_login_attempts_per_user:
return (
False,
MSG_USER_LOCKED,
{
"type": BLOCK_TYPE_USER_LOCKED,
"username": username,
"failed_count": failed_counts["user_failed_count"],
"max_attempts": self.max_login_attempts_per_user,
"lockout_minutes": self.user_lockout_minutes,
},
)
backoff_seconds = self._calculate_backoff_time(
client_ip, username, failed_counts
)
if backoff_seconds > 0:
return (
False,
f"Please wait {backoff_seconds} seconds before next attempt",
{
"type": BLOCK_TYPE_EXPONENTIAL_BACKOFF,
"wait_seconds": backoff_seconds,
"client_ip": client_ip,
"username": username,
},
)
return (
True,
MSG_LOGIN_ALLOWED,
{
"type": BLOCK_TYPE_ALLOWED,
"client_ip": client_ip,
"username": username,
"failed_counts": failed_counts,
},
)
except Exception as e:
logger.error(
"Error checking login permission",
client_ip=client_ip,
username=username,
error=str(e),
)
return (
True,
MSG_LOGIN_ALLOWED_ERROR,
{"type": BLOCK_TYPE_ERROR_FALLBACK, "error": str(e)},
)
def record_failed_login(
self,
client_ip: str,
username: str,
failure_reason: str = DEFAULT_FAILURE_REASON,
) -> None:
"""
Record failed login attempt.
Args:
client_ip: Client IP address.
username: Username.
failure_reason: Failure reason.
"""
try:
redis_rate_limiter.record_failed_login(client_ip, username)
logger.warning(
"Failed login attempt recorded",
client_ip=client_ip,
username=username,
failure_reason=failure_reason,
)
except Exception as e:
logger.error(
"Failed to record failed login attempt",
client_ip=client_ip,
username=username,
error=str(e),
)
def record_successful_login(self, client_ip: str, username: str) -> None:
"""
Record successful login (clear failed attempts).
Args:
client_ip: Client IP address.
username: Username.
"""
try:
redis_rate_limiter.clear_failed_logins(client_ip, username)
logger.info(
"Successful login recorded, failed attempts cleared",
client_ip=client_ip,
username=username,
)
except Exception as e:
logger.error(
"Failed to record successful login",
client_ip=client_ip,
username=username,
error=str(e),
)
def _calculate_backoff_time(
self, client_ip: str, username: str, failed_counts: Dict[str, int]
) -> int:
"""
Calculate wait time for exponential backoff.
Args:
client_ip: Client IP address.
username: Username.
failed_counts: Failed attempt counts.
Returns:
Wait time in seconds.
"""
try:
max_failed = max(
failed_counts["ip_failed_count"],
failed_counts["user_failed_count"],
)
if max_failed <= MIN_FAILED_ATTEMPTS_FOR_BACKOFF:
return 0
backoff_seconds = min(
self.exponential_backoff_base
** (max_failed - MIN_FAILED_ATTEMPTS_FOR_BACKOFF),
MAX_BACKOFF_SECONDS,
)
return backoff_seconds
except Exception as e:
logger.error("Error calculating backoff time", error=str(e))
return 0
def get_protection_stats(self) -> Dict[str, Any]:
"""
Get brute-force protection statistics.
Returns:
Protection statistics dictionary.
"""
try:
rate_limit_stats = redis_rate_limiter.get_rate_limit_stats()
return {
"max_login_attempts_per_ip": self.max_login_attempts_per_ip,
"max_login_attempts_per_user": self.max_login_attempts_per_user,
"login_window_minutes": self.login_window_minutes,
"user_lockout_minutes": self.user_lockout_minutes,
"exponential_backoff_base": self.exponential_backoff_base,
"rate_limit_stats": rate_limit_stats,
}
except Exception as e:
logger.error("Failed to get protection stats", error=str(e))
return {"error": str(e)}
def force_unlock_user(self, username: str, unlocked_by: str) -> bool:
"""
Force unlock user (for administrators).
Args:
username: Username to unlock.
unlocked_by: Who unlocked the user.
Returns:
True if unlock successful.
"""
try:
redis_rate_limiter.clear_failed_logins(EMPTY_IP, username)
logger.info("User force unlocked", username=username, unlocked_by=unlocked_by)
return True
except Exception as e:
logger.error(
"Failed to force unlock user",
username=username,
unlocked_by=unlocked_by,
error=str(e),
)
return False
def force_unlock_ip(self, client_ip: str, unlocked_by: str) -> bool:
"""
Force unlock IP (for administrators).
Args:
client_ip: IP address to unlock.
unlocked_by: Who unlocked the IP.
Returns:
True if unlock successful.
"""
try:
redis_rate_limiter.clear_failed_logins(client_ip, EMPTY_USERNAME)
logger.info(
"IP force unlocked", client_ip=client_ip, unlocked_by=unlocked_by
)
return True
except Exception as e:
logger.error(
"Failed to force unlock IP",
client_ip=client_ip,
unlocked_by=unlocked_by,
error=str(e),
)
return False
brute_force_protection = BruteForceProtection()

View File

@ -0,0 +1,361 @@
"""CSRF protection using Double Submit Cookie pattern."""
import hashlib
import json
import os
import secrets
import time
from datetime import datetime, timedelta
from typing import Any, Dict, FrozenSet
import redis
import structlog
logger = structlog.get_logger(__name__)
# Redis configuration
REDIS_DEFAULT_HOST = "redis"
REDIS_DEFAULT_PORT = "6379"
REDIS_DEFAULT_DB = 0
# Token configuration
REDIS_KEY_PREFIX = "csrf:token:"
CSRF_TOKEN_TTL_SECONDS = 3600
TOKEN_SIZE_BYTES = 32
SECRET_KEY_SIZE_BYTES = 32
TOKEN_PARTS_COUNT = 3
TOKEN_PREVIEW_LENGTH = 16
SCAN_BATCH_SIZE = 100
# Redis TTL special values
TTL_KEY_NOT_EXISTS = -2
TTL_KEY_NO_EXPIRY = -1
# Protected HTTP methods
PROTECTED_METHODS: FrozenSet[str] = frozenset({"POST", "PUT", "DELETE", "PATCH"})
# Excluded endpoints (no CSRF protection)
EXCLUDED_ENDPOINTS: FrozenSet[str] = frozenset({
"/auth/login",
"/health",
"/health/detailed",
"/health/ready",
"/health/live",
"/health/routing",
"/metrics",
"/docs",
"/openapi.json",
})
class CSRFProtection:
"""CSRF protection with Double Submit Cookie pattern."""
def __init__(self) -> None:
"""
Initialize CSRF protection.
Raises:
RuntimeError: If Redis connection fails.
"""
self._redis_client = redis.Redis(
host=os.getenv("REDIS_HOST", REDIS_DEFAULT_HOST),
port=int(os.getenv("REDIS_PORT", REDIS_DEFAULT_PORT)),
password=os.getenv("REDIS_PASSWORD"),
db=REDIS_DEFAULT_DB,
decode_responses=True,
)
self._csrf_token_ttl = CSRF_TOKEN_TTL_SECONDS
self._token_size = TOKEN_SIZE_BYTES
self._secret_key = secrets.token_bytes(SECRET_KEY_SIZE_BYTES)
try:
self._redis_client.ping()
logger.info("CSRF Protection connected to Redis successfully")
except Exception as e:
logger.error("Failed to connect to Redis for CSRF", error=str(e))
raise RuntimeError(f"Redis connection failed: {e}") from e
self._protected_methods: FrozenSet[str] = PROTECTED_METHODS
self._excluded_endpoints: FrozenSet[str] = EXCLUDED_ENDPOINTS
def generate_csrf_token(self, user_id: str) -> str:
"""
Generate CSRF token for user.
Args:
user_id: User ID.
Returns:
CSRF token.
"""
try:
random_bytes = secrets.token_bytes(self._token_size)
timestamp = str(int(time.time()))
data_to_sign = f"{user_id}:{timestamp}:{random_bytes.hex()}"
signature = hashlib.sha256(
f"{data_to_sign}:{self._secret_key.hex()}".encode()
).hexdigest()
csrf_token = f"{random_bytes.hex()}:{timestamp}:{signature}"
now = datetime.now()
expires_at = now + timedelta(seconds=self._csrf_token_ttl)
token_data = {
"user_id": user_id,
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"used": False,
}
redis_key = f"{REDIS_KEY_PREFIX}{csrf_token}"
self._redis_client.setex(
redis_key, self._csrf_token_ttl, json.dumps(token_data)
)
logger.debug(
"CSRF token generated in Redis",
user_id=user_id,
token_preview=csrf_token[:TOKEN_PREVIEW_LENGTH] + "...",
expires_at=expires_at.isoformat(),
)
return csrf_token
except Exception as e:
logger.error("Failed to generate CSRF token", user_id=user_id, error=str(e))
raise
def validate_csrf_token(self, token: str, user_id: str) -> bool:
"""
Validate CSRF token.
Args:
token: CSRF token.
user_id: User ID.
Returns:
True if token is valid.
"""
try:
if not token or not user_id:
return False
redis_key = f"{REDIS_KEY_PREFIX}{token}"
token_json = self._redis_client.get(redis_key)
if not token_json:
logger.warning(
"CSRF token not found in Redis",
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
user_id=user_id,
)
return False
token_data = json.loads(token_json)
expires_at = datetime.fromisoformat(token_data["expires_at"])
if datetime.now() > expires_at:
logger.warning(
"CSRF token expired",
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
user_id=user_id,
)
self._redis_client.delete(redis_key)
return False
if token_data["user_id"] != user_id:
logger.warning(
"CSRF token user mismatch",
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
expected_user=user_id,
actual_user=token_data["user_id"],
)
return False
if not self._verify_token_signature(token, user_id):
logger.warning(
"CSRF token signature invalid",
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
user_id=user_id,
)
self._redis_client.delete(redis_key)
return False
token_data["used"] = True
ttl = self._redis_client.ttl(redis_key)
if ttl > 0:
self._redis_client.setex(redis_key, ttl, json.dumps(token_data))
logger.debug(
"CSRF token validated successfully",
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "...",
user_id=user_id,
)
return True
except Exception as e:
logger.error(
"Error validating CSRF token",
token_preview=token[:TOKEN_PREVIEW_LENGTH] + "..." if token else "none",
user_id=user_id,
error=str(e),
)
return False
def _verify_token_signature(self, token: str, user_id: str) -> bool:
"""
Verify token signature.
Args:
token: CSRF token.
user_id: User ID.
Returns:
True if signature is valid.
"""
try:
parts = token.split(":")
if len(parts) != TOKEN_PARTS_COUNT:
return False
random_hex, timestamp, signature = parts
data_to_sign = f"{user_id}:{timestamp}:{random_hex}"
expected_signature = hashlib.sha256(
f"{data_to_sign}:{self._secret_key.hex()}".encode()
).hexdigest()
return signature == expected_signature
except Exception:
return False
def should_protect_endpoint(self, method: str, path: str) -> bool:
"""
Check if endpoint needs CSRF protection.
Args:
method: HTTP method.
path: Endpoint path.
Returns:
True if CSRF protection is needed.
"""
if method not in self._protected_methods:
return False
if path in self._excluded_endpoints:
return False
for excluded_path in self._excluded_endpoints:
if path.startswith(excluded_path):
return False
return True
def cleanup_expired_tokens(self) -> None:
"""
Clean up expired CSRF tokens from Redis.
Note: Redis automatically removes keys with expired TTL.
"""
try:
pattern = f"{REDIS_KEY_PREFIX}*"
keys = list(self._redis_client.scan_iter(match=pattern, count=SCAN_BATCH_SIZE))
cleaned_count = 0
for key in keys:
ttl = self._redis_client.ttl(key)
if ttl == TTL_KEY_NOT_EXISTS:
cleaned_count += 1
elif ttl == TTL_KEY_NO_EXPIRY:
self._redis_client.delete(key)
cleaned_count += 1
if cleaned_count > 0:
logger.info(
"CSRF tokens cleanup completed",
cleaned_count=cleaned_count,
remaining_count=len(keys) - cleaned_count,
)
except Exception as e:
logger.error("Failed to cleanup expired CSRF tokens", error=str(e))
def get_csrf_stats(self) -> Dict[str, Any]:
"""
Get CSRF token statistics from Redis.
Returns:
Dictionary with CSRF statistics.
"""
try:
pattern = f"{REDIS_KEY_PREFIX}*"
keys = list(self._redis_client.scan_iter(match=pattern, count=SCAN_BATCH_SIZE))
active_tokens = 0
used_tokens = 0
for key in keys:
try:
token_json = self._redis_client.get(key)
if token_json:
token_data = json.loads(token_json)
active_tokens += 1
if token_data.get("used", False):
used_tokens += 1
except Exception:
continue
return {
"total_tokens": len(keys),
"active_tokens": active_tokens,
"used_tokens": used_tokens,
"token_ttl_seconds": self._csrf_token_ttl,
"protected_methods": sorted(self._protected_methods),
"excluded_endpoints": sorted(self._excluded_endpoints),
"storage": "Redis",
}
except Exception as e:
logger.error("Failed to get CSRF stats", error=str(e))
return {"error": str(e), "storage": "Redis"}
def revoke_user_tokens(self, user_id: str) -> None:
"""
Revoke all user tokens from Redis.
Args:
user_id: User ID.
"""
try:
pattern = f"{REDIS_KEY_PREFIX}*"
keys = list(self._redis_client.scan_iter(match=pattern, count=SCAN_BATCH_SIZE))
revoked_count = 0
for key in keys:
try:
token_json = self._redis_client.get(key)
if token_json:
token_data = json.loads(token_json)
if token_data.get("user_id") == user_id:
self._redis_client.delete(key)
revoked_count += 1
except Exception:
continue
if revoked_count > 0:
logger.info(
"Revoked user CSRF tokens from Redis",
user_id=user_id,
count=revoked_count,
)
except Exception as e:
logger.error(
"Failed to revoke user CSRF tokens", user_id=user_id, error=str(e)
)
csrf_protection = CSRFProtection()

View File

@ -0,0 +1,485 @@
"""Integration with Guacamole API for authentication and user management."""
import os
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import requests
import structlog
from .models import UserRole
from .permissions import PermissionChecker
from .session_storage import session_storage
from .utils import create_jwt_token
logger = structlog.get_logger(__name__)
class GuacamoleAuthenticator:
"""Class for authentication via Guacamole API."""
def __init__(self) -> None:
"""
Initialize Guacamole authenticator.
Raises:
ValueError: If system credentials are not set in environment variables.
"""
self.base_url = os.getenv("GUACAMOLE_URL", "http://guacamole:8080")
self.session = requests.Session()
self._system_token: Optional[str] = None
self._system_token_expires: Optional[datetime] = None
self._system_username = os.getenv("SYSTEM_ADMIN_USERNAME")
self._system_password = os.getenv("SYSTEM_ADMIN_PASSWORD")
if not self._system_username or not self._system_password:
raise ValueError(
"SYSTEM_ADMIN_USERNAME and SYSTEM_ADMIN_PASSWORD environment "
"variables are required. Set these in your .env or "
"production.env file for security. Never use default "
"credentials in production!"
)
def get_system_token(self) -> str:
"""
Get system user token for administrative operations.
Returns:
System user token.
Raises:
RuntimeError: If system user authentication fails.
"""
if (
self._system_token is None
or self._system_token_expires is None
or self._system_token_expires <= datetime.now()
):
logger.debug("Refreshing system token", username=self._system_username)
auth_url = f"{self.base_url}/guacamole/api/tokens"
auth_data = {
"username": self._system_username,
"password": self._system_password,
}
try:
response = self.session.post(auth_url, data=auth_data, timeout=10)
response.raise_for_status()
auth_result = response.json()
self._system_token = auth_result.get("authToken")
if not self._system_token:
raise RuntimeError("No authToken in response")
self._system_token_expires = datetime.now() + timedelta(hours=7)
logger.info(
"System token refreshed successfully",
username=self._system_username,
expires_at=self._system_token_expires.isoformat(),
)
except requests.exceptions.RequestException as e:
logger.error(
"Failed to authenticate system user",
username=self._system_username,
error=str(e),
)
raise RuntimeError(
f"Failed to authenticate system user: {e}"
) from e
except Exception as e:
logger.error(
"Unexpected error during system authentication",
username=self._system_username,
error=str(e),
)
raise
return self._system_token
def authenticate_user(
self, username: str, password: str
) -> Optional[Dict[str, Any]]:
"""
Authenticate user via Guacamole API.
Args:
username: Username in Guacamole.
password: Password in Guacamole.
Returns:
Dictionary with user information or None if authentication fails.
"""
auth_url = f"{self.base_url}/guacamole/api/tokens"
auth_data = {"username": username, "password": password}
try:
logger.debug("Attempting user authentication", username=username)
response = self.session.post(auth_url, data=auth_data, timeout=10)
if response.status_code != 200:
logger.info(
"Authentication failed",
username=username,
status_code=response.status_code,
response=response.text[:200],
)
return None
auth_result = response.json()
auth_token = auth_result.get("authToken")
if not auth_token:
logger.warning(
"No authToken in successful response",
username=username,
response=auth_result,
)
return None
user_info = self.get_user_info(auth_token)
if not user_info:
logger.warning(
"Failed to get user info after authentication", username=username
)
return None
system_permissions = user_info.get("systemPermissions", [])
user_role = PermissionChecker.determine_role_from_permissions(
system_permissions
)
result = {
"username": username,
"auth_token": auth_token,
"role": user_role.value,
"permissions": system_permissions,
"full_name": user_info.get("fullName"),
"email": user_info.get("emailAddress"),
"organization": user_info.get("organization"),
"organizational_role": user_info.get("organizationalRole"),
}
logger.info(
"User authenticated successfully",
username=username,
role=user_role.value,
permissions_count=len(system_permissions),
)
return result
except requests.exceptions.RequestException as e:
logger.error(
"Network error during authentication", username=username, error=str(e)
)
return None
except Exception as e:
logger.error(
"Unexpected error during authentication", username=username, error=str(e)
)
return None
def get_user_info(self, auth_token: str) -> Optional[Dict[str, Any]]:
"""
Get user information via Guacamole API.
Args:
auth_token: User authentication token.
Returns:
Dictionary with user information or None.
"""
user_url = f"{self.base_url}/guacamole/api/session/data/postgresql/self"
headers = {"Guacamole-Token": auth_token}
try:
response = self.session.get(user_url, headers=headers, timeout=10)
if response.status_code != 200:
logger.warning(
"Failed to get user info",
status_code=response.status_code,
response=response.text[:200],
)
return None
user_data = response.json()
username = user_data.get("username")
if not username:
logger.warning("No username in user info response")
return None
permissions_url = (
f"{self.base_url}/guacamole/api/session/data/postgresql/"
f"users/{username}/permissions"
)
try:
perm_response = self.session.get(
permissions_url, headers=headers, timeout=10
)
if perm_response.status_code == 200:
permissions_data = perm_response.json()
system_permissions = permissions_data.get("systemPermissions", [])
logger.info(
"System permissions retrieved",
username=username,
system_permissions=system_permissions,
permissions_count=len(system_permissions),
)
else:
logger.warning(
"Failed to get user permissions",
username=username,
status_code=perm_response.status_code,
response=perm_response.text[:200],
)
system_permissions = []
except Exception as e:
logger.warning(
"Error getting user permissions", username=username, error=str(e)
)
system_permissions = []
user_data["systemPermissions"] = system_permissions
attributes = user_data.get("attributes", {})
user_data.update(
{
"fullName": attributes.get("guac-full-name"),
"emailAddress": attributes.get("guac-email-address"),
"organization": attributes.get("guac-organization"),
"organizationalRole": attributes.get("guac-organizational-role"),
}
)
logger.info(
"User info retrieved successfully",
username=username,
system_permissions=system_permissions,
permissions_count=len(system_permissions),
)
return user_data
except requests.exceptions.RequestException as e:
logger.error("Network error getting user info", error=str(e))
return None
except Exception as e:
logger.error("Unexpected error getting user info", error=str(e))
return None
def create_jwt_for_user(self, user_info: Dict[str, Any]) -> str:
"""
Create JWT token for user with session storage.
Args:
user_info: User information from authenticate_user.
Returns:
JWT token.
"""
session_id = session_storage.create_session(
user_info=user_info,
guac_token=user_info["auth_token"],
expires_in_minutes=int(
os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "60")
),
)
return create_jwt_token(user_info, session_id)
def get_user_connections(self, auth_token: str) -> List[Dict[str, Any]]:
"""
Get list of user connections.
Args:
auth_token: User authentication token.
Returns:
List of connections.
"""
connections_url = (
f"{self.base_url}/guacamole/api/session/data/postgresql/connections"
)
headers = {"Guacamole-Token": auth_token}
try:
response = self.session.get(connections_url, headers=headers, timeout=10)
if response.status_code != 200:
logger.warning(
"Failed to get user connections", status_code=response.status_code
)
return []
connections_data = response.json()
if isinstance(connections_data, dict):
connections = list(connections_data.values())
else:
connections = connections_data
logger.debug("Retrieved user connections", count=len(connections))
return connections
except requests.exceptions.RequestException as e:
logger.error("Network error getting connections", error=str(e))
return []
except Exception as e:
logger.error("Unexpected error getting connections", error=str(e))
return []
def create_connection_with_token(
self, connection_config: Dict[str, Any], auth_token: str
) -> Optional[Dict[str, Any]]:
"""
Create connection using user token.
Args:
connection_config: Connection configuration.
auth_token: User authentication token.
Returns:
Information about created connection or None.
"""
create_url = (
f"{self.base_url}/guacamole/api/session/data/postgresql/connections"
)
headers = {
"Content-Type": "application/json",
"Guacamole-Token": auth_token,
}
try:
response = self.session.post(
create_url, headers=headers, json=connection_config, timeout=30
)
if response.status_code not in [200, 201]:
logger.error(
"Failed to create connection",
status_code=response.status_code,
response=response.text[:500],
)
return None
created_connection = response.json()
connection_id = created_connection.get("identifier")
if not connection_id:
logger.error(
"No connection ID in response", response=created_connection
)
return None
logger.info(
"Connection created successfully",
connection_id=connection_id,
protocol=connection_config.get("protocol"),
hostname=connection_config.get("parameters", {}).get("hostname"),
)
return created_connection
except requests.exceptions.RequestException as e:
logger.error("Network error creating connection", error=str(e))
return None
except Exception as e:
logger.error("Unexpected error creating connection", error=str(e))
return None
def delete_connection_with_token(
self, connection_id: str, auth_token: str
) -> bool:
"""
Delete connection using user token.
Args:
connection_id: Connection ID to delete.
auth_token: User authentication token.
Returns:
True if deletion successful, False otherwise.
"""
delete_url = (
f"{self.base_url}/guacamole/api/session/data/postgresql/"
f"connections/{connection_id}"
)
headers = {"Guacamole-Token": auth_token}
try:
response = self.session.delete(delete_url, headers=headers, timeout=10)
if response.status_code == 204:
logger.info("Connection deleted successfully", connection_id=connection_id)
return True
logger.warning(
"Failed to delete connection",
connection_id=connection_id,
status_code=response.status_code,
response=response.text[:200],
)
return False
except requests.exceptions.RequestException as e:
logger.error(
"Network error deleting connection",
connection_id=connection_id,
error=str(e),
)
return False
except Exception as e:
logger.error(
"Unexpected error deleting connection",
connection_id=connection_id,
error=str(e),
)
return False
def validate_token(self, auth_token: str) -> bool:
"""
Validate Guacamole token.
Args:
auth_token: Token to validate.
Returns:
True if token is valid, False otherwise.
"""
try:
user_info = self.get_user_info(auth_token)
return user_info is not None
except Exception:
return False
def refresh_user_token(
self, username: str, current_token: str
) -> Optional[str]:
"""
Refresh user token (if supported by Guacamole).
Args:
username: Username.
current_token: Current token.
Returns:
New token or None.
"""
logger.debug(
"Token refresh requested but not supported by Guacamole", username=username
)
return None

View File

@ -0,0 +1,474 @@
"""Module for working with real KMS/HSM systems."""
import base64
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Dict
import boto3
import requests
from botocore.exceptions import ClientError
logger = logging.getLogger(__name__)
class KMSProvider(ABC):
"""Abstract class for KMS providers."""
@abstractmethod
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
"""
Encrypt data using KMS.
Args:
plaintext: Data to encrypt.
key_id: Key identifier.
Returns:
Encrypted data.
"""
pass
@abstractmethod
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
"""
Decrypt data using KMS.
Args:
ciphertext: Encrypted data.
key_id: Key identifier.
Returns:
Decrypted data.
"""
pass
@abstractmethod
def generate_data_key(
self, key_id: str, key_spec: str = "AES_256"
) -> Dict[str, bytes]:
"""
Generate data encryption key.
Args:
key_id: Key identifier.
key_spec: Key specification.
Returns:
Dictionary with 'plaintext' and 'ciphertext' keys.
"""
pass
class AWSKMSProvider(KMSProvider):
"""AWS KMS provider."""
def __init__(self, region_name: str = "us-east-1") -> None:
"""
Initialize AWS KMS provider.
Args:
region_name: AWS region name.
"""
self.kms_client = boto3.client("kms", region_name=region_name)
self.region_name = region_name
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
"""Encrypt data using AWS KMS."""
try:
response = self.kms_client.encrypt(KeyId=key_id, Plaintext=plaintext)
return response["CiphertextBlob"]
except ClientError as e:
logger.error("AWS KMS encryption failed: %s", e)
raise
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
"""Decrypt data using AWS KMS."""
try:
response = self.kms_client.decrypt(
CiphertextBlob=ciphertext, KeyId=key_id
)
return response["Plaintext"]
except ClientError as e:
logger.error("AWS KMS decryption failed: %s", e)
raise
def generate_data_key(
self, key_id: str, key_spec: str = "AES_256"
) -> Dict[str, bytes]:
"""Generate data encryption key."""
try:
response = self.kms_client.generate_data_key(
KeyId=key_id, KeySpec=key_spec
)
return {
"plaintext": response["Plaintext"],
"ciphertext": response["CiphertextBlob"],
}
except ClientError as e:
logger.error("AWS KMS data key generation failed: %s", e)
raise
class GoogleCloudKMSProvider(KMSProvider):
"""Google Cloud KMS provider."""
def __init__(self, project_id: str, location: str = "global") -> None:
"""
Initialize Google Cloud KMS provider.
Args:
project_id: Google Cloud project ID.
location: Key location.
"""
self.project_id = project_id
self.location = location
self.base_url = (
f"https://cloudkms.googleapis.com/v1/projects/{project_id}"
f"/locations/{location}"
)
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
"""Encrypt data using Google Cloud KMS."""
try:
url = (
f"{self.base_url}/keyRings/default/cryptoKeys/{key_id}:encrypt"
)
response = requests.post(
url,
json={"plaintext": base64.b64encode(plaintext).decode()},
headers={
"Authorization": f"Bearer {self._get_access_token()}"
},
timeout=30,
)
response.raise_for_status()
return base64.b64decode(response.json()["ciphertext"])
except requests.RequestException as e:
logger.error("Google Cloud KMS encryption failed: %s", e)
raise RuntimeError(
f"Google Cloud KMS encryption failed: {e}"
) from e
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
"""Decrypt data using Google Cloud KMS."""
try:
url = (
f"{self.base_url}/keyRings/default/cryptoKeys/{key_id}:decrypt"
)
response = requests.post(
url,
json={"ciphertext": base64.b64encode(ciphertext).decode()},
headers={
"Authorization": f"Bearer {self._get_access_token()}"
},
timeout=30,
)
response.raise_for_status()
return base64.b64decode(response.json()["plaintext"])
except requests.RequestException as e:
logger.error("Google Cloud KMS decryption failed: %s", e)
raise RuntimeError(
f"Google Cloud KMS decryption failed: {e}"
) from e
def generate_data_key(
self, key_id: str, key_spec: str = "AES_256"
) -> Dict[str, bytes]:
"""Generate data encryption key."""
try:
url = (
f"{self.base_url}/keyRings/default/cryptoKeys/{key_id}"
":generateDataKey"
)
response = requests.post(
url,
json={"keySpec": key_spec},
headers={
"Authorization": f"Bearer {self._get_access_token()}"
},
timeout=30,
)
response.raise_for_status()
data = response.json()
return {
"plaintext": base64.b64decode(data["plaintext"]),
"ciphertext": base64.b64decode(data["ciphertext"]),
}
except requests.RequestException as e:
logger.error("Google Cloud KMS data key generation failed: %s", e)
raise RuntimeError(
f"Google Cloud KMS data key generation failed: {e}"
) from e
def _get_access_token(self) -> str:
"""
Get access token for Google Cloud API.
Note: In production, use service account or metadata server.
"""
return os.getenv("GOOGLE_CLOUD_ACCESS_TOKEN", "")
class YubiHSMProvider(KMSProvider):
"""YubiHSM provider (hardware security module)."""
def __init__(self, hsm_url: str, auth_key_id: int) -> None:
"""
Initialize YubiHSM provider.
Args:
hsm_url: YubiHSM URL.
auth_key_id: Authentication key ID.
"""
self.hsm_url = hsm_url
self.auth_key_id = auth_key_id
def encrypt(self, plaintext: bytes, key_id: str) -> bytes:
"""Encrypt data using YubiHSM."""
try:
response = requests.post(
f"{self.hsm_url}/api/v1/encrypt",
json={"key_id": key_id, "plaintext": plaintext.hex()},
headers={"Authorization": f"Bearer {self._get_hsm_token()}"},
timeout=30,
)
response.raise_for_status()
return bytes.fromhex(response.json()["ciphertext"])
except requests.RequestException as e:
logger.error("YubiHSM encryption failed: %s", e)
raise RuntimeError(f"YubiHSM encryption failed: {e}") from e
def decrypt(self, ciphertext: bytes, key_id: str) -> bytes:
"""Decrypt data using YubiHSM."""
try:
response = requests.post(
f"{self.hsm_url}/api/v1/decrypt",
json={"key_id": key_id, "ciphertext": ciphertext.hex()},
headers={"Authorization": f"Bearer {self._get_hsm_token()}"},
timeout=30,
)
response.raise_for_status()
return bytes.fromhex(response.json()["plaintext"])
except requests.RequestException as e:
logger.error("YubiHSM decryption failed: %s", e)
raise RuntimeError(f"YubiHSM decryption failed: {e}") from e
def generate_data_key(
self, key_id: str, key_spec: str = "AES_256"
) -> Dict[str, bytes]:
"""Generate data encryption key."""
try:
response = requests.post(
f"{self.hsm_url}/api/v1/generate-data-key",
json={"key_id": key_id, "key_spec": key_spec},
headers={"Authorization": f"Bearer {self._get_hsm_token()}"},
timeout=30,
)
response.raise_for_status()
data = response.json()
return {
"plaintext": bytes.fromhex(data["plaintext"]),
"ciphertext": bytes.fromhex(data["ciphertext"]),
}
except requests.RequestException as e:
logger.error("YubiHSM data key generation failed: %s", e)
raise RuntimeError(
f"YubiHSM data key generation failed: {e}"
) from e
def _get_hsm_token(self) -> str:
"""
Get token for YubiHSM.
Note: In production, use proper YubiHSM authentication.
"""
return os.getenv("YUBIHSM_TOKEN", "")
class SecureKeyManager:
"""Key manager using real KMS/HSM systems."""
def __init__(self, kms_provider: KMSProvider, master_key_id: str) -> None:
"""
Initialize secure key manager.
Args:
kms_provider: KMS provider instance.
master_key_id: Master key identifier.
"""
self.kms_provider = kms_provider
self.master_key_id = master_key_id
self.key_cache: Dict[str, bytes] = {}
def encrypt_session_key(self, session_key: bytes, session_id: str) -> bytes:
"""
Encrypt session key using KMS/HSM.
Args:
session_key: Session key to encrypt.
session_id: Session ID for context.
Returns:
Encrypted session key.
"""
try:
encrypted_key = self.kms_provider.encrypt(
session_key, self.master_key_id
)
logger.info(
"Session key encrypted with KMS/HSM",
extra={
"session_id": session_id,
"key_length": len(session_key),
"encrypted_length": len(encrypted_key),
},
)
return encrypted_key
except Exception as e:
logger.error(
"Failed to encrypt session key with KMS/HSM",
extra={"session_id": session_id, "error": str(e)},
)
raise
def decrypt_session_key(
self, encrypted_session_key: bytes, session_id: str
) -> bytes:
"""
Decrypt session key using KMS/HSM.
Args:
encrypted_session_key: Encrypted session key.
session_id: Session ID for context.
Returns:
Decrypted session key.
"""
try:
decrypted_key = self.kms_provider.decrypt(
encrypted_session_key, self.master_key_id
)
logger.info(
"Session key decrypted with KMS/HSM",
extra={"session_id": session_id, "key_length": len(decrypted_key)},
)
return decrypted_key
except Exception as e:
logger.error(
"Failed to decrypt session key with KMS/HSM",
extra={"session_id": session_id, "error": str(e)},
)
raise
def generate_encryption_key(self, session_id: str) -> Dict[str, bytes]:
"""
Generate encryption key using KMS/HSM.
Args:
session_id: Session ID.
Returns:
Dictionary with 'plaintext' and 'ciphertext' keys.
"""
try:
key_data = self.kms_provider.generate_data_key(self.master_key_id)
logger.info(
"Encryption key generated with KMS/HSM",
extra={
"session_id": session_id,
"key_length": len(key_data["plaintext"]),
},
)
return key_data
except Exception as e:
logger.error(
"Failed to generate encryption key with KMS/HSM",
extra={"session_id": session_id, "error": str(e)},
)
raise
class KMSProviderFactory:
"""Factory for creating KMS providers."""
@staticmethod
def create_provider(provider_type: str, **kwargs: Any) -> KMSProvider:
"""
Create KMS provider by type.
Args:
provider_type: Provider type ('aws', 'gcp', 'yubihsm').
**kwargs: Provider-specific arguments.
Returns:
KMS provider instance.
Raises:
ValueError: If provider type is unsupported.
"""
if provider_type == "aws":
return AWSKMSProvider(**kwargs)
if provider_type == "gcp":
return GoogleCloudKMSProvider(**kwargs)
if provider_type == "yubihsm":
return YubiHSMProvider(**kwargs)
raise ValueError(f"Unsupported KMS provider: {provider_type}")
def get_secure_key_manager() -> SecureKeyManager:
"""
Get configured secure key manager.
Returns:
Configured SecureKeyManager instance.
Raises:
ValueError: If provider type is unsupported.
"""
provider_type = os.getenv("KMS_PROVIDER", "aws")
master_key_id = os.getenv("KMS_MASTER_KEY_ID", "alias/session-keys")
if provider_type == "aws":
provider = AWSKMSProvider(
region_name=os.getenv("AWS_REGION", "us-east-1")
)
elif provider_type == "gcp":
provider = GoogleCloudKMSProvider(
project_id=os.getenv("GCP_PROJECT_ID", ""),
location=os.getenv("GCP_LOCATION", "global"),
)
elif provider_type == "yubihsm":
provider = YubiHSMProvider(
hsm_url=os.getenv("YUBIHSM_URL", ""),
auth_key_id=int(os.getenv("YUBIHSM_AUTH_KEY_ID", "0")),
)
else:
raise ValueError(f"Unsupported KMS provider: {provider_type}")
return SecureKeyManager(provider, master_key_id)
secure_key_manager = get_secure_key_manager()

View File

@ -0,0 +1,286 @@
"""Log sanitization for removing sensitive information from logs."""
import json
import re
from typing import Any, Dict
import structlog
logger = structlog.get_logger(__name__)
class LogSanitizer:
"""Class for cleaning logs from sensitive information."""
def __init__(self) -> None:
"""Initialize LogSanitizer with sensitive fields and patterns."""
self.sensitive_fields = {
"password",
"passwd",
"pwd",
"secret",
"token",
"key",
"auth_token",
"guac_token",
"jwt_token",
"access_token",
"refresh_token",
"api_key",
"private_key",
"encryption_key",
"session_id",
"cookie",
"authorization",
"credential",
"credentials",
"global_credentials",
"machine_credentials",
"ssh_password",
"ssh_username",
"credential_hash",
"password_hash",
"password_salt",
"encrypted_password",
}
self.sensitive_patterns = [
r'password["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
r'token["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
r'key["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
r'secret["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
r'authorization["\']?\s*[:=]\s*["\']?([^"\'\\s]+)["\']?',
]
self.jwt_pattern = re.compile(
r"\b[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\b"
)
self.api_key_pattern = re.compile(r"\b[A-Za-z0-9]{32,}\b")
def mask_sensitive_value(self, value: str, mask_char: str = "*") -> str:
"""
Mask sensitive value.
Args:
value: Value to mask.
mask_char: Character to use for masking.
Returns:
Masked value.
"""
if not value or len(value) <= 4:
return mask_char * 4
if len(value) <= 8:
return value[:2] + mask_char * (len(value) - 4) + value[-2:]
return value[:4] + mask_char * (len(value) - 8) + value[-4:]
def sanitize_string(self, text: str) -> str:
"""
Clean string from sensitive information.
Args:
text: Text to clean.
Returns:
Cleaned text.
"""
if not isinstance(text, str):
return text
sanitized = text
sanitized = self.jwt_pattern.sub(
lambda m: self.mask_sensitive_value(m.group(0)), sanitized
)
sanitized = self.api_key_pattern.sub(
lambda m: self.mask_sensitive_value(m.group(0)), sanitized
)
for pattern in self.sensitive_patterns:
sanitized = re.sub(
pattern,
lambda m: m.group(0).replace(
m.group(1), self.mask_sensitive_value(m.group(1))
),
sanitized,
flags=re.IGNORECASE,
)
return sanitized
def sanitize_dict(
self, data: Dict[str, Any], max_depth: int = 10
) -> Dict[str, Any]:
"""
Recursively clean dictionary from sensitive information.
Args:
data: Dictionary to clean.
max_depth: Maximum recursion depth.
Returns:
Cleaned dictionary.
"""
if max_depth <= 0:
return {"error": "max_depth_exceeded"}
if not isinstance(data, dict):
return data
sanitized = {}
for key, value in data.items():
key_lower = key.lower()
is_sensitive_key = any(
sensitive_field in key_lower
for sensitive_field in self.sensitive_fields
)
if is_sensitive_key:
if isinstance(value, str):
sanitized[key] = self.mask_sensitive_value(value)
elif isinstance(value, (dict, list)):
sanitized[key] = self.sanitize_value(value, max_depth - 1)
else:
sanitized[key] = "[MASKED]"
else:
sanitized[key] = self.sanitize_value(value, max_depth - 1)
return sanitized
def sanitize_value(self, value: Any, max_depth: int = 10) -> Any:
"""
Clean value of any type.
Args:
value: Value to clean.
max_depth: Maximum recursion depth.
Returns:
Cleaned value.
"""
if max_depth <= 0:
return "[max_depth_exceeded]"
if isinstance(value, str):
return self.sanitize_string(value)
if isinstance(value, dict):
return self.sanitize_dict(value, max_depth)
if isinstance(value, list):
return [
self.sanitize_value(item, max_depth - 1) for item in value
]
if isinstance(value, (int, float, bool, type(None))):
return value
return self.sanitize_string(str(value))
def sanitize_log_event(
self, event_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Clean log event from sensitive information.
Args:
event_dict: Log event dictionary.
Returns:
Cleaned event dictionary.
"""
try:
sanitized_event = event_dict.copy()
sanitized_event = self.sanitize_dict(sanitized_event)
special_fields = ["request_body", "response_body", "headers"]
for field in special_fields:
if field in sanitized_event:
sanitized_event[field] = self.sanitize_value(
sanitized_event[field]
)
return sanitized_event
except Exception as e:
logger.error("Error sanitizing log event", error=str(e))
return {
"error": "sanitization_failed",
"original_error": str(e),
}
def sanitize_json_string(self, json_string: str) -> str:
"""
Clean JSON string from sensitive information.
Args:
json_string: JSON string to clean.
Returns:
Cleaned JSON string.
"""
try:
data = json.loads(json_string)
sanitized_data = self.sanitize_value(data)
return json.dumps(sanitized_data, ensure_ascii=False)
except json.JSONDecodeError:
return self.sanitize_string(json_string)
except Exception as e:
logger.error("Error sanitizing JSON string", error=str(e))
return json_string
def is_sensitive_field(self, field_name: str) -> bool:
"""
Check if field is sensitive.
Args:
field_name: Field name.
Returns:
True if field is sensitive.
"""
field_lower = field_name.lower()
return any(
sensitive_field in field_lower
for sensitive_field in self.sensitive_fields
)
def get_sanitization_stats(self) -> Dict[str, Any]:
"""
Get sanitization statistics.
Returns:
Sanitization statistics dictionary.
"""
return {
"sensitive_fields_count": len(self.sensitive_fields),
"sensitive_patterns_count": len(self.sensitive_patterns),
"sensitive_fields": list(self.sensitive_fields),
"jwt_pattern_active": bool(self.jwt_pattern),
"api_key_pattern_active": bool(self.api_key_pattern),
}
log_sanitizer = LogSanitizer()
def sanitize_log_processor(
logger: Any, name: str, event_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Processor for structlog for automatic log sanitization.
Args:
logger: Logger instance.
name: Logger name.
event_dict: Event dictionary.
Returns:
Sanitized event dictionary.
"""
return log_sanitizer.sanitize_log_event(event_dict)

View File

@ -0,0 +1,296 @@
"""Authentication and authorization middleware."""
from typing import Any, Awaitable, Callable, Dict, Optional
from fastapi import HTTPException, Request, Response
from fastapi.responses import JSONResponse
import structlog
from .models import UserRole
from .permissions import PermissionChecker
from .session_storage import session_storage
from .utils import extract_token_from_header, verify_jwt_token
logger = structlog.get_logger(__name__)
# Public endpoints that don't require authentication
PUBLIC_PATHS = {
"/",
"/api/health",
"/api/docs",
"/api/openapi.json",
"/api/redoc",
"/favicon.ico",
"/api/auth/login",
}
# Static file prefixes for FastAPI (Swagger UI)
STATIC_PREFIXES = [
"/static/",
"/docs/",
"/redoc/",
"/api/static/",
"/api/docs/",
"/api/redoc/",
]
async def jwt_auth_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""
Middleware for JWT token verification and user authentication.
Supports JWT token in Authorization header: Bearer <token>
"""
if request.method == "OPTIONS":
return await call_next(request)
path = request.url.path
if path in PUBLIC_PATHS:
return await call_next(request)
if any(path.startswith(prefix) for prefix in STATIC_PREFIXES):
return await call_next(request)
user_token: Optional[str] = None
user_info: Optional[Dict[str, Any]] = None
auth_method: Optional[str] = None
try:
auth_header = request.headers.get("Authorization")
if auth_header:
jwt_token = extract_token_from_header(auth_header)
if jwt_token:
jwt_payload = verify_jwt_token(jwt_token)
if jwt_payload:
session_id = jwt_payload.get("session_id")
if session_id:
session_data = session_storage.get_session(session_id)
if session_data:
user_token = session_data.get("guac_token")
else:
logger.warning(
"Session not found in Redis",
session_id=session_id,
username=jwt_payload.get("username"),
)
else:
user_token = jwt_payload.get("guac_token")
user_info = {
"username": jwt_payload["username"],
"role": jwt_payload["role"],
"permissions": jwt_payload.get("permissions", []),
"full_name": jwt_payload.get("full_name"),
"email": jwt_payload.get("email"),
"organization": jwt_payload.get("organization"),
"organizational_role": jwt_payload.get("organizational_role"),
}
auth_method = "jwt"
logger.debug(
"JWT authentication successful",
username=user_info["username"],
role=user_info["role"],
has_session=session_id is not None,
has_token=user_token is not None,
)
if not user_token or not user_info:
logger.info(
"Authentication required",
path=path,
method=request.method,
client_ip=request.client.host if request.client else "unknown",
)
return JSONResponse(
status_code=401,
content={
"error": "Authentication required",
"message": (
"Provide JWT token in Authorization header. "
"Get token via /auth/login"
),
"login_endpoint": "/auth/login",
},
)
user_role = UserRole(user_info["role"])
allowed, reason = PermissionChecker.check_endpoint_access(
user_role, request.method, path
)
if not allowed:
logger.warning(
"Access denied to endpoint",
username=user_info["username"],
role=user_info["role"],
endpoint=f"{request.method} {path}",
reason=reason,
)
return JSONResponse(
status_code=403,
content={
"error": "Access denied",
"message": reason,
"required_role": "Higher privileges required",
},
)
request.state.user_token = user_token
request.state.user_info = user_info
request.state.auth_method = auth_method
logger.debug(
"Authentication and authorization successful",
username=user_info["username"],
role=user_info["role"],
auth_method=auth_method,
endpoint=f"{request.method} {path}",
)
response = await call_next(request)
if hasattr(request.state, "user_info"):
response.headers["X-User"] = request.state.user_info["username"]
response.headers["X-User-Role"] = request.state.user_info["role"]
response.headers["X-Auth-Method"] = request.state.auth_method
return response
except HTTPException:
raise
except Exception as e:
logger.error(
"Unexpected error in auth middleware",
error=str(e),
path=path,
method=request.method,
)
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"message": "Authentication system error",
},
)
def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
"""
Get current user information from request.state.
Args:
request: FastAPI Request object.
Returns:
User information dictionary or None.
"""
return getattr(request.state, "user_info", None)
def get_current_user_token(request: Request) -> Optional[str]:
"""
Get current user token from request.state.
Args:
request: FastAPI Request object.
Returns:
User token string or None.
"""
return getattr(request.state, "user_token", None)
def require_role(required_role: UserRole) -> Callable:
"""
Decorator to check user role.
Args:
required_role: Required user role.
Returns:
Function decorator.
"""
def decorator(func: Callable) -> Callable:
async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any:
user_info = get_current_user(request)
if not user_info:
raise HTTPException(
status_code=401, detail="Authentication required"
)
user_role = UserRole(user_info["role"])
permission = f"role_{required_role.value}"
if not PermissionChecker.check_permission(user_role, permission):
raise HTTPException(
status_code=403,
detail=f"Role {required_role.value} required",
)
return await func(request, *args, **kwargs)
return wrapper
return decorator
def require_permission(permission: str) -> Callable:
"""
Decorator to check specific permission.
Args:
permission: Required permission string.
Returns:
Function decorator.
"""
def decorator(func: Callable) -> Callable:
async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any:
user_info = get_current_user(request)
if not user_info:
raise HTTPException(
status_code=401, detail="Authentication required"
)
user_role = UserRole(user_info["role"])
if not PermissionChecker.check_permission(user_role, permission):
raise HTTPException(
status_code=403,
detail=f"Permission '{permission}' required",
)
return await func(request, *args, **kwargs)
return wrapper
return decorator
async def validate_connection_ownership(
request: Request, connection_id: str
) -> bool:
"""
Check user permissions for connection management.
Args:
request: FastAPI Request object.
connection_id: Connection ID.
Returns:
True if user can manage the connection, False otherwise.
"""
user_info = get_current_user(request)
if not user_info:
return False
user_role = UserRole(user_info["role"])
if PermissionChecker.can_delete_any_connection(user_role):
return True
return True

View File

@ -0,0 +1,343 @@
"""
Pydantic models for authentication system.
"""
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class UserRole(str, Enum):
"""User roles in the system."""
GUEST = "GUEST"
USER = "USER"
ADMIN = "ADMIN"
SUPER_ADMIN = "SUPER_ADMIN"
class LoginRequest(BaseModel):
"""Authentication request."""
username: str = Field(..., description="Username in Guacamole")
password: str = Field(..., description="Password in Guacamole")
class LoginResponse(BaseModel):
"""Successful authentication response."""
access_token: str = Field(..., description="JWT access token")
token_type: str = Field(default="bearer", description="Token type")
expires_in: int = Field(..., description="Token lifetime in seconds")
user_info: Dict[str, Any] = Field(..., description="User information")
class UserInfo(BaseModel):
"""User information."""
username: str = Field(..., description="Username")
role: UserRole = Field(..., description="User role")
permissions: List[str] = Field(
default_factory=list, description="System permissions"
)
full_name: Optional[str] = Field(None, description="Full name")
email: Optional[str] = Field(None, description="Email address")
organization: Optional[str] = Field(None, description="Organization")
organizational_role: Optional[str] = Field(None, description="Job title")
class ConnectionRequest(BaseModel):
"""Connection creation request.
Requires JWT token in Authorization header: Bearer <token>
Get token via /auth/login
"""
hostname: str = Field(..., description="IP address or hostname")
protocol: str = Field(
default="rdp", description="Connection protocol (rdp, vnc, ssh)"
)
username: Optional[str] = Field(
None, description="Username for remote machine connection"
)
password: Optional[str] = Field(
None,
description="Encrypted password for remote machine connection (Base64 AES-256-GCM)",
)
port: Optional[int] = Field(
None, description="Port (default used if not specified)"
)
ttl_minutes: Optional[int] = Field(
default=60, description="Connection lifetime in minutes"
)
enable_sftp: Optional[bool] = Field(
default=True, description="Enable SFTP for SSH (file browser with drag'n'drop)"
)
sftp_root_directory: Optional[str] = Field(
default="/", description="Root directory for SFTP (default: /)"
)
sftp_server_alive_interval: Optional[int] = Field(
default=0, description="SFTP keep-alive interval in seconds (0 = disabled)"
)
class ConnectionResponse(BaseModel):
"""Connection creation response."""
connection_id: str = Field(..., description="Created connection ID")
connection_url: str = Field(..., description="URL to access connection")
status: str = Field(..., description="Connection status")
expires_at: str = Field(..., description="Connection expiration time")
ttl_minutes: int = Field(..., description="TTL in minutes")
class RefreshTokenRequest(BaseModel):
"""Token refresh request."""
refresh_token: str = Field(..., description="Refresh token")
class LogoutRequest(BaseModel):
"""Logout request."""
token: str = Field(..., description="Token to revoke")
class PermissionCheckRequest(BaseModel):
"""Permission check request."""
action: str = Field(..., description="Action to check")
resource: Optional[str] = Field(None, description="Resource (optional)")
class PermissionCheckResponse(BaseModel):
"""Permission check response."""
allowed: bool = Field(..., description="Whether action is allowed")
reason: Optional[str] = Field(None, description="Denial reason (if applicable)")
class SavedMachineCreate(BaseModel):
"""Saved machine create/update request."""
name: str = Field(..., min_length=1, max_length=255, description="Machine name")
hostname: str = Field(
..., min_length=1, max_length=255, description="IP address or hostname"
)
port: int = Field(..., gt=0, lt=65536, description="Connection port")
protocol: str = Field(
..., description="Connection protocol (rdp, ssh, vnc, telnet)"
)
os: Optional[str] = Field(
None,
max_length=255,
description="Operating system (e.g., Windows Server 2019, Ubuntu 22.04)",
)
description: Optional[str] = Field(None, description="Machine description")
tags: Optional[List[str]] = Field(
default_factory=list, description="Tags for grouping"
)
is_favorite: bool = Field(default=False, description="Favorite machine")
class Config:
json_schema_extra = {
"example": {
"name": "Production Web Server",
"hostname": "192.168.1.100",
"port": 3389,
"protocol": "rdp",
"os": "Windows Server 2019",
"description": "Main production web server",
"tags": ["production", "web"],
"is_favorite": True,
}
}
class SavedMachineUpdate(BaseModel):
"""Saved machine partial update request."""
name: Optional[str] = Field(None, min_length=1, max_length=255)
hostname: Optional[str] = Field(None, min_length=1, max_length=255)
port: Optional[int] = Field(None, gt=0, lt=65536)
protocol: Optional[str] = None
os: Optional[str] = Field(None, max_length=255, description="Operating system")
description: Optional[str] = None
tags: Optional[List[str]] = None
is_favorite: Optional[bool] = None
class SavedMachineResponse(BaseModel):
"""Saved machine information response."""
id: str = Field(..., description="Machine UUID")
user_id: str = Field(..., description="Owner user ID")
name: str = Field(..., description="Machine name")
hostname: str = Field(..., description="IP address or hostname")
port: int = Field(..., description="Connection port")
protocol: str = Field(..., description="Connection protocol")
os: Optional[str] = Field(None, description="Operating system")
description: Optional[str] = Field(None, description="Description")
tags: List[str] = Field(default_factory=list, description="Tags")
is_favorite: bool = Field(default=False, description="Favorite")
created_at: str = Field(..., description="Creation date (ISO 8601)")
updated_at: str = Field(..., description="Update date (ISO 8601)")
last_connected_at: Optional[str] = Field(
None, description="Last connection (ISO 8601)"
)
connection_stats: Optional[Dict[str, Any]] = Field(
None, description="Connection statistics"
)
class Config:
from_attributes = True
class SavedMachineList(BaseModel):
"""Saved machines list."""
total: int = Field(..., description="Total number of machines")
machines: List[SavedMachineResponse] = Field(..., description="List of machines")
class ConnectionHistoryCreate(BaseModel):
"""Connection history record creation request."""
machine_id: str = Field(..., description="Machine UUID")
success: bool = Field(default=True, description="Successful connection")
error_message: Optional[str] = Field(None, description="Error message")
duration_seconds: Optional[int] = Field(
None, description="Connection duration in seconds"
)
class ConnectionHistoryResponse(BaseModel):
"""Connection history record information response."""
id: str = Field(..., description="Record UUID")
user_id: str = Field(..., description="User ID")
machine_id: str = Field(..., description="Machine UUID")
connected_at: str = Field(..., description="Connection time (ISO 8601)")
disconnected_at: Optional[str] = Field(
None, description="Disconnection time (ISO 8601)"
)
duration_seconds: Optional[int] = Field(None, description="Duration")
success: bool = Field(..., description="Successful connection")
error_message: Optional[str] = Field(None, description="Error message")
client_ip: Optional[str] = Field(None, description="Client IP")
class Config:
from_attributes = True
class BulkHealthCheckRequest(BaseModel):
"""Bulk machine availability check request."""
machine_ids: List[str] = Field(
..., min_items=1, max_items=200, description="List of machine IDs to check"
)
timeout: int = Field(
default=5, ge=1, le=30, description="Timeout for each check in seconds"
)
check_port: bool = Field(default=True, description="Check connection port")
class BulkHealthCheckResult(BaseModel):
"""Single machine check result."""
machine_id: str = Field(..., description="Machine ID")
machine_name: str = Field(..., description="Machine name")
hostname: str = Field(..., description="Hostname/IP")
status: str = Field(..., description="success, failed, timeout")
available: bool = Field(..., description="Machine is available")
response_time_ms: Optional[int] = Field(
None, description="Response time in milliseconds"
)
error: Optional[str] = Field(None, description="Error message")
checked_at: str = Field(..., description="Check time (ISO 8601)")
class BulkHealthCheckResponse(BaseModel):
"""Bulk availability check response."""
total: int = Field(..., description="Total number of machines")
success: int = Field(..., description="Number of successful checks")
failed: int = Field(..., description="Number of failed checks")
available: int = Field(..., description="Number of available machines")
unavailable: int = Field(..., description="Number of unavailable machines")
results: List[BulkHealthCheckResult] = Field(..., description="Detailed results")
execution_time_ms: int = Field(
..., description="Total execution time in milliseconds"
)
started_at: str = Field(..., description="Start time (ISO 8601)")
completed_at: str = Field(..., description="Completion time (ISO 8601)")
class SSHCredentials(BaseModel):
"""SSH credentials for machine."""
username: str = Field(..., min_length=1, max_length=255, description="SSH username")
password: str = Field(
..., min_length=1, description="SSH password (will be encrypted in transit)"
)
class BulkSSHCommandRequest(BaseModel):
"""Bulk SSH command execution request."""
machine_ids: List[str] = Field(
..., min_items=1, max_items=100, description="List of machine IDs"
)
machine_hostnames: Optional[Dict[str, str]] = Field(
None,
description="Optional hostname/IP for non-saved machines {machine_id: hostname}",
)
command: str = Field(
..., min_length=1, max_length=500, description="SSH command to execute"
)
credentials_mode: str = Field(
..., description="Credentials mode: 'global' (same for all), 'custom' (per-machine)"
)
global_credentials: Optional[SSHCredentials] = Field(
None, description="Shared credentials for all machines (mode 'global')"
)
machine_credentials: Optional[Dict[str, SSHCredentials]] = Field(
None, description="Individual credentials (mode 'custom')"
)
timeout: int = Field(
default=30, ge=5, le=300, description="Command execution timeout (seconds)"
)
class BulkSSHCommandResult(BaseModel):
"""Single machine SSH command execution result."""
machine_id: str = Field(..., description="Machine ID")
machine_name: str = Field(..., description="Machine name")
hostname: str = Field(..., description="Hostname/IP")
status: str = Field(..., description="success, failed, timeout, no_credentials")
exit_code: Optional[int] = Field(None, description="Command exit code")
stdout: Optional[str] = Field(None, description="Stdout output")
stderr: Optional[str] = Field(None, description="Stderr output")
error: Optional[str] = Field(None, description="Error message")
execution_time_ms: Optional[int] = Field(
None, description="Execution time in milliseconds"
)
executed_at: str = Field(..., description="Execution time (ISO 8601)")
class BulkSSHCommandResponse(BaseModel):
"""Bulk SSH command execution response."""
total: int = Field(..., description="Total number of machines")
success: int = Field(..., description="Number of successful executions")
failed: int = Field(..., description="Number of failed executions")
results: List[BulkSSHCommandResult] = Field(..., description="Detailed results")
execution_time_ms: int = Field(
..., description="Total execution time in milliseconds"
)
command: str = Field(..., description="Executed command")
started_at: str = Field(..., description="Start time (ISO 8601)")
completed_at: str = Field(..., description="Completion time (ISO 8601)")

View File

@ -0,0 +1,292 @@
"""Permission and role system for Remote Access API."""
from typing import Dict, FrozenSet, List, Optional, Tuple
import structlog
from .models import UserRole
logger = structlog.get_logger(__name__)
class PermissionChecker:
"""User permission checker class."""
ROLE_MAPPING: Dict[str, UserRole] = {
"ADMINISTER": UserRole.SUPER_ADMIN,
"CREATE_USER": UserRole.ADMIN,
"CREATE_CONNECTION": UserRole.USER,
}
ROLE_PERMISSIONS: Dict[UserRole, FrozenSet[str]] = {
UserRole.GUEST: frozenset({
"view_own_connections",
"view_own_profile"
}),
UserRole.USER: frozenset({
"view_own_connections",
"view_own_profile",
"create_connections",
"delete_own_connections",
}),
UserRole.ADMIN: frozenset({
"view_own_connections",
"view_own_profile",
"create_connections",
"delete_own_connections",
"view_all_connections",
"delete_any_connection",
"view_system_stats",
"view_system_metrics",
}),
UserRole.SUPER_ADMIN: frozenset({
"view_own_connections",
"view_own_profile",
"create_connections",
"delete_own_connections",
"view_all_connections",
"delete_any_connection",
"view_system_stats",
"view_system_metrics",
"reset_system_stats",
"manage_users",
"view_system_logs",
"change_system_config",
}),
}
ENDPOINT_PERMISSIONS = {
"POST /connect": "create_connections",
"GET /connections": "view_own_connections",
"DELETE /connections": "delete_own_connections",
"GET /stats": "view_system_stats",
"GET /metrics": "view_system_metrics",
"POST /stats/reset": "reset_system_stats",
"GET /auth/profile": "view_own_profile",
"GET /auth/permissions": "view_own_profile",
}
@classmethod
def determine_role_from_permissions(cls, guacamole_permissions: List[str]) -> UserRole:
"""
Determine user role based on Guacamole system permissions.
Args:
guacamole_permissions: List of system permissions from Guacamole.
Returns:
User role.
"""
for permission, role in cls.ROLE_MAPPING.items():
if permission in guacamole_permissions:
logger.debug(
"Role determined from permission",
permission=permission,
role=role.value,
all_permissions=guacamole_permissions,
)
return role
logger.debug(
"No system permissions found, assigning GUEST role",
permissions=guacamole_permissions,
)
return UserRole.GUEST
@classmethod
def get_role_permissions(cls, role: UserRole) -> FrozenSet[str]:
"""
Get all permissions for a role.
Args:
role: User role.
Returns:
Frozen set of permissions.
"""
return cls.ROLE_PERMISSIONS.get(role, frozenset())
@classmethod
def check_permission(cls, user_role: UserRole, permission: str) -> bool:
"""
Check if role has specific permission.
Args:
user_role: User role.
permission: Permission to check.
Returns:
True if permission exists, False otherwise.
"""
role_permissions = cls.get_role_permissions(user_role)
has_permission = permission in role_permissions
logger.debug(
"Permission check",
role=user_role.value,
permission=permission,
allowed=has_permission,
)
return has_permission
@classmethod
def check_endpoint_access(
cls, user_role: UserRole, method: str, path: str
) -> Tuple[bool, Optional[str]]:
"""
Check endpoint access.
Args:
user_role: User role.
method: HTTP method (GET, POST, DELETE, etc.).
path: Endpoint path.
Returns:
Tuple of (allowed: bool, reason: Optional[str]).
"""
endpoint_key = f"{method} {path}"
required_permission = cls.ENDPOINT_PERMISSIONS.get(endpoint_key)
if not required_permission:
for pattern, permission in cls.ENDPOINT_PERMISSIONS.items():
if cls._match_endpoint_pattern(endpoint_key, pattern):
required_permission = permission
break
if not required_permission:
return True, None
has_permission = cls.check_permission(user_role, required_permission)
if not has_permission:
reason = (
f"Required permission '{required_permission}' "
f"not granted to role '{user_role.value}'"
)
logger.info(
"Endpoint access denied",
role=user_role.value,
endpoint=endpoint_key,
required_permission=required_permission,
reason=reason,
)
return False, reason
logger.debug(
"Endpoint access granted",
role=user_role.value,
endpoint=endpoint_key,
required_permission=required_permission,
)
return True, None
@classmethod
def _match_endpoint_pattern(cls, endpoint: str, pattern: str) -> bool:
"""
Check if endpoint matches pattern.
Args:
endpoint: Endpoint to check (e.g., "DELETE /connections/123").
pattern: Pattern (e.g., "DELETE /connections").
Returns:
True if matches.
"""
if pattern.endswith("/connections"):
base_pattern = pattern.replace("/connections", "/connections/")
return endpoint.startswith(base_pattern)
return False
@classmethod
def check_connection_ownership(
cls, user_role: UserRole, username: str, connection_owner: str
) -> Tuple[bool, Optional[str]]:
"""
Check connection management rights.
Args:
user_role: User role.
username: Username.
connection_owner: Connection owner.
Returns:
Tuple of (allowed: bool, reason: Optional[str]).
"""
if user_role in (UserRole.ADMIN, UserRole.SUPER_ADMIN):
return True, None
if username == connection_owner:
return True, None
reason = (
f"User '{username}' cannot manage connection owned by '{connection_owner}'"
)
logger.info(
"Connection ownership check failed",
user=username,
owner=connection_owner,
role=user_role.value,
reason=reason,
)
return False, reason
@classmethod
def can_view_all_connections(cls, user_role: UserRole) -> bool:
"""
Check if user can view all connections.
Args:
user_role: User role.
Returns:
True if can view all connections.
"""
return cls.check_permission(user_role, "view_all_connections")
@classmethod
def can_delete_any_connection(cls, user_role: UserRole) -> bool:
"""
Check if user can delete any connection.
Args:
user_role: User role.
Returns:
True if can delete any connection.
"""
return cls.check_permission(user_role, "delete_any_connection")
@classmethod
def get_user_permissions_list(cls, user_role: UserRole) -> List[str]:
"""
Get sorted list of user permissions for API response.
Args:
user_role: User role.
Returns:
Sorted list of permissions.
"""
permissions = cls.get_role_permissions(user_role)
return sorted(permissions)
@classmethod
def validate_role_hierarchy(
cls, current_user_role: UserRole, target_user_role: UserRole
) -> bool:
"""
Validate role hierarchy for user management.
Args:
current_user_role: Current user role.
target_user_role: Target user role.
Returns:
True if current user can manage target user.
"""
return current_user_role == UserRole.SUPER_ADMIN

View File

@ -0,0 +1,259 @@
"""
Module for PKI/CA certificate handling for server key signature verification.
"""
import logging
import ssl
from datetime import datetime, timezone
from typing import Dict, List, Optional
import requests
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import ed25519
logger = logging.getLogger(__name__)
class PKICertificateVerifier:
"""PKI/CA certificate verifier for server key signature verification."""
def __init__(self, ca_cert_path: str, crl_urls: Optional[List[str]] = None) -> None:
"""Initialize PKI certificate verifier.
Args:
ca_cert_path: Path to CA certificate file.
crl_urls: List of CRL URLs (optional).
"""
self.ca_cert_path = ca_cert_path
self.crl_urls = crl_urls or []
self.ca_cert = self._load_ca_certificate()
self.cert_store = self._build_cert_store()
def _load_ca_certificate(self) -> x509.Certificate:
"""Load CA certificate.
Returns:
Loaded CA certificate.
Raises:
Exception: If certificate cannot be loaded.
"""
try:
with open(self.ca_cert_path, "rb") as f:
ca_cert_data = f.read()
return x509.load_pem_x509_certificate(ca_cert_data)
except Exception as e:
logger.error("Failed to load CA certificate", extra={"error": str(e)})
raise
def _build_cert_store(self) -> x509.CertificateStore:
"""Build certificate store.
Returns:
Certificate store with CA certificate.
"""
store = x509.CertificateStore()
store.add_cert(self.ca_cert)
return store
def verify_server_certificate(self, server_cert_pem: bytes) -> bool:
"""
Verify server certificate through PKI/CA.
Args:
server_cert_pem: PEM-encoded server certificate.
Returns:
True if certificate is valid, False otherwise.
"""
try:
server_cert = x509.load_pem_x509_certificate(server_cert_pem)
if not self._verify_certificate_chain(server_cert):
logger.warning("Certificate chain verification failed")
return False
if not self._check_certificate_revocation(server_cert):
logger.warning("Certificate is revoked")
return False
if not self._check_certificate_validity(server_cert):
logger.warning("Certificate is expired or not yet valid")
return False
logger.info("Server certificate verified successfully")
return True
except Exception as e:
logger.error("Certificate verification error", extra={"error": str(e)})
return False
def _verify_certificate_chain(self, server_cert: x509.Certificate) -> bool:
"""Verify certificate chain.
Args:
server_cert: Server certificate to verify.
Returns:
True if chain is valid, False otherwise.
"""
try:
ca_public_key = self.ca_cert.public_key()
ca_public_key.verify(
server_cert.signature,
server_cert.tbs_certificate_bytes,
server_cert.signature_algorithm_oid,
)
return True
except Exception as e:
logger.error(
"Certificate chain verification failed", extra={"error": str(e)}
)
return False
def _check_certificate_revocation(self, server_cert: x509.Certificate) -> bool:
"""Check certificate revocation via CRL.
Args:
server_cert: Server certificate to check.
Returns:
True if certificate is not revoked, False otherwise.
"""
try:
crl_dps = server_cert.extensions.get_extension_for_oid(
x509.ExtensionOID.CRL_DISTRIBUTION_POINTS
).value
for crl_dp in crl_dps:
for crl_url in crl_dp.full_name:
if self._check_crl(server_cert, crl_url.value):
return False
return True
except Exception as e:
logger.warning("CRL check failed", extra={"error": str(e)})
return True
def _check_crl(self, server_cert: x509.Certificate, crl_url: str) -> bool:
"""Check specific CRL for certificate revocation.
Args:
server_cert: Server certificate to check.
crl_url: CRL URL.
Returns:
True if certificate is revoked, False otherwise.
"""
try:
response = requests.get(crl_url, timeout=10)
if response.status_code == 200:
crl_data = response.content
crl = x509.load_der_x509_crl(crl_data)
return server_cert.serial_number in [revoked.serial_number for revoked in crl]
return False
except Exception as e:
logger.warning("Failed to check CRL", extra={"crl_url": crl_url, "error": str(e)})
return False
def _check_certificate_validity(self, server_cert: x509.Certificate) -> bool:
"""Check certificate validity period.
Args:
server_cert: Server certificate to check.
Returns:
True if certificate is valid, False otherwise.
"""
now = datetime.now(timezone.utc)
return server_cert.not_valid_before <= now <= server_cert.not_valid_after
def extract_public_key_from_certificate(
self, server_cert_pem: bytes
) -> ed25519.Ed25519PublicKey:
"""Extract public key from certificate.
Args:
server_cert_pem: PEM-encoded server certificate.
Returns:
Extracted Ed25519 public key.
"""
server_cert = x509.load_pem_x509_certificate(server_cert_pem)
public_key = server_cert.public_key()
if not isinstance(public_key, ed25519.Ed25519PublicKey):
raise ValueError("Certificate does not contain Ed25519 public key")
return public_key
class ServerCertificateManager:
"""Server certificate manager."""
def __init__(self, pki_verifier: PKICertificateVerifier) -> None:
"""Initialize server certificate manager.
Args:
pki_verifier: PKI certificate verifier instance.
"""
self.pki_verifier = pki_verifier
self.server_certificates: Dict[str, bytes] = {}
def get_server_certificate(self, server_hostname: str) -> Optional[bytes]:
"""Get server certificate via TLS handshake.
Args:
server_hostname: Server hostname.
Returns:
PEM-encoded server certificate or None if failed.
"""
try:
context = ssl.create_default_context()
context.check_hostname = True
context.verify_mode = ssl.CERT_REQUIRED
with ssl.create_connection((server_hostname, 443)) as sock:
with context.wrap_socket(sock, server_hostname=server_hostname) as ssock:
cert_der = ssock.getpeercert_chain()[0]
cert_pem = ssl.DER_cert_to_PEM_cert(cert_der)
cert_bytes = cert_pem.encode()
if self.pki_verifier.verify_server_certificate(cert_bytes):
self.server_certificates[server_hostname] = cert_bytes
return cert_bytes
else:
logger.error(
"Server certificate verification failed",
extra={"server_hostname": server_hostname},
)
return None
except Exception as e:
logger.error(
"Failed to get server certificate",
extra={"server_hostname": server_hostname, "error": str(e)},
)
return None
def get_trusted_public_key(
self, server_hostname: str
) -> Optional[ed25519.Ed25519PublicKey]:
"""Get trusted public key from server certificate.
Args:
server_hostname: Server hostname.
Returns:
Ed25519 public key or None if failed.
"""
cert_pem = self.get_server_certificate(server_hostname)
if cert_pem:
return self.pki_verifier.extract_public_key_from_certificate(cert_pem)
return None
pki_verifier = PKICertificateVerifier(
ca_cert_path="/etc/ssl/certs/ca-certificates.crt",
crl_urls=["http://crl.example.com/crl.pem"],
)
certificate_manager = ServerCertificateManager(pki_verifier)

View File

@ -0,0 +1,342 @@
"""Redis-based thread-safe rate limiting."""
import os
import time
from typing import Any, Dict, Optional, Tuple
import redis
import structlog
logger = structlog.get_logger(__name__)
# Redis connection constants
REDIS_DEFAULT_HOST = "localhost"
REDIS_DEFAULT_PORT = "6379"
REDIS_DEFAULT_DB = "0"
REDIS_SOCKET_TIMEOUT = 5
# Rate limiting constants
DEFAULT_RATE_LIMIT_REQUESTS = 10
DEFAULT_RATE_LIMIT_WINDOW_SECONDS = 60
FAILED_LOGIN_RETENTION_SECONDS = 3600 # 1 hour
LOGIN_RATE_LIMIT_REQUESTS = 5
LOGIN_RATE_LIMIT_WINDOW_SECONDS = 900 # 15 minutes
DEFAULT_FAILED_LOGIN_WINDOW_MINUTES = 60
SECONDS_PER_MINUTE = 60
# Redis key prefixes
RATE_LIMIT_KEY_PREFIX = "rate_limit:"
FAILED_LOGINS_IP_PREFIX = "failed_logins:ip:"
FAILED_LOGINS_USER_PREFIX = "failed_logins:user:"
LOGIN_LIMIT_PREFIX = "login_limit:"
# Rate limit headers
HEADER_RATE_LIMIT = "X-RateLimit-Limit"
HEADER_RATE_LIMIT_WINDOW = "X-RateLimit-Window"
HEADER_RATE_LIMIT_USED = "X-RateLimit-Used"
HEADER_RATE_LIMIT_REMAINING = "X-RateLimit-Remaining"
HEADER_RATE_LIMIT_RESET = "X-RateLimit-Reset"
HEADER_RATE_LIMIT_STATUS = "X-RateLimit-Status"
class RedisRateLimiter:
"""Thread-safe Redis-based rate limiter with sliding window algorithm."""
def __init__(self) -> None:
"""Initialize Redis rate limiter."""
self.redis_client = redis.Redis(
host=os.getenv("REDIS_HOST", REDIS_DEFAULT_HOST),
port=int(os.getenv("REDIS_PORT", REDIS_DEFAULT_PORT)),
password=os.getenv("REDIS_PASSWORD"),
db=int(os.getenv("REDIS_DB", REDIS_DEFAULT_DB)),
decode_responses=True,
socket_connect_timeout=REDIS_SOCKET_TIMEOUT,
socket_timeout=REDIS_SOCKET_TIMEOUT,
retry_on_timeout=True,
)
try:
self.redis_client.ping()
logger.info("Rate limiter Redis connection established")
except redis.ConnectionError as e:
logger.error("Failed to connect to Redis for rate limiting", error=str(e))
raise
def check_rate_limit(
self,
client_ip: str,
requests_limit: int = DEFAULT_RATE_LIMIT_REQUESTS,
window_seconds: int = DEFAULT_RATE_LIMIT_WINDOW_SECONDS,
) -> Tuple[bool, Dict[str, int]]:
"""
Check rate limit using sliding window algorithm.
Args:
client_ip: Client IP address.
requests_limit: Maximum number of requests.
window_seconds: Time window in seconds.
Returns:
Tuple of (allowed: bool, headers: Dict[str, int]).
"""
try:
current_time = int(time.time())
window_start = current_time - window_seconds
key = f"{RATE_LIMIT_KEY_PREFIX}{client_ip}"
lua_script = """
local key = KEYS[1]
local window_start = tonumber(ARGV[1])
local current_time = tonumber(ARGV[2])
local requests_limit = tonumber(ARGV[3])
local window_seconds = tonumber(ARGV[4])
-- Remove old entries (outside window)
redis.call('ZREMRANGEBYSCORE', key, '-inf', window_start)
-- Count current requests
local current_requests = redis.call('ZCARD', key)
-- Check limit
if current_requests >= requests_limit then
-- Return blocking information
local oldest_request = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local reset_time = oldest_request[2] + window_seconds
return {0, current_requests, reset_time}
else
-- Add current request
redis.call('ZADD', key, current_time, current_time)
redis.call('EXPIRE', key, window_seconds)
-- Count updated requests
local new_count = redis.call('ZCARD', key)
return {1, new_count, 0}
end
"""
result = self.redis_client.eval(
lua_script, 1, key, window_start, current_time, requests_limit, window_seconds
)
allowed = bool(result[0])
current_requests = result[1]
reset_time = result[2] if result[2] > 0 else 0
headers = {
HEADER_RATE_LIMIT: requests_limit,
HEADER_RATE_LIMIT_WINDOW: window_seconds,
HEADER_RATE_LIMIT_USED: current_requests,
HEADER_RATE_LIMIT_REMAINING: max(0, requests_limit - current_requests),
}
if reset_time > 0:
headers[HEADER_RATE_LIMIT_RESET] = reset_time
if allowed:
logger.debug(
"Rate limit check passed",
client_ip=client_ip,
current_requests=current_requests,
limit=requests_limit,
)
else:
logger.warning(
"Rate limit exceeded",
client_ip=client_ip,
current_requests=current_requests,
limit=requests_limit,
reset_time=reset_time,
)
return allowed, headers
except redis.RedisError as e:
logger.error(
"Redis error during rate limit check", client_ip=client_ip, error=str(e)
)
return True, {
HEADER_RATE_LIMIT: requests_limit,
HEADER_RATE_LIMIT_WINDOW: window_seconds,
HEADER_RATE_LIMIT_USED: 0,
HEADER_RATE_LIMIT_REMAINING: requests_limit,
HEADER_RATE_LIMIT_STATUS: "redis_error",
}
except Exception as e:
logger.error(
"Unexpected error during rate limit check", client_ip=client_ip, error=str(e)
)
return True, {
HEADER_RATE_LIMIT: requests_limit,
HEADER_RATE_LIMIT_WINDOW: window_seconds,
HEADER_RATE_LIMIT_USED: 0,
HEADER_RATE_LIMIT_REMAINING: requests_limit,
HEADER_RATE_LIMIT_STATUS: "error",
}
def check_login_rate_limit(
self, client_ip: str, username: Optional[str] = None
) -> Tuple[bool, Dict[str, int]]:
"""
Special rate limit for login endpoint.
Args:
client_ip: Client IP address.
username: Username (optional).
Returns:
Tuple of (allowed: bool, headers: Dict[str, int]).
"""
allowed, headers = self.check_rate_limit(
client_ip, LOGIN_RATE_LIMIT_REQUESTS, LOGIN_RATE_LIMIT_WINDOW_SECONDS
)
if username and allowed:
user_key = f"{LOGIN_LIMIT_PREFIX}{username}"
user_allowed, user_headers = self.check_rate_limit(
user_key, LOGIN_RATE_LIMIT_REQUESTS, LOGIN_RATE_LIMIT_WINDOW_SECONDS
)
if not user_allowed:
logger.warning(
"Login rate limit exceeded for user",
username=username,
client_ip=client_ip,
)
return False, user_headers
return allowed, headers
def record_failed_login(self, client_ip: str, username: str) -> None:
"""
Record failed login attempt for brute-force attack tracking.
Args:
client_ip: Client IP address.
username: Username.
"""
try:
current_time = int(time.time())
ip_key = f"{FAILED_LOGINS_IP_PREFIX}{client_ip}"
self.redis_client.zadd(ip_key, {current_time: current_time})
self.redis_client.expire(ip_key, FAILED_LOGIN_RETENTION_SECONDS)
user_key = f"{FAILED_LOGINS_USER_PREFIX}{username}"
self.redis_client.zadd(user_key, {current_time: current_time})
self.redis_client.expire(user_key, FAILED_LOGIN_RETENTION_SECONDS)
logger.debug("Failed login recorded", client_ip=client_ip, username=username)
except Exception as e:
logger.error(
"Failed to record failed login attempt",
client_ip=client_ip,
username=username,
error=str(e),
)
def get_failed_login_count(
self,
client_ip: str,
username: Optional[str] = None,
window_minutes: int = DEFAULT_FAILED_LOGIN_WINDOW_MINUTES,
) -> Dict[str, int]:
"""
Get count of failed login attempts.
Args:
client_ip: Client IP address.
username: Username (optional).
window_minutes: Time window in minutes.
Returns:
Dictionary with failed login counts.
"""
try:
current_time = int(time.time())
window_start = current_time - (window_minutes * SECONDS_PER_MINUTE)
result = {"ip_failed_count": 0, "user_failed_count": 0}
ip_key = f"{FAILED_LOGINS_IP_PREFIX}{client_ip}"
ip_count = self.redis_client.zcount(ip_key, window_start, current_time)
result["ip_failed_count"] = ip_count
if username:
user_key = f"{FAILED_LOGINS_USER_PREFIX}{username}"
user_count = self.redis_client.zcount(user_key, window_start, current_time)
result["user_failed_count"] = user_count
return result
except Exception as e:
logger.error(
"Failed to get failed login count",
client_ip=client_ip,
username=username,
error=str(e),
)
return {"ip_failed_count": 0, "user_failed_count": 0}
def clear_failed_logins(
self, client_ip: str, username: Optional[str] = None
) -> None:
"""
Clear failed login attempt records.
Args:
client_ip: Client IP address.
username: Username (optional).
"""
try:
ip_key = f"{FAILED_LOGINS_IP_PREFIX}{client_ip}"
self.redis_client.delete(ip_key)
if username:
user_key = f"{FAILED_LOGINS_USER_PREFIX}{username}"
self.redis_client.delete(user_key)
logger.debug(
"Failed login records cleared", client_ip=client_ip, username=username
)
except Exception as e:
logger.error(
"Failed to clear failed login records",
client_ip=client_ip,
username=username,
error=str(e),
)
def get_rate_limit_stats(self) -> Dict[str, Any]:
"""
Get rate limiting statistics.
Returns:
Rate limiting statistics dictionary.
"""
try:
rate_limit_keys = self.redis_client.keys(f"{RATE_LIMIT_KEY_PREFIX}*")
failed_login_keys = self.redis_client.keys(f"{FAILED_LOGINS_IP_PREFIX}*")
return {
"active_rate_limits": len(rate_limit_keys),
"failed_login_trackers": len(failed_login_keys),
"redis_memory_usage": (
self.redis_client.memory_usage(f"{RATE_LIMIT_KEY_PREFIX}*")
+ self.redis_client.memory_usage(f"{FAILED_LOGINS_IP_PREFIX}*")
),
}
except Exception as e:
logger.error("Failed to get rate limit stats", error=str(e))
return {
"active_rate_limits": 0,
"failed_login_trackers": 0,
"redis_memory_usage": 0,
"error": str(e),
}
redis_rate_limiter = RedisRateLimiter()

View File

@ -0,0 +1,266 @@
"""
Redis Storage Helper for storing shared state in cluster.
"""
import json
import os
from typing import Any, Dict, List, Optional
import redis
import structlog
logger = structlog.get_logger(__name__)
class RedisConnectionStorage:
"""
Redis storage for active connections.
Supports cluster operation with automatic TTL.
"""
def __init__(self) -> None:
"""Initialize Redis connection storage."""
self._redis_client = redis.Redis(
host=os.getenv("REDIS_HOST", "redis"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
db=0,
decode_responses=True,
)
try:
self._redis_client.ping()
logger.info("Redis Connection Storage initialized successfully")
except Exception as e:
logger.error("Failed to connect to Redis for connections", error=str(e))
raise RuntimeError(f"Redis connection failed: {e}")
def add_connection(
self,
connection_id: str,
connection_data: Dict[str, Any],
ttl_seconds: Optional[int] = None,
) -> None:
"""
Add connection to Redis.
Args:
connection_id: Connection ID.
connection_data: Connection data dictionary.
ttl_seconds: TTL in seconds (None = no automatic expiration).
"""
try:
redis_key = f"connection:active:{connection_id}"
if ttl_seconds is not None:
self._redis_client.setex(
redis_key, ttl_seconds, json.dumps(connection_data)
)
logger.debug(
"Connection added to Redis with TTL",
connection_id=connection_id,
ttl_seconds=ttl_seconds,
)
else:
self._redis_client.set(redis_key, json.dumps(connection_data))
logger.debug(
"Connection added to Redis without TTL",
connection_id=connection_id,
)
except Exception as e:
logger.error(
"Failed to add connection to Redis",
connection_id=connection_id,
error=str(e),
)
raise
def get_connection(self, connection_id: str) -> Optional[Dict[str, Any]]:
"""
Get connection from Redis.
Args:
connection_id: Connection ID.
Returns:
Connection data dictionary or None if not found.
"""
try:
redis_key = f"connection:active:{connection_id}"
conn_json = self._redis_client.get(redis_key)
if not conn_json:
return None
return json.loads(conn_json)
except Exception as e:
logger.error(
"Failed to get connection from Redis",
connection_id=connection_id,
error=str(e),
)
return None
def update_connection(
self, connection_id: str, update_data: Dict[str, Any]
) -> None:
"""
Update connection data.
Args:
connection_id: Connection ID.
update_data: Data to update (will be merged with existing data).
"""
try:
redis_key = f"connection:active:{connection_id}"
conn_json = self._redis_client.get(redis_key)
if not conn_json:
logger.warning(
"Cannot update non-existent connection",
connection_id=connection_id,
)
return
conn_data = json.loads(conn_json)
conn_data.update(update_data)
ttl = self._redis_client.ttl(redis_key)
if ttl > 0:
self._redis_client.setex(redis_key, ttl, json.dumps(conn_data))
logger.debug("Connection updated in Redis", connection_id=connection_id)
except Exception as e:
logger.error(
"Failed to update connection in Redis",
connection_id=connection_id,
error=str(e),
)
def delete_connection(self, connection_id: str) -> bool:
"""
Delete connection from Redis.
Args:
connection_id: Connection ID.
Returns:
True if connection was deleted, False otherwise.
"""
try:
redis_key = f"connection:active:{connection_id}"
result = self._redis_client.delete(redis_key)
if result > 0:
logger.debug("Connection deleted from Redis", connection_id=connection_id)
return True
return False
except Exception as e:
logger.error(
"Failed to delete connection from Redis",
connection_id=connection_id,
error=str(e),
)
return False
def get_all_connections(self) -> Dict[str, Dict[str, Any]]:
"""
Get all active connections.
Returns:
Dictionary mapping connection_id to connection_data.
"""
try:
pattern = "connection:active:*"
keys = list(self._redis_client.scan_iter(match=pattern, count=100))
connections = {}
for key in keys:
try:
conn_id = key.replace("connection:active:", "")
conn_json = self._redis_client.get(key)
if conn_json:
connections[conn_id] = json.loads(conn_json)
except Exception:
continue
return connections
except Exception as e:
logger.error("Failed to get all connections from Redis", error=str(e))
return {}
def get_user_connections(self, username: str) -> List[Dict[str, Any]]:
"""
Get all user connections.
Args:
username: Username.
Returns:
List of user connections.
"""
try:
all_connections = self.get_all_connections()
user_connections = [
conn_data
for conn_data in all_connections.values()
if conn_data.get("owner_username") == username
]
return user_connections
except Exception as e:
logger.error(
"Failed to get user connections from Redis",
username=username,
error=str(e),
)
return []
def cleanup_expired_connections(self) -> int:
"""
Cleanup expired connections.
Returns:
Number of removed connections.
"""
try:
pattern = "connection:active:*"
keys = list(self._redis_client.scan_iter(match=pattern, count=100))
cleaned_count = 0
for key in keys:
ttl = self._redis_client.ttl(key)
if ttl == -2:
cleaned_count += 1
elif ttl == -1:
self._redis_client.delete(key)
cleaned_count += 1
if cleaned_count > 0:
logger.info(
"Connections cleanup completed", cleaned_count=cleaned_count
)
return cleaned_count
except Exception as e:
logger.error("Failed to cleanup expired connections", error=str(e))
return 0
def get_stats(self) -> Dict[str, Any]:
"""
Get connection statistics.
Returns:
Connection statistics dictionary.
"""
try:
pattern = "connection:active:*"
keys = list(self._redis_client.scan_iter(match=pattern, count=100))
return {"total_connections": len(keys), "storage": "Redis"}
except Exception as e:
logger.error("Failed to get connection stats", error=str(e))
return {"error": str(e), "storage": "Redis"}
redis_connection_storage = RedisConnectionStorage()

View File

@ -0,0 +1,172 @@
"""
Module for nonce management and replay attack prevention.
"""
import hashlib
import logging
import time
import redis
logger = logging.getLogger(__name__)
# Constants
NONCE_TTL_SECONDS = 300 # 5 minutes TTL for nonce
TIMESTAMP_TOLERANCE_SECONDS = 30 # 30 seconds tolerance for timestamp
class NonceManager:
"""Nonce manager for replay attack prevention."""
def __init__(self, redis_client: redis.Redis) -> None:
"""Initialize nonce manager.
Args:
redis_client: Redis client instance.
"""
self.redis = redis_client
self.nonce_ttl = NONCE_TTL_SECONDS
self.timestamp_tolerance = TIMESTAMP_TOLERANCE_SECONDS
def validate_nonce(
self, client_nonce: bytes, timestamp: int, session_id: str
) -> bool:
"""
Validate nonce uniqueness and timestamp validity.
Args:
client_nonce: Nonce from client.
timestamp: Timestamp from client.
session_id: Session ID.
Returns:
True if nonce is valid, False otherwise.
"""
try:
if not self._validate_timestamp(timestamp):
logger.warning(
"Invalid timestamp",
extra={"timestamp": timestamp, "session_id": session_id},
)
return False
nonce_key = self._create_nonce_key(client_nonce, session_id)
nonce_hash = hashlib.sha256(client_nonce).hexdigest()[:16]
if self.redis.exists(nonce_key):
logger.warning(
"Nonce already used",
extra={"session_id": session_id, "nonce_hash": nonce_hash},
)
return False
self.redis.setex(nonce_key, self.nonce_ttl, timestamp)
logger.info(
"Nonce validated successfully",
extra={"session_id": session_id, "nonce_hash": nonce_hash},
)
return True
except Exception as e:
logger.error(
"Nonce validation error",
extra={"error": str(e), "session_id": session_id},
)
return False
def _validate_timestamp(self, timestamp: int) -> bool:
"""Validate timestamp.
Args:
timestamp: Timestamp in milliseconds.
Returns:
True if timestamp is within tolerance, False otherwise.
"""
current_time = int(time.time() * 1000)
time_diff = abs(current_time - timestamp)
return time_diff <= (self.timestamp_tolerance * 1000)
def _create_nonce_key(self, client_nonce: bytes, session_id: str) -> str:
"""Create unique key for nonce in Redis.
Args:
client_nonce: Nonce from client.
session_id: Session ID.
Returns:
Redis key string.
"""
nonce_hash = hashlib.sha256(client_nonce).hexdigest()
return f"nonce:{session_id}:{nonce_hash}"
def cleanup_expired_nonces(self) -> int:
"""
Cleanup expired nonces.
Redis automatically removes keys by TTL, but this method provides
additional cleanup for keys without TTL.
Returns:
Number of expired nonces removed.
"""
try:
pattern = "nonce:*"
keys = self.redis.keys(pattern)
expired_count = 0
for key in keys:
ttl = self.redis.ttl(key)
if ttl == -1:
self.redis.delete(key)
expired_count += 1
logger.info(
"Nonce cleanup completed",
extra={"expired_count": expired_count, "total_keys": len(keys)},
)
return expired_count
except Exception as e:
logger.error("Nonce cleanup error", extra={"error": str(e)})
return 0
class ReplayProtection:
"""Replay attack protection."""
def __init__(self, redis_client: redis.Redis) -> None:
"""Initialize replay protection.
Args:
redis_client: Redis client instance.
"""
self.nonce_manager = NonceManager(redis_client)
def validate_request(
self, client_nonce: bytes, timestamp: int, session_id: str
) -> bool:
"""
Validate request for replay attacks.
Args:
client_nonce: Nonce from client.
timestamp: Timestamp from client.
session_id: Session ID.
Returns:
True if request is valid, False otherwise.
"""
return self.nonce_manager.validate_nonce(
client_nonce, timestamp, session_id
)
def cleanup(self) -> int:
"""
Cleanup expired nonces.
Returns:
Number of expired nonces removed.
"""
return self.nonce_manager.cleanup_expired_nonces()

View File

@ -0,0 +1,401 @@
"""
Database operations for saved user machines.
"""
import os
from typing import Any, Dict, List, Optional
import psycopg2
import structlog
from psycopg2.extras import RealDictCursor
from psycopg2.extensions import connection as Connection
logger = structlog.get_logger(__name__)
class SavedMachinesDB:
"""PostgreSQL operations for saved machines."""
def __init__(self) -> None:
"""Initialize database configuration."""
self.db_config = {
"host": os.getenv("POSTGRES_HOST", "postgres"),
"port": int(os.getenv("POSTGRES_PORT", "5432")),
"database": os.getenv("POSTGRES_DB", "guacamole_db"),
"user": os.getenv("POSTGRES_USER", "guacamole_user"),
"password": os.getenv("POSTGRES_PASSWORD"),
}
def _get_connection(self) -> Connection:
"""Get database connection."""
try:
return psycopg2.connect(**self.db_config)
except Exception as e:
logger.error("Failed to connect to database", error=str(e))
raise
def create_machine(
self,
user_id: str,
name: str,
hostname: str,
port: int,
protocol: str,
os: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
is_favorite: bool = False,
) -> Dict[str, Any]:
"""
Create new saved machine.
Args:
user_id: User ID.
name: Machine name.
hostname: Machine hostname.
port: Connection port.
protocol: Connection protocol.
os: Operating system (optional).
description: Description (optional).
tags: Tags list (optional).
is_favorite: Whether machine is favorite.
Returns:
Dictionary with created machine data including ID.
"""
with self._get_connection() as conn:
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
query = """
INSERT INTO api.user_saved_machines
(user_id, name, hostname, port, protocol, os,
description, tags, is_favorite)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id, user_id, name, hostname, port, protocol, os,
description, tags, is_favorite, created_at, updated_at,
last_connected_at
"""
cur.execute(
query,
(
user_id,
name,
hostname,
port,
protocol,
os,
description,
tags or [],
is_favorite,
),
)
result = dict(cur.fetchone())
conn.commit()
logger.info(
"Saved machine created",
machine_id=result["id"],
user_id=user_id,
name=name,
)
return result
except Exception as e:
conn.rollback()
logger.error(
"Failed to create saved machine", error=str(e), user_id=user_id
)
raise
def get_user_machines(
self, user_id: str, include_stats: bool = False
) -> List[Dict[str, Any]]:
"""
Get all user machines.
Args:
user_id: User ID.
include_stats: Include connection statistics.
Returns:
List of machine dictionaries.
"""
with self._get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
if include_stats:
query = """
SELECT
m.*,
json_build_object(
'total_connections', COALESCE(COUNT(h.id), 0),
'last_connection', MAX(h.connected_at),
'successful_connections',
COALESCE(SUM(CASE WHEN h.success = TRUE THEN 1 ELSE 0 END), 0),
'failed_connections',
COALESCE(SUM(CASE WHEN h.success = FALSE THEN 1 ELSE 0 END), 0)
) as connection_stats
FROM api.user_saved_machines m
LEFT JOIN api.connection_history h ON m.id = h.machine_id
WHERE m.user_id = %s
GROUP BY m.id
ORDER BY m.is_favorite DESC, m.updated_at DESC
"""
else:
query = """
SELECT * FROM api.user_saved_machines
WHERE user_id = %s
ORDER BY is_favorite DESC, updated_at DESC
"""
cur.execute(query, (user_id,))
results = [dict(row) for row in cur.fetchall()]
logger.debug(
"Retrieved user machines", user_id=user_id, count=len(results)
)
return results
def get_machine_by_id(
self, machine_id: str, user_id: str
) -> Optional[Dict[str, Any]]:
"""
Get machine by ID with owner verification.
Args:
machine_id: Machine UUID.
user_id: User ID for permission check.
Returns:
Machine dictionary or None if not found.
"""
with self._get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
query = """
SELECT * FROM api.user_saved_machines
WHERE id = %s AND user_id = %s
"""
cur.execute(query, (machine_id, user_id))
result = cur.fetchone()
return dict(result) if result else None
def update_machine(
self, machine_id: str, user_id: str, **updates: Any
) -> Optional[Dict[str, Any]]:
"""
Update machine.
Args:
machine_id: Machine UUID.
user_id: User ID for permission check.
**updates: Fields to update.
Returns:
Updated machine dictionary or None if not found.
"""
allowed_fields = {
"name",
"hostname",
"port",
"protocol",
"os",
"description",
"tags",
"is_favorite",
}
updates_filtered = {
k: v for k, v in updates.items() if k in allowed_fields and v is not None
}
if not updates_filtered:
return self.get_machine_by_id(machine_id, user_id)
set_clause = ", ".join([f"{k} = %s" for k in updates_filtered.keys()])
values = list(updates_filtered.values()) + [machine_id, user_id]
query = f"""
UPDATE api.user_saved_machines
SET {set_clause}
WHERE id = %s AND user_id = %s
RETURNING id, user_id, name, hostname, port, protocol, os,
description, tags, is_favorite, created_at, updated_at,
last_connected_at
"""
with self._get_connection() as conn:
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(query, values)
result = cur.fetchone()
conn.commit()
if result:
logger.info(
"Saved machine updated",
machine_id=machine_id,
user_id=user_id,
updated_fields=list(updates_filtered.keys()),
)
return dict(result)
return None
except Exception as e:
conn.rollback()
logger.error(
"Failed to update machine",
error=str(e),
machine_id=machine_id,
user_id=user_id,
)
raise
def delete_machine(self, machine_id: str, user_id: str) -> bool:
"""
Delete machine.
Args:
machine_id: Machine UUID.
user_id: User ID for permission check.
Returns:
True if deleted, False if not found.
"""
with self._get_connection() as conn:
try:
with conn.cursor() as cur:
query = """
DELETE FROM api.user_saved_machines
WHERE id = %s AND user_id = %s
"""
cur.execute(query, (machine_id, user_id))
deleted_count = cur.rowcount
conn.commit()
if deleted_count > 0:
logger.info(
"Saved machine deleted",
machine_id=machine_id,
user_id=user_id,
)
return True
logger.warning(
"Machine not found for deletion",
machine_id=machine_id,
user_id=user_id,
)
return False
except Exception as e:
conn.rollback()
logger.error(
"Failed to delete machine",
error=str(e),
machine_id=machine_id,
user_id=user_id,
)
raise
def update_last_connected(self, machine_id: str, user_id: str) -> None:
"""
Update last connection time.
Args:
machine_id: Machine UUID.
user_id: User ID.
"""
with self._get_connection() as conn:
try:
with conn.cursor() as cur:
query = """
UPDATE api.user_saved_machines
SET last_connected_at = NOW()
WHERE id = %s AND user_id = %s
"""
cur.execute(query, (machine_id, user_id))
conn.commit()
logger.debug(
"Updated last_connected_at",
machine_id=machine_id,
user_id=user_id,
)
except Exception as e:
conn.rollback()
logger.error("Failed to update last_connected", error=str(e))
def add_connection_history(
self,
user_id: str,
machine_id: str,
success: bool = True,
error_message: Optional[str] = None,
duration_seconds: Optional[int] = None,
client_ip: Optional[str] = None,
) -> str:
"""
Add connection history record.
Args:
user_id: User ID.
machine_id: Machine ID.
success: Whether connection was successful.
error_message: Error message if failed (optional).
duration_seconds: Connection duration in seconds (optional).
client_ip: Client IP address (optional).
Returns:
UUID of created record.
"""
with self._get_connection() as conn:
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
query = """
INSERT INTO api.connection_history
(user_id, machine_id, success, error_message, duration_seconds, client_ip)
VALUES (%s, %s, %s, %s, %s, %s)
RETURNING id
"""
cur.execute(
query,
(
user_id,
machine_id,
success,
error_message,
duration_seconds,
client_ip,
),
)
result = cur.fetchone()
conn.commit()
logger.info(
"Connection history record created",
machine_id=machine_id,
user_id=user_id,
success=success,
)
return str(result["id"])
except Exception as e:
conn.rollback()
logger.error("Failed to add connection history", error=str(e))
raise
saved_machines_db = SavedMachinesDB()

View File

@ -0,0 +1,339 @@
"""
Redis-based session storage for Guacamole tokens.
"""
import json
import os
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
import redis
import structlog
logger = structlog.get_logger(__name__)
class SessionStorage:
"""Redis-based session storage for secure Guacamole token storage."""
def __init__(self) -> None:
"""Initialize Redis client and verify connection."""
self.redis_client = redis.Redis(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
db=int(os.getenv("REDIS_DB", "0")),
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
retry_on_timeout=True,
)
try:
self.redis_client.ping()
logger.info("Redis connection established successfully")
except redis.ConnectionError as e:
logger.error("Failed to connect to Redis", error=str(e))
raise
def create_session(
self,
user_info: Dict[str, Any],
guac_token: str,
expires_in_minutes: int = 60,
) -> str:
"""
Create new session.
Args:
user_info: User information dictionary.
guac_token: Guacamole authentication token.
expires_in_minutes: Session lifetime in minutes.
Returns:
Unique session ID.
"""
session_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
session_data = {
"session_id": session_id,
"user_info": user_info,
"guac_token": guac_token,
"created_at": now.isoformat(),
"expires_at": (now + timedelta(minutes=expires_in_minutes)).isoformat(),
"last_accessed": now.isoformat(),
}
try:
ttl_seconds = expires_in_minutes * 60
self.redis_client.setex(
f"session:{session_id}",
ttl_seconds,
json.dumps(session_data),
)
self.redis_client.setex(
f"user_session:{user_info['username']}",
ttl_seconds,
session_id,
)
logger.info(
"Session created successfully",
session_id=session_id,
username=user_info["username"],
expires_in_minutes=expires_in_minutes,
redis_key=f"session:{session_id}",
has_guac_token=bool(guac_token),
guac_token_length=len(guac_token) if guac_token else 0,
)
return session_id
except redis.RedisError as e:
logger.error("Failed to create session", error=str(e))
raise
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Get session data.
Args:
session_id: Session ID.
Returns:
Session data or None if not found/expired.
"""
try:
session_data = self.redis_client.get(f"session:{session_id}")
if not session_data:
logger.debug("Session not found", session_id=session_id)
return None
session = json.loads(session_data)
session["last_accessed"] = datetime.now(timezone.utc).isoformat()
ttl = self.redis_client.ttl(f"session:{session_id}")
if ttl > 0:
self.redis_client.setex(
f"session:{session_id}",
ttl,
json.dumps(session),
)
logger.debug(
"Session retrieved successfully",
session_id=session_id,
username=session["user_info"]["username"],
)
return session
except redis.RedisError as e:
logger.error("Failed to get session", session_id=session_id, error=str(e))
return None
except json.JSONDecodeError as e:
logger.error(
"Failed to decode session data", session_id=session_id, error=str(e)
)
return None
def get_session_by_username(self, username: str) -> Optional[Dict[str, Any]]:
"""
Get session by username.
Args:
username: Username.
Returns:
Session data or None.
"""
try:
session_id = self.redis_client.get(f"user_session:{username}")
if not session_id:
logger.debug("No active session for user", username=username)
return None
return self.get_session(session_id)
except redis.RedisError as e:
logger.error(
"Failed to get session by username", username=username, error=str(e)
)
return None
def update_session(self, session_id: str, updates: Dict[str, Any]) -> bool:
"""
Update session data.
Args:
session_id: Session ID.
updates: Updates to apply.
Returns:
True if update successful.
"""
try:
session_data = self.redis_client.get(f"session:{session_id}")
if not session_data:
logger.warning("Session not found for update", session_id=session_id)
return False
session = json.loads(session_data)
session.update(updates)
session["last_accessed"] = datetime.now(timezone.utc).isoformat()
ttl = self.redis_client.ttl(f"session:{session_id}")
if ttl > 0:
self.redis_client.setex(
f"session:{session_id}",
ttl,
json.dumps(session),
)
logger.debug(
"Session updated successfully",
session_id=session_id,
updates=list(updates.keys()),
)
return True
else:
logger.warning("Session expired during update", session_id=session_id)
return False
except redis.RedisError as e:
logger.error("Failed to update session", session_id=session_id, error=str(e))
return False
except json.JSONDecodeError as e:
logger.error(
"Failed to decode session data for update",
session_id=session_id,
error=str(e),
)
return False
def delete_session(self, session_id: str) -> bool:
"""
Delete session.
Args:
session_id: Session ID.
Returns:
True if deletion successful.
"""
try:
session_data = self.redis_client.get(f"session:{session_id}")
if session_data:
session = json.loads(session_data)
username = session["user_info"]["username"]
self.redis_client.delete(f"session:{session_id}")
self.redis_client.delete(f"user_session:{username}")
logger.info(
"Session deleted successfully",
session_id=session_id,
username=username,
)
return True
else:
logger.debug("Session not found for deletion", session_id=session_id)
return False
except redis.RedisError as e:
logger.error("Failed to delete session", session_id=session_id, error=str(e))
return False
except json.JSONDecodeError as e:
logger.error(
"Failed to decode session data for deletion",
session_id=session_id,
error=str(e),
)
return False
def delete_user_sessions(self, username: str) -> int:
"""
Delete all user sessions.
Args:
username: Username.
Returns:
Number of deleted sessions.
"""
try:
pattern = f"user_session:{username}"
session_keys = self.redis_client.keys(pattern)
deleted_count = 0
for key in session_keys:
session_id = self.redis_client.get(key)
if session_id and self.delete_session(session_id):
deleted_count += 1
logger.info(
"User sessions deleted", username=username, deleted_count=deleted_count
)
return deleted_count
except redis.RedisError as e:
logger.error(
"Failed to delete user sessions", username=username, error=str(e)
)
return 0
def cleanup_expired_sessions(self) -> int:
"""
Cleanup expired sessions.
Redis automatically removes keys by TTL, so this method is mainly
for compatibility and potential logic extension.
Returns:
Number of cleaned sessions (always 0, as Redis handles this automatically).
"""
logger.debug(
"Expired sessions cleanup completed (Redis TTL handles this automatically)"
)
return 0
def get_session_stats(self) -> Dict[str, Any]:
"""
Get session statistics.
Returns:
Session statistics dictionary.
"""
try:
session_keys = self.redis_client.keys("session:*")
user_keys = self.redis_client.keys("user_session:*")
memory_usage = (
self.redis_client.memory_usage("session:*") if session_keys else 0
)
return {
"active_sessions": len(session_keys),
"active_users": len(user_keys),
"redis_memory_usage": memory_usage,
}
except redis.RedisError as e:
logger.error("Failed to get session stats", error=str(e))
return {
"active_sessions": 0,
"active_users": 0,
"redis_memory_usage": 0,
"error": str(e),
}
session_storage = SessionStorage()

View File

@ -0,0 +1,154 @@
"""Module for verifying server key signatures with constant-time comparison."""
import logging
from typing import Dict, Optional
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
logger = logging.getLogger(__name__)
# Ed25519 constants
ED25519_SIGNATURE_LENGTH = 64
ED25519_PUBLIC_KEY_LENGTH = 32
DEFAULT_KEY_ID = "default"
class SignatureVerifier:
"""Signature verifier with constant-time comparison."""
def __init__(self) -> None:
"""Initialize the signature verifier."""
self.trusted_public_keys = self._load_trusted_keys()
def _load_trusted_keys(self) -> Dict[str, Optional[ed25519.Ed25519PublicKey]]:
"""
Load trusted public keys.
Returns:
Dictionary mapping key IDs to public keys.
"""
return {DEFAULT_KEY_ID: None}
def verify_server_key_signature(
self,
public_key_pem: bytes,
signature: bytes,
kid: Optional[str] = None,
) -> bool:
"""
Verify server public key signature with constant-time comparison.
Args:
public_key_pem: PEM-encoded public key.
signature: Signature bytes.
kid: Key ID for key selection (optional).
Returns:
True if signature is valid, False otherwise.
"""
try:
if len(signature) != ED25519_SIGNATURE_LENGTH:
logger.warning(
"Invalid signature length",
extra={
"expected": ED25519_SIGNATURE_LENGTH,
"actual": len(signature),
"kid": kid,
},
)
return False
try:
public_key = serialization.load_pem_public_key(public_key_pem)
except Exception as e:
logger.warning(
"Failed to load PEM public key",
extra={"error": str(e), "kid": kid},
)
return False
if not isinstance(public_key, ed25519.Ed25519PublicKey):
logger.warning(
"Public key is not Ed25519",
extra={"kid": kid},
)
return False
raw_public_key = public_key.public_bytes_raw()
if len(raw_public_key) != ED25519_PUBLIC_KEY_LENGTH:
logger.warning(
"Invalid public key length",
extra={
"expected": ED25519_PUBLIC_KEY_LENGTH,
"actual": len(raw_public_key),
"kid": kid,
},
)
return False
trusted_key = self._get_trusted_key(kid)
if not trusted_key:
logger.error("No trusted key found", extra={"kid": kid})
return False
try:
trusted_key.verify(signature, public_key_pem)
logger.info("Signature verification successful", extra={"kid": kid})
return True
except InvalidSignature:
logger.warning("Signature verification failed", extra={"kid": kid})
return False
except Exception as e:
logger.error(
"Signature verification error",
extra={"error": str(e), "kid": kid},
)
return False
def _get_trusted_key(
self, kid: Optional[str] = None
) -> Optional[ed25519.Ed25519PublicKey]:
"""
Get trusted public key by kid.
Args:
kid: Key ID (optional).
Returns:
Trusted public key or None if not found.
"""
key_id = kid if kid else DEFAULT_KEY_ID
return self.trusted_public_keys.get(key_id)
def add_trusted_key(self, kid: str, public_key_pem: bytes) -> bool:
"""
Add trusted public key.
Args:
kid: Key ID.
public_key_pem: PEM-encoded public key.
Returns:
True if key was added successfully.
"""
try:
public_key = serialization.load_pem_public_key(public_key_pem)
if not isinstance(public_key, ed25519.Ed25519PublicKey):
logger.error("Public key is not Ed25519", extra={"kid": kid})
return False
self.trusted_public_keys[kid] = public_key
logger.info("Trusted key added", extra={"kid": kid})
return True
except Exception as e:
logger.error(
"Failed to add trusted key",
extra={"error": str(e), "kid": kid},
)
return False
signature_verifier = SignatureVerifier()

View File

@ -0,0 +1,327 @@
"""Enhanced SSRF attack protection with DNS pinning and rebinding prevention."""
# Standard library imports
import ipaddress
import socket
import time
from typing import Any, Dict, List, Optional, Set, Tuple
# Third-party imports
import structlog
logger = structlog.get_logger(__name__)
class SSRFProtection:
"""Enhanced SSRF attack protection with DNS pinning."""
def __init__(self) -> None:
"""Initialize SSRF protection with blocked IPs and networks."""
self._dns_cache: Dict[str, Tuple[str, float, int]] = {}
self._dns_cache_ttl = 300
self._blocked_ips: Set[str] = {
"127.0.0.1",
"::1",
"0.0.0.0",
"169.254.169.254",
"10.0.0.1",
"10.255.255.255",
"172.16.0.1",
"172.31.255.255",
"192.168.0.1",
"192.168.255.255",
}
self._blocked_networks = [
"127.0.0.0/8",
"169.254.0.0/16",
"224.0.0.0/4",
"240.0.0.0/4",
"172.17.0.0/16",
"172.18.0.0/16",
"172.19.0.0/16",
"172.20.0.0/16",
"172.21.0.0/16",
"172.22.0.0/16",
"172.23.0.0/16",
"172.24.0.0/16",
"172.25.0.0/16",
"172.26.0.0/16",
"172.27.0.0/16",
"172.28.0.0/16",
"172.29.0.0/16",
"172.30.0.0/16",
"172.31.0.0/16",
]
self._allowed_networks: Dict[str, List[str]] = {
"USER": ["10.0.0.0/8", "172.16.0.0/16", "192.168.1.0/24"],
"ADMIN": [
"10.0.0.0/8",
"172.16.0.0/16",
"192.168.0.0/16",
"203.0.113.0/24",
],
"SUPER_ADMIN": ["0.0.0.0/0"],
}
def validate_host(
self, hostname: str, user_role: str
) -> Tuple[bool, str]:
"""Validate host with enhanced SSRF protection.
Args:
hostname: Hostname or IP address
user_role: User role
Returns:
Tuple of (allowed: bool, reason: str)
"""
try:
if not hostname or len(hostname) > 253:
return False, f"Invalid hostname length: {hostname}"
suspicious_chars = [
"..",
"//",
"\\",
"<",
">",
'"',
"'",
"`",
"\x00",
]
if any(char in hostname for char in suspicious_chars):
return False, f"Suspicious characters in hostname: {hostname}"
if hostname.lower() in ("localhost", "127.0.0.1", "::1"):
return False, f"Host {hostname} is blocked (localhost)"
resolved_ip = self._resolve_hostname_with_pinning(hostname)
if not resolved_ip:
return False, f"Cannot resolve hostname: {hostname}"
if resolved_ip in self._blocked_ips:
return False, f"IP {resolved_ip} is in blocked list"
ip_addr = ipaddress.ip_address(resolved_ip)
for blocked_network in self._blocked_networks:
if ip_addr in ipaddress.ip_network(blocked_network):
return (
False,
f"IP {resolved_ip} is in blocked network {blocked_network}",
)
allowed_networks = self._allowed_networks.get(user_role, [])
if not allowed_networks:
return False, f"Role {user_role} has no allowed networks"
if user_role == "SUPER_ADMIN":
return True, f"IP {resolved_ip} allowed for SUPER_ADMIN"
for allowed_network in allowed_networks:
if ip_addr in ipaddress.ip_network(allowed_network):
return (
True,
f"IP {resolved_ip} allowed in network {allowed_network}",
)
return (
False,
f"IP {resolved_ip} not in any allowed network for role {user_role}",
)
except Exception as e:
logger.error("SSRF validation error", hostname=hostname, error=str(e))
return False, f"Error validating host: {str(e)}"
def _resolve_hostname_with_pinning(self, hostname: str) -> Optional[str]:
"""DNS resolution with pinning to prevent rebinding attacks.
Args:
hostname: Hostname to resolve
Returns:
IP address or None if resolution failed
"""
try:
cache_key = hostname.lower()
if cache_key in self._dns_cache:
cached_ip, timestamp, ttl = self._dns_cache[cache_key]
if time.time() - timestamp < ttl:
logger.debug(
"Using cached DNS resolution",
hostname=hostname,
ip=cached_ip,
age_seconds=int(time.time() - timestamp),
)
return cached_ip
del self._dns_cache[cache_key]
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(5)
try:
ip1 = socket.gethostbyname(hostname)
time.sleep(0.1)
ip2 = socket.gethostbyname(hostname)
if ip1 != ip2:
logger.warning(
"DNS rebinding detected",
hostname=hostname,
ip1=ip1,
ip2=ip2,
)
return None
if ip1 in ("127.0.0.1", "::1"):
logger.warning(
"DNS resolution returned localhost", hostname=hostname, ip=ip1
)
return None
self._dns_cache[cache_key] = (
ip1,
time.time(),
self._dns_cache_ttl,
)
logger.info(
"DNS resolution successful", hostname=hostname, ip=ip1, cached=True
)
return ip1
finally:
socket.setdefaulttimeout(original_timeout)
except socket.gaierror as e:
logger.warning("DNS resolution failed", hostname=hostname, error=str(e))
return None
except socket.timeout:
logger.warning("DNS resolution timeout", hostname=hostname)
return None
except Exception as e:
logger.error(
"Unexpected DNS resolution error", hostname=hostname, error=str(e)
)
return None
def validate_port(self, port: int) -> Tuple[bool, str]:
"""Validate port number.
Args:
port: Port number
Returns:
Tuple of (valid: bool, reason: str)
"""
if not isinstance(port, int) or port < 1 or port > 65535:
return False, f"Invalid port: {port}"
blocked_ports = {
22,
23,
25,
53,
80,
110,
143,
443,
993,
995,
135,
139,
445,
1433,
1521,
3306,
5432,
6379,
3389,
5900,
5901,
5902,
5903,
5904,
5905,
8080,
8443,
9090,
9091,
}
if port in blocked_ports:
return False, f"Port {port} is blocked (system port)"
return True, f"Port {port} is valid"
def cleanup_expired_cache(self) -> None:
"""Clean up expired DNS cache entries."""
current_time = time.time()
expired_keys = [
key
for key, (_, timestamp, ttl) in self._dns_cache.items()
if current_time - timestamp > ttl
]
for key in expired_keys:
del self._dns_cache[key]
if expired_keys:
logger.info(
"Cleaned up expired DNS cache entries", count=len(expired_keys)
)
def get_cache_stats(self) -> Dict[str, Any]:
"""Get DNS cache statistics.
Returns:
Dictionary with cache statistics
"""
current_time = time.time()
active_entries = 0
expired_entries = 0
for _, timestamp, ttl in self._dns_cache.values():
if current_time - timestamp < ttl:
active_entries += 1
else:
expired_entries += 1
return {
"total_entries": len(self._dns_cache),
"active_entries": active_entries,
"expired_entries": expired_entries,
"cache_ttl_seconds": self._dns_cache_ttl,
"blocked_ips_count": len(self._blocked_ips),
"blocked_networks_count": len(self._blocked_networks),
}
def add_blocked_ip(self, ip: str) -> None:
"""Add IP to blocked list.
Args:
ip: IP address to block
"""
self._blocked_ips.add(ip)
logger.info("Added IP to blocked list", ip=ip)
def remove_blocked_ip(self, ip: str) -> None:
"""Remove IP from blocked list.
Args:
ip: IP address to unblock
"""
self._blocked_ips.discard(ip)
logger.info("Removed IP from blocked list", ip=ip)
# Global instance for use in API
ssrf_protection = SSRFProtection()

View File

@ -0,0 +1,263 @@
"""Redis-based token blacklist for JWT token revocation."""
# Standard library imports
import hashlib
import json
import os
from datetime import datetime, timezone
from typing import Any, Dict, Optional
# Third-party imports
import redis
import structlog
# Local imports
from .session_storage import session_storage
logger = structlog.get_logger(__name__)
# Redis configuration constants
REDIS_SOCKET_TIMEOUT = 5
REDIS_DEFAULT_HOST = "localhost"
REDIS_DEFAULT_PORT = "6379"
REDIS_DEFAULT_DB = "0"
# Blacklist constants
BLACKLIST_KEY_PREFIX = "blacklist:"
TOKEN_HASH_PREVIEW_LENGTH = 16
DEFAULT_REVOCATION_REASON = "logout"
DEFAULT_FORCE_LOGOUT_REASON = "force_logout"
class TokenBlacklist:
"""Redis-based blacklist for JWT token revocation."""
def __init__(self) -> None:
"""Initialize token blacklist with Redis connection."""
self.redis_client = redis.Redis(
host=os.getenv("REDIS_HOST", REDIS_DEFAULT_HOST),
port=int(os.getenv("REDIS_PORT", REDIS_DEFAULT_PORT)),
password=os.getenv("REDIS_PASSWORD"),
db=int(os.getenv("REDIS_DB", REDIS_DEFAULT_DB)),
decode_responses=True,
socket_connect_timeout=REDIS_SOCKET_TIMEOUT,
socket_timeout=REDIS_SOCKET_TIMEOUT,
retry_on_timeout=True,
)
try:
self.redis_client.ping()
logger.info("Token blacklist Redis connection established")
except redis.ConnectionError as e:
logger.error(
"Failed to connect to Redis for token blacklist", error=str(e)
)
raise
def _get_token_hash(self, token: str) -> str:
"""Get token hash for use as Redis key.
Args:
token: JWT token
Returns:
SHA-256 hash of token
"""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def revoke_token(
self,
token: str,
reason: str = DEFAULT_REVOCATION_REASON,
revoked_by: Optional[str] = None,
) -> bool:
"""Revoke token (add to blacklist).
Args:
token: JWT token to revoke
reason: Revocation reason
revoked_by: Username who revoked the token
Returns:
True if token successfully revoked
"""
try:
from .utils import get_token_expiry_info
expiry_info = get_token_expiry_info(token)
if not expiry_info:
logger.warning(
"Cannot revoke token: invalid or expired", reason=reason
)
return False
token_hash = self._get_token_hash(token)
now = datetime.now(timezone.utc)
blacklist_data = {
"token_hash": token_hash,
"reason": reason,
"revoked_at": now.isoformat(),
"revoked_by": revoked_by,
"expires_at": expiry_info["expires_at"],
"username": expiry_info.get("username"),
"token_type": expiry_info.get("token_type", "access"),
}
expires_at = datetime.fromisoformat(expiry_info["expires_at"])
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
ttl_seconds = int((expires_at - now).total_seconds())
if ttl_seconds <= 0:
logger.debug(
"Token already expired, no need to blacklist",
username=expiry_info.get("username"),
)
return True
self.redis_client.setex(
f"{BLACKLIST_KEY_PREFIX}{token_hash}",
ttl_seconds,
json.dumps(blacklist_data),
)
logger.info(
"Token revoked successfully",
token_hash=token_hash[:TOKEN_HASH_PREVIEW_LENGTH] + "...",
username=expiry_info.get("username"),
reason=reason,
revoked_by=revoked_by,
ttl_seconds=ttl_seconds,
)
return True
except Exception as e:
logger.error("Failed to revoke token", error=str(e), reason=reason)
return False
def is_token_revoked(self, token: str) -> bool:
"""Check if token is revoked.
Args:
token: JWT token to check
Returns:
True if token is revoked
"""
try:
token_hash = self._get_token_hash(token)
blacklist_data = self.redis_client.get(f"{BLACKLIST_KEY_PREFIX}{token_hash}")
if blacklist_data:
data = json.loads(blacklist_data)
logger.debug(
"Token is revoked",
token_hash=token_hash[:TOKEN_HASH_PREVIEW_LENGTH] + "...",
reason=data.get("reason"),
revoked_at=data.get("revoked_at"),
)
return True
return False
except Exception as e:
logger.error(
"Failed to check token revocation status", error=str(e)
)
return False
def revoke_user_tokens(
self,
username: str,
reason: str = DEFAULT_FORCE_LOGOUT_REASON,
revoked_by: Optional[str] = None,
) -> int:
"""Revoke all user tokens.
Args:
username: Username
reason: Revocation reason
revoked_by: Who revoked the tokens
Returns:
Number of revoked tokens
"""
try:
session = session_storage.get_session_by_username(username)
if not session:
logger.debug(
"No active session found for user", username=username
)
return 0
session_storage.delete_user_sessions(username)
logger.info(
"User tokens revoked",
username=username,
reason=reason,
revoked_by=revoked_by,
)
return 1
except Exception as e:
logger.error(
"Failed to revoke user tokens", username=username, error=str(e)
)
return 0
def get_blacklist_stats(self) -> Dict[str, Any]:
"""Get blacklist statistics.
Returns:
Statistics about revoked tokens
"""
try:
blacklist_keys = self.redis_client.keys(f"{BLACKLIST_KEY_PREFIX}*")
reasons_count: Dict[str, int] = {}
for key in blacklist_keys:
data = self.redis_client.get(key)
if data:
blacklist_data = json.loads(data)
reason = blacklist_data.get("reason", "unknown")
reasons_count[reason] = reasons_count.get(reason, 0) + 1
return {
"revoked_tokens": len(blacklist_keys),
"reasons": reasons_count,
"redis_memory_usage": (
self.redis_client.memory_usage(f"{BLACKLIST_KEY_PREFIX}*")
if blacklist_keys
else 0
),
}
except Exception as e:
logger.error("Failed to get blacklist stats", error=str(e))
return {
"revoked_tokens": 0,
"reasons": {},
"redis_memory_usage": 0,
"error": str(e),
}
def cleanup_expired_blacklist(self) -> int:
"""Clean up expired blacklist entries.
Returns:
Number of cleaned entries (always 0, Redis handles this automatically)
"""
logger.debug(
"Expired blacklist cleanup completed (Redis TTL handles this automatically)"
)
return 0
# Global instance for use in API
token_blacklist = TokenBlacklist()

View File

@ -0,0 +1,362 @@
"""Utilities for JWT token and session storage operations."""
# Standard library imports
import os
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
# Third-party imports
import jwt
import structlog
# Local imports
from .session_storage import session_storage
from .token_blacklist import token_blacklist
logger = structlog.get_logger(__name__)
# JWT configuration from environment variables
JWT_SECRET_KEY = os.getenv(
"JWT_SECRET_KEY",
"your_super_secret_jwt_key_minimum_32_characters_long",
)
JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
JWT_ACCESS_TOKEN_EXPIRE_MINUTES = int(
os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "60")
)
JWT_REFRESH_TOKEN_EXPIRE_DAYS = int(
os.getenv("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "7")
)
def create_jwt_token(
user_info: Dict[str, Any], session_id: str, token_type: str = "access"
) -> str:
"""Create JWT token with session_id instead of Guacamole token.
Args:
user_info: User information dictionary
session_id: Session ID in Redis
token_type: Token type ("access" or "refresh")
Returns:
JWT token as string
Raises:
Exception: If token creation fails
"""
try:
if token_type == "refresh":
expire_delta = timedelta(days=JWT_REFRESH_TOKEN_EXPIRE_DAYS)
else:
expire_delta = timedelta(
minutes=JWT_ACCESS_TOKEN_EXPIRE_MINUTES
)
now = datetime.now(timezone.utc)
payload = {
"username": user_info["username"],
"role": user_info["role"],
"permissions": user_info.get("permissions", []),
"session_id": session_id,
"token_type": token_type,
"exp": now + expire_delta,
"iat": now,
"iss": "remote-access-api",
}
optional_fields = [
"full_name",
"email",
"organization",
"organizational_role",
]
for field in optional_fields:
if field in user_info:
payload[field] = user_info[field]
token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
logger.info(
"JWT token created successfully",
username=user_info["username"],
token_type=token_type,
session_id=session_id,
expires_in_minutes=expire_delta.total_seconds() / 60,
payload_keys=list(payload.keys()),
token_prefix=token[:30] + "...",
)
return token
except Exception as e:
logger.error(
"Failed to create JWT token",
username=user_info.get("username", "unknown"),
error=str(e),
)
raise
def verify_jwt_token(token: str) -> Optional[Dict[str, Any]]:
"""Verify and decode JWT token with blacklist check.
Args:
token: JWT token to verify
Returns:
Decoded payload or None if token is invalid
"""
try:
logger.debug("Starting JWT verification", token_prefix=token[:30] + "...")
if token_blacklist.is_token_revoked(token):
logger.info("JWT token is revoked", token_prefix=token[:20] + "...")
return None
logger.debug("Token not in blacklist, attempting decode")
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
logger.info(
"JWT decode successful",
username=payload.get("username"),
payload_keys=list(payload.keys()),
has_session_id="session_id" in payload,
session_id=payload.get("session_id", "NOT_FOUND"),
)
required_fields = ["username", "role", "session_id", "exp", "iat"]
for field in required_fields:
if field not in payload:
logger.warning(
"JWT token missing required field",
field=field,
username=payload.get("username", "unknown"),
available_fields=list(payload.keys()),
)
return None
logger.debug("All required fields present")
exp_timestamp = payload["exp"]
current_timestamp = datetime.now(timezone.utc).timestamp()
if current_timestamp > exp_timestamp:
logger.info(
"JWT token expired",
username=payload["username"],
expired_at=datetime.fromtimestamp(
exp_timestamp, tz=timezone.utc
).isoformat(),
current_time=datetime.now(timezone.utc).isoformat(),
)
return None
logger.debug(
"Token not expired, checking Redis session",
session_id=payload["session_id"],
)
session_data = session_storage.get_session(payload["session_id"])
if not session_data:
logger.warning(
"Session not found for JWT token",
username=payload["username"],
session_id=payload["session_id"],
possible_reasons=[
"session expired in Redis",
"session never created",
"Redis connection issue",
],
)
return None
logger.debug(
"Session found in Redis",
username=payload["username"],
session_id=payload["session_id"],
session_keys=list(session_data.keys()),
)
if "guac_token" not in session_data:
logger.error(
"Session exists but missing guac_token",
username=payload["username"],
session_id=payload["session_id"],
session_keys=list(session_data.keys()),
)
return None
payload["guac_token"] = session_data["guac_token"]
logger.info(
"JWT token verified successfully",
username=payload["username"],
role=payload["role"],
token_type=payload.get("token_type", "access"),
session_id=payload["session_id"],
guac_token_length=len(session_data["guac_token"]),
)
return payload
except jwt.ExpiredSignatureError:
logger.info(
"JWT token expired (ExpiredSignatureError)",
token_prefix=token[:20] + "...",
)
return None
except jwt.InvalidTokenError as e:
logger.warning(
"Invalid JWT token (InvalidTokenError)",
error=str(e),
error_type=type(e).__name__,
token_prefix=token[:20] + "...",
)
return None
except Exception as e:
logger.error(
"Unexpected error verifying JWT token",
error=str(e),
error_type=type(e).__name__,
)
return None
def create_refresh_token(
user_info: Dict[str, Any], session_id: str
) -> str:
"""Create refresh token.
Args:
user_info: User information dictionary
session_id: Session ID in Redis
Returns:
Refresh token
"""
return create_jwt_token(user_info, session_id, token_type="refresh")
def extract_token_from_header(
authorization_header: Optional[str],
) -> Optional[str]:
"""Extract token from Authorization header.
Args:
authorization_header: Authorization header value
Returns:
JWT token or None
"""
if not authorization_header:
return None
if not authorization_header.startswith("Bearer "):
return None
return authorization_header.split(" ", 1)[1]
def get_token_expiry_info(token: str) -> Optional[Dict[str, Any]]:
"""Get token expiration information.
Args:
token: JWT token
Returns:
Expiration information or None
"""
try:
payload = jwt.decode(token, options={"verify_signature": False})
exp_timestamp = payload.get("exp")
iat_timestamp = payload.get("iat")
if not exp_timestamp:
return None
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
iat_datetime = (
datetime.fromtimestamp(iat_timestamp, tz=timezone.utc)
if iat_timestamp
else None
)
current_time = datetime.now(timezone.utc)
return {
"expires_at": exp_datetime.isoformat(),
"issued_at": iat_datetime.isoformat() if iat_datetime else None,
"expires_in_seconds": max(
0, int((exp_datetime - current_time).total_seconds())
),
"is_expired": current_time > exp_datetime,
"username": payload.get("username"),
"token_type": payload.get("token_type", "access"),
"session_id": payload.get("session_id"),
}
except Exception as e:
logger.error("Failed to get token expiry info", error=str(e))
return None
def is_token_expired(token: str) -> bool:
"""Check if token is expired.
Args:
token: JWT token
Returns:
True if token is expired, False if valid
"""
expiry_info = get_token_expiry_info(token)
return expiry_info["is_expired"] if expiry_info else True
def revoke_token(
token: str,
reason: str = "logout",
revoked_by: Optional[str] = None,
) -> bool:
"""Revoke token (add to blacklist).
Args:
token: JWT token to revoke
reason: Revocation reason
revoked_by: Who revoked the token
Returns:
True if token successfully revoked
"""
try:
return token_blacklist.revoke_token(token, reason, revoked_by)
except Exception as e:
logger.error("Failed to revoke token", error=str(e))
return False
def revoke_user_tokens(
username: str,
reason: str = "force_logout",
revoked_by: Optional[str] = None,
) -> int:
"""Revoke all user tokens.
Args:
username: Username
reason: Revocation reason
revoked_by: Who revoked the tokens
Returns:
Number of revoked tokens
"""
try:
return token_blacklist.revoke_user_tokens(
username, reason, revoked_by
)
except Exception as e:
logger.error(
"Failed to revoke user tokens", username=username, error=str(e)
)
return 0

View File

@ -0,0 +1,326 @@
"""WebSocket Manager for real-time client notifications."""
# Standard library imports
import asyncio
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
# Third-party imports
from fastapi import WebSocket
import structlog
logger = structlog.get_logger(__name__)
class WebSocketManager:
"""WebSocket connection manager for sending notifications to clients.
Supported events:
- connection_expired: Connection expired
- connection_deleted: Connection deleted manually
- connection_will_expire: Connection will expire soon (5 min warning)
- jwt_will_expire: JWT token will expire soon (5 min warning)
- jwt_expired: JWT token expired
- connection_extended: Connection TTL extended
"""
def __init__(self) -> None:
"""Initialize WebSocket manager."""
self.active_connections: Dict[str, Set[WebSocket]] = {}
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, username: str) -> None:
"""Connect a new client.
Args:
websocket: WebSocket connection (already accepted)
username: Username
"""
async with self._lock:
if username not in self.active_connections:
self.active_connections[username] = set()
self.active_connections[username].add(websocket)
logger.info(
"WebSocket client connected",
username=username,
total_connections=len(
self.active_connections.get(username, set())
),
)
async def disconnect(self, websocket: WebSocket, username: str) -> None:
"""Disconnect a client.
Args:
websocket: WebSocket connection
username: Username
"""
async with self._lock:
if username in self.active_connections:
self.active_connections[username].discard(websocket)
if not self.active_connections[username]:
del self.active_connections[username]
logger.info(
"WebSocket client disconnected",
username=username,
remaining_connections=len(
self.active_connections.get(username, set())
),
)
async def send_to_user(
self, username: str, message: Dict[str, Any]
) -> None:
"""Send message to all WebSocket connections of a user.
Args:
username: Username
message: Dictionary with data to send
"""
if username not in self.active_connections:
logger.debug(
"No active WebSocket connections for user", username=username
)
return
connections = self.active_connections[username].copy()
disconnected = []
for websocket in connections:
try:
await websocket.send_json(message)
logger.debug(
"Message sent via WebSocket",
username=username,
event_type=message.get("type"),
)
except Exception as e:
logger.error(
"Failed to send WebSocket message",
username=username,
error=str(e),
)
disconnected.append(websocket)
if disconnected:
async with self._lock:
for ws in disconnected:
self.active_connections[username].discard(ws)
if not self.active_connections[username]:
del self.active_connections[username]
async def send_connection_expired(
self,
username: str,
connection_id: str,
hostname: str,
protocol: str,
) -> None:
"""Notify about connection expiration.
Args:
username: Username
connection_id: Connection ID
hostname: Machine hostname
protocol: Connection protocol
"""
message = {
"type": "connection_expired",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"connection_id": connection_id,
"hostname": hostname,
"protocol": protocol,
"reason": "TTL expired",
},
}
await self.send_to_user(username, message)
logger.info(
"Connection expired notification sent",
username=username,
connection_id=connection_id,
hostname=hostname,
)
async def send_connection_deleted(
self,
username: str,
connection_id: str,
hostname: str,
protocol: str,
reason: str = "manual",
) -> None:
"""Notify about connection deletion.
Args:
username: Username
connection_id: Connection ID
hostname: Machine hostname
protocol: Connection protocol
reason: Deletion reason (manual, expired, error)
"""
message = {
"type": "connection_deleted",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"connection_id": connection_id,
"hostname": hostname,
"protocol": protocol,
"reason": reason,
},
}
await self.send_to_user(username, message)
logger.info(
"Connection deleted notification sent",
username=username,
connection_id=connection_id,
reason=reason,
)
async def send_connection_will_expire(
self,
username: str,
connection_id: str,
hostname: str,
protocol: str,
minutes_remaining: int,
) -> None:
"""Warn about upcoming connection expiration.
Args:
username: Username
connection_id: Connection ID
hostname: Machine hostname
protocol: Connection protocol
minutes_remaining: Minutes until expiration
"""
message = {
"type": "connection_will_expire",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"connection_id": connection_id,
"hostname": hostname,
"protocol": protocol,
"minutes_remaining": minutes_remaining,
},
}
await self.send_to_user(username, message)
logger.info(
"Connection expiration warning sent",
username=username,
connection_id=connection_id,
minutes_remaining=minutes_remaining,
)
async def send_jwt_will_expire(
self, username: str, minutes_remaining: int
) -> None:
"""Warn about upcoming JWT token expiration.
Args:
username: Username
minutes_remaining: Minutes until expiration
"""
message = {
"type": "jwt_will_expire",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"minutes_remaining": minutes_remaining,
"action_required": "Please refresh your token or re-login",
},
}
await self.send_to_user(username, message)
logger.info(
"JWT expiration warning sent",
username=username,
minutes_remaining=minutes_remaining,
)
async def send_jwt_expired(self, username: str) -> None:
"""Notify about JWT token expiration.
Args:
username: Username
"""
message = {
"type": "jwt_expired",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {"action_required": "Please re-login"},
}
await self.send_to_user(username, message)
logger.info("JWT expired notification sent", username=username)
async def send_connection_extended(
self,
username: str,
connection_id: str,
hostname: str,
new_expires_at: datetime,
additional_minutes: int,
) -> None:
"""Notify about connection extension.
Args:
username: Username
connection_id: Connection ID
hostname: Machine hostname
new_expires_at: New expiration time
additional_minutes: Minutes added
"""
message = {
"type": "connection_extended",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"connection_id": connection_id,
"hostname": hostname,
"new_expires_at": new_expires_at.isoformat(),
"additional_minutes": additional_minutes,
},
}
await self.send_to_user(username, message)
logger.info(
"Connection extension notification sent",
username=username,
connection_id=connection_id,
additional_minutes=additional_minutes,
)
def get_active_users(self) -> List[str]:
"""Get list of users with active WebSocket connections.
Returns:
List of usernames
"""
return list(self.active_connections.keys())
def get_connection_count(self, username: Optional[str] = None) -> int:
"""Get count of active WebSocket connections.
Args:
username: Username (if None, returns total count)
Returns:
Number of connections
"""
if username:
return len(self.active_connections.get(username, []))
return sum(
len(connections)
for connections in self.active_connections.values()
)
# Singleton instance
websocket_manager = WebSocketManager()