diff --git a/application/cmd/cre_main.py b/application/cmd/cre_main.py index ead5a4281..2e24be47c 100644 --- a/application/cmd/cre_main.py +++ b/application/cmd/cre_main.py @@ -277,8 +277,8 @@ def register_standard( # calculate gap analysis populate_neo4j_db(db_connection_str) jobs = [] - pending_stadards = collection.standards() - for standard_name in pending_stadards: + pending_standards = collection.standards() + for standard_name in pending_standards: if standard_name == importing_name: continue @@ -293,9 +293,9 @@ def register_standard( jobs.append(forward_job) except exceptions.NoSuchJobError as nje: logger.error( - f"Could not find gap analysis job for for {importing_name} and {standard_name} putting {standard_name} back in the queue" + f"Could not find gap analysis job for {importing_name} and {standard_name} putting {standard_name} back in the queue" ) - pending_stadards.append(standard_name) + pending_standards.append(standard_name) bw_key = gap_analysis.make_resources_key([standard_name, importing_name]) if not collection.gap_analysis_exists(bw_key): @@ -308,9 +308,9 @@ def register_standard( jobs.append(backward_job) except exceptions.NoSuchJobError as nje: logger.error( - f"Could not find gap analysis job for for {importing_name} and {standard_name} putting {standard_name} back in the queue" + f"Could not find gap analysis job for {importing_name} and {standard_name} putting {standard_name} back in the queue" ) - pending_stadards.append(standard_name) + pending_standards.append(standard_name) redis.wait_for_jobs(jobs) conn.set(standard_hash, value="") diff --git a/application/database/db.py b/application/database/db.py index d4ac9b7e8..1d5fa38d4 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -981,23 +981,23 @@ def get_by_tags(self, tags: List[str]) -> List[cre_defs.Document]: cre_where_clause.append(sqla.and_(CRE.tags.like("%{}%".format(tag)))) nodes = Node.query.filter(*nodes_where_clause).all() or [] - for node in nodes: - node = self.get_nodes( - name=node.name, - section=node.section, - subsection=node.subsection, - version=node.version, - link=node.link, - ntype=node.ntype, - sectionID=node.section_id, + for db_node in nodes: + resolved = self.get_nodes( + name=db_node.name, + section=db_node.section, + subsection=db_node.subsection, + version=db_node.version, + link=db_node.link, + ntype=db_node.ntype, + sectionID=db_node.section_id, ) - if node: - documents.extend(node) + if resolved: + documents.extend(resolved) else: logger.fatal( - "db.get_node returned None for" + "get_nodes() returned no documents for " "Node %s:%s:%s that exists, BUG!" - % (node.name, node.section, node.section_id) + % (db_node.name, db_node.section, db_node.section_id) ) cres = CRE.query.filter(*cre_where_clause).all() or [] @@ -1582,7 +1582,6 @@ def add_node( ) -> Optional[Node]: if not node: raise ValueError(f"Node is None") - return None dbnode = dbNodeFromNode(node) if not dbnode: logger.warning(f"{node} could not be transformed to a DB object") diff --git a/application/tests/db_test.py b/application/tests/db_test.py index 1d13bd0be..98c5525c5 100644 --- a/application/tests/db_test.py +++ b/application/tests/db_test.py @@ -125,6 +125,26 @@ def test_get_by_tags(self) -> None: self.assertEqual(self.collection.get_by_tags([]), []) self.assertEqual(self.collection.get_by_tags(["this should not be a tag"]), []) + @patch("application.database.db.logger") + def test_get_by_tags_empty_nodes_regression(self, mock_logger) -> None: + """ + Simulate get_nodes returning an empty list to test the error path in get_by_tags. + See PR #837 regression fix. + """ + dbstandard = db.Node( + section="regression_section", + name="regression_name", + tags="regression_tag", + ntype=defs.Standard.__name__, + ) + self.collection.session.add(dbstandard) + self.collection.session.commit() + + with patch.object(self.collection, "get_nodes", return_value=[]): + res = self.collection.get_by_tags(["regression_tag"]) + self.assertEqual(res, []) + mock_logger.fatal.assert_called_once() + def test_get_standards_names(self) -> None: result = self.collection.get_node_names() expected = [("Standard", "BarStand"), ("Standard", "Unlinked")] diff --git a/application/utils/gap_analysis.py b/application/utils/gap_analysis.py index 9e3dab04d..64e547dba 100644 --- a/application/utils/gap_analysis.py +++ b/application/utils/gap_analysis.py @@ -41,9 +41,9 @@ def get_path_score(path): if step["relationship"] == "CONTAINS": penalty_type = f"CONTAINS_{get_relation_direction(step, previous_id)}" - pentalty = PENALTIES[penalty_type] - score += pentalty - step["score"] = pentalty + penalty = PENALTIES[penalty_type] + score += penalty + step["score"] = penalty previous_id = get_next_id(step, previous_id) return score diff --git a/application/web/web_main.py b/application/web/web_main.py index 29567470a..80e86faa9 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -17,7 +17,6 @@ from rq import job, exceptions from application.utils import spreadsheet_parsers -from application.utils import oscal_utils, redis from application.database import db from application.cmd import cre_main from application.defs import cre_defs as defs