Skip to main content
"""
WebSocket Reconnect
===================

Tests reconnect behavior for a running workflow: initial subscription, disconnection, reconnect, and missed-event catch-up.
"""

import asyncio
import json
from typing import Optional

# ---------------------------------------------------------------------------
# Setup
# ---------------------------------------------------------------------------
try:
    import websockets
except ImportError:
    print("websockets library not installed. Install with: uv pip install websockets")
    exit(1)


# ---------------------------------------------------------------------------
# Define Helpers
# ---------------------------------------------------------------------------
def parse_sse_message(message: str) -> dict:
    lines = message.strip().split("\n")
    data_line = None

    for line in lines:
        if line.startswith("data: "):
            data_line = line[6:]
            break

    if data_line:
        return json.loads(data_line)
    return json.loads(message)


# ---------------------------------------------------------------------------
# Create WebSocket Tester
# ---------------------------------------------------------------------------
class WorkflowWebSocketTester:
    def __init__(self, ws_url: str = "ws://localhost:7777/workflows/ws"):
        self.ws_url = ws_url
        self.run_id: Optional[str] = None
        self.last_event_index: Optional[int] = None
        self.received_events = []

    async def test_workflow_execution_with_reconnection(self) -> None:
        print("\n" + "=" * 80)
        print("WebSocket Reconnection Test")
        print("=" * 80)

        print("\nPhase 1: Starting workflow and receiving initial events...")
        await self._phase1_start_workflow()

        print("\nSimulating user leaving page for 3 seconds...")
        await asyncio.sleep(3)

        print("\nPhase 2: Reconnecting to workflow...")
        await self._phase2_reconnect()

        print("\nTest completed")
        self._print_summary()

    async def _phase1_start_workflow(self) -> None:
        try:
            async with websockets.connect(self.ws_url) as websocket:
                print(f"[OK] Connected to {self.ws_url}")

                response = await websocket.recv()
                data = parse_sse_message(response)
                print(f"[OK] Server: {data.get('message', 'Connected')}")

                print("\nSending: start-workflow action")
                await websocket.send(
                    json.dumps(
                        {
                            "action": "start-workflow",
                            "workflow_id": "content-creation-workflow",
                            "message": "Research and create content plan for AI agents",
                            "session_id": "test-session-123",
                        }
                    )
                )

                event_count = 0
                max_initial_events = 20

                print("\nReceiving initial events:")
                async for message in websocket:
                    data = parse_sse_message(message)
                    event_type = data.get("event")

                    if "run_id" in data and not self.run_id:
                        self.run_id = data["run_id"]
                    if "event_index" in data:
                        self.last_event_index = data["event_index"]

                    self.received_events.append(data)
                    event_count += 1

                    event_index = data.get("event_index", "N/A")
                    print(
                        f"  [{event_count}] event_index={event_index}, event={event_type}"
                    )

                    if event_type in ["WorkflowCompleted", "WorkflowError"]:
                        print(
                            f"\nWorkflow finished during initial connection: {event_type}"
                        )
                        break

                    if event_count >= max_initial_events:
                        print(
                            f"\nSimulating disconnect after {event_count} events "
                            f"(last_event_index={self.last_event_index})"
                        )
                        break

        except Exception as e:
            print(f"Error in Phase 1: {e}")
            raise

    async def _phase2_reconnect(self) -> None:
        if not self.run_id:
            print("No run_id found, cannot reconnect")
            return

        try:
            async with websockets.connect(self.ws_url) as websocket:
                print(f"[OK] Reconnected to {self.ws_url}")

                response = await websocket.recv()
                data = parse_sse_message(response)
                print(f"[OK] Server: {data.get('message', 'Connected')}")

                print(
                    f"\nSending: reconnect action (run_id={self.run_id}, "
                    f"last_event_index={self.last_event_index})"
                )
                await websocket.send(
                    json.dumps(
                        {
                            "action": "reconnect",
                            "run_id": self.run_id,
                            "last_event_index": self.last_event_index,
                            "workflow_id": "content-creation-workflow",
                            "session_id": "test-session-123",
                        }
                    )
                )

                print("\nReceiving events after reconnection:")
                event_count = 0
                missed_events_count = 0

                async for message in websocket:
                    data = parse_sse_message(message)
                    event_type = data.get("event")

                    if "event_index" in data:
                        self.last_event_index = data["event_index"]

                    self.received_events.append(data)
                    event_count += 1

                    if event_type == "catch_up":
                        missed_events_count = data.get("missed_events", 0)
                        print(f"catch_up: {missed_events_count} missed events")
                        print(
                            f"status={data.get('status')}, current_event_count={data.get('current_event_count')}"
                        )
                        continue
                    if event_type == "replay":
                        print(
                            f"replay: status={data.get('status')}, total_events={data.get('total_events')}"
                        )
                        print(f"message={data.get('message')}")
                        continue
                    if event_type == "subscribed":
                        print(f"subscribed: status={data.get('status')}")
                        print(f"current_event_count={data.get('current_event_count')}")
                        print("\nNow listening for NEW events as workflow continues...")
                        continue
                    if event_type == "error":
                        print(f"ERROR: {data.get('error', 'Unknown error')}")
                        print(f"Full data: {data}")
                        continue

                    event_index = data.get("event_index", "N/A")
                    is_missed = event_count <= missed_events_count
                    marker = "MISSED" if is_missed else "NEW"
                    print(
                        f"  [{event_count}] {marker} event_index={event_index}, event={event_type}"
                    )

                    if event_type in ["WorkflowCompleted", "WorkflowError"]:
                        print(f"\nWorkflow finished: {event_type}")
                        break

                print("\nWebSocket connection closed (workflow may have completed)")

        except asyncio.TimeoutError:
            print("\nTimeout waiting for events (30s). Workflow may still be running.")
        except Exception as e:
            print(f"Error in Phase 2: {e}")
            raise

    def _print_summary(self) -> None:
        print("\n" + "=" * 80)
        print("Test Summary")
        print("=" * 80)
        print(f"Run ID: {self.run_id}")
        print(f"Last Event Index: {self.last_event_index}")
        print(f"Total Events Received: {len(self.received_events)}")

        event_types = {}
        for event in self.received_events:
            event_type = event.get("event", "unknown")
            event_types[event_type] = event_types.get(event_type, 0) + 1

        print("\nEvent Type Breakdown:")
        for event_type, count in sorted(event_types.items()):
            print(f"  {event_type}: {count}")

        print("\nEvent Index Validation:")
        event_indices = [
            e.get("event_index") for e in self.received_events if "event_index" in e
        ]
        if event_indices:
            print(f"  First event_index: {min(event_indices)}")
            print(f"  Last event_index: {max(event_indices)}")
            print(f"  Total with event_index: {len(event_indices)}")

            expected = set(range(min(event_indices), max(event_indices) + 1))
            actual = set(event_indices)
            gaps = expected - actual
            if gaps:
                print(f"Gaps in event_index: {sorted(gaps)}")
            else:
                print("No gaps in event_index (all events received)")
        else:
            print("No events with event_index found")

        print("=" * 80)


# ---------------------------------------------------------------------------
# Run Workflow
# ---------------------------------------------------------------------------
async def main() -> None:
    print("\nStarting WebSocket Reconnection Test")
    print("Prerequisites:")
    print("  1. AgentOS server should be running at http://localhost:7777")
    print("  2. Run: python cookbook/agent_os/workflow/basic_workflow.py")
    print("\nStarting test in 2 seconds...")
    await asyncio.sleep(2)

    tester = WorkflowWebSocketTester()
    try:
        await tester.test_workflow_execution_with_reconnection()
    except ConnectionRefusedError:
        print("\nConnection refused. Is the AgentOS server running?")
        print("  Start it with: python cookbook/agent_os/workflow/basic_workflow.py")
    except Exception as e:
        print(f"\nTest failed: {e}")
        import traceback

        traceback.print_exc()


if __name__ == "__main__":
    asyncio.run(main())

Run the Example

# Clone and setup repo
git clone https://github.com/agno-agi/agno.git
cd agno/cookbook/04_workflows/06_advanced_concepts/long_running

# Create and activate virtual environment
./scripts/demo_setup.sh
source .venvs/demo/bin/activate

python websocket_reconnect.py