From e379865ec5a3f56cbae2e1971b7ee654c9144519 Mon Sep 17 00:00:00 2001 From: John Gemignani Date: Thu, 12 Feb 2026 11:17:37 -0800 Subject: [PATCH] Update python-driver security and formatting Note: This PR was created with AI tools and a human. - Add parameterized query construction using psycopg.sql to prevent SQL injection in all Cypher execution paths (age.py, networkx/lib.py) - Replace all %-format and f-string SQL in networkx/lib.py with sql.Identifier() for schema/table names and sql.Literal() for values - Add validate_graph_name() with AGE-aligned VALID_GRAPH_NAME regex: start with letter/underscore, allow dots and hyphens in middle positions, end with letter/digit/underscore, min 3 chars, max 63 chars - Add validate_identifier() with strict VALID_IDENTIFIER regex for labels, column names, and SQL types (no dots or hyphens) - Add validation calls to all networkx/lib.py entry points: graph names validated on entry, labels validated before SQL construction - Add _validate_column() to sanitize column specifications in buildCypher() - Fix exception constructors (AgeNotSet, GraphNotFound, GraphAlreadyExists) to always call super().__init__() with a meaningful default message so that str(exception) never returns an empty string - Add InvalidGraphName and InvalidIdentifier exception classes with structured name/reason/context fields - Fix builder.py: change erroneous 'return Exception(...)' to 'raise ValueError(...)' for unknown float expressions - Fix copy-paste docstring in create_elabel() ('create_vlabels' -> 'create_elabels') - Remove unused 'from psycopg.adapt import Loader' import in age.py - Add design documentation in source explaining: - VALID_GRAPH_NAME regex uses '*' (not '+') intentionally so that the min-length check fires first with a clear error message - buildCypher uses string concatenation (not sql.Identifier) because column specs are pre-validated 'name type' pairs that don't map to sql.Identifier(); graphName and cypherStmt are NOT embedded - Update test_networkx.py GraphNotFound assertion to use assertIn() instead of assertEqual() to match the improved exception messages - Strip Windows carriage returns (^M) from 7 source files - Fix requirements.txt: convert from UTF-16LE+BOM+CRLF to clean UTF-8+LF, move --no-binary flag from requirements.txt to CI workflow pip command - Upgrade actions/setup-python from v4 (deprecated) to v5 in CI workflow - Add 46 security unit tests in test_security.py covering: - Graph name validation (AGE naming rules, injection, edge cases) - SQL identifier validation (labels, columns, types) - Column spec sanitization - buildCypher injection prevention - Exception constructor correctness (str() never empty) - Add test_security.py to CI pipeline (python-driver.yaml) - pip-audit: 0 known vulnerabilities in all dependencies modified: .github/workflows/python-driver.yaml modified: drivers/python/age/VERSION.py modified: drivers/python/age/__init__.py modified: drivers/python/age/age.py modified: drivers/python/age/builder.py modified: drivers/python/age/exceptions.py modified: drivers/python/age/models.py modified: drivers/python/age/networkx/lib.py modified: drivers/python/requirements.txt modified: drivers/python/setup.py modified: drivers/python/test_agtypes.py modified: drivers/python/test_networkx.py new file: drivers/python/test_security.py --- .github/workflows/python-driver.yaml | 5 +- drivers/python/age/VERSION.py | 44 +- drivers/python/age/__init__.py | 80 ++-- drivers/python/age/age.py | 611 ++++++++++++++++----------- drivers/python/age/builder.py | 420 +++++++++--------- drivers/python/age/exceptions.py | 59 ++- drivers/python/age/models.py | 586 ++++++++++++------------- drivers/python/age/networkx/lib.py | 97 +++-- drivers/python/requirements.txt | Bin 176 -> 59 bytes drivers/python/setup.py | 44 +- drivers/python/test_agtypes.py | 264 ++++++------ drivers/python/test_networkx.py | 2 +- drivers/python/test_security.py | 274 ++++++++++++ 13 files changed, 1481 insertions(+), 1005 deletions(-) create mode 100644 drivers/python/test_security.py diff --git a/.github/workflows/python-driver.yaml b/.github/workflows/python-driver.yaml index 4dad14638..16ccface4 100644 --- a/.github/workflows/python-driver.yaml +++ b/.github/workflows/python-driver.yaml @@ -22,14 +22,14 @@ jobs: run: docker compose up -d - name: Set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.12' - name: Install pre-requisites run: | sudo apt-get install python3-dev libpq-dev - pip install -r requirements.txt + pip install --no-binary psycopg -r requirements.txt - name: Build run: | @@ -40,3 +40,4 @@ jobs: python test_age_py.py -db "postgres" -u "postgres" -pass "agens" python test_networkx.py -db "postgres" -u "postgres" -pass "agens" python -m unittest -v test_agtypes.py + python -m unittest -v test_security.py diff --git a/drivers/python/age/VERSION.py b/drivers/python/age/VERSION.py index 3b014ea5b..5136181ae 100644 --- a/drivers/python/age/VERSION.py +++ b/drivers/python/age/VERSION.py @@ -1,22 +1,22 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - - -VER_MAJOR = 1 -VER_MINOR = 0 -VER_MICRO = 0 - -VERSION = '.'.join([str(VER_MAJOR),str(VER_MINOR),str(VER_MICRO)]) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + + +VER_MAJOR = 1 +VER_MINOR = 0 +VER_MICRO = 0 + +VERSION = '.'.join([str(VER_MAJOR),str(VER_MINOR),str(VER_MICRO)]) diff --git a/drivers/python/age/__init__.py b/drivers/python/age/__init__.py index fd50135af..685f0fe74 100644 --- a/drivers/python/age/__init__.py +++ b/drivers/python/age/__init__.py @@ -1,40 +1,40 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import psycopg.conninfo as conninfo -from . import age -from .age import * -from .models import * -from .builder import ResultHandler, DummyResultHandler, parseAgeValue, newResultHandler -from . import VERSION - -def version(): - return VERSION.VERSION - - -def connect(dsn=None, graph=None, connection_factory=None, cursor_factory=ClientCursor, load_from_plugins=False, - **kwargs): - - dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs) - - ag = Age() - ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory, cursor_factory=cursor_factory, - load_from_plugins=load_from_plugins, **kwargs) - return ag - -# Dummy ResultHandler -rawPrinter = DummyResultHandler() - -__name__="age" +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import psycopg.conninfo as conninfo +from . import age +from .age import * +from .models import * +from .builder import ResultHandler, DummyResultHandler, parseAgeValue, newResultHandler +from . import VERSION + +def version(): + return VERSION.VERSION + + +def connect(dsn=None, graph=None, connection_factory=None, cursor_factory=ClientCursor, load_from_plugins=False, + **kwargs): + + dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs) + + ag = Age() + ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory, cursor_factory=cursor_factory, + load_from_plugins=load_from_plugins, **kwargs) + return ag + +# Dummy ResultHandler +rawPrinter = DummyResultHandler() + +__name__="age" diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index b1aa82158..fad1f27b1 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -1,236 +1,375 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import re -import psycopg -from psycopg.types import TypeInfo -from psycopg.adapt import Loader -from psycopg import sql -from psycopg.client_cursor import ClientCursor -from .exceptions import * -from .builder import parseAgeValue - - -_EXCEPTION_NoConnection = NoConnection() -_EXCEPTION_GraphNotSet = GraphNotSet() - -WHITESPACE = re.compile(r'\s') - - -class AgeDumper(psycopg.adapt.Dumper): - def dump(self, obj: Any) -> bytes | bytearray | memoryview: - pass - - -class AgeLoader(psycopg.adapt.Loader): - def load(self, data: bytes | bytearray | memoryview) -> Any | None: - if isinstance(data, memoryview): - data_bytes = data.tobytes() - else: - data_bytes = data - - return parseAgeValue(data_bytes.decode('utf-8')) - - -def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False): - with conn.cursor() as cursor: - if load_from_plugins: - cursor.execute("LOAD '$libdir/plugins/age';") - else: - cursor.execute("LOAD 'age';") - - cursor.execute("SET search_path = ag_catalog, '$user', public;") - - ag_info = TypeInfo.fetch(conn, 'agtype') - - if not ag_info: - raise AgeNotSet() - - conn.adapters.register_loader(ag_info.oid, AgeLoader) - conn.adapters.register_loader(ag_info.array_oid, AgeLoader) - - # Check graph exists - if graphName != None: - checkGraphCreated(conn, graphName) - -# Create the graph, if it does not exist -def checkGraphCreated(conn:psycopg.connection, graphName:str): - with conn.cursor() as cursor: - cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName))) - if cursor.fetchone()[0] == 0: - cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName))) - conn.commit() - - -def deleteGraph(conn:psycopg.connection, graphName:str): - with conn.cursor() as cursor: - cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName))) - conn.commit() - - -def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str: - if graphName == None: - raise _EXCEPTION_GraphNotSet - - columnExp=[] - if columns != None and len(columns) > 0: - for col in columns: - if col.strip() == '': - continue - elif WHITESPACE.search(col) != None: - columnExp.append(col) - else: - columnExp.append(col + " agtype") - else: - columnExp.append('v agtype') - - stmtArr = [] - stmtArr.append("SELECT * from cypher(NULL,NULL) as (") - stmtArr.append(','.join(columnExp)) - stmtArr.append(");") - return "".join(stmtArr) - -def execSql(conn:psycopg.connection, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor : - if conn == None or conn.closed: - raise _EXCEPTION_NoConnection - - cursor = conn.cursor() - try: - cursor.execute(stmt, params) - if commit: - conn.commit() - - return cursor - except SyntaxError as cause: - conn.rollback() - raise cause - except Exception as cause: - conn.rollback() - raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause) - - -def querySql(conn:psycopg.connection, stmt:str, params:tuple=None) -> psycopg.cursor : - return execSql(conn, stmt, False, params) - -# Execute cypher statement and return cursor. -# If cypher statement changes data (create, set, remove), -# You must commit session(ag.commit()) -# (Otherwise the execution cannot make any effect.) -def execCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : - if conn == None or conn.closed: - raise _EXCEPTION_NoConnection - - cursor = conn.cursor() - #clean up the string for mogrification - cypherStmt = cypherStmt.replace("\n", "") - cypherStmt = cypherStmt.replace("\t", "") - cypher = str(cursor.mogrify(cypherStmt, params)) - cypher = cypher.strip() - - preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})" - - cursor = conn.cursor() - try: - cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher))) - except SyntaxError as cause: - conn.rollback() - raise cause - except Exception as cause: - conn.rollback() - raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + preparedStmt +")", cause) - - stmt = buildCypher(graphName, cypher, cols) - - cursor = conn.cursor() - try: - cursor.execute(stmt) - return cursor - except SyntaxError as cause: - conn.rollback() - raise cause - except Exception as cause: - conn.rollback() - raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause) - - -def cypher(cursor:psycopg.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : - #clean up the string for mogrification - cypherStmt = cypherStmt.replace("\n", "") - cypherStmt = cypherStmt.replace("\t", "") - cypher = str(cursor.mogrify(cypherStmt, params)) - cypher = cypher.strip() - - preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})" - cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher))) - - stmt = buildCypher(graphName, cypher, cols) - cursor.execute(stmt) - - -# def execCypherWithReturn(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : -# stmt = buildCypher(graphName, cypherStmt, columns) -# return execSql(conn, stmt, False, params) - -# def queryCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : -# return execCypherWithReturn(conn, graphName, cypherStmt, columns, params) - - -class Age: - def __init__(self): - self.connection = None # psycopg connection] - self.graphName = None - - # Connect to PostgreSQL Server and establish session and type extension environment. - def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor, - load_from_plugins:bool=False, **kwargs): - conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs) - setUpAge(conn, graph, load_from_plugins) - self.connection = conn - self.graphName = graph - return self - - def close(self): - self.connection.close() - - def setGraph(self, graph:str): - checkGraphCreated(self.connection, graph) - self.graphName = graph - return self - - def commit(self): - self.connection.commit() - - def rollback(self): - self.connection.rollback() - - def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : - return execCypher(self.connection, self.graphName, cypherStmt, cols=cols, params=params) - - def cypher(self, cursor:psycopg.cursor, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : - return cypher(cursor, self.graphName, cypherStmt, cols=cols, params=params) - - # def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor : - # return execSql(self.connection, stmt, commit, params) - - - # def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor : - # return execCypher(self.connection, self.graphName, cypherStmt, commit, params) - - # def execCypherWithReturn(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : - # return execCypherWithReturn(self.connection, self.graphName, cypherStmt, columns, params) - - # def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : - # return queryCypher(self.connection, self.graphName, cypherStmt, columns, params) - +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import psycopg +from psycopg.types import TypeInfo +from psycopg import sql +from psycopg.client_cursor import ClientCursor +from .exceptions import * +from .builder import parseAgeValue + + +_EXCEPTION_NoConnection = NoConnection() +_EXCEPTION_GraphNotSet = GraphNotSet() + +WHITESPACE = re.compile(r'\s') + +# Valid AGE graph name pattern aligned with Apache AGE's internal validation +# and Neo4j/openCypher naming conventions. +# Start: letter or underscore +# Middle: letter, digit, underscore, dot, or hyphen +# End: letter, digit, or underscore +# +# Design note: The middle segment uses `*` (not `+`) intentionally. +# This makes the regex match names as short as 2 characters at the +# regex level. However, validate_graph_name() checks MIN_GRAPH_NAME_LENGTH +# *before* applying this regex, so 2-character names are rejected with a +# clear "must be at least 3 characters" error rather than a confusing +# regex-mismatch error. This ordering gives users actionable feedback. +VALID_GRAPH_NAME = re.compile(r'^[A-Za-z_][A-Za-z0-9_.\-]*[A-Za-z0-9_]$') +MIN_GRAPH_NAME_LENGTH = 3 + +# Valid SQL identifier for labels, column names, and types. +# Stricter than graph names — no dots or hyphens. +VALID_IDENTIFIER = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$') +MAX_IDENTIFIER_LENGTH = 63 + + +def validate_graph_name(graph_name: str) -> None: + """Validate that a graph name conforms to Apache AGE's naming rules. + + Graph names must: + - Be at least 3 characters and at most 63 characters + - Start with a letter or underscore + - Contain only letters, digits, underscores, dots, and hyphens + - End with a letter, digit, or underscore + + This aligns with AGE's internal validation and Neo4j/openCypher + naming conventions. + + Args: + graph_name: The graph name to validate. + + Raises: + InvalidGraphName: If the graph name is invalid. + """ + if not graph_name or not isinstance(graph_name, str): + raise InvalidGraphName( + str(graph_name), + "Graph name must be a non-empty string." + ) + if len(graph_name) < MIN_GRAPH_NAME_LENGTH: + raise InvalidGraphName( + graph_name, + f"Graph names must be at least {MIN_GRAPH_NAME_LENGTH} characters." + ) + if len(graph_name) > MAX_IDENTIFIER_LENGTH: + raise InvalidGraphName( + graph_name, + f"Must not exceed {MAX_IDENTIFIER_LENGTH} characters " + "(PostgreSQL name limit)." + ) + if not VALID_GRAPH_NAME.match(graph_name): + raise InvalidGraphName( + graph_name, + "Graph names must start with a letter or underscore, " + "may contain letters, digits, underscores, dots, and hyphens, " + "and must end with a letter, digit, or underscore." + ) + + +def validate_identifier(name: str, context: str = "identifier") -> None: + """Validate that a name is a safe SQL identifier for labels, columns, or types. + + This follows stricter rules than graph names — only letters, digits, + and underscores are permitted (no dots or hyphens). + + Args: + name: The identifier to validate. + context: What the identifier represents (for error messages). + + Raises: + InvalidIdentifier: If the identifier is invalid. + """ + if not name or not isinstance(name, str): + raise InvalidIdentifier( + str(name), + f"{context} must be a non-empty string." + ) + if len(name) > MAX_IDENTIFIER_LENGTH: + raise InvalidIdentifier( + name, + f"{context} must not exceed {MAX_IDENTIFIER_LENGTH} characters." + ) + if not VALID_IDENTIFIER.match(name): + raise InvalidIdentifier( + name, + f"{context} must start with a letter or underscore " + "and contain only letters, digits, and underscores." + ) + + +class AgeDumper(psycopg.adapt.Dumper): + def dump(self, obj: Any) -> bytes | bytearray | memoryview: + pass + + +class AgeLoader(psycopg.adapt.Loader): + def load(self, data: bytes | bytearray | memoryview) -> Any | None: + if isinstance(data, memoryview): + data_bytes = data.tobytes() + else: + data_bytes = data + + return parseAgeValue(data_bytes.decode('utf-8')) + + +def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False): + with conn.cursor() as cursor: + if load_from_plugins: + cursor.execute("LOAD '$libdir/plugins/age';") + else: + cursor.execute("LOAD 'age';") + + cursor.execute("SET search_path = ag_catalog, '$user', public;") + + ag_info = TypeInfo.fetch(conn, 'agtype') + + if not ag_info: + raise AgeNotSet( + "AGE agtype type not found. Ensure the AGE extension is " + "installed and loaded in the current database. " + "Run CREATE EXTENSION age; first." + ) + + conn.adapters.register_loader(ag_info.oid, AgeLoader) + conn.adapters.register_loader(ag_info.array_oid, AgeLoader) + + # Check graph exists + if graphName != None: + checkGraphCreated(conn, graphName) + +# Create the graph, if it does not exist +def checkGraphCreated(conn:psycopg.connection, graphName:str): + validate_graph_name(graphName) + with conn.cursor() as cursor: + cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName))) + if cursor.fetchone()[0] == 0: + cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName))) + conn.commit() + + +def deleteGraph(conn:psycopg.connection, graphName:str): + validate_graph_name(graphName) + with conn.cursor() as cursor: + cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName))) + conn.commit() + + +def _validate_column(col: str) -> str: + """Validate and normalize a column specification for use in SQL. + + Accepts either a plain column name (e.g. 'v') or a name with type + (e.g. 'v agtype'). Validates each component to prevent SQL injection. + + Args: + col: Column specification string. + + Returns: + Normalized column specification, or empty string if blank. + + Raises: + InvalidIdentifier: If any component is invalid. + """ + col = col.strip() + if not col: + return '' + + if WHITESPACE.search(col): + parts = col.split() + if len(parts) != 2: + raise InvalidIdentifier( + col, + "Column specification must be 'name' or 'name type'." + ) + name, type_name = parts + validate_identifier(name, "Column name") + validate_identifier(type_name, "Column type") + return f"{name} {type_name}" + else: + validate_identifier(col, "Column name") + return f"{col} agtype" + + +def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str: + if graphName == None: + raise _EXCEPTION_GraphNotSet + + columnExp=[] + if columns != None and len(columns) > 0: + for col in columns: + validated = _validate_column(col) + if validated: + columnExp.append(validated) + else: + columnExp.append('v agtype') + + # Design note: String concatenation is used here instead of + # psycopg.sql.Identifier() because column specifications are + # "name type" pairs (e.g. "v agtype") that don't map directly to + # sql.Identifier(). Each component has already been validated by + # _validate_column() → validate_identifier(), which restricts + # names to ^[A-Za-z_][A-Za-z0-9_]*$ and max 63 chars. The + # graphName and cypherStmt are NOT embedded here — this template + # only contains the validated column list and static SQL keywords. + stmtArr = [] + stmtArr.append("SELECT * from cypher(NULL,NULL) as (") + stmtArr.append(','.join(columnExp)) + stmtArr.append(");") + return "".join(stmtArr) + +def execSql(conn:psycopg.connection, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor : + if conn == None or conn.closed: + raise _EXCEPTION_NoConnection + + cursor = conn.cursor() + try: + cursor.execute(stmt, params) + if commit: + conn.commit() + + return cursor + except SyntaxError as cause: + conn.rollback() + raise cause + except Exception as cause: + conn.rollback() + raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause) + + +def querySql(conn:psycopg.connection, stmt:str, params:tuple=None) -> psycopg.cursor : + return execSql(conn, stmt, False, params) + +# Execute cypher statement and return cursor. +# If cypher statement changes data (create, set, remove), +# You must commit session(ag.commit()) +# (Otherwise the execution cannot make any effect.) +def execCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : + if conn == None or conn.closed: + raise _EXCEPTION_NoConnection + + cursor = conn.cursor() + #clean up the string for mogrification + cypherStmt = cypherStmt.replace("\n", "") + cypherStmt = cypherStmt.replace("\t", "") + cypher = str(cursor.mogrify(cypherStmt, params)) + cypher = cypher.strip() + + preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})" + + cursor = conn.cursor() + try: + cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher))) + except SyntaxError as cause: + conn.rollback() + raise cause + except Exception as cause: + conn.rollback() + raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + preparedStmt +")", cause) + + stmt = buildCypher(graphName, cypher, cols) + + cursor = conn.cursor() + try: + cursor.execute(stmt) + return cursor + except SyntaxError as cause: + conn.rollback() + raise cause + except Exception as cause: + conn.rollback() + raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause) + + +def cypher(cursor:psycopg.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : + #clean up the string for mogrification + cypherStmt = cypherStmt.replace("\n", "") + cypherStmt = cypherStmt.replace("\t", "") + cypher = str(cursor.mogrify(cypherStmt, params)) + cypher = cypher.strip() + + preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})" + cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher))) + + stmt = buildCypher(graphName, cypher, cols) + cursor.execute(stmt) + + +# def execCypherWithReturn(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : +# stmt = buildCypher(graphName, cypherStmt, columns) +# return execSql(conn, stmt, False, params) + +# def queryCypher(conn:psycopg.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : +# return execCypherWithReturn(conn, graphName, cypherStmt, columns, params) + + +class Age: + def __init__(self): + self.connection = None # psycopg connection] + self.graphName = None + + # Connect to PostgreSQL Server and establish session and type extension environment. + def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor, + load_from_plugins:bool=False, **kwargs): + conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs) + setUpAge(conn, graph, load_from_plugins) + self.connection = conn + self.graphName = graph + return self + + def close(self): + self.connection.close() + + def setGraph(self, graph:str): + checkGraphCreated(self.connection, graph) + self.graphName = graph + return self + + def commit(self): + self.connection.commit() + + def rollback(self): + self.connection.rollback() + + def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : + return execCypher(self.connection, self.graphName, cypherStmt, cols=cols, params=params) + + def cypher(self, cursor:psycopg.cursor, cypherStmt:str, cols:list=None, params:tuple=None) -> psycopg.cursor : + return cypher(cursor, self.graphName, cypherStmt, cols=cols, params=params) + + # def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor : + # return execSql(self.connection, stmt, commit, params) + + + # def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> psycopg.cursor : + # return execCypher(self.connection, self.graphName, cypherStmt, commit, params) + + # def execCypherWithReturn(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : + # return execCypherWithReturn(self.connection, self.graphName, cypherStmt, columns, params) + + # def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> psycopg.cursor : + # return queryCypher(self.connection, self.graphName, cypherStmt, columns, params) + diff --git a/drivers/python/age/builder.py b/drivers/python/age/builder.py index a3815b829..f1e7a2ce8 100644 --- a/drivers/python/age/builder.py +++ b/drivers/python/age/builder.py @@ -1,210 +1,210 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from . import gen -from .gen.AgtypeLexer import AgtypeLexer -from .gen.AgtypeParser import AgtypeParser -from .gen.AgtypeVisitor import AgtypeVisitor -from .models import * -from .exceptions import * -from antlr4 import InputStream, CommonTokenStream, ParserRuleContext -from antlr4.tree.Tree import TerminalNode -from decimal import Decimal - -resultHandler = None - -class ResultHandler: - def parse(ageData): - pass - -def newResultHandler(query=""): - resultHandler = Antlr4ResultHandler(None, query) - return resultHandler - -def parseAgeValue(value, cursor=None): - if value is None: - return None - - global resultHandler - if (resultHandler == None): - resultHandler = Antlr4ResultHandler(None) - try: - return resultHandler.parse(value) - except Exception as ex: - raise AGTypeError(value, ex) - - -class Antlr4ResultHandler(ResultHandler): - def __init__(self, vertexCache, query=None): - self.lexer = AgtypeLexer() - self.parser = AgtypeParser(None) - self.visitor = ResultVisitor(vertexCache) - - def parse(self, ageData): - if not ageData: - return None - # print("Parse::", ageData) - - self.lexer.inputStream = InputStream(ageData) - self.parser.setTokenStream(CommonTokenStream(self.lexer)) - self.parser.reset() - tree = self.parser.agType() - parsed = tree.accept(self.visitor) - return parsed - - -# print raw result String -class DummyResultHandler(ResultHandler): - def parse(self, ageData): - print(ageData) - -# default agType visitor -class ResultVisitor(AgtypeVisitor): - vertexCache = None - - def __init__(self, cache) -> None: - super().__init__() - self.vertexCache = cache - - - def visitAgType(self, ctx:AgtypeParser.AgTypeContext): - agVal = ctx.agValue() - if agVal != None: - obj = ctx.agValue().accept(self) - return obj - - return None - - def visitAgValue(self, ctx:AgtypeParser.AgValueContext): - annoCtx = ctx.typeAnnotation() - valueCtx = ctx.value() - - if annoCtx is not None: - annoCtx.accept(self) - anno = annoCtx.IDENT().getText() - return self.handleAnnotatedValue(anno, valueCtx) - else: - return valueCtx.accept(self) - - - # Visit a parse tree produced by AgtypeParser#StringValue. - def visitStringValue(self, ctx:AgtypeParser.StringValueContext): - return ctx.STRING().getText().strip('"') - - - # Visit a parse tree produced by AgtypeParser#IntegerValue. - def visitIntegerValue(self, ctx:AgtypeParser.IntegerValueContext): - return int(ctx.INTEGER().getText()) - - # Visit a parse tree produced by AgtypeParser#floatLiteral. - def visitFloatLiteral(self, ctx:AgtypeParser.FloatLiteralContext): - c = ctx.getChild(0) - tp = c.symbol.type - text = ctx.getText() - if tp == AgtypeParser.RegularFloat: - return float(text) - elif tp == AgtypeParser.ExponentFloat: - return float(text) - else: - if text == 'NaN': - return float('nan') - elif text == '-Infinity': - return float('-inf') - elif text == 'Infinity': - return float('inf') - else: - return Exception("Unknown float expression:"+text) - - - # Visit a parse tree produced by AgtypeParser#TrueBoolean. - def visitTrueBoolean(self, ctx:AgtypeParser.TrueBooleanContext): - return True - - - # Visit a parse tree produced by AgtypeParser#FalseBoolean. - def visitFalseBoolean(self, ctx:AgtypeParser.FalseBooleanContext): - return False - - - # Visit a parse tree produced by AgtypeParser#NullValue. - def visitNullValue(self, ctx:AgtypeParser.NullValueContext): - return None - - - # Visit a parse tree produced by AgtypeParser#obj. - def visitObj(self, ctx:AgtypeParser.ObjContext): - obj = dict() - for c in ctx.getChildren(): - if isinstance(c, AgtypeParser.PairContext): - namVal = self.visitPair(c) - name = namVal[0] - valCtx = namVal[1] - val = valCtx.accept(self) - obj[name] = val - return obj - - - # Visit a parse tree produced by AgtypeParser#pair. - def visitPair(self, ctx:AgtypeParser.PairContext): - self.visitChildren(ctx) - return (ctx.STRING().getText().strip('"') , ctx.agValue()) - - - # Visit a parse tree produced by AgtypeParser#array. - def visitArray(self, ctx:AgtypeParser.ArrayContext): - li = list() - for c in ctx.getChildren(): - if not isinstance(c, TerminalNode): - val = c.accept(self) - li.append(val) - return li - - def handleAnnotatedValue(self, anno:str, ctx:ParserRuleContext): - if anno == "numeric": - return Decimal(ctx.getText()) - elif anno == "vertex": - dict = ctx.accept(self) - vid = dict["id"] - vertex = None - if self.vertexCache != None and vid in self.vertexCache : - vertex = self.vertexCache[vid] - else: - vertex = Vertex() - vertex.id = dict["id"] - vertex.label = dict["label"] - vertex.properties = dict["properties"] - - if self.vertexCache != None: - self.vertexCache[vid] = vertex - - return vertex - - elif anno == "edge": - edge = Edge() - dict = ctx.accept(self) - edge.id = dict["id"] - edge.label = dict["label"] - edge.end_id = dict["end_id"] - edge.start_id = dict["start_id"] - edge.properties = dict["properties"] - - return edge - - elif anno == "path": - arr = ctx.accept(self) - path = Path(arr) - - return path - - return ctx.accept(self) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from . import gen +from .gen.AgtypeLexer import AgtypeLexer +from .gen.AgtypeParser import AgtypeParser +from .gen.AgtypeVisitor import AgtypeVisitor +from .models import * +from .exceptions import * +from antlr4 import InputStream, CommonTokenStream, ParserRuleContext +from antlr4.tree.Tree import TerminalNode +from decimal import Decimal + +resultHandler = None + +class ResultHandler: + def parse(ageData): + pass + +def newResultHandler(query=""): + resultHandler = Antlr4ResultHandler(None, query) + return resultHandler + +def parseAgeValue(value, cursor=None): + if value is None: + return None + + global resultHandler + if (resultHandler == None): + resultHandler = Antlr4ResultHandler(None) + try: + return resultHandler.parse(value) + except Exception as ex: + raise AGTypeError(value, ex) + + +class Antlr4ResultHandler(ResultHandler): + def __init__(self, vertexCache, query=None): + self.lexer = AgtypeLexer() + self.parser = AgtypeParser(None) + self.visitor = ResultVisitor(vertexCache) + + def parse(self, ageData): + if not ageData: + return None + # print("Parse::", ageData) + + self.lexer.inputStream = InputStream(ageData) + self.parser.setTokenStream(CommonTokenStream(self.lexer)) + self.parser.reset() + tree = self.parser.agType() + parsed = tree.accept(self.visitor) + return parsed + + +# print raw result String +class DummyResultHandler(ResultHandler): + def parse(self, ageData): + print(ageData) + +# default agType visitor +class ResultVisitor(AgtypeVisitor): + vertexCache = None + + def __init__(self, cache) -> None: + super().__init__() + self.vertexCache = cache + + + def visitAgType(self, ctx:AgtypeParser.AgTypeContext): + agVal = ctx.agValue() + if agVal != None: + obj = ctx.agValue().accept(self) + return obj + + return None + + def visitAgValue(self, ctx:AgtypeParser.AgValueContext): + annoCtx = ctx.typeAnnotation() + valueCtx = ctx.value() + + if annoCtx is not None: + annoCtx.accept(self) + anno = annoCtx.IDENT().getText() + return self.handleAnnotatedValue(anno, valueCtx) + else: + return valueCtx.accept(self) + + + # Visit a parse tree produced by AgtypeParser#StringValue. + def visitStringValue(self, ctx:AgtypeParser.StringValueContext): + return ctx.STRING().getText().strip('"') + + + # Visit a parse tree produced by AgtypeParser#IntegerValue. + def visitIntegerValue(self, ctx:AgtypeParser.IntegerValueContext): + return int(ctx.INTEGER().getText()) + + # Visit a parse tree produced by AgtypeParser#floatLiteral. + def visitFloatLiteral(self, ctx:AgtypeParser.FloatLiteralContext): + c = ctx.getChild(0) + tp = c.symbol.type + text = ctx.getText() + if tp == AgtypeParser.RegularFloat: + return float(text) + elif tp == AgtypeParser.ExponentFloat: + return float(text) + else: + if text == 'NaN': + return float('nan') + elif text == '-Infinity': + return float('-inf') + elif text == 'Infinity': + return float('inf') + else: + raise ValueError("Unknown float expression: " + text) + + + # Visit a parse tree produced by AgtypeParser#TrueBoolean. + def visitTrueBoolean(self, ctx:AgtypeParser.TrueBooleanContext): + return True + + + # Visit a parse tree produced by AgtypeParser#FalseBoolean. + def visitFalseBoolean(self, ctx:AgtypeParser.FalseBooleanContext): + return False + + + # Visit a parse tree produced by AgtypeParser#NullValue. + def visitNullValue(self, ctx:AgtypeParser.NullValueContext): + return None + + + # Visit a parse tree produced by AgtypeParser#obj. + def visitObj(self, ctx:AgtypeParser.ObjContext): + obj = dict() + for c in ctx.getChildren(): + if isinstance(c, AgtypeParser.PairContext): + namVal = self.visitPair(c) + name = namVal[0] + valCtx = namVal[1] + val = valCtx.accept(self) + obj[name] = val + return obj + + + # Visit a parse tree produced by AgtypeParser#pair. + def visitPair(self, ctx:AgtypeParser.PairContext): + self.visitChildren(ctx) + return (ctx.STRING().getText().strip('"') , ctx.agValue()) + + + # Visit a parse tree produced by AgtypeParser#array. + def visitArray(self, ctx:AgtypeParser.ArrayContext): + li = list() + for c in ctx.getChildren(): + if not isinstance(c, TerminalNode): + val = c.accept(self) + li.append(val) + return li + + def handleAnnotatedValue(self, anno:str, ctx:ParserRuleContext): + if anno == "numeric": + return Decimal(ctx.getText()) + elif anno == "vertex": + dict = ctx.accept(self) + vid = dict["id"] + vertex = None + if self.vertexCache != None and vid in self.vertexCache : + vertex = self.vertexCache[vid] + else: + vertex = Vertex() + vertex.id = dict["id"] + vertex.label = dict["label"] + vertex.properties = dict["properties"] + + if self.vertexCache != None: + self.vertexCache[vid] = vertex + + return vertex + + elif anno == "edge": + edge = Edge() + dict = ctx.accept(self) + edge.id = dict["id"] + edge.label = dict["label"] + edge.end_id = dict["end_id"] + edge.start_id = dict["start_id"] + edge.properties = dict["properties"] + + return edge + + elif anno == "path": + arr = ctx.accept(self) + path = Path(arr) + + return path + + return ctx.accept(self) diff --git a/drivers/python/age/exceptions.py b/drivers/python/age/exceptions.py index 3aa94f4b8..18292cc08 100644 --- a/drivers/python/age/exceptions.py +++ b/drivers/python/age/exceptions.py @@ -16,39 +16,74 @@ from psycopg.errors import * class AgeNotSet(Exception): - def __init__(self, name): + def __init__(self, name=None): self.name = name + super().__init__(name or 'AGE extension is not set.') - def __repr__(self) : + def __repr__(self): return 'AGE extension is not set.' class GraphNotFound(Exception): - def __init__(self, name): + def __init__(self, name=None): self.name = name + super().__init__(f'Graph[{name}] does not exist.' if name else 'Graph does not exist.') - def __repr__(self) : - return 'Graph[' + self.name + '] does not exist.' + def __repr__(self): + if self.name: + return 'Graph[' + self.name + '] does not exist.' + return 'Graph does not exist.' class GraphAlreadyExists(Exception): - def __init__(self, name): + def __init__(self, name=None): self.name = name + super().__init__(f'Graph[{name}] already exists.' if name else 'Graph already exists.') - def __repr__(self) : - return 'Graph[' + self.name + '] already exists.' + def __repr__(self): + if self.name: + return 'Graph[' + self.name + '] already exists.' + return 'Graph already exists.' + + +class InvalidGraphName(Exception): + """Raised when a graph name contains invalid characters.""" + def __init__(self, name, reason=None): + self.name = name + self.reason = reason + msg = f"Invalid graph name: '{name}'." + if reason: + msg += f" {reason}" + super().__init__(msg) + + def __repr__(self): + return f"InvalidGraphName('{self.name}')" + + +class InvalidIdentifier(Exception): + """Raised when an identifier (column, label, etc.) is invalid.""" + def __init__(self, name, context=None): + self.name = name + self.context = context + msg = f"Invalid identifier: '{name}'." + if context: + msg += f" {context}" + super().__init__(msg) + + def __repr__(self): + return f"InvalidIdentifier('{self.name}')" class GraphNotSet(Exception): - def __repr__(self) : + def __repr__(self): return 'Graph name is not set.' class NoConnection(Exception): - def __repr__(self) : + def __repr__(self): return 'No Connection' class NoCursor(Exception): - def __repr__(self) : + def __repr__(self): return 'No Cursor' class SqlExecutionError(Exception): @@ -57,7 +92,7 @@ def __init__(self, msg, cause): self.cause = cause super().__init__(msg, cause) - def __repr__(self) : + def __repr__(self): return 'SqlExecution [' + self.msg + ']' class AGTypeError(Exception): diff --git a/drivers/python/age/models.py b/drivers/python/age/models.py index aee1b7599..6d9095485 100644 --- a/drivers/python/age/models.py +++ b/drivers/python/age/models.py @@ -1,294 +1,294 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import json -from io import StringIO - - -TP_NONE = 0 -TP_VERTEX = 1 -TP_EDGE = 2 -TP_PATH = 3 - - -class Graph(): - def __init__(self, stmt=None) -> None: - self.statement = stmt - self.rows = list() - self.vertices = dict() - - def __iter__(self): - return self.rows.__iter__() - - def __len__(self): - return self.rows.__len__() - - def __getitem__(self,index): - return self.rows[index] - - def size(self): - return self.rows.__len__() - - def append(self, agObj): - self.rows.append(agObj) - - def getVertices(self): - return self.vertices - - def getVertex(self, id): - if id in self.vertices: - return self.vertices[id] - else: - return None - -class AGObj: - @property - def gtype(self): - return TP_NONE - - -class Path(AGObj): - entities = [] - def __init__(self, entities=None) -> None: - self.entities = entities - - @property - def gtype(self): - return TP_PATH - - def __iter__(self): - return self.entities.__iter__() - - def __len__(self): - return self.entities.__len__() - - def __getitem__(self,index): - return self.entities[index] - - def size(self): - return self.entities.__len__() - - def append(self, agObj:AGObj ): - self.entities.append(agObj) - - def __str__(self) -> str: - return self.toString() - - def __repr__(self) -> str: - return self.toString() - - def toString(self) -> str: - buf = StringIO() - buf.write("[") - max = len(self.entities) - idx = 0 - while idx < max: - if idx > 0: - buf.write(",") - self.entities[idx]._toString(buf) - idx += 1 - buf.write("]::PATH") - - return buf.getvalue() - - def toJson(self) -> str: - buf = StringIO() - buf.write("{\"gtype\": \"path\", \"elements\": [") - - max = len(self.entities) - idx = 0 - while idx < max: - if idx > 0: - buf.write(",") - self.entities[idx]._toJson(buf) - idx += 1 - buf.write("]}") - - return buf.getvalue() - - - - -class Vertex(AGObj): - def __init__(self, id=None, label=None, properties=None) -> None: - self.id = id - self.label = label - self.properties = properties - - @property - def gtype(self): - return TP_VERTEX - - def __setitem__(self,name, value): - self.properties[name]=value - - def __getitem__(self,name): - if name in self.properties: - return self.properties[name] - else: - return None - - def __str__(self) -> str: - return self.toString() - - def __repr__(self) -> str: - return self.toString() - - def toString(self) -> str: - return nodeToString(self) - - def _toString(self, buf): - _nodeToString(self, buf) - - def toJson(self) -> str: - return nodeToJson(self) - - def _toJson(self, buf): - _nodeToJson(self, buf) - - -class Edge(AGObj): - def __init__(self, id=None, label=None, properties=None) -> None: - self.id = id - self.label = label - self.start_id = None - self.end_id = None - self.properties = properties - - @property - def gtype(self): - return TP_EDGE - - def __setitem__(self,name, value): - self.properties[name]=value - - def __getitem__(self,name): - if name in self.properties: - return self.properties[name] - else: - return None - - def __str__(self) -> str: - return self.toString() - - def __repr__(self) -> str: - return self.toString() - - def extraStrFormat(node, buf): - if node.start_id != None: - buf.write(", start_id:") - buf.write(str(node.start_id)) - - if node.end_id != None: - buf.write(", end_id:") - buf.write(str(node.end_id)) - - - def toString(self) -> str: - return nodeToString(self, Edge.extraStrFormat) - - def _toString(self, buf): - _nodeToString(self, buf, Edge.extraStrFormat) - - def extraJsonFormat(node, buf): - if node.start_id != None: - buf.write(", \"start_id\": \"") - buf.write(str(node.start_id)) - buf.write("\"") - - if node.end_id != None: - buf.write(", \"end_id\": \"") - buf.write(str(node.end_id)) - buf.write("\"") - - def toJson(self) -> str: - return nodeToJson(self, Edge.extraJsonFormat) - - def _toJson(self, buf): - _nodeToJson(self, buf, Edge.extraJsonFormat) - - -def nodeToString(node, extraFormatter=None): - buf = StringIO() - _nodeToString(node,buf,extraFormatter=extraFormatter) - return buf.getvalue() - - -def _nodeToString(node, buf, extraFormatter=None): - buf.write("{") - if node.label != None: - buf.write("label:") - buf.write(node.label) - - if node.id != None: - buf.write(", id:") - buf.write(str(node.id)) - - if node.properties != None: - buf.write(", properties:{") - prop_list = [] - for k, v in node.properties.items(): - prop_list.append(f"{k}: {str(v)}") - - # Join properties with comma and write to buffer - buf.write(", ".join(prop_list)) - buf.write("}") - - if extraFormatter != None: - extraFormatter(node, buf) - - if node.gtype == TP_VERTEX: - buf.write("}::VERTEX") - if node.gtype == TP_EDGE: - buf.write("}::EDGE") - - -def nodeToJson(node, extraFormatter=None): - buf = StringIO() - _nodeToJson(node, buf, extraFormatter=extraFormatter) - return buf.getvalue() - - -def _nodeToJson(node, buf, extraFormatter=None): - buf.write("{\"gtype\": ") - if node.gtype == TP_VERTEX: - buf.write("\"vertex\", ") - if node.gtype == TP_EDGE: - buf.write("\"edge\", ") - - if node.label != None: - buf.write("\"label\":\"") - buf.write(node.label) - buf.write("\"") - - if node.id != None: - buf.write(", \"id\":") - buf.write(str(node.id)) - - if extraFormatter != None: - extraFormatter(node, buf) - - if node.properties != None: - buf.write(", \"properties\":{") - - prop_list = [] - for k, v in node.properties.items(): - prop_list.append(f"\"{k}\": \"{str(v)}\"") - - # Join properties with comma and write to buffer - buf.write(", ".join(prop_list)) - buf.write("}") - buf.write("}") +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from io import StringIO + + +TP_NONE = 0 +TP_VERTEX = 1 +TP_EDGE = 2 +TP_PATH = 3 + + +class Graph(): + def __init__(self, stmt=None) -> None: + self.statement = stmt + self.rows = list() + self.vertices = dict() + + def __iter__(self): + return self.rows.__iter__() + + def __len__(self): + return self.rows.__len__() + + def __getitem__(self,index): + return self.rows[index] + + def size(self): + return self.rows.__len__() + + def append(self, agObj): + self.rows.append(agObj) + + def getVertices(self): + return self.vertices + + def getVertex(self, id): + if id in self.vertices: + return self.vertices[id] + else: + return None + +class AGObj: + @property + def gtype(self): + return TP_NONE + + +class Path(AGObj): + entities = [] + def __init__(self, entities=None) -> None: + self.entities = entities + + @property + def gtype(self): + return TP_PATH + + def __iter__(self): + return self.entities.__iter__() + + def __len__(self): + return self.entities.__len__() + + def __getitem__(self,index): + return self.entities[index] + + def size(self): + return self.entities.__len__() + + def append(self, agObj:AGObj ): + self.entities.append(agObj) + + def __str__(self) -> str: + return self.toString() + + def __repr__(self) -> str: + return self.toString() + + def toString(self) -> str: + buf = StringIO() + buf.write("[") + max = len(self.entities) + idx = 0 + while idx < max: + if idx > 0: + buf.write(",") + self.entities[idx]._toString(buf) + idx += 1 + buf.write("]::PATH") + + return buf.getvalue() + + def toJson(self) -> str: + buf = StringIO() + buf.write("{\"gtype\": \"path\", \"elements\": [") + + max = len(self.entities) + idx = 0 + while idx < max: + if idx > 0: + buf.write(",") + self.entities[idx]._toJson(buf) + idx += 1 + buf.write("]}") + + return buf.getvalue() + + + + +class Vertex(AGObj): + def __init__(self, id=None, label=None, properties=None) -> None: + self.id = id + self.label = label + self.properties = properties + + @property + def gtype(self): + return TP_VERTEX + + def __setitem__(self,name, value): + self.properties[name]=value + + def __getitem__(self,name): + if name in self.properties: + return self.properties[name] + else: + return None + + def __str__(self) -> str: + return self.toString() + + def __repr__(self) -> str: + return self.toString() + + def toString(self) -> str: + return nodeToString(self) + + def _toString(self, buf): + _nodeToString(self, buf) + + def toJson(self) -> str: + return nodeToJson(self) + + def _toJson(self, buf): + _nodeToJson(self, buf) + + +class Edge(AGObj): + def __init__(self, id=None, label=None, properties=None) -> None: + self.id = id + self.label = label + self.start_id = None + self.end_id = None + self.properties = properties + + @property + def gtype(self): + return TP_EDGE + + def __setitem__(self,name, value): + self.properties[name]=value + + def __getitem__(self,name): + if name in self.properties: + return self.properties[name] + else: + return None + + def __str__(self) -> str: + return self.toString() + + def __repr__(self) -> str: + return self.toString() + + def extraStrFormat(node, buf): + if node.start_id != None: + buf.write(", start_id:") + buf.write(str(node.start_id)) + + if node.end_id != None: + buf.write(", end_id:") + buf.write(str(node.end_id)) + + + def toString(self) -> str: + return nodeToString(self, Edge.extraStrFormat) + + def _toString(self, buf): + _nodeToString(self, buf, Edge.extraStrFormat) + + def extraJsonFormat(node, buf): + if node.start_id != None: + buf.write(", \"start_id\": \"") + buf.write(str(node.start_id)) + buf.write("\"") + + if node.end_id != None: + buf.write(", \"end_id\": \"") + buf.write(str(node.end_id)) + buf.write("\"") + + def toJson(self) -> str: + return nodeToJson(self, Edge.extraJsonFormat) + + def _toJson(self, buf): + _nodeToJson(self, buf, Edge.extraJsonFormat) + + +def nodeToString(node, extraFormatter=None): + buf = StringIO() + _nodeToString(node,buf,extraFormatter=extraFormatter) + return buf.getvalue() + + +def _nodeToString(node, buf, extraFormatter=None): + buf.write("{") + if node.label != None: + buf.write("label:") + buf.write(node.label) + + if node.id != None: + buf.write(", id:") + buf.write(str(node.id)) + + if node.properties != None: + buf.write(", properties:{") + prop_list = [] + for k, v in node.properties.items(): + prop_list.append(f"{k}: {str(v)}") + + # Join properties with comma and write to buffer + buf.write(", ".join(prop_list)) + buf.write("}") + + if extraFormatter != None: + extraFormatter(node, buf) + + if node.gtype == TP_VERTEX: + buf.write("}::VERTEX") + if node.gtype == TP_EDGE: + buf.write("}::EDGE") + + +def nodeToJson(node, extraFormatter=None): + buf = StringIO() + _nodeToJson(node, buf, extraFormatter=extraFormatter) + return buf.getvalue() + + +def _nodeToJson(node, buf, extraFormatter=None): + buf.write("{\"gtype\": ") + if node.gtype == TP_VERTEX: + buf.write("\"vertex\", ") + if node.gtype == TP_EDGE: + buf.write("\"edge\", ") + + if node.label != None: + buf.write("\"label\":\"") + buf.write(node.label) + buf.write("\"") + + if node.id != None: + buf.write(", \"id\":") + buf.write(str(node.id)) + + if extraFormatter != None: + extraFormatter(node, buf) + + if node.properties != None: + buf.write(", \"properties\":{") + + prop_list = [] + for k, v in node.properties.items(): + prop_list.append(f"\"{k}\": \"{str(v)}\"") + + # Join properties with comma and write to buffer + buf.write(", ".join(prop_list)) + buf.write("}") + buf.write("}") \ No newline at end of file diff --git a/drivers/python/age/networkx/lib.py b/drivers/python/age/networkx/lib.py index 308658620..5df761eae 100644 --- a/drivers/python/age/networkx/lib.py +++ b/drivers/python/age/networkx/lib.py @@ -20,17 +20,18 @@ from psycopg import sql from typing import Dict, Any, List, Set from age.models import Vertex, Edge, Path +from age.age import validate_graph_name, validate_identifier def checkIfGraphNameExistInAGE(connection: psycopg.connect, graphName: str): """Check if the age graph exists""" + validate_graph_name(graphName) with connection.cursor() as cursor: - cursor.execute(sql.SQL(""" - SELECT count(*) - FROM ag_catalog.ag_graph - WHERE name='%s' - """ % (graphName))) + cursor.execute( + sql.SQL("SELECT count(*) FROM ag_catalog.ag_graph WHERE name={gn}") + .format(gn=sql.Literal(graphName)) + ) if cursor.fetchone()[0] == 0: raise GraphNotFound(graphName) @@ -38,11 +39,13 @@ def checkIfGraphNameExistInAGE(connection: psycopg.connect, def getOidOfGraph(connection: psycopg.connect, graphName: str) -> int: """Returns oid of a graph""" + validate_graph_name(graphName) try: with connection.cursor() as cursor: - cursor.execute(sql.SQL(""" - SELECT graphid FROM ag_catalog.ag_graph WHERE name='%s' ; - """ % (graphName))) + cursor.execute( + sql.SQL("SELECT graphid FROM ag_catalog.ag_graph WHERE name={gn}") + .format(gn=sql.Literal(graphName)) + ) oid = cursor.fetchone()[0] return oid except Exception as e: @@ -56,7 +59,9 @@ def get_vlabel(connection: psycopg.connect, try: with connection.cursor() as cursor: cursor.execute( - """SELECT name FROM ag_catalog.ag_label WHERE kind='v' AND graph=%s;""" % oid) + sql.SQL("SELECT name FROM ag_catalog.ag_label WHERE kind='v' AND graph={oid}") + .format(oid=sql.Literal(oid)) + ) for row in cursor: node_label_list.append(row[0]) @@ -69,18 +74,19 @@ def create_vlabel(connection: psycopg.connect, graphName: str, node_label_list: List): """create_vlabels from list if not exist""" + validate_graph_name(graphName) try: node_label_set = set(get_vlabel(connection, graphName)) - crete_label_statement = '' for label in node_label_list: if label in node_label_set: continue - crete_label_statement += """SELECT create_vlabel('%s','%s');\n""" % ( - graphName, label) - if crete_label_statement != '': + validate_identifier(label, "Vertex label") with connection.cursor() as cursor: - cursor.execute(crete_label_statement) - connection.commit() + cursor.execute( + sql.SQL("SELECT create_vlabel({gn},{lbl})") + .format(gn=sql.Literal(graphName), lbl=sql.Literal(label)) + ) + connection.commit() except Exception as e: raise Exception(e) @@ -92,7 +98,9 @@ def get_elabel(connection: psycopg.connect, try: with connection.cursor() as cursor: cursor.execute( - """SELECT name FROM ag_catalog.ag_label WHERE kind='e' AND graph=%s;""" % oid) + sql.SQL("SELECT name FROM ag_catalog.ag_label WHERE kind='e' AND graph={oid}") + .format(oid=sql.Literal(oid)) + ) for row in cursor: edge_label_list.append(row[0]) except Exception as ex: @@ -103,19 +111,20 @@ def get_elabel(connection: psycopg.connect, def create_elabel(connection: psycopg.connect, graphName: str, edge_label_list: List): - """create_vlabels from list if not exist""" + """create_elabels from list if not exist""" + validate_graph_name(graphName) try: edge_label_set = set(get_elabel(connection, graphName)) - crete_label_statement = '' for label in edge_label_list: if label in edge_label_set: continue - crete_label_statement += """SELECT create_elabel('%s','%s');\n""" % ( - graphName, label) - if crete_label_statement != '': + validate_identifier(label, "Edge label") with connection.cursor() as cursor: - cursor.execute(crete_label_statement) - connection.commit() + cursor.execute( + sql.SQL("SELECT create_elabel({gn},{lbl})") + .format(gn=sql.Literal(graphName), lbl=sql.Literal(label)) + ) + connection.commit() except Exception as e: raise Exception(e) @@ -171,6 +180,7 @@ def getEdgeLabelListAfterPreprocessing(G: nx.DiGraph): def addAllNodesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGraph, node_label_list: Set): """Add all node to AGE""" + validate_graph_name(graphName) try: queue_data = {label: [] for label in node_label_list} id_data = {} @@ -180,8 +190,11 @@ def addAllNodesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGrap queue_data[data['label']].append((json_string,)) for label, rows in queue_data.items(): - table_name = """%s."%s" """ % (graphName, label) - insert_query = f"INSERT INTO {table_name} (properties) VALUES (%s) RETURNING id" + validate_identifier(label, "Node label") + insert_query = sql.SQL("INSERT INTO {schema}.{table} (properties) VALUES (%s) RETURNING id").format( + schema=sql.Identifier(graphName), + table=sql.Identifier(label) + ) cursor = connection.cursor() cursor.executemany(insert_query, rows, returning=True) ids = [] @@ -205,6 +218,7 @@ def addAllNodesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGrap def addAllEdgesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGraph, edge_label_list: Set): """Add all edge to AGE""" + validate_graph_name(graphName) try: queue_data = {label: [] for label in edge_label_list} for u, v, data in G.edges(data=True): @@ -213,8 +227,11 @@ def addAllEdgesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGrap (G.nodes[u]['properties']['__gid__'], G.nodes[v]['properties']['__gid__'], json_string,)) for label, rows in queue_data.items(): - table_name = """%s."%s" """ % (graphName, label) - insert_query = f"INSERT INTO {table_name} (start_id,end_id,properties) VALUES (%s, %s, %s)" + validate_identifier(label, "Edge label") + insert_query = sql.SQL("INSERT INTO {schema}.{table} (start_id,end_id,properties) VALUES (%s, %s, %s)").format( + schema=sql.Identifier(graphName), + table=sql.Identifier(label) + ) cursor = connection.cursor() cursor.executemany(insert_query, rows) connection.commit() @@ -225,14 +242,19 @@ def addAllEdgesIntoAGE(connection: psycopg.connect, graphName: str, G: nx.DiGrap def addAllNodesIntoNetworkx(connection: psycopg.connect, graphName: str, G: nx.DiGraph): """Add all nodes to Networkx""" + validate_graph_name(graphName) node_label_list = get_vlabel(connection, graphName) try: for label in node_label_list: + validate_identifier(label, "Node label") with connection.cursor() as cursor: - cursor.execute(""" - SELECT id, CAST(properties AS VARCHAR) - FROM %s."%s"; - """ % (graphName, label)) + cursor.execute( + sql.SQL("SELECT id, CAST(properties AS VARCHAR) FROM {schema}.{table}") + .format( + schema=sql.Identifier(graphName), + table=sql.Identifier(label) + ) + ) rows = cursor.fetchall() for row in rows: G.add_node(int(row[0]), label=label, @@ -243,14 +265,19 @@ def addAllNodesIntoNetworkx(connection: psycopg.connect, graphName: str, G: nx.D def addAllEdgesIntoNetworkx(connection: psycopg.connect, graphName: str, G: nx.DiGraph): """Add All edges to Networkx""" + validate_graph_name(graphName) try: edge_label_list = get_elabel(connection, graphName) for label in edge_label_list: + validate_identifier(label, "Edge label") with connection.cursor() as cursor: - cursor.execute(""" - SELECT start_id, end_id, CAST(properties AS VARCHAR) - FROM %s."%s"; - """ % (graphName, label)) + cursor.execute( + sql.SQL("SELECT start_id, end_id, CAST(properties AS VARCHAR) FROM {schema}.{table}") + .format( + schema=sql.Identifier(graphName), + table=sql.Identifier(label) + ) + ) rows = cursor.fetchall() for row in rows: G.add_edge(int(row[0]), int( diff --git a/drivers/python/requirements.txt b/drivers/python/requirements.txt index b0593b79218c7fc2ac7cff48c8fcabb87b24fd10..449d38c673c0d668086c21e2481caa9693dca840 100644 GIT binary patch literal 59 zcmXRYu1wA^Nasq-E6FJ`(JiPf$;i($)-5W{E6L1FwY4?TGc?pQ|;&(A65 O%1bRN&o9cZ-~s@K1{F5| literal 176 zcmY+6K?=e^5CrQi_=kKTf|ygl$3zJljGJY%qWOHZHaBS)da9@AyGCXfu1rL3RMaZC z)m#{K9m%|+)s3pv|9AH6%mUdo(b$YOGIzfOPVR}PM^$F&&+_b5bWUoN N6dpGImLwj0_yN-MAWQ%N diff --git a/drivers/python/setup.py b/drivers/python/setup.py index d0eed26be..853f1006a 100644 --- a/drivers/python/setup.py +++ b/drivers/python/setup.py @@ -1,22 +1,22 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This setup.py is maintained for backward compatibility. -# All package configuration is in pyproject.toml. For installation, -# use: pip install . - -from setuptools import setup - -setup() +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This setup.py is maintained for backward compatibility. +# All package configuration is in pyproject.toml. For installation, +# use: pip install . + +from setuptools import setup + +setup() diff --git a/drivers/python/test_agtypes.py b/drivers/python/test_agtypes.py index 69bbbc298..4e9752e61 100644 --- a/drivers/python/test_agtypes.py +++ b/drivers/python/test_agtypes.py @@ -1,132 +1,132 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest -from decimal import Decimal -import math -import age - - -class TestAgtype(unittest.TestCase): - resultHandler = None - - def __init__(self, methodName: str) -> None: - super().__init__(methodName=methodName) - self.resultHandler = age.newResultHandler() - - def parse(self, exp): - return self.resultHandler.parse(exp) - - def test_scalar(self): - print("\nTesting Scalar Value Parsing. Result : ", end='') - - mapStr = '{"name": "Smith", "num":123, "yn":true, "bigInt":123456789123456789123456789123456789::numeric}' - arrStr = '["name", "Smith", "num", 123, "yn", true, 123456789123456789123456789123456789.8888::numeric]' - strStr = '"abcd"' - intStr = '1234' - floatStr = '1234.56789' - floatStr2 = '6.45161290322581e+46' - numericStr1 = '12345678901234567890123456789123456789.789::numeric' - numericStr2 = '12345678901234567890123456789123456789::numeric' - boolStr = 'true' - nullStr = '' - nanStr = "NaN" - infpStr = "Infinity" - infnStr = "-Infinity" - - mapVal = self.parse(mapStr) - arrVal = self.parse(arrStr) - str = self.parse(strStr) - intVal = self.parse(intStr) - floatVal = self.parse(floatStr) - floatVal2 = self.parse(floatStr2) - bigFloat = self.parse(numericStr1) - bigInt = self.parse(numericStr2) - boolVal = self.parse(boolStr) - nullVal = self.parse(nullStr) - nanVal = self.parse(nanStr) - infpVal = self.parse(infpStr) - infnVal = self.parse(infnStr) - - self.assertEqual(mapVal, {'name': 'Smith', 'num': 123, 'yn': True, 'bigInt': Decimal( - '123456789123456789123456789123456789')}) - self.assertEqual(arrVal, ["name", "Smith", "num", 123, "yn", True, Decimal( - "123456789123456789123456789123456789.8888")]) - self.assertEqual(str, "abcd") - self.assertEqual(intVal, 1234) - self.assertEqual(floatVal, 1234.56789) - self.assertEqual(floatVal2, 6.45161290322581e+46) - self.assertEqual(bigFloat, Decimal( - "12345678901234567890123456789123456789.789")) - self.assertEqual(bigInt, Decimal( - "12345678901234567890123456789123456789")) - self.assertEqual(boolVal, True) - self.assertTrue(math.isnan(nanVal)) - self.assertTrue(math.isinf(infpVal)) - self.assertTrue(math.isinf(infnVal)) - - def test_vertex(self): - - print("\nTesting vertex Parsing. Result : ", end='') - - vertexExp = '''{"id": 2251799813685425, "label": "Person", - "properties": {"name": "Smith", "numInt":123, "numFloat": 384.23424, - "bigInt":123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789::numeric, - "bigFloat":123456789123456789123456789123456789.12345::numeric, - "yn":true, "nullVal": null}}::vertex''' - - vertex = self.parse(vertexExp) - self.assertEqual(vertex.id, 2251799813685425) - self.assertEqual(vertex.label, "Person") - self.assertEqual(vertex["name"], "Smith") - self.assertEqual(vertex["numInt"], 123) - self.assertEqual(vertex["numFloat"], 384.23424) - self.assertEqual(vertex["bigInt"], Decimal( - "123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789")) - self.assertEqual(vertex["bigFloat"], Decimal( - "123456789123456789123456789123456789.12345")) - self.assertEqual(vertex["yn"], True) - self.assertEqual(vertex["nullVal"], None) - - def test_path(self): - - print("\nTesting Path Parsing. Result : ", end='') - - pathExp = '''[{"id": 2251799813685425, "label": "Person", "properties": {"name": "Smith"}}::vertex, - {"id": 2533274790396576, "label": "workWith", "end_id": 2251799813685425, "start_id": 2251799813685424, - "properties": {"weight": 3, "bigFloat":123456789123456789123456789.12345::numeric}}::edge, - {"id": 2251799813685424, "label": "Person", "properties": {"name": "Joe"}}::vertex]::path''' - - path = self.parse(pathExp) - vertexStart = path[0] - edge = path[1] - vertexEnd = path[2] - self.assertEqual(vertexStart.id, 2251799813685425) - self.assertEqual(vertexStart.label, "Person") - self.assertEqual(vertexStart["name"], "Smith") - - self.assertEqual(edge.id, 2533274790396576) - self.assertEqual(edge.label, "workWith") - self.assertEqual(edge["weight"], 3) - self.assertEqual(edge["bigFloat"], Decimal( - "123456789123456789123456789.12345")) - - self.assertEqual(vertexEnd.id, 2251799813685424) - self.assertEqual(vertexEnd.label, "Person") - self.assertEqual(vertexEnd["name"], "Joe") - - -if __name__ == '__main__': - unittest.main() +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from decimal import Decimal +import math +import age + + +class TestAgtype(unittest.TestCase): + resultHandler = None + + def __init__(self, methodName: str) -> None: + super().__init__(methodName=methodName) + self.resultHandler = age.newResultHandler() + + def parse(self, exp): + return self.resultHandler.parse(exp) + + def test_scalar(self): + print("\nTesting Scalar Value Parsing. Result : ", end='') + + mapStr = '{"name": "Smith", "num":123, "yn":true, "bigInt":123456789123456789123456789123456789::numeric}' + arrStr = '["name", "Smith", "num", 123, "yn", true, 123456789123456789123456789123456789.8888::numeric]' + strStr = '"abcd"' + intStr = '1234' + floatStr = '1234.56789' + floatStr2 = '6.45161290322581e+46' + numericStr1 = '12345678901234567890123456789123456789.789::numeric' + numericStr2 = '12345678901234567890123456789123456789::numeric' + boolStr = 'true' + nullStr = '' + nanStr = "NaN" + infpStr = "Infinity" + infnStr = "-Infinity" + + mapVal = self.parse(mapStr) + arrVal = self.parse(arrStr) + str = self.parse(strStr) + intVal = self.parse(intStr) + floatVal = self.parse(floatStr) + floatVal2 = self.parse(floatStr2) + bigFloat = self.parse(numericStr1) + bigInt = self.parse(numericStr2) + boolVal = self.parse(boolStr) + nullVal = self.parse(nullStr) + nanVal = self.parse(nanStr) + infpVal = self.parse(infpStr) + infnVal = self.parse(infnStr) + + self.assertEqual(mapVal, {'name': 'Smith', 'num': 123, 'yn': True, 'bigInt': Decimal( + '123456789123456789123456789123456789')}) + self.assertEqual(arrVal, ["name", "Smith", "num", 123, "yn", True, Decimal( + "123456789123456789123456789123456789.8888")]) + self.assertEqual(str, "abcd") + self.assertEqual(intVal, 1234) + self.assertEqual(floatVal, 1234.56789) + self.assertEqual(floatVal2, 6.45161290322581e+46) + self.assertEqual(bigFloat, Decimal( + "12345678901234567890123456789123456789.789")) + self.assertEqual(bigInt, Decimal( + "12345678901234567890123456789123456789")) + self.assertEqual(boolVal, True) + self.assertTrue(math.isnan(nanVal)) + self.assertTrue(math.isinf(infpVal)) + self.assertTrue(math.isinf(infnVal)) + + def test_vertex(self): + + print("\nTesting vertex Parsing. Result : ", end='') + + vertexExp = '''{"id": 2251799813685425, "label": "Person", + "properties": {"name": "Smith", "numInt":123, "numFloat": 384.23424, + "bigInt":123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789::numeric, + "bigFloat":123456789123456789123456789123456789.12345::numeric, + "yn":true, "nullVal": null}}::vertex''' + + vertex = self.parse(vertexExp) + self.assertEqual(vertex.id, 2251799813685425) + self.assertEqual(vertex.label, "Person") + self.assertEqual(vertex["name"], "Smith") + self.assertEqual(vertex["numInt"], 123) + self.assertEqual(vertex["numFloat"], 384.23424) + self.assertEqual(vertex["bigInt"], Decimal( + "123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789123456789")) + self.assertEqual(vertex["bigFloat"], Decimal( + "123456789123456789123456789123456789.12345")) + self.assertEqual(vertex["yn"], True) + self.assertEqual(vertex["nullVal"], None) + + def test_path(self): + + print("\nTesting Path Parsing. Result : ", end='') + + pathExp = '''[{"id": 2251799813685425, "label": "Person", "properties": {"name": "Smith"}}::vertex, + {"id": 2533274790396576, "label": "workWith", "end_id": 2251799813685425, "start_id": 2251799813685424, + "properties": {"weight": 3, "bigFloat":123456789123456789123456789.12345::numeric}}::edge, + {"id": 2251799813685424, "label": "Person", "properties": {"name": "Joe"}}::vertex]::path''' + + path = self.parse(pathExp) + vertexStart = path[0] + edge = path[1] + vertexEnd = path[2] + self.assertEqual(vertexStart.id, 2251799813685425) + self.assertEqual(vertexStart.label, "Person") + self.assertEqual(vertexStart["name"], "Smith") + + self.assertEqual(edge.id, 2533274790396576) + self.assertEqual(edge.label, "workWith") + self.assertEqual(edge["weight"], 3) + self.assertEqual(edge["bigFloat"], Decimal( + "123456789123456789123456789.12345")) + + self.assertEqual(vertexEnd.id, 2251799813685424) + self.assertEqual(vertexEnd.label, "Person") + self.assertEqual(vertexEnd["name"], "Joe") + + +if __name__ == '__main__': + unittest.main() diff --git a/drivers/python/test_networkx.py b/drivers/python/test_networkx.py index 310d2cf5e..dbaaf8664 100644 --- a/drivers/python/test_networkx.py +++ b/drivers/python/test_networkx.py @@ -224,7 +224,7 @@ def test_existing_graph(self): with self.assertRaises(GraphNotFound) as context: age_to_networkx(ag.connection, graphName=non_existing_graph) # Check the raised exception has the expected error message - self.assertEqual(str(context.exception), non_existing_graph) + self.assertIn(non_existing_graph, str(context.exception)) class TestNetworkxToAGE(unittest.TestCase): diff --git a/drivers/python/test_security.py b/drivers/python/test_security.py new file mode 100644 index 000000000..55347868e --- /dev/null +++ b/drivers/python/test_security.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Security tests for the Apache AGE Python driver. + +Tests input validation, SQL injection prevention, and exception handling. +""" + +import unittest +from age.age import ( + validate_graph_name, + validate_identifier, + buildCypher, + _validate_column, +) +from age.exceptions import ( + AgeNotSet, + GraphNotFound, + GraphAlreadyExists, + GraphNotSet, + InvalidGraphName, + InvalidIdentifier, +) + + +class TestGraphNameValidation(unittest.TestCase): + """Test validate_graph_name rejects dangerous inputs.""" + + def test_rejects_empty_string(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('') + + def test_rejects_none(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name(None) + + def test_rejects_non_string(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name(123) + + def test_rejects_digit_start(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('123graph') + + def test_rejects_sql_injection_drop_table(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name("'; DROP TABLE ag_graph; --") + + def test_rejects_sql_injection_semicolon(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name("test'); DROP TABLE users; --") + + def test_rejects_sql_injection_select(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name("graph; SELECT * FROM pg_shadow") + + def test_accepts_hyphenated_graph_name(self): + # AGE allows hyphens in middle positions of graph names. + validate_graph_name('my-graph') + + def test_rejects_space(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('my graph') + + def test_accepts_dotted_graph_name(self): + # AGE allows dots in middle positions of graph names. + validate_graph_name('my.graph') + + def test_rejects_dollar(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('my$graph') + + def test_rejects_exceeding_63_chars(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('a' * 64) + + def test_accepts_valid_names(self): + # These should NOT raise + validate_graph_name('my_graph') + validate_graph_name('MyGraph') + validate_graph_name('_pr_ivate') + validate_graph_name('graph123') + validate_graph_name('my-graph') + validate_graph_name('my.graph') + validate_graph_name('a-b.c_d') + validate_graph_name('abc') + validate_graph_name('a' * 63) + + def test_rejects_shorter_than_3_chars(self): + # AGE requires minimum 3 character graph names. + with self.assertRaises(InvalidGraphName): + validate_graph_name('a') + with self.assertRaises(InvalidGraphName): + validate_graph_name('ab') + + def test_rejects_name_ending_with_hyphen(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('graph-') + + def test_rejects_name_ending_with_dot(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('graph.') + + def test_rejects_name_starting_with_hyphen(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('-graph') + + def test_rejects_name_starting_with_dot(self): + with self.assertRaises(InvalidGraphName): + validate_graph_name('.graph') + + def test_error_message_contains_name(self): + try: + validate_graph_name("bad;name") + self.fail("Expected InvalidGraphName") + except InvalidGraphName as e: + self.assertIn("bad;name", str(e)) + self.assertIn("Invalid graph name", str(e)) + + +class TestIdentifierValidation(unittest.TestCase): + """Test validate_identifier rejects dangerous inputs.""" + + def test_rejects_empty_string(self): + with self.assertRaises(InvalidIdentifier): + validate_identifier('') + + def test_rejects_none(self): + with self.assertRaises(InvalidIdentifier): + validate_identifier(None) + + def test_rejects_sql_injection(self): + with self.assertRaises(InvalidIdentifier): + validate_identifier("Person'; DROP TABLE--") + + def test_rejects_special_chars(self): + with self.assertRaises(InvalidIdentifier): + validate_identifier("col; DROP TABLE") + + def test_accepts_valid_identifiers(self): + validate_identifier('Person') + validate_identifier('KNOWS') + validate_identifier('_internal') + validate_identifier('col1') + + def test_error_includes_context(self): + try: + validate_identifier("bad;name", "Column name") + self.fail("Expected InvalidIdentifier") + except InvalidIdentifier as e: + self.assertIn("Column name", str(e)) + + +class TestColumnValidation(unittest.TestCase): + """Test _validate_column prevents injection through column specs.""" + + def test_plain_column_name(self): + self.assertEqual(_validate_column('v'), 'v agtype') + + def test_column_with_type(self): + self.assertEqual(_validate_column('n agtype'), 'n agtype') + + def test_empty_column(self): + self.assertEqual(_validate_column(''), '') + self.assertEqual(_validate_column(' '), '') + + def test_rejects_injection_in_column_name(self): + with self.assertRaises(InvalidIdentifier): + _validate_column("v); DROP TABLE ag_graph; --") + + def test_rejects_injection_in_column_type(self): + with self.assertRaises(InvalidIdentifier): + _validate_column("v agtype); DROP TABLE") + + def test_rejects_three_part_column(self): + with self.assertRaises(InvalidIdentifier): + _validate_column("a b c") + + def test_rejects_semicolon_in_name(self): + with self.assertRaises(InvalidIdentifier): + _validate_column("col;") + + +class TestBuildCypher(unittest.TestCase): + """Test buildCypher validates columns and rejects injection.""" + + def test_default_column(self): + result = buildCypher('test_graph', 'MATCH (n) RETURN n', None) + self.assertIn('v agtype', result) + + def test_single_column(self): + result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n']) + self.assertIn('n agtype', result) + + def test_typed_column(self): + result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n agtype']) + self.assertIn('n agtype', result) + + def test_multiple_columns(self): + result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['a', 'b']) + self.assertIn('a agtype', result) + self.assertIn('b agtype', result) + + def test_rejects_injection_in_column(self): + with self.assertRaises(InvalidIdentifier): + buildCypher('test_graph', 'MATCH (n) RETURN n', + ["v); DROP TABLE ag_graph;--"]) + + def test_rejects_none_graph_name(self): + with self.assertRaises(GraphNotSet): + buildCypher(None, 'MATCH (n) RETURN n', None) + + +class TestExceptionConstructors(unittest.TestCase): + """Test that exception constructors work correctly.""" + + def test_age_not_set_no_args(self): + """AgeNotSet() must work without arguments (previously crashed).""" + e = AgeNotSet() + self.assertIsNone(e.name) + self.assertIn('not set', repr(e)) + + def test_age_not_set_with_message(self): + e = AgeNotSet("custom message") + self.assertEqual(e.name, "custom message") + + def test_graph_not_found_no_args(self): + e = GraphNotFound() + self.assertIsNone(e.name) + self.assertIn('does not exist', repr(e)) + + def test_graph_not_found_with_name(self): + e = GraphNotFound("test_graph") + self.assertEqual(e.name, "test_graph") + self.assertIn('test_graph', repr(e)) + + def test_graph_already_exists_no_args(self): + e = GraphAlreadyExists() + self.assertIsNone(e.name) + self.assertIn('already exists', repr(e)) + + def test_graph_already_exists_with_name(self): + e = GraphAlreadyExists("test_graph") + self.assertEqual(e.name, "test_graph") + self.assertIn('test_graph', repr(e)) + + def test_invalid_graph_name_fields(self): + e = InvalidGraphName("bad;name", "must be valid") + self.assertEqual(e.name, "bad;name") + self.assertEqual(e.reason, "must be valid") + self.assertIn("bad;name", str(e)) + self.assertIn("must be valid", str(e)) + + def test_invalid_identifier_fields(self): + e = InvalidIdentifier("col;drop", "Column name") + self.assertEqual(e.name, "col;drop") + self.assertEqual(e.context, "Column name") + self.assertIn("col;drop", str(e)) + + +if __name__ == '__main__': + unittest.main()