Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions ingestify/infra/event_log/consumer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import logging
import time
from typing import Callable, Optional
from typing import Callable, List, Optional

from sqlalchemy import create_engine, select

from ingestify.domain.models.event.domain_event import DomainEvent

from .event_log import EventLog
from .tables import get_tables

OnEventHandler = Callable[[DomainEvent], None]
OnEventsHandler = Callable[[List[DomainEvent]], None]

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -77,7 +82,7 @@ def _update_cursor(self, conn, event_id: int) -> None:
)
conn.commit()

def _run_once(self, on_event: Callable, batch_size: int = 100) -> int:
def _run_once(self, on_event: OnEventHandler, batch_size: int = 100) -> int:
"""Returns number of events processed, or -1 if a processing error occurred."""
with self._engine.connect() as conn:
self._ensure_reader_state(conn)
Expand All @@ -102,7 +107,7 @@ def _run_once(self, on_event: Callable, batch_size: int = 100) -> int:

def run(
self,
on_event: Callable,
on_event: OnEventHandler,
poll_interval: Optional[int] = None,
batch_size: int = 100,
) -> int:
Expand All @@ -114,3 +119,58 @@ def run(
if poll_interval is None:
return 0
time.sleep(poll_interval)

def _run_batched_once(self, on_events: OnEventsHandler, batch_size: int) -> int:
"""Returns number of events processed, or -1 if a processing error occurred.

on_events receives the full list of DomainEvent instances for this
batch. The cursor advances to the last event's id only after
on_events returns without raising.
"""
with self._engine.connect() as conn:
self._ensure_reader_state(conn)
last_id = self._get_last_event_id(conn)

rows = self._event_log.fetch_batch(last_id, batch_size)
if not rows:
return 0

events = [event for _, event in rows]
try:
on_events(events)
except Exception:
logger.exception(
"Failed to process batch of %d events — cursor NOT advanced",
len(rows),
)
return -1

last_event_id = rows[-1][0]
with self._engine.connect() as conn:
self._update_cursor(conn, last_event_id)

return len(rows)

def run_batched(
self,
on_events: OnEventsHandler,
poll_interval: Optional[int] = None,
batch_size: int = 1000,
) -> int:
"""Consume events in batches.

on_events is called with a List[DomainEvent] per batch. The cursor
advances once per batch, not per event — letting callers parallelize
I/O-bound work within a batch (threads, asyncio, etc.) without
hitting the DB on every event.

Exit codes match run(): 0 success, 1 processing error.
"""
while True:
count = self._run_batched_once(on_events, batch_size)
if count < 0:
return 1
if count == 0:
if poll_interval is None:
return 0
time.sleep(poll_interval)
Loading