diff --git a/TCT/trapi.py b/TCT/trapi.py index 9bad9b9..467988d 100644 --- a/TCT/trapi.py +++ b/TCT/trapi.py @@ -12,7 +12,7 @@ # TODO: incorporate object ids into the method. def build_query(subject_ids:list[str], object_categories:list[str], predicates:list[str], - return_json:bool=True, + return_json:bool=False, object_ids=None, subject_categories=None): """ This constructs a query json for use with TRAPI. Queries are of the form [subject_ids]-[predicates]-[object_categories]. @@ -33,7 +33,7 @@ def build_query(subject_ids:list[str], A list of predicates that we are interested in. Example: ["biolink:positively_correlated_with", "biolink:physically_interacts_with"]. return_json - If true, returns a json string; if false, returns a dict. + If true, returns a json string; if false, returns a dict (default). object_ids None by default @@ -42,7 +42,7 @@ def build_query(subject_ids:list[str], Returns ------- - A json string + A dict (default) or json string if return_json=True Examples -------- @@ -94,7 +94,7 @@ def process_result(result:dict): """ -def query(url:str, query:str): +def query(url:str, query:dict): """ Queries a single TRAPI endpoint. @@ -102,8 +102,8 @@ def query(url:str, query:str): ------ url : str The URL for the API endpoint. - query : str - A JSON string representing the query, as produced by build_query + query : dict + A dict representing the query, as produced by build_query Returns ------- @@ -111,11 +111,16 @@ def query(url:str, query:str): Examples -------- - >>> query = build_query(['NCBIGene:3845'], ['biolink:Gene'], ['biolink:physically_interacts_with']) - >>> response = query(url, query) + >>> query_dict = build_query(['NCBIGene:3845'], ['biolink:Gene'], ['biolink:physically_interacts_with']) + >>> response = query(url, query_dict) >>> print(response) """ - # example: 1. get APIs, 2. get APIs that have the target object and subject types, and the target predicates. 3. build the query and run the query. + if isinstance(query, str): + raise TypeError( + "query must be a dict, not a JSON string. " + "Use build_query(...) without return_json=True, " + "or pass json.loads(query) instead." + ) response = requests.post(url, json=query) if response.status_code == 200: # TODO diff --git a/tests/test_trapi.py b/tests/test_trapi.py new file mode 100644 index 0000000..d3a131f --- /dev/null +++ b/tests/test_trapi.py @@ -0,0 +1,84 @@ +"""Tests for TCT.trapi module, focused on the predicate_query branch changes: +- build_query defaults to returning a dict (return_json=False) +- query() accepts a dict instead of a JSON string +""" + +import json +import inspect + +import pytest + +from TCT.trapi import build_query, query + + +# Test data +EXAMPLE_QUERIES = [ + { + 'subject_ids': ['NCBIGene:3845'], + 'object_categories': ['biolink:Gene'], + 'predicates': ['biolink:physically_interacts_with'], + }, + { + 'subject_ids': ['NCBIGene:3845'], + 'object_categories': ['biolink:Gene'], + 'predicates': [ + 'biolink:positively_correlated_with', + 'biolink:physically_interacts_with', + ], + }, +] + + +def test_build_query_returns_dict_by_default(): + q = EXAMPLE_QUERIES[0] + result = build_query(q['subject_ids'], q['object_categories'], q['predicates']) + assert isinstance(result, dict) + + +def test_build_query_default_is_explicitly_false(): + sig = inspect.signature(build_query) + assert sig.parameters['return_json'].default is False + + +def test_build_query_returns_json_string_when_requested(): + q = EXAMPLE_QUERIES[0] + result = build_query(q['subject_ids'], q['object_categories'], q['predicates'], return_json=True) + assert isinstance(result, str) + parsed = json.loads(result) + assert isinstance(parsed, dict) + + +def test_build_query_dict_and_json_are_equivalent(): + q = EXAMPLE_QUERIES[0] + dict_result = build_query(q['subject_ids'], q['object_categories'], q['predicates'], return_json=False) + json_result = build_query(q['subject_ids'], q['object_categories'], q['predicates'], return_json=True) + assert dict_result == json.loads(json_result) + + +@pytest.mark.parametrize("example_query", EXAMPLE_QUERIES) +def test_build_query_structure(example_query): + result = build_query( + example_query['subject_ids'], + example_query['object_categories'], + example_query['predicates'], + ) + assert 'message' in result + qg = result['message']['query_graph'] + assert qg['edges']['e00']['predicates'] == example_query['predicates'] + assert qg['edges']['e00']['subject'] == 'n00' + assert qg['edges']['e00']['object'] == 'n01' + assert qg['nodes']['n00']['ids'] == example_query['subject_ids'] + assert qg['nodes']['n01']['categories'] == example_query['object_categories'] + + +def test_query_signature_expects_dict(): + sig = inspect.signature(query) + assert sig.parameters['query'].annotation is dict + + +def test_build_query_output_matches_query_input_type(): + """build_query's default output type should match what query() expects.""" + q = EXAMPLE_QUERIES[0] + result = build_query(q['subject_ids'], q['object_categories'], q['predicates']) + sig = inspect.signature(query) + assert isinstance(result, sig.parameters['query'].annotation)