diff --git a/Scripts/generate_test_report.py b/Scripts/generate_test_report.py index 16ddfc8..effc165 100644 --- a/Scripts/generate_test_report.py +++ b/Scripts/generate_test_report.py @@ -35,7 +35,13 @@ def __init__(self, trx_path, coverage_path=None): self.file_coverage = [] def parse_trx(self): - tree = ET.parse(self.trx_path) + # Create a secure XML parser that disables external entity processing + parser = ET.XMLParser() + parser.parser.DefaultHandler = lambda data: None + parser.parser.ExternalEntityRefHandler = lambda context, base, uri, notationName: False + parser.parser.EntityDeclHandler = lambda entityName, is_parameter_entity, value, base, systemId, notationName, publicId: False + + tree = ET.parse(self.trx_path, parser) root = tree.getroot() ns = {'t': 'http://microsoft.com/schemas/VisualStudio/TeamTest/2010'} @@ -72,7 +78,9 @@ def parse_trx(self): duration_str = result.get('duration', '0') duration = self._parse_duration(duration_str) - test_def = root.find(f".//t:UnitTest[@id='{test_id}']/t:TestMethod", ns) + # Sanitize test_id to prevent XPath injection + sanitized_test_id = self._sanitize_xml_attribute_value(test_id) + test_def = root.find(f".//t:UnitTest[@id='{sanitized_test_id}']/t:TestMethod", ns) class_name = test_def.get('className', '') if test_def is not None else '' parts = class_name.split(',')[0].rsplit('.', 1) @@ -113,7 +121,13 @@ def parse_coverage(self): if not self.coverage_path or not os.path.exists(self.coverage_path): return try: - tree = ET.parse(self.coverage_path) + # Create a secure XML parser that disables external entity processing + parser = ET.XMLParser() + parser.parser.DefaultHandler = lambda data: None + parser.parser.ExternalEntityRefHandler = lambda context, base, uri, notationName: False + parser.parser.EntityDeclHandler = lambda entityName, is_parameter_entity, value, base, systemId, notationName, publicId: False + + tree = ET.parse(self.coverage_path, parser) root = tree.getroot() self.coverage['lines_pct'] = float(root.get('line-rate', 0)) * 100 self.coverage['branches_pct'] = float(root.get('branch-rate', 0)) * 100 @@ -207,6 +221,44 @@ def _parse_condition_coverage(cond_str): return int(m.group(2)), int(m.group(3)) return 0, 0 + @staticmethod + def _sanitize_xml_attribute_value(value): + """Sanitize XML attribute value to prevent XPath injection.""" + if not value: + return "" + # Remove potentially dangerous characters that could be used in XPath injection + # Keep only alphanumeric, dash, underscore, and dot characters + sanitized = re.sub(r'[^a-zA-Z0-9\-_\.]', '', str(value)) + return sanitized + + @staticmethod + def _validate_output_path(output_path): + """Validate output path to prevent directory traversal attacks.""" + if not output_path: + raise ValueError("Output path cannot be empty") + + # Normalize the path to resolve any .. or . components + normalized_path = os.path.normpath(output_path) + + # Get absolute paths for comparison + abs_output_path = os.path.abspath(normalized_path) + current_dir = os.path.abspath(os.getcwd()) + + # Ensure the output file will be created in or under the current directory + if not abs_output_path.startswith(current_dir + os.sep) and abs_output_path != current_dir: + # Allow files in current directory or subdirectories only + raise ValueError(f"Invalid output path: {output_path}. Path traversal detected.") + + # Additional check: ensure no directory traversal patterns + if '..' in normalized_path or normalized_path.startswith('/'): + raise ValueError(f"Invalid output path: {output_path}. Path traversal patterns detected.") + + # Ensure it's an HTML file + if not normalized_path.lower().endswith('.html'): + raise ValueError(f"Output path must be an HTML file: {output_path}") + + return normalized_path + @staticmethod def _esc(text): if text is None: @@ -232,6 +284,9 @@ def _format_duration_display(self, seconds): return f"{h}h {m}m" def generate_html(self, output_path): + # Validate output path to prevent directory traversal attacks + validated_output_path = self._validate_output_path(output_path) + pass_rate = (self.results['passed'] / self.results['total'] * 100) if self.results['total'] > 0 else 0 by_file = {} @@ -249,9 +304,9 @@ def generate_html(self, output_path): html += self._html_scripts() html += "" - with open(output_path, 'w', encoding='utf-8') as f: + with open(validated_output_path, 'w', encoding='utf-8') as f: f.write(html) - return output_path + return validated_output_path def _html_head(self): return """