Source code for pybragerone.gateway

"""Gateway: WS connect → modules.connect → listen → prime.

Maintains the WS connection and emits ParamUpdate events on the EventBus.
Does not contain heavy logic (such as mapping) internally — this is the role of ParamStore/HA.
"""

from __future__ import annotations

import asyncio
import logging
import time
from asyncio import CancelledError, Event, Task, TaskGroup, create_task, gather, sleep
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from typing import Any, Protocol

from .api import BragerOneApiClient, RealtimeManager, ServerConfig
from .models.events import EventBus, ParamUpdate

LOG = logging.getLogger(__name__)

# Callback signatures
ParametersCb = Callable[[str, dict[str, Any]], Awaitable[None] | None]  # (event_name, payload)
SnapshotCb = Callable[[dict[str, Any]], Awaitable[None] | None]
GenericCb = Callable[[str, Any], Awaitable[None] | None]


[docs] class ApiClient(Protocol): """Protocol for the HTTP client used by the gateway. This makes the gateway easy to test by allowing a lightweight fake. """ @property def access_token(self) -> str: # noqa: D102 raise NotImplementedError
[docs] async def modules_connect( # noqa: D102 self, wsid_ns: str, modules: list[str], group_id: int | None = None, engine_sid: str | None = None, ) -> bool: raise NotImplementedError
[docs] async def modules_parameters_prime(self, modules: list[str], *, return_data: bool = False) -> tuple[int, Any] | bool: # noqa: D102 raise NotImplementedError
[docs] async def modules_activity_quantity_prime(self, modules: list[str], *, return_data: bool = False) -> tuple[int, Any] | bool: # noqa: D102 raise NotImplementedError
[docs] async def close(self) -> None: # noqa: D102 raise NotImplementedError
[docs] class RealtimeManagerClient(Protocol): """Protocol for the WS client used by the gateway.""" @property def group_id(self) -> int | None: # noqa: D102 raise NotImplementedError @group_id.setter def group_id(self, group_id: int | None) -> None: raise NotImplementedError
[docs] def on_event(self, handler: Any) -> None: # noqa: D102 raise NotImplementedError
[docs] async def connect(self) -> None: # noqa: D102 raise NotImplementedError
[docs] async def disconnect(self) -> None: # noqa: D102 raise NotImplementedError
[docs] def add_on_connected(self, cb: Callable[[], None | Awaitable[None]]) -> None: # noqa: D102 raise NotImplementedError
[docs] def sid(self) -> str | None: # noqa: D102 raise NotImplementedError
[docs] def engine_sid(self) -> str | None: # noqa: D102 raise NotImplementedError
[docs] async def subscribe(self, modules: list[str]) -> None: # noqa: D102 raise NotImplementedError
[docs] class BragerOneGateway: """High-level orchestrator for BragerOne realtime data. Flow: 1) ensure_auth (proactive/reactive refresh in HTTP client) 2) Socket.IO connect → modules.connect (binding WS with DEV) 3) subscribe to streams (parameters, activity) 4) "prime" (REST snapshot of parameters + activity quantities) 5) EventBus emits ParamUpdate for consumers (ParamStore/HA/CLI) """ def __init__( self, *, api: ApiClient, object_id: int, modules: Iterable[str], ws: RealtimeManagerClient | None = None, owns_api: bool = False, ) -> None: """Initialize the gateway but do not start it yet. Args: api: Authenticated API client. The gateway uses it for module binding and primes. object_id: BragerOne object/group ID. modules: Modules to subscribe. ws: Optional WS client instance (useful for testing). owns_api: If True, the gateway closes the API client on :meth:`stop`. """ self.object_id = int(object_id) self.modules = sorted(set(modules)) self.api: ApiClient = api self.ws: RealtimeManagerClient | None = ws self.bus = EventBus() self._owns_api = owns_api self._tasks: set[Task[Any]] = set() self._started = False # (optional) diagnostic signals self._prime_done = Event() self._prime_seq: int | None = None self._first_snapshot = Event() # callbacks (optional backward compatibility) self._on_parameters_change: list[ParametersCb] = [] self._on_snapshot: list[SnapshotCb] = [] self._on_any: list[GenericCb] = []
[docs] @classmethod async def from_credentials( cls, *, email: str, password: str, object_id: int, modules: Iterable[str], server: ServerConfig | None = None, ws: RealtimeManagerClient | None = None, api: BragerOneApiClient | None = None, ) -> BragerOneGateway: """Create a gateway from credentials. This is a convenience helper for CLI/examples. Args: email: BragerOne account email. password: BragerOne account password. object_id: BragerOne object/group ID. modules: Modules to subscribe. server: Optional server/platform configuration (e.g. TiSConnect). ws: Optional WS client instance (testing). api: Optional API client instance (testing/customization). Returns: An initialized gateway (not started). """ owned_api = api is None api_client = api or BragerOneApiClient(server=server) await api_client.ensure_auth(email, password) return cls(api=api_client, object_id=object_id, modules=modules, ws=ws, owns_api=owned_api)
# ------------------------- Public API -------------------------
[docs] def on_parameters_change(self, cb: ParametersCb) -> None: """Register callback for `app:modules:parameters:change`.""" self._on_parameters_change.append(cb)
[docs] def on_snapshot(self, cb: SnapshotCb) -> None: """Register callback for `snapshot` event (full state-like payload).""" self._on_snapshot.append(cb)
[docs] def on_any(self, cb: GenericCb) -> None: """Register callback for *any* WS event for diagnostics.""" self._on_any.append(cb)
[docs] async def start(self) -> None: """Start the whole flow (idempotent).""" if self._started: return self._started = True started_at = time.monotonic() # 1) WS connect if self.ws is None: if isinstance(self.api, BragerOneApiClient): self.ws = RealtimeManager( token=self.api.access_token, origin=self.api.one_base, referer=f"{self.api.one_base}/", io_base=self.api.io_base, ) else: self.ws = RealtimeManager(token=self.api.access_token) ws = self.ws if ws is None: raise RuntimeError("RealtimeManager is not initialized") ws.on_event(self._ws_dispatch) await ws.connect() ws.add_on_connected(self.resubscribe) # in case of reconnect ws_connected_at = time.monotonic() # 3) modules.connect binds the current WS session with modules sid_ns = ws.sid() sid_engine = ws.engine_sid() if not sid_ns: raise RuntimeError("No namespace SID after connecting to WS (Socket.IO).") ok = await self.api.modules_connect(sid_ns, self.modules, group_id=self.object_id, engine_sid=sid_engine) LOG.info("modules.connect: %s (ns_sid=%s, engine_sid=%s)", ok, sid_ns, sid_engine) modules_connected_at = time.monotonic() # 4) WS subscribe + PRIME via REST (in parallel) ws.group_id = self.object_id await ws.subscribe(self.modules) subscribed_at = time.monotonic() ok_params, ok_act = await self._prime_with_retry() primed_at = time.monotonic() LOG.debug("prime injected: parameters=%s activity=%s", ok_params, ok_act) LOG.info( "Gateway started: object_id=%s, modules=%s", self.object_id, ",".join(self.modules), ) LOG.debug( "Gateway startup timings: total=%.3fs ws_connect=%.3fs modules_connect=%.3fs subscribe=%.3fs prime=%.3fs", primed_at - started_at, ws_connected_at - started_at, modules_connected_at - ws_connected_at, subscribed_at - modules_connected_at, primed_at - subscribed_at, )
[docs] async def stop(self) -> None: """Gracefully stop the gateway: drop WS and release HTTP resources.""" self._started = False # Cancel background tasks first (callbacks / bus injectors) try: await self._cancel_all_tasks() except Exception: LOG.exception("Error while canceling background tasks") # 1) disconnect WS try: if self.ws is not None: await self.ws.disconnect() except asyncio.CancelledError: # do not propagate CancelledError during shutdown pass except Exception: LOG.exception("Error while disconnecting WS") # 2) close the HTTP client (if the gateway manages it) try: if self._owns_api: await self.api.close() except Exception: LOG.exception("Error while closing ApiClient")
async def __aenter__(self) -> BragerOneGateway: """Async context manager enter.""" await self.start() return self async def __aexit__(self, *exc: Any) -> None: """Async context manager exit.""" await self.stop()
[docs] async def resubscribe(self) -> None: """Call after WS reconnect to re-bind modules + prime again.""" ws = self.ws if ws is None: return sid_ns = ws.sid() sid_engine = ws.engine_sid() if not sid_ns: return ok = await self.api.modules_connect(sid_ns, self.modules, group_id=self.object_id, engine_sid=sid_engine) LOG.info("modules.connect (resub): %s", ok) await ws.subscribe(self.modules) okp, oka = await self._prime_with_retry() LOG.debug("prime after resubscribe: parameters=%s activity=%s", okp, oka)
[docs] async def wait_for_prime(self, timeout: float | None = None) -> bool: """Wait until the latest prime pass is finished. Args: timeout: Optional timeout in seconds. When ``None``, waits indefinitely. Returns: ``True`` if prime completion event was observed, ``False`` on timeout. """ if self._prime_done.is_set(): return True try: if timeout is None: await self._prime_done.wait() else: await asyncio.wait_for(self._prime_done.wait(), timeout=timeout) return True except TimeoutError: return False
# ------------------------- PRIME & ingest ------------------------- async def _prime(self) -> tuple[bool, bool]: """Fetch initial state via REST (/modules/parameters + /modules/activity/quantity).""" ok_params = False ok_act = False async with TaskGroup() as tg: """Fetch parameters and activity quantities in parallel.""" t_params = tg.create_task( self.api.modules_parameters_prime(self.modules, return_data=True), name="gateway.api.modules_parameters_prime", ) t_act = tg.create_task( self.api.modules_activity_quantity_prime(self.modules, return_data=True), name="gateway.api.modules_activity_quantity_prime", ) res1 = t_params.result() if isinstance(res1, tuple) and len(res1) == 2: st1, data1 = res1 if st1 in (200, 204) and isinstance(data1, dict): await self.ingest_prime_parameters(data1) ok_params = True res2 = t_act.result() if isinstance(res2, tuple) and len(res2) == 2: st2, data2 = res2 if st2 in (200, 204): await self.ingest_activity_quantity(data2 if isinstance(data2, dict) else None) ok_act = True self._prime_seq = self.bus.last_seq() self._prime_done.set() return ok_params, ok_act async def _prime_with_retry(self, tries: int = 3) -> tuple[bool, bool]: """Retry prime a few times with exponential backoff.""" delay = 0.25 for attempt in range(tries): attempt_started = time.monotonic() okp, oka = await self._prime() LOG.debug( "Prime attempt %s/%s finished in %.3fs (parameters=%s activity=%s)", attempt + 1, tries, time.monotonic() - attempt_started, okp, oka, ) if okp: # we care mainly about parameters return okp, oka await sleep(delay) delay = min(delay * 2.0, 2.0) return False, False
[docs] async def ingest_prime_parameters(self, data: dict[str, Any]) -> None: """Treat /modules/parameters prime as "cold snapshot" and publish all pairs.""" pairs = list(self.flatten_parameters(data, source="prime")) async def _pub_all() -> None: for upd in pairs: await self.bus.publish(upd) await _pub_all()
[docs] async def ingest_activity_quantity(self, data: dict[str, Any] | None) -> None: """Ingest /modules/activity/quantity prime (optional).""" if isinstance(data, dict): LOG.debug("activityQuantity: %s", data.get("activityQuantity"))
# ------------------------- WS dispatch ------------------------- async def _invoke_list(self, cbs: list[Callable[..., Any]], *args: Any, **kwargs: Any) -> None: for cb in list(cbs): try: res = cb(*args, **kwargs) if asyncio.iscoroutine(res): await res except Exception: LOG.exception("Callback error") def _ws_dispatch(self, event_name: str, payload: Any) -> Awaitable[None] | None: # Any-listeners (diagnostics) if self._on_any: self._spawn( self._invoke_list(self._on_any, event_name, payload), name="gateway.on_any", ) # snapshot if event_name == "snapshot" and isinstance(payload, dict): pairs = list(self.flatten_parameters(payload, source="snapshot")) async def _pub_all() -> None: for upd in pairs: await self.bus.publish(upd) self._spawn(_pub_all(), name="gateway.publish_snapshot") if self._on_snapshot: self._spawn( self._invoke_list(self._on_snapshot, payload), name="gateway.on_snapshot", ) self._first_snapshot.set() return None # parameters:change if event_name.endswith("parameters:change") and isinstance(payload, dict): pairs = list(self.flatten_parameters(payload, source="ws")) async def _pub_all() -> None: for upd in pairs: await self.bus.publish(upd) self._spawn(_pub_all(), name="gateway.publish_parameters_change") if self._on_parameters_change: self._spawn( self._invoke_list(self._on_parameters_change, event_name, payload), name="gateway.on_parameters_change", ) return None # ------------------------- Helpers -------------------------
[docs] def flatten_parameters(self, payload: dict[str, Any], *, source: str = "unknown") -> list[ParamUpdate]: """Convert WS/REST parameter payload into ParamUpdate events.""" out: list[ParamUpdate] = [] for devid, pools in payload.items(): if not isinstance(pools, dict): continue for pool, entries in pools.items(): if not isinstance(entries, dict): continue for chan_idx, body in entries.items(): if not isinstance(chan_idx, str) or len(chan_idx) < 2: continue chan = chan_idx[0] try: idx = int(chan_idx[1:]) except ValueError: continue val: Any | None = None meta: dict[str, Any] = {} if isinstance(body, dict): if "value" in body: val = body["value"] meta = {k: v for k, v in body.items() if k != "value"} else: val = body meta["_source"] = source out.append( ParamUpdate( devid=str(devid), pool=str(pool), chan=chan, idx=idx, value=val, meta=meta, ) ) return out
def _spawn(self, coro: Coroutine[Any, Any, Any], *, name: str | None = None) -> Task[Any]: """Start a background task, keep reference, and log exceptions.""" t = create_task(coro, name=name) self._tasks.add(t) def _finalizer(task: Task[Any]) -> None: try: _ = task.result() except CancelledError: pass except Exception: LOG.exception("Background task failed: %s", task.get_name() or "<unnamed>") finally: self._tasks.discard(task) t.add_done_callback(_finalizer) return t async def _cancel_all_tasks(self) -> None: """Cancel all tracked tasks and wait for completion.""" if not self._tasks: return for t in list(self._tasks): t.cancel() await gather(*self._tasks, return_exceptions=True) self._tasks.clear()