Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions Scripts/generate_test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def __init__(self, trx_path, coverage_path=None):
self.file_coverage = []

def parse_trx(self):
tree = ET.parse(self.trx_path)
# Create a secure XML parser that disables external entity processing
parser = ET.XMLParser()
parser.parser.DefaultHandler = lambda data: None
parser.parser.ExternalEntityRefHandler = lambda context, base, uri, notationName: False
parser.parser.EntityDeclHandler = lambda entityName, is_parameter_entity, value, base, systemId, notationName, publicId: False

tree = ET.parse(self.trx_path, parser)
root = tree.getroot()
ns = {'t': 'http://microsoft.com/schemas/VisualStudio/TeamTest/2010'}

Expand Down Expand Up @@ -72,7 +78,9 @@ def parse_trx(self):
duration_str = result.get('duration', '0')
duration = self._parse_duration(duration_str)

test_def = root.find(f".//t:UnitTest[@id='{test_id}']/t:TestMethod", ns)
# Sanitize test_id to prevent XPath injection
sanitized_test_id = self._sanitize_xml_attribute_value(test_id)
test_def = root.find(f".//t:UnitTest[@id='{sanitized_test_id}']/t:TestMethod", ns)
class_name = test_def.get('className', '') if test_def is not None else ''

parts = class_name.split(',')[0].rsplit('.', 1)
Expand Down Expand Up @@ -113,7 +121,13 @@ def parse_coverage(self):
if not self.coverage_path or not os.path.exists(self.coverage_path):
return
try:
tree = ET.parse(self.coverage_path)
# Create a secure XML parser that disables external entity processing
parser = ET.XMLParser()
parser.parser.DefaultHandler = lambda data: None
parser.parser.ExternalEntityRefHandler = lambda context, base, uri, notationName: False
parser.parser.EntityDeclHandler = lambda entityName, is_parameter_entity, value, base, systemId, notationName, publicId: False

tree = ET.parse(self.coverage_path, parser)
root = tree.getroot()
self.coverage['lines_pct'] = float(root.get('line-rate', 0)) * 100
self.coverage['branches_pct'] = float(root.get('branch-rate', 0)) * 100
Expand Down Expand Up @@ -207,6 +221,44 @@ def _parse_condition_coverage(cond_str):
return int(m.group(2)), int(m.group(3))
return 0, 0

@staticmethod
def _sanitize_xml_attribute_value(value):
"""Sanitize XML attribute value to prevent XPath injection."""
if not value:
return ""
# Remove potentially dangerous characters that could be used in XPath injection
# Keep only alphanumeric, dash, underscore, and dot characters
sanitized = re.sub(r'[^a-zA-Z0-9\-_\.]', '', str(value))
return sanitized

@staticmethod
def _validate_output_path(output_path):
"""Validate output path to prevent directory traversal attacks."""
if not output_path:
raise ValueError("Output path cannot be empty")

# Normalize the path to resolve any .. or . components
normalized_path = os.path.normpath(output_path)

# Get absolute paths for comparison
abs_output_path = os.path.abspath(normalized_path)
current_dir = os.path.abspath(os.getcwd())

# Ensure the output file will be created in or under the current directory
if not abs_output_path.startswith(current_dir + os.sep) and abs_output_path != current_dir:
# Allow files in current directory or subdirectories only
raise ValueError(f"Invalid output path: {output_path}. Path traversal detected.")

# Additional check: ensure no directory traversal patterns
if '..' in normalized_path or normalized_path.startswith('/'):
raise ValueError(f"Invalid output path: {output_path}. Path traversal patterns detected.")

# Ensure it's an HTML file
if not normalized_path.lower().endswith('.html'):
raise ValueError(f"Output path must be an HTML file: {output_path}")

return normalized_path

@staticmethod
def _esc(text):
if text is None:
Expand All @@ -232,6 +284,9 @@ def _format_duration_display(self, seconds):
return f"{h}h {m}m"

def generate_html(self, output_path):
# Validate output path to prevent directory traversal attacks
validated_output_path = self._validate_output_path(output_path)

pass_rate = (self.results['passed'] / self.results['total'] * 100) if self.results['total'] > 0 else 0

by_file = {}
Expand All @@ -249,9 +304,9 @@ def generate_html(self, output_path):
html += self._html_scripts()
html += "</div></body></html>"

with open(output_path, 'w', encoding='utf-8') as f:
with open(validated_output_path, 'w', encoding='utf-8') as f:
f.write(html)
return output_path
return validated_output_path

def _html_head(self):
return """<!DOCTYPE html>
Expand Down
Loading