Skip to content

Authentication

PyStrands provides a flexible authentication system through the on_connection_request callback. This allows you to validate incoming WebSocket connections before they're accepted.

Connection Request Flow

When a WebSocket client attempts to connect, the following flow occurs:

1. Client connects to ws://broker/room-name
2. Go Broker → Python: connection_request message
3. Python validates headers, URL, IP address
4. Python → Go Broker: accept/reject response
5. Go Broker accepts or rejects the WebSocket

The ConnectionRequestContext

The on_connection_request method receives a ConnectionRequestContext object containing:

Attribute Type Description
headers Dict[str, List[str]] HTTP headers from the WebSocket upgrade request
url str URL path the client connected to
remote_addr str Client's IP address
context Context The context object to modify for accepted connections
accepted bool Default True, set to False to reject

The context object contains:

Attribute Type Description
client_id str Auto-generated unique client ID
room_id str Room assignment (modify this!)
metadata dict Custom data storage for auth info, user details, etc.

Basic Authentication Example

Async Client

import asyncio
from pystrands import AsyncPyStrandsClient

class AuthenticatedBackend(AsyncPyStrandsClient):
    async def on_connection_request(self, request):
        # Extract Authorization header
        headers = request.headers
        auth_header = headers.get("Authorization", [None])[0]

        # Validate token
        if auth_header != "Bearer valid-token":
            print(f"Rejecting connection from {request.remote_addr}: invalid token")
            return False  # Reject the connection

        # Set room based on URL path
        request.context.room_id = request.url.strip("/")

        # Store user info in metadata for later use
        request.context.metadata = {
            "user_id": "user-123",
            "role": "premium",
            "joined_at": "2024-01-15T10:30:00Z"
        }

        print(f"Accepted connection: {request.context.client_id}")
        return True  # Accept the connection

    async def on_message(self, message, context):
        # Access metadata set during authentication
        user_role = context.metadata.get("role", "guest")
        print(f"[{user_role}] {context.client_id}: {message}")

        await self.send_room_message(context.room_id, f"echo: {message}")

client = AuthenticatedBackend(host="localhost", port=8081)
asyncio.run(client.run_forever())

Sync Client

from pystrands import PyStrandsClient

class AuthenticatedBackend(PyStrandsClient):
    def on_connection_request(self, request):
        # Check for API key in query parameters or headers
        auth_header = request.headers.get("X-API-Key", [None])[0]

        valid_keys = ["secret-key-1", "secret-key-2"]

        if auth_header not in valid_keys:
            return False

        # Assign room and metadata
        request.context.room_id = request.url.strip("/") or "lobby"
        request.context.metadata = {"api_key": auth_header}

        return True

    def on_message(self, message, context):
        self.send_room_message(context.room_id, f"echo: {message}")

client = AuthenticatedBackend(host="localhost", port=8081)
client.run_forever()

JWT Authentication

For production applications, you might use JWT tokens:

import asyncio
import jwt
from pystrands import AsyncPyStrandsClient

JWT_SECRET = "your-secret-key"

class JWTBackend(AsyncPyStrandsClient):
    async def on_connection_request(self, request):
        # Get token from Authorization header
        auth_header = request.headers.get("Authorization", [None])[0]

        if not auth_header or not auth_header.startswith("Bearer "):
            return False

        token = auth_header[7:]  # Remove "Bearer " prefix

        try:
            # Verify JWT
            payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])

            # Set room from URL or JWT claim
            request.context.room_id = request.url.strip("/") or payload.get("room", "default")

            # Store user info in metadata
            request.context.metadata = {
                "user_id": payload["sub"],
                "username": payload.get("username"),
                "permissions": payload.get("permissions", [])
            }

            return True

        except jwt.ExpiredSignatureError:
            print(f"Token expired from {request.remote_addr}")
            return False
        except jwt.InvalidTokenError as e:
            print(f"Invalid token from {request.remote_addr}: {e}")
            return False

    async def on_message(self, message, context):
        user = context.metadata.get("username", "anonymous")
        await self.send_room_message(context.room_id, f"{user}: {message}")

client = JWTBackend(host="localhost", port=8081)
asyncio.run(client.run_forever())

JWT Library

Install PyJWT with: pip install pyjwt

IP-Based Access Control

Restrict connections by IP address:

ALLOWED_IPS = {"192.168.1.100", "10.0.0.50"}
BLOCKED_IPS = {"192.168.1.200"}

async def on_connection_request(self, request):
    client_ip = request.remote_addr

    # Check blocked IPs
    if client_ip in BLOCKED_IPS:
        print(f"Blocked connection from {client_ip}")
        return False

    # Check allowed IPs (if whitelist is defined)
    if ALLOWED_IPS and client_ip not in ALLOWED_IPS:
        print(f"Connection from {client_ip} not in whitelist")
        return False

    request.context.room_id = "restricted-room"
    return True

Query Parameter Authentication

Sometimes it's easier to pass tokens in the URL:

from urllib.parse import parse_qs, urlparse

async def on_connection_request(self, request):
    # Parse URL query parameters
    # URL format: ws://broker/room-name?token=abc123
    parsed = urlparse(request.url)
    params = parse_qs(parsed.query)

    token = params.get("token", [None])[0]

    if not self.validate_token(token):
        return False

    # Room is the path, not including query params
    request.context.room_id = parsed.path.strip("/")
    request.context.metadata = {"token": token}

    return True

def validate_token(self, token):
    # Your validation logic here
    return token in VALID_TOKENS

Rate Limiting

Implement rate limiting to prevent abuse:

import asyncio
import time
from collections import defaultdict

class RateLimitedBackend(AsyncPyStrandsClient):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.connection_attempts = defaultdict(list)  # IP -> [timestamps]
        self.rate_limit = 5  # connections per minute
        self.rate_window = 60  # seconds

    async def on_connection_request(self, request):
        client_ip = request.remote_addr
        now = time.time()

        # Clean old entries
        self.connection_attempts[client_ip] = [
            ts for ts in self.connection_attempts[client_ip]
            if now - ts < self.rate_window
        ]

        # Check rate limit
        if len(self.connection_attempts[client_ip]) >= self.rate_limit:
            print(f"Rate limit exceeded for {client_ip}")
            return False

        # Record this attempt
        self.connection_attempts[client_ip].append(now)

        # Continue with normal auth...
        request.context.room_id = request.url.strip("/")
        return True

Rejection Best Practices

When rejecting connections:

  1. Log rejections — For security monitoring and debugging
  2. Don't expose too much info — Avoid revealing why a connection was rejected (prevents information leakage)
  3. Rate limit your rejections — Don't let attackers flood your logs
  4. Use appropriate metadata — Store enough info in context.metadata for your application logic

No Custom Error Messages

Currently, rejected connections receive a standard WebSocket close. If you need to communicate rejection reasons to clients, implement a separate endpoint for pre-validation.

Accessing Metadata in Handlers

After authentication, access the metadata in all handlers:

async def on_new_connection(self, context):
    user = context.metadata.get("username", "anonymous")
    print(f"User {user} connected to {context.room_id}")

async def on_message(self, message, context):
    user_id = context.metadata.get("user_id")
    permissions = context.metadata.get("permissions", [])

    # Check permissions before processing
    if "send_message" not in permissions:
        await self.send_private_message(
            context.client_id,
            "You don't have permission to send messages"
        )
        return

    # Process message...

async def on_disconnect(self, context):
    user = context.metadata.get("username", "unknown")
    print(f"User {user} disconnected")