diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index afc09b3e9d..77d6fab94c 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -20,11 +20,11 @@ import logging from typing import Any from typing import Optional +from typing import overload from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import select -from sqlalchemy import text from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import async_sessionmaker @@ -98,37 +98,85 @@ def __init__(self, version: str): class DatabaseSessionService(BaseSessionService): """A session service that uses a database for storage.""" - def __init__(self, db_url: str, **kwargs: Any): - """Initializes the database session service with a database URL.""" - # 1. Create DB engine for db connection + @overload + def __init__( + self, + db_url: str, + **kwargs: Any, + ) -> None: + """Initializes the database session service with a database URL. + + Args: + db_url: Database URL string for creating a new engine. + **kwargs: Additional keyword arguments passed to create_async_engine. + """ + + @overload + def __init__( + self, + *, + db_engine: AsyncEngine, + ) -> None: + """Initializes the database session service with an existing SQLAlchemy AsyncEngine. + + Args: + db_engine: Existing SQLAlchemy AsyncEngine instance to use. + """ + + def __init__( + self, + db_url: Optional[str] = None, + db_engine: Optional[AsyncEngine] = None, + **kwargs: Any, + ) -> None: + """Initializes the database session service. + + Args: + db_url: Database URL string for creating a new engine. Mutually exclusive + with db_engine. + db_engine: Existing AsyncEngine instance. Mutually exclusive with db_url. + **kwargs: Additional keyword arguments passed to create_async_engine when + db_url is provided. Ignored when db_engine is provided. + + Raises: + ValueError: If neither or both db_url and db_engine are provided, or if + engine creation fails. + """ + if (db_url is None) == (db_engine is None): + raise ValueError( + "Exactly one of 'db_url' or 'db_engine' must be provided." + ) + + # 1. Create or use provided DB engine for db connection # 2. Create all tables based on schema # 3. Initialize all properties - try: - engine_kwargs = dict(kwargs) - url = make_url(db_url) - if url.get_backend_name() == "sqlite" and url.database == ":memory:": - engine_kwargs.setdefault("poolclass", StaticPool) - connect_args = dict(engine_kwargs.get("connect_args", {})) - connect_args.setdefault("check_same_thread", False) - engine_kwargs["connect_args"] = connect_args - - db_engine = create_async_engine(db_url, **engine_kwargs) - if db_engine.dialect.name == "sqlite": - # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma) - - except Exception as e: - if isinstance(e, ArgumentError): - raise ValueError( - f"Invalid database URL format or argument '{db_url}'." - ) from e - if isinstance(e, ImportError): + if db_engine is None: + try: + engine_kwargs = dict(kwargs) + url = make_url(db_url) + if url.get_backend_name() == "sqlite" and url.database == ":memory:": + engine_kwargs.setdefault("poolclass", StaticPool) + connect_args = dict(engine_kwargs.get("connect_args", {})) + connect_args.setdefault("check_same_thread", False) + engine_kwargs["connect_args"] = connect_args + + db_engine = create_async_engine(db_url, **engine_kwargs) + if db_engine.dialect.name == "sqlite": + # Set sqlite pragma to enable foreign keys constraints + event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma) + + except Exception as e: + if isinstance(e, ArgumentError): + raise ValueError( + f"Invalid database URL format or argument '{db_url}'." + ) from e + if isinstance(e, ImportError): + raise ValueError( + f"Database related module not found for URL '{db_url}'." + ) from e raise ValueError( - f"Database related module not found for URL '{db_url}'." + f"Failed to create database engine for URL '{db_url}'" ) from e - raise ValueError( - f"Failed to create database engine for URL '{db_url}'" - ) from e # Get the local timezone local_timezone = get_localzone() diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 30eb15678b..4fef8a40d7 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -26,12 +26,13 @@ from google.adk.sessions.sqlite_session_service import SqliteSessionService from google.genai import types import pytest +from sqlalchemy.ext.asyncio import create_async_engine class SessionServiceType(enum.Enum): - IN_MEMORY = 'IN_MEMORY' - DATABASE = 'DATABASE' - SQLITE = 'SQLITE' + IN_MEMORY = "IN_MEMORY" + DATABASE = "DATABASE" + SQLITE = "SQLITE" def get_session_service( @@ -40,9 +41,10 @@ def get_session_service( ): """Creates a session service for testing.""" if service_type == SessionServiceType.DATABASE: - return DatabaseSessionService('sqlite+aiosqlite:///:memory:') + # Using positional argument to test backward compatibility + return DatabaseSessionService("sqlite+aiosqlite:///:memory:") if service_type == SessionServiceType.SQLITE: - return SqliteSessionService(str(tmp_path / 'sqlite.db')) + return SqliteSessionService(str(tmp_path / "sqlite.db")) return InMemorySessionService() @@ -67,13 +69,13 @@ async def test_sqlite_session_service_accepts_sqlite_urls( ): monkeypatch.chdir(tmp_path) - service = SqliteSessionService('sqlite+aiosqlite:///./sessions.db') - await service.create_session(app_name='app', user_id='user') - assert (tmp_path / 'sessions.db').exists() + service = SqliteSessionService("sqlite+aiosqlite:///./sessions.db") + await service.create_session(app_name="app", user_id="user") + assert (tmp_path / "sessions.db").exists() - service = SqliteSessionService('sqlite:///./sessions2.db') - await service.create_session(app_name='app', user_id='user') - assert (tmp_path / 'sessions2.db').exists() + service = SqliteSessionService("sqlite:///./sessions2.db") + await service.create_session(app_name="app", user_id="user") + assert (tmp_path / "sessions2.db").exists() @pytest.mark.asyncio @@ -81,38 +83,38 @@ async def test_sqlite_session_service_preserves_uri_query_parameters( tmp_path, monkeypatch ): monkeypatch.chdir(tmp_path) - db_path = tmp_path / 'readonly.db' + db_path = tmp_path / "readonly.db" with sqlite3.connect(db_path) as conn: - conn.execute('CREATE TABLE IF NOT EXISTS t (id INTEGER)') + conn.execute("CREATE TABLE IF NOT EXISTS t (id INTEGER)") conn.commit() - service = SqliteSessionService(f'sqlite+aiosqlite:///{db_path}?mode=ro') + service = SqliteSessionService(f"sqlite+aiosqlite:///{db_path}?mode=ro") # `mode=ro` opens the DB read-only; schema creation should fail. - with pytest.raises(sqlite3.OperationalError, match=r'readonly'): - await service.create_session(app_name='app', user_id='user') + with pytest.raises(sqlite3.OperationalError, match=r"readonly"): + await service.create_session(app_name="app", user_id="user") @pytest.mark.asyncio async def test_sqlite_session_service_accepts_absolute_sqlite_urls(tmp_path): - abs_db_path = tmp_path / 'absolute.db' - abs_url = 'sqlite+aiosqlite:////' + str(abs_db_path).lstrip('/') + abs_db_path = tmp_path / "absolute.db" + abs_url = "sqlite+aiosqlite:////" + str(abs_db_path).lstrip("/") service = SqliteSessionService(abs_url) - await service.create_session(app_name='app', user_id='user') + await service.create_session(app_name="app", user_id="user") assert abs_db_path.exists() @pytest.mark.asyncio async def test_get_empty_session(session_service): assert not await session_service.get_session( - app_name='my_app', user_id='test_user', session_id='123' + app_name="my_app", user_id="test_user", session_id="123" ) @pytest.mark.asyncio async def test_create_get_session(session_service): - app_name = 'my_app' - user_id = 'test_user' - state = {'key': 'value'} + app_name = "my_app" + user_id = "test_user" + state = {"key": "value"} session = await session_service.create_session( app_name=app_name, user_id=user_id, state=state @@ -150,16 +152,16 @@ async def test_create_get_session(session_service): @pytest.mark.asyncio async def test_create_and_list_sessions(session_service): - app_name = 'my_app' - user_id = 'test_user' + app_name = "my_app" + user_id = "test_user" - session_ids = ['session' + str(i) for i in range(5)] + session_ids = ["session" + str(i) for i in range(5)] for session_id in session_ids: await session_service.create_session( app_name=app_name, user_id=user_id, session_id=session_id, - state={'key': 'value' + session_id}, + state={"key": "value" + session_id}, ) list_sessions_response = await session_service.list_sessions( @@ -169,32 +171,32 @@ async def test_create_and_list_sessions(session_service): assert len(sessions) == len(session_ids) assert {s.id for s in sessions} == set(session_ids) for session in sessions: - assert session.state == {'key': 'value' + session.id} + assert session.state == {"key": "value" + session.id} @pytest.mark.asyncio async def test_list_sessions_all_users(session_service): - app_name = 'my_app' - user_id_1 = 'user1' - user_id_2 = 'user2' + app_name = "my_app" + user_id_1 = "user1" + user_id_2 = "user2" await session_service.create_session( app_name=app_name, user_id=user_id_1, - session_id='session1a', - state={'key': 'value1a'}, + session_id="session1a", + state={"key": "value1a"}, ) await session_service.create_session( app_name=app_name, user_id=user_id_1, - session_id='session1b', - state={'key': 'value1b'}, + session_id="session1b", + state={"key": "value1b"}, ) await session_service.create_session( app_name=app_name, user_id=user_id_2, - session_id='session2a', - state={'key': 'value2a'}, + session_id="session2a", + state={"key": "value2a"}, ) # List sessions for user1 - should contain merged state @@ -204,8 +206,8 @@ async def test_list_sessions_all_users(session_service): sessions_1 = list_sessions_response_1.sessions assert len(sessions_1) == 2 sessions_1_map = {s.id: s for s in sessions_1} - assert sessions_1_map['session1a'].state == {'key': 'value1a'} - assert sessions_1_map['session1b'].state == {'key': 'value1b'} + assert sessions_1_map["session1a"].state == {"key": "value1a"} + assert sessions_1_map["session1b"].state == {"key": "value1b"} # List sessions for user2 - should contain merged state list_sessions_response_2 = await session_service.list_sessions( @@ -213,8 +215,8 @@ async def test_list_sessions_all_users(session_service): ) sessions_2 = list_sessions_response_2.sessions assert len(sessions_2) == 1 - assert sessions_2[0].id == 'session2a' - assert sessions_2[0].state == {'key': 'value2a'} + assert sessions_2[0].id == "session2a" + assert sessions_2[0].state == {"key": "value2a"} # List sessions for all users - should contain merged state list_sessions_response_all = await session_service.list_sessions( @@ -223,150 +225,150 @@ async def test_list_sessions_all_users(session_service): sessions_all = list_sessions_response_all.sessions assert len(sessions_all) == 3 sessions_all_map = {s.id: s for s in sessions_all} - assert sessions_all_map['session1a'].state == {'key': 'value1a'} - assert sessions_all_map['session1b'].state == {'key': 'value1b'} - assert sessions_all_map['session2a'].state == {'key': 'value2a'} + assert sessions_all_map["session1a"].state == {"key": "value1a"} + assert sessions_all_map["session1b"].state == {"key": "value1b"} + assert sessions_all_map["session2a"].state == {"key": "value2a"} @pytest.mark.asyncio async def test_app_state_is_shared_by_all_users_of_app(session_service): - app_name = 'my_app' + app_name = "my_app" # User 1 creates a session, establishing app:k1 session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1', state={'app:k1': 'v1'} + app_name=app_name, user_id="u1", session_id="s1", state={"app:k1": "v1"} ) # User 1 appends an event to session1, establishing app:k2 event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'app:k2': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"app:k2": "v2"}), ) await session_service.append_event(session=session1, event=event) # User 2 creates a new session session2, it should see app:k1 and app:k2 session2 = await session_service.create_session( - app_name=app_name, user_id='u2', session_id='s2' + app_name=app_name, user_id="u2", session_id="s2" ) - assert session2.state == {'app:k1': 'v1', 'app:k2': 'v2'} + assert session2.state == {"app:k1": "v1", "app:k2": "v2"} # If we get session session1 again, it should also see both session1_got = await session_service.get_session( - app_name=app_name, user_id='u1', session_id='s1' + app_name=app_name, user_id="u1", session_id="s1" ) - assert session1_got.state.get('app:k1') == 'v1' - assert session1_got.state.get('app:k2') == 'v2' + assert session1_got.state.get("app:k1") == "v1" + assert session1_got.state.get("app:k2") == "v2" @pytest.mark.asyncio async def test_user_state_is_shared_only_by_user_sessions(session_service): - app_name = 'my_app' + app_name = "my_app" # User 1 creates a session, establishing user:k1 for user 1 session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1', state={'user:k1': 'v1'} + app_name=app_name, user_id="u1", session_id="s1", state={"user:k1": "v1"} ) # User 1 appends an event to session1, establishing user:k2 for user 1 event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'user:k2': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"user:k2": "v2"}), ) await session_service.append_event(session=session1, event=event) # Another session for User 1 should see user:k1 and user:k2 session1b = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1b' + app_name=app_name, user_id="u1", session_id="s1b" ) - assert session1b.state == {'user:k1': 'v1', 'user:k2': 'v2'} + assert session1b.state == {"user:k1": "v1", "user:k2": "v2"} # A session for User 2 should NOT see user:k1 or user:k2 session2 = await session_service.create_session( - app_name=app_name, user_id='u2', session_id='s2' + app_name=app_name, user_id="u2", session_id="s2" ) assert session2.state == {} @pytest.mark.asyncio async def test_session_state_is_not_shared(session_service): - app_name = 'my_app' + app_name = "my_app" # User 1 creates a session session1, establishing sk1 only for session1 session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1', state={'sk1': 'v1'} + app_name=app_name, user_id="u1", session_id="s1", state={"sk1": "v1"} ) # User 1 appends an event to session1, establishing sk2 only for session1 event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'sk2': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"sk2": "v2"}), ) await session_service.append_event(session=session1, event=event) # Getting session1 should show sk1 and sk2 session1_got = await session_service.get_session( - app_name=app_name, user_id='u1', session_id='s1' + app_name=app_name, user_id="u1", session_id="s1" ) - assert session1_got.state.get('sk1') == 'v1' - assert session1_got.state.get('sk2') == 'v2' + assert session1_got.state.get("sk1") == "v1" + assert session1_got.state.get("sk2") == "v2" # Creating another session session1b for User 1 should NOT see sk1 or sk2 session1b = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1b' + app_name=app_name, user_id="u1", session_id="s1b" ) assert session1b.state == {} @pytest.mark.asyncio async def test_temp_state_is_not_persisted_in_state_or_events(session_service): - app_name = 'my_app' - user_id = 'u1' + app_name = "my_app" + user_id = "u1" session = await session_service.create_session( - app_name=app_name, user_id=user_id, session_id='s1' + app_name=app_name, user_id=user_id, session_id="s1" ) event = Event( - invocation_id='inv1', - author='user', - actions=EventActions(state_delta={'temp:k1': 'v1', 'sk': 'v2'}), + invocation_id="inv1", + author="user", + actions=EventActions(state_delta={"temp:k1": "v1", "sk": "v2"}), ) await session_service.append_event(session=session, event=event) # Refetch session and check state and event session_got = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id='s1' + app_name=app_name, user_id=user_id, session_id="s1" ) # Check session state does not contain temp keys - assert session_got.state.get('sk') == 'v2' - assert 'temp:k1' not in session_got.state + assert session_got.state.get("sk") == "v2" + assert "temp:k1" not in session_got.state # Check event as stored in session does not contain temp keys in state_delta - assert 'temp:k1' not in session_got.events[0].actions.state_delta - assert session_got.events[0].actions.state_delta.get('sk') == 'v2' + assert "temp:k1" not in session_got.events[0].actions.state_delta + assert session_got.events[0].actions.state_delta.get("sk") == "v2" @pytest.mark.asyncio async def test_get_session_respects_user_id(session_service): - app_name = 'my_app' + app_name = "my_app" # u1 creates session 's1' and adds an event session1 = await session_service.create_session( - app_name=app_name, user_id='u1', session_id='s1' + app_name=app_name, user_id="u1", session_id="s1" ) - event = Event(invocation_id='inv1', author='user') + event = Event(invocation_id="inv1", author="user") await session_service.append_event(session1, event) # u2 creates a session with the same session_id 's1' await session_service.create_session( - app_name=app_name, user_id='u2', session_id='s1' + app_name=app_name, user_id="u2", session_id="s1" ) # Check that getting s1 for u2 returns u2's session (with no events) # not u1's session. session2_got = await session_service.get_session( - app_name=app_name, user_id='u2', session_id='s1' + app_name=app_name, user_id="u2", session_id="s1" ) - assert session2_got.user_id == 'u2' + assert session2_got.user_id == "u2" assert len(session2_got.events) == 0 @pytest.mark.asyncio async def test_create_session_with_existing_id_raises_error(session_service): - app_name = 'my_app' - user_id = 'test_user' - session_id = 'existing_session' + app_name = "my_app" + user_id = "test_user" + session_id = "existing_session" # Create the first session await session_service.create_session( @@ -386,25 +388,25 @@ async def test_create_session_with_existing_id_raises_error(session_service): @pytest.mark.asyncio async def test_append_event_bytes(session_service): - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" session = await session_service.create_session( app_name=app_name, user_id=user_id ) test_content = types.Content( - role='user', + role="user", parts=[ - types.Part.from_bytes(data=b'test_image_data', mime_type='image/png'), + types.Part.from_bytes(data=b"test_image_data", mime_type="image/png"), ], ) test_grounding_metadata = types.GroundingMetadata( - search_entry_point=types.SearchEntryPoint(sdk_blob=b'test_sdk_blob') + search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") ) event = Event( - invocation_id='invocation', - author='user', + invocation_id="invocation", + author="user", content=test_content, grounding_metadata=test_grounding_metadata, ) @@ -423,43 +425,43 @@ async def test_append_event_bytes(session_service): @pytest.mark.asyncio async def test_append_event_complete(session_service): - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" session = await session_service.create_session( app_name=app_name, user_id=user_id ) event = Event( - invocation_id='invocation', - author='user', - content=types.Content(role='user', parts=[types.Part(text='test_text')]), + invocation_id="invocation", + author="user", + content=types.Content(role="user", parts=[types.Part(text="test_text")]), turn_complete=True, partial=False, actions=EventActions( artifact_delta={ - 'file': 0, + "file": 0, }, - transfer_to_agent='agent', + transfer_to_agent="agent", escalate=True, ), - long_running_tool_ids={'tool1'}, - error_code='error_code', - error_message='error_message', + long_running_tool_ids={"tool1"}, + error_code="error_code", + error_message="error_message", interrupted=True, grounding_metadata=types.GroundingMetadata( - web_search_queries=['query1'], + web_search_queries=["query1"], ), usage_metadata=types.GenerateContentResponseUsageMetadata( prompt_token_count=1, candidates_token_count=1, total_token_count=2 ), citation_metadata=types.CitationMetadata(), - custom_metadata={'custom_key': 'custom_value'}, + custom_metadata={"custom_key": "custom_value"}, input_transcription=types.Transcription( - text='input transcription', + text="input transcription", finished=True, ), output_transcription=types.Transcription( - text='output transcription', + text="output transcription", finished=True, ), ) @@ -475,8 +477,8 @@ async def test_append_event_complete(session_service): @pytest.mark.asyncio async def test_session_last_update_time_updates_on_event(session_service): - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" session = await session_service.create_session( app_name=app_name, user_id=user_id @@ -485,8 +487,8 @@ async def test_session_last_update_time_updates_on_event(session_service): event_timestamp = original_update_time + 10 event = Event( - invocation_id='invocation', - author='user', + invocation_id="invocation", + author="user", timestamp=event_timestamp, ) await session_service.append_event(session=session, event=event) @@ -509,18 +511,18 @@ async def test_append_event_to_stale_session(): service_type=SessionServiceType.DATABASE ) - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" current_time = datetime.now().astimezone(timezone.utc).timestamp() original_session = await session_service.create_session( app_name=app_name, user_id=user_id ) event1 = Event( - invocation_id='inv1', - author='user', + invocation_id="inv1", + author="user", timestamp=current_time + 1, - actions=EventActions(state_delta={'sk1': 'v1'}), + actions=EventActions(state_delta={"sk1": "v1"}), ) await session_service.append_event(original_session, event1) @@ -528,24 +530,24 @@ async def test_append_event_to_stale_session(): app_name=app_name, user_id=user_id, session_id=original_session.id ) event2 = Event( - invocation_id='inv2', - author='user', + invocation_id="inv2", + author="user", timestamp=current_time + 2, - actions=EventActions(state_delta={'sk2': 'v2'}), + actions=EventActions(state_delta={"sk2": "v2"}), ) await session_service.append_event(updated_session, event2) # original_session is now stale assert original_session.last_update_time < updated_session.last_update_time assert len(original_session.events) == 1 - assert 'sk2' not in original_session.state + assert "sk2" not in original_session.state # Appending another event to stale original_session event3 = Event( - invocation_id='inv3', - author='user', + invocation_id="inv3", + author="user", timestamp=current_time + 3, - actions=EventActions(state_delta={'sk3': 'v3'}), + actions=EventActions(state_delta={"sk3": "v3"}), ) await session_service.append_event(original_session, event3) @@ -555,27 +557,27 @@ async def test_append_event_to_stale_session(): app_name=app_name, user_id=user_id, session_id=original_session.id ) assert len(session_final.events) == 3 - assert session_final.state.get('sk1') == 'v1' - assert session_final.state.get('sk2') == 'v2' - assert session_final.state.get('sk3') == 'v3' + assert session_final.state.get("sk1") == "v1" + assert session_final.state.get("sk2") == "v2" + assert session_final.state.get("sk3") == "v3" assert [e.invocation_id for e in session_final.events] == [ - 'inv1', - 'inv2', - 'inv3', + "inv1", + "inv2", + "inv3", ] @pytest.mark.asyncio -async def test_get_session_with_config(session_service): - app_name = 'my_app' - user_id = 'user' +async def test_get_session_with_config_filters(session_service): + app_name = "my_app" + user_id = "user" num_test_events = 5 session = await session_service.create_session( app_name=app_name, user_id=user_id ) for i in range(1, num_test_events + 1): - event = Event(author='user', timestamp=i) + event = Event(author="user", timestamp=i) await session_service.append_event(session, event) # No config, expect all events to be returned. @@ -627,12 +629,12 @@ async def test_get_session_with_config(session_service): @pytest.mark.asyncio async def test_partial_events_are_not_persisted(session_service): - app_name = 'my_app' - user_id = 'user' + app_name = "my_app" + user_id = "user" session = await session_service.create_session( app_name=app_name, user_id=user_id ) - event = Event(author='user', partial=True) + event = Event(author="user", partial=True) await session_service.append_event(session, event) # Check in-memory session @@ -642,3 +644,80 @@ async def test_partial_events_are_not_persisted(session_service): app_name=app_name, user_id=user_id, session_id=session.id ) assert len(session_got.events) == 0 + + +@pytest.mark.asyncio +async def test_database_session_service_with_db_url(): + """Test DatabaseSessionService initialization with db_url.""" + # Test db_url as positional argument + service = DatabaseSessionService("sqlite+aiosqlite:///:memory:") + app_name = "test_app" + user_id = "test_user" + + # Create and retrieve a session + session = await service.create_session( + app_name=app_name, user_id=user_id, state={"key": "value"} + ) + assert session.app_name == app_name + assert session.user_id == user_id + assert session.state == {"key": "value"} + + # Let's check that we can retrieve it + retrieved = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert retrieved == session + + # test db_url as keyword argument + service2 = DatabaseSessionService(db_url="sqlite+aiosqlite:///:memory:") + session2 = await service2.create_session( + app_name=app_name, user_id=user_id, state={"key": "value2"} + ) + assert session2.state == {"key": "value2"} + + +@pytest.mark.asyncio +async def test_database_session_service_with_db_engine(): + """Test DatabaseSessionService initialization with db_engine.""" + # Create an engine manually + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + # Create service with db_engine + service = DatabaseSessionService(db_engine=engine) + app_name = "test_app" + user_id = "test_user" + + # Create and retrieve a session + session = await service.create_session( + app_name=app_name, user_id=user_id, state={"key": "value"} + ) + assert session.app_name == app_name + assert session.user_id == user_id + assert session.state == {"key": "value"} + + # Let's check that we can retrieve it + retrieved = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert retrieved == session + + +@pytest.mark.asyncio +async def test_database_session_service_requires_one_argument(): + """Test that DatabaseSessionService requires exactly one of db_url or db_engine.""" + # Neither argument provided + with pytest.raises( + ValueError, + match="Exactly one of 'db_url' or 'db_engine' must be provided", + ): + DatabaseSessionService() + + # Both arguments provided + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + with pytest.raises( + ValueError, + match="Exactly one of 'db_url' or 'db_engine' must be provided", + ): + DatabaseSessionService( + db_url="sqlite+aiosqlite:///:memory:", db_engine=engine + )