diff --git a/.dockerignore b/.dockerignore index f17bf4d6..9e7e9fa5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,4 @@ +.coverage .git .github .gitignore @@ -16,10 +17,9 @@ netsecgame.egg-info/ notebooks/ NetSecGameAgents/ site/ -tests/ trajectories/ readme_images/ tests/ *trajectories*.json -README.md +README*.md mkdocs.yml \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 591e5f55..76852c36 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,26 +1,20 @@ # Use an official Python 3.12 runtime as a parent image -FROM python:3.12.10-slim +FROM python:3.12.12-slim-bookworm # Set the working directory in the container ENV DESTINATION_DIR=/netsecgame +WORKDIR ${DESTINATION_DIR} +# Copy the source code FIRST so pip has access to pyproject.toml +COPY . ${DESTINATION_DIR}/ -# Install system dependencies +# The "Single Layer" Trick: Install tools, build app, purge tools RUN apt-get update && \ - apt-get install -y --no-install-recommends \ - git \ - build-essential \ - && rm -rf /var/lib/apt/lists/* -RUN pip install --upgrade pip - -COPY . ${DESTINATION_DIR}/ - -# Set the working directory in the container -WORKDIR ${DESTINATION_DIR} - -# Install any necessary Python dependencies -# If a requirements.txt file is in the repository -RUN if [ -f pyproject.toml ]; then pip install .[server] ; fi + apt-get install -y --no-install-recommends build-essential && \ + pip install --no-cache-dir --upgrade pip && \ + if [ -f pyproject.toml ]; then pip install --no-cache-dir .[server] ; fi && \ + apt-get purge -y --auto-remove build-essential && \ + rm -rf /var/lib/apt/lists/* ARG GAME_MODULE="netsecgame.game.worlds.NetSecGame" # Pass the build argument to an environment variable so CMD can use it @@ -29,8 +23,8 @@ ENV ENV_GAME_MODULE=$GAME_MODULE # Expose the port the coordinator will run on EXPOSE 9000 -# Run the Python script when the container launches (with default arguments --task_config=netsecenv_conf.yaml --game_port=9000 --game_host=0.0.0.0) +# Run the Python script when the container launches ENTRYPOINT ["sh", "-c", "exec python3 -m ${ENV_GAME_MODULE} --task_config=netsecenv_conf.yaml --game_port=9000 --game_host=0.0.0.0 \"$@\"", "--"] # Default command arguments (can be overridden at runtime) -CMD ["--debug_level=INFO"] +CMD ["--debug_level=INFO"] \ No newline at end of file diff --git a/NetSecGameAgents b/NetSecGameAgents index 085692da..cb7a74f8 160000 --- a/NetSecGameAgents +++ b/NetSecGameAgents @@ -1 +1 @@ -Subproject commit 085692da0d85635bfa7343d8900f8621bbd132e9 +Subproject commit cb7a74f8c6a64ac9cbf2df0a5acfff7b73b980aa diff --git a/examples/example_task_configuration.yaml b/examples/example_task_configuration.yaml index 7522856e..e043ab95 100644 --- a/examples/example_task_configuration.yaml +++ b/examples/example_task_configuration.yaml @@ -10,7 +10,6 @@ coordinator: max_steps: 50 goal: description: "Exfiltrate data from Samba server to remote C&C server (213.47.23.195)." - is_any_part_of_goal_random: True known_networks: [] known_hosts: [] controlled_hosts: [] @@ -20,7 +19,7 @@ coordinator: start_position: # Defined starting position of the attacker known_networks: [] known_hosts: [] - controlled_hosts: [213.47.23.195, random] # + controlled_hosts: [213.47.23.195, 192.168.1.1] # known_services: {} known_data: {} known_blocks: {} @@ -28,7 +27,6 @@ coordinator: Defender: goal: description: "Block all attackers" - is_any_part_of_goal_random: False known_networks: [] known_hosts: [] controlled_hosts: [] @@ -48,7 +46,7 @@ coordinator: env: scenario: 'two_networks_tiny' # use the smallest topology for this example use_global_defender: False # Do not use global SIEM Defender - use_dynamic_addresses: False # Do not randomize IP addresses + use_dynamic_addresses: True # Do not randomize IP addresses use_firewall: True # Use firewall save_trajectories: False # Do not store trajectories required_players: 1 diff --git a/netsecgame/agents/base_agent.py b/netsecgame/agents/base_agent.py index 1578f870..95142a45 100644 --- a/netsecgame/agents/base_agent.py +++ b/netsecgame/agents/base_agent.py @@ -3,7 +3,8 @@ import logging import socket import json -from abc import ABC +from abc import ABC +from typing import Optional, Tuple, Dict, Any from netsecgame.game_components import Action, GameState, Observation, ActionType, GameStatus, AgentInfo, ProtocolConfig, AgentRole @@ -55,7 +56,7 @@ def role(self)->str: def logger(self)->logging.Logger: return self._logger - def make_step(self, action: Action) -> Observation | None: + def make_step(self, action: Action) -> Optional[Observation]: """ Executes a single step in the environment by sending the agent's action to the server and receiving the resulting observation. @@ -75,7 +76,7 @@ def make_step(self, action: Action) -> Observation | None: else: return None - def communicate(self, data:Action)-> tuple: + def communicate(self, data:Action)-> Tuple[GameStatus, Dict[str, Any], Optional[str]]: """ Exchanges data with the server and returns the server's response. This method sends an `Action` object to the server and waits for a response. @@ -102,7 +103,7 @@ def _send_data(socket, msg:str)->None: self._logger.error(f'Exception in _send_data(): {e}') raise e - def _receive_data(socket)->tuple: + def _receive_data(socket)->Tuple[GameStatus, Dict[str, Any], Optional[str]]: """ Receive data from server """ @@ -138,7 +139,7 @@ def _receive_data(socket)->tuple: _send_data(self._socket, data) return _receive_data(self._socket) - def register(self)->Observation | None: + def register(self)->Optional[Observation]: """ Method for registering agent to the game server. Classname is used as agent name and the role is based on the 'role' argument. @@ -162,19 +163,28 @@ def register(self)->Observation | None: except Exception as e: self._logger.error(f'Exception in register(): {e}') - def request_game_reset(self, request_trajectory=False, randomize_topology=True, randomize_topology_seed=None) -> Observation|None: - """ - Requests a game reset from the server. Optionally requests a trajectory and/or topology randomization. + def request_game_reset( + self, + request_trajectory: bool = False, + randomize_topology: bool = False, + seed: Optional[int] = None + ) -> Optional[Observation]: + """Request a game reset from the server. Args: - request_trajectory (bool): If True, requests the server to provide a trajectory of the last episode. - randomize_topology (bool): If True, requests the server to randomize the network topology for the next episode. Defaults to True. - randomize_topology_seed (int): If provided, requests the server to use this seed for randomizing the network topology. Defaults to None. + request_trajectory: If True, requests the server to provide a + trajectory of the last episode. + randomize_topology: If True, requests the server to randomize the + network topology for the next episode. Defaults to False. + seed: If provided, requests the server to use this seed for + randomizing the environment. Required if randomize_topology is True. Returns: - Observation: The initial observation after the reset if successful, None otherwise. + The initial observation after the reset if successful, None otherwise. """ + if seed is None and randomize_topology: + raise ValueError("Topology randomization without seed is not supported.") self._logger.debug("Requesting game reset") - status, observation_dict, message = self.communicate(Action(ActionType.ResetGame, parameters={"request_trajectory": request_trajectory, "randomize_topology": randomize_topology})) - if status: + status, observation_dict, message = self.communicate(Action(ActionType.ResetGame, parameters={"request_trajectory": request_trajectory, "randomize_topology": randomize_topology, "seed": seed})) + if status is GameStatus.RESET_DONE: self._logger.debug('\tReset successful') return Observation(GameState.from_dict(observation_dict["state"]), observation_dict["reward"], observation_dict["end"], message) else: diff --git a/netsecgame/game/config_parser.py b/netsecgame/game/config_parser.py index 9683cfea..efc1f126 100644 --- a/netsecgame/game/config_parser.py +++ b/netsecgame/game/config_parser.py @@ -3,13 +3,12 @@ # Author: Ondrej Lukas, ondrej.lukas@aic.fel.cvut.cz import yaml -# This is used so the agent can see the environment and game components -import importlib -from netsecgame.game_components import IP, Data, Network, Service import netaddr import logging from random import randint from typing import Optional +from netsecgame.game_components import IP, Data, Network, Service +from netsecgame.game.scenarios import SCENARIO_REGISTRY class ConfigParser(): """ @@ -386,25 +385,15 @@ def get_scenario(self): """ Get the scenario config objects based on the configuration. Only import objects that are selected via importlib. """ - allowed_names = { - "scenario1" : "netsecgame.game.scenarios.scenario_configuration", - "scenario1_small" : "netsecgame.game.scenarios.smaller_scenario_configuration", - "scenario1_tiny" : "netsecgame.game.scenarios.tiny_scenario_configuration", - "one_network": "netsecgame.game.scenarios.one_net", - "three_net_scenario": "netsecgame.game.scenarios.three_net_scenario", - "two_networks": "netsecgame.game.scenarios.two_nets", # same as scenario1 - "two_networks_small": "netsecgame.game.scenarios.two_nets_small", # same as scenario1_small - "two_networks_tiny": "netsecgame.game.scenarios.two_nets_tiny", # same as scenario1_small - - } scenario_name = self.config['env']['scenario'] # make sure to validate the input - if scenario_name not in allowed_names: - raise ValueError(f"Unsupported scenario: {scenario_name}") + if scenario_name not in SCENARIO_REGISTRY: + raise ValueError( + f"Unsupported scenario: {scenario_name}. " + f"Available scenarios: {list(SCENARIO_REGISTRY.keys())}" + ) - # import the correct module - module = importlib.import_module(allowed_names[scenario_name]) - return module.configuration_objects + return SCENARIO_REGISTRY[scenario_name] def get_seed(self, whom): """ @@ -419,11 +408,13 @@ def get_randomize_goal_every_episode(self, default_value: bool = False) -> bool: """ Get if the randomization should be done only once or at the beginning of every episode """ + # TODO Remove in future try: randomize_goal_every_episode = self.config["coordinator"]["agents"]["attackers"]["goal"]["is_any_part_of_goal_random"] except KeyError: # Option is not in the configuration - default to FALSE randomize_goal_every_episode = default_value + raise DeprecationWarning("This function is deprecated.") return randomize_goal_every_episode def get_use_firewall(self, default_value: bool = False)->bool: diff --git a/netsecgame/game/coordinator.py b/netsecgame/game/coordinator.py index faef03a6..5860d938 100644 --- a/netsecgame/game/coordinator.py +++ b/netsecgame/game/coordinator.py @@ -104,6 +104,7 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por # reset request per agent_addr (bool) self._reset_requests = {} self._randomize_topology_requests = {} + self._reset_seed_requests = {} self._agent_status = {} self._episode_ends = {} self._agent_observations = {} @@ -447,8 +448,14 @@ async def _process_reset_game_action(self, agent_addr: tuple, reset_action:Actio async with self._reset_lock: # add reset request for this agent self._reset_requests[agent_addr] = True + # get the seed for the reset (default None - no change to the rng) + self._reset_seed_requests[agent_addr] = reset_action.parameters.get("seed", None) + self.logger.debug(f"Agent {agent_addr} requested reset with seed {self._reset_seed_requests[agent_addr]}") + # record topology randomization request + # - ONLY consider agents that submitted seed # register if the agent wants to randomize the topology - self._randomize_topology_requests[agent_addr] = reset_action.parameters.get("randomize_topology", True) + if self._reset_seed_requests[agent_addr] is not None: + self._randomize_topology_requests[agent_addr] = reset_action.parameters.get("randomize_topology", True) if all(self._reset_requests.values()): # all agents want reset - reset the world self.logger.debug(f"All agents requested reset, setting the event") @@ -559,36 +566,37 @@ async def _assign_rewards_episode_end(self): asyncio.create_task(self.shutdown_flag.wait())], return_when=asyncio.FIRST_COMPLETED, ) - # Check if shutdown_flag was set + # Check if shutdown_flag was set if self.shutdown_flag.is_set(): self.logger.debug("\tExiting reward assignment task.") break - self.logger.info("Episode finished. Assigning final rewards to agents.") - async with self._agents_lock: - attackers = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "attacker"] - defenders = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "defender"] - successful_attack = False - # award attackers - for agent in attackers: - self.logger.debug(f"Processing reward for agent {agent}") - if self._agent_status[agent] is AgentStatus.Success: - self._agent_rewards[agent] += self._rewards["success"] - successful_attack = True - else: - self._agent_rewards[agent] += self._rewards["fail"] - - # award defenders - for agent in defenders: - self.logger.debug(f"Processing reward for agent {agent}") - if not successful_attack: - self._agent_rewards[agent] += self._rewards["success"] - self._agent_status[agent] = AgentStatus.Success - else: - self._agent_rewards[agent] += self._rewards["fail"] - self._agent_status[agent] = AgentStatus.Fail - # dicrease the reward for false positives - self.logger.debug(f"Processing false positives for agent {agent}: {self._agent_false_positives[agent]}") - self._agent_rewards[agent] += self._agent_false_positives[agent] * self._rewards["false_positive"] + if len(self.agents) > 0: + self.logger.info("Episode finished. Assigning final rewards to agents.") + async with self._agents_lock: + attackers = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "attacker"] + defenders = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "defender"] + successful_attack = False + # award attackers + for agent in attackers: + self.logger.debug(f"Processing reward for agent {agent}") + if self._agent_status[agent] is AgentStatus.Success: + self._agent_rewards[agent] += self._rewards["success"] + successful_attack = True + else: + self._agent_rewards[agent] += self._rewards["fail"] + + # award defenders + for agent in defenders: + self.logger.debug(f"Processing reward for agent {agent}") + if not successful_attack: + self._agent_rewards[agent] += self._rewards["success"] + self._agent_status[agent] = AgentStatus.Success + else: + self._agent_rewards[agent] += self._rewards["fail"] + self._agent_status[agent] = AgentStatus.Fail + # dicrease the reward for false positives + self.logger.debug(f"Processing false positives for agent {agent}: {self._agent_false_positives[agent]}") + self._agent_rewards[agent] += self._agent_false_positives[agent] * self._rewards["false_positive"] # clear the episode end event self._episode_end_event.clear() # notify all waiting agents @@ -596,6 +604,53 @@ async def _assign_rewards_episode_end(self): self._episode_rewards_condition.notify_all() self.logger.info("\tReward assignment task stopped.") + + async def _handle_invalid_reset(self, error_msg:str): + """Task that handles invalid reset""" + self.logger.error(error_msg) + for agent in self.agents: + async with self._agents_lock: + output_message_dict = { + "to_agent": agent, + "status": str(GameStatus.BAD_REQUEST), + "observation": None, + "message": {"message": error_msg}, + } + response_msg_json = convert_msg_dict_to_json(output_message_dict) + await self._agent_response_queues[agent].put(response_msg_json) + self.shutdown_flag.set() + + + async def _handle_valid_reset(self, seed: Optional[int], topology_change: Optional[bool]): + """Task that handles valid reset""" + self.logger.info(f"Resetting game to initial state with seed: {seed} and topology change: {topology_change}") + # reset the game + await self.reset(seed=seed, topology_change=topology_change) + for agent in self.agents: + if self.config_manager.get_store_trajectories(): + async with self._agents_lock: + self._store_trajectory_to_file(agent) + self.logger.debug(f"Resetting agent {agent}") + agent_role = self.agents[agent][1] + # reset the agent in the world + new_state, new_goal_state = await self.reset_agent(agent, agent_role, self._starting_positions_per_role[agent_role], self._win_conditions_per_role[agent_role]) + new_observation = Observation(new_state, 0, False, {}) + async with self._agents_lock: + self._agent_states[agent] = new_state + self._agent_goal_states[agent] = new_goal_state + self._agent_observations[agent] = new_observation + self._episode_ends[agent] = False + self._reset_requests[agent] = False + self._randomize_topology_requests.pop(agent, None) + self._reset_seed_requests.pop(agent, None) + self._agent_rewards[agent] = 0 + self._agent_steps[agent] = 0 + self._agent_false_positives[agent] = 0 + if self.agents[agent][1].lower() == "attacker": + self._agent_status[agent] = AgentStatus.PlayingWithTimeout + else: + self._agent_status[agent] = AgentStatus.Playing + async def _reset_game(self): """Task that waits for all agents to request resets""" self.logger.debug("Starting task for game reset handelling.") @@ -610,32 +665,41 @@ async def _reset_game(self): if self.shutdown_flag.is_set(): self.logger.debug("\tExiting reset_game task.") break - # wait until episode is finished by all agents - self.logger.info("Resetting game to initial state.") - await self.reset() - for agent in self.agents: - if self.config_manager.get_store_trajectories(): - async with self._agents_lock: - self._store_trajectory_to_file(agent) - self.logger.debug(f"Resetting agent {agent}") - agent_role = self.agents[agent][1] - # reset the agent in the world - new_state, new_goal_state = await self.reset_agent(agent, agent_role, self._starting_positions_per_role[agent_role], self._win_conditions_per_role[agent_role]) - new_observation = Observation(new_state, 0, False, {}) - async with self._agents_lock: - self._agent_states[agent] = new_state - self._agent_goal_states[agent] = new_goal_state - self._agent_observations[agent] = new_observation - self._episode_ends[agent] = False - self._reset_requests[agent] = False - self._randomize_topology_requests[agent] = False - self._agent_rewards[agent] = 0 - self._agent_steps[agent] = 0 - self._agent_false_positives[agent] = 0 - if self.agents[agent][1].lower() == "attacker": - self._agent_status[agent] = AgentStatus.PlayingWithTimeout - else: - self._agent_status[agent] = AgentStatus.Playing + if len(self.agents) > 0: + # verify that all agents agreed on the seed (or sent None) + valid_seeding = False + valid_topology_change = False + non_none_seeds = [seed for seed in self._reset_seed_requests.values() if seed is not None] + if len(non_none_seeds) == 0: # no agent wants to change the seed + seed = None + valid_seeding = True + elif len(set(non_none_seeds)) == 1: # all agents agree on the seed + seed = non_none_seeds[0] + valid_seeding = True + else: # agents disagree on the seed + seed = None + # verify that all agents agreed on the topology change (or sent None) + valid_seed_agents = [agent for agent in self.agents if self._reset_seed_requests[agent] is not None] + valid_topology_requests = [self._randomize_topology_requests[agent] for agent in valid_seed_agents] + if len(set(valid_topology_requests)) == 1: # all valid agents agree on the topology change + valid_topology_change = True + topology_change = valid_topology_requests[0] + else: # agents disagree on the topology change + valid_topology_change = False + topology_change = None + + if valid_seeding and valid_topology_change: + await self._handle_valid_reset(seed, topology_change) + self._reset_event.clear() + # notify all waiting agents + async with self._reset_done_condition: + self._reset_done_condition.notify_all() + elif not valid_seeding: + await self._handle_invalid_reset("Agents disagree on the seed. Undefined state. Stopping the game") + self._reset_event.clear() + elif not valid_topology_change: + await self._handle_invalid_reset("Agents disagree on the topology change. Undefined state. Stopping the game") + self._reset_event.clear() self._reset_event.clear() # notify all waiting agents async with self._reset_done_condition: @@ -700,9 +764,11 @@ async def _remove_agent_from_game(self, agent_addr): agent_info["topology_reset_request"] = self._randomize_topology_requests.pop(agent_addr, False) # remove agent from reset requests agent_info["reset_request"] = self._reset_requests.pop(agent_addr) + agent_info["reset_seed"] = self._reset_seed_requests.pop(agent_addr, None) # check if this agent was not preventing reset if all(self._reset_requests.values()): - self._reset_event.set() + if len(self.agents) > 0: + self._reset_event.set() agent_info["episode_end"] = self._episode_ends.pop(agent_addr) #check if this agent was not preventing episode end if all(self._episode_ends.values()): @@ -724,10 +790,13 @@ async def step(self, agent_id:tuple, agent_state:GameState, action:Action): """ raise NotImplementedError - async def reset(self)->bool: + async def reset(self, seed:Optional[int]=None)->bool: """ Domain specific method of the environment. Creates the initial state of the agent. Must be implemented by the domain specific environment. + + Args: + seed (int, optional): Seed for the random number generator. Defaults to None. """ raise NotImplementedError diff --git a/netsecgame/game/scenarios/__init__.py b/netsecgame/game/scenarios/__init__.py index e69de29b..459ebed7 100644 --- a/netsecgame/game/scenarios/__init__.py +++ b/netsecgame/game/scenarios/__init__.py @@ -0,0 +1,18 @@ +from . import ( + scenario_configuration, + smaller_scenario_configuration, + tiny_scenario_configuration, + one_net, + three_net_scenario, + two_nets, +) + +# Static Registry +SCENARIO_REGISTRY = { + "scenario1": scenario_configuration.configuration_objects, + "scenario1_small": smaller_scenario_configuration.configuration_objects, + "scenario1_tiny": tiny_scenario_configuration.configuration_objects, + "one_network": one_net.configuration_objects, + "three_net_scenario": three_net_scenario.configuration_objects, + "two_networks": two_nets.configuration_objects, +} \ No newline at end of file diff --git a/netsecgame/game/worlds/NetSecGame.py b/netsecgame/game/worlds/NetSecGame.py index 94b4c9d1..684b6039 100644 --- a/netsecgame/game/worlds/NetSecGame.py +++ b/netsecgame/game/worlds/NetSecGame.py @@ -9,7 +9,7 @@ import json from faker import Faker from pathlib import Path -from typing import Iterable, Any +from typing import Iterable, Any, Set, Dict, Optional from collections import defaultdict from netsecgame.game_components import GameState, Action, ActionType, IP, Network, Data, Service, AgentRole @@ -18,6 +18,18 @@ from netsecgame.utils.utils import get_logging_level +def state_parts_deep_copy(state:GameState)->tuple: + """ + Deep copies the relevant parts of the GameState. + """ + new_nets = copy.deepcopy(state.known_networks) + new_known_h = copy.deepcopy(state.known_hosts) + new_controlled_h = copy.deepcopy(state.controlled_hosts) + new_services = copy.deepcopy(state.known_services) + new_data = copy.deepcopy(state.known_data) + new_blocked = copy.deepcopy(state.known_blocks) + return new_nets, new_known_h, new_controlled_h, new_services, new_data, new_blocked + class NetSecGame(GameCoordinator): def __init__(self, game_host, game_port, task_config:str, seed=None): @@ -39,13 +51,28 @@ def __init__(self, game_host, game_port, task_config:str, seed=None): self._network_mapping = {} self._ip_mapping = {} - - np.random.seed(seed) - random.seed(seed) + # Set the random seed + self._set_random_seed(seed) + + def _set_random_seed(self, seed)->None: + """ + Sets the random seed for the environment. + + Args: + seed (int): The random seed to set. + """ self._seed = seed - self.logger.info(f'Setting env seed to {seed}') + if seed is not None: + np.random.seed(seed) + random.seed(seed) + # if faker is used, seed it too + if hasattr(self, '_faker_object'): + Faker.seed(seed) + self.logger.info(f'Setting env seed to {seed}') + else: + self.logger.warning("No seed provided, using random seed") - def _initialize(self): + def _initialize(self)->None: """ Initializes the NetSecGame environment. @@ -72,7 +99,7 @@ def _initialize(self): else: self.logger.error("CYST configuration not loaded, cannot initialize the environment!") - def _get_hosts_from_view(self, view_hosts:Iterable, allowed_hosts=None)->set[IP]: + def _get_hosts_from_view(self, view_hosts:Iterable, allowed_hosts=None)->Set[IP]: """ Parses view and translates all keywords. Produces set of controlled host (IP) Args: @@ -106,7 +133,7 @@ def _get_hosts_from_view(self, view_hosts:Iterable, allowed_hosts=None)->set[IP] self.logger.error(f"Unsupported value encountered in view_hosts: {host}") return hosts - def _get_services_from_view(self, view_known_services:dict)->dict: + def _get_services_from_view(self, view_known_services:dict)->Dict[IP, Set[Service]]: """ Parses view and translates all keywords. Produces dict of known services {IP: set(Service)} @@ -145,7 +172,7 @@ def _get_services_from_view(self, view_known_services:dict)->dict: # re-map all IPs based on current mapping in self._ip_mapping return known_services - def _get_data_from_view(self, view_known_data:dict, keyword_scope:str="host", exclude_types=["log"])->dict: + def _get_data_from_view(self, view_known_data:dict, keyword_scope:str="host", exclude_types=["log"])->Dict[IP, Set[Data]]: """ Parses view and translates all keywords. Produces dict of known data {IP: set(Data)} @@ -196,7 +223,7 @@ def _get_data_from_view(self, view_known_data:dict, keyword_scope:str="host", ex # re-map all IPs based on current mapping in self._ip_mapping return known_data - def _get_networks_from_view(self, view_known_networks:Iterable)->set[Network]: + def _get_networks_from_view(self, view_known_networks:Iterable)->Set[Network]: """ Parses view and translates all keywords. Produces set of known networks (Network). Args: @@ -477,13 +504,18 @@ def process_firewall()->dict: self.logger.info(f"\tintitial self._ip_mapping: {self._ip_mapping}") self.logger.info("CYST configuration processed successfully") - def _dynamic_ip_change(self, max_attempts:int=10)-> None: + def _dynamic_ip_change(self, max_attempts:int=10, seed=None)-> None: """ Changes the IP and network addresses in the environment + Args: + max_attempts (int, optional): Maximum number of attempts to find a valid mapping. Defaults to 10. + seed (int, optional): Seed for random number generator. Defaults to None. + Returns: + None """ self.logger.info("Changing IP and Network addresses in the environment") # find a new IP and network mapping - mapping_nets, mapping_ips = self._create_new_network_mapping(max_attempts) + mapping_nets, mapping_ips = self._create_new_network_mapping(max_attempts, seed=seed) # update ALL data structure in the environment with the new mappings @@ -600,71 +632,138 @@ def replacer(match): for ip, mapping in self._ip_mapping.items(): self._ip_mapping[ip] = mapping_ips[mapping] self.logger.debug(f"self._ip_mapping: {self._ip_mapping}") - - def _create_new_network_mapping(self, max_attempts:int=10)->tuple: - """ Method that generates random IP and Network addreses - while following the topology loaded in the environment. - All internal data structures are updated with the newly generated addresses.""" + + def _create_new_network_mapping(self, max_attempts: int = 10, seed=None) -> tuple[Dict[Network, Network], Dict[IP, IP]]: + """ + Generates new network addresses (preserving relative distance between networks) + and maps host IPs by preserving their relative offset within the subnet. + """ + #self.logger.info(f"Generating new network and IP address mapping with seed {seed} (max attempts: {max_attempts})") + + # # setup random generators + # if seed is not None: + # fake = Faker() + # fake.seed_instance(seed) + # rng = random.Random(seed) + # else: + # fake = self._faker_object + # rng = random fake = self._faker_object + rng = random + + mapping_nets = {} mapping_ips = {} - # generate mapping for networks + + # sort networks for deterministic processing (order should be deterministic in Python 3.7+ but we enforce it) + sorted_networks = sorted(self._networks.keys(), key=str) + + # generate network mappings (Preserves distance between private networks) private_nets = [] - for net in self._networks.keys(): + for net in sorted_networks: if netaddr.IPNetwork(str(net)).ip.is_private(): private_nets.append(net) else: mapping_nets[net] = Network(fake.ipv4_public(), net.mask) - # for private networks, we want to keep the distances among them - private_nets_sorted = sorted(private_nets) - valid_valid_network_mapping = False + # Private Network logic + valid_network_mapping = False counter_iter = 0 - while not valid_valid_network_mapping: + + while not valid_network_mapping: try: - # find the new lowest networks - new_base = netaddr.IPNetwork(f"{fake.ipv4_private()}/{private_nets_sorted[0].mask}") - # store its new mapping - mapping_nets[private_nets[0]] = Network(str(new_base.network), private_nets_sorted[0].mask) - base = netaddr.IPNetwork(str(private_nets_sorted[0])) - is_private_net_checks = [] - for i in range(1,len(private_nets_sorted)): - current = netaddr.IPNetwork(str(private_nets_sorted[i])) - # find the distance before mapping - diff_ip = current.ip - base.ip - # find the new mapping - new_net_addr = netaddr.IPNetwork(str(mapping_nets[private_nets_sorted[0]])).ip + diff_ip - # evaluate if its still a private network - is_private_net_checks.append(new_net_addr.is_private()) - # store the new mapping - mapping_nets[private_nets_sorted[i]] = Network(str(new_net_addr), private_nets_sorted[i].mask) - if False not in is_private_net_checks: # verify that ALL new networks are still in the private ranges - valid_valid_network_mapping = True - except IndexError as e: - self.logger.info(f"Dynamic address sampling failed, re-trying. {e}") - counter_iter +=1 + # Pick a random start for the first private network + new_base = netaddr.IPNetwork(f"{fake.ipv4_private()}/{private_nets[0].mask}") + mapping_nets[private_nets[0]] = Network(str(new_base.network), private_nets[0].mask) + + base_orig = netaddr.IPNetwork(str(private_nets[0])) + checks = [] + + for i in range(1, len(private_nets)): + current_orig = netaddr.IPNetwork(str(private_nets[i])) + # Calculate distance between Network A and Network B + diff = current_orig.ip - base_orig.ip + + # Apply distance to new base + new_net_ip = netaddr.IPNetwork(str(mapping_nets[private_nets[0]])).ip + diff + + checks.append(new_net_ip.is_private()) + mapping_nets[private_nets[i]] = Network(str(new_net_ip), private_nets[i].mask) + + if all(checks): + valid_network_mapping = True + except IndexError: + counter_iter += 1 if counter_iter > max_attempts: - self.logger.error(f"Dynamic address failed more than {max_attempts} times - stopping.") + self.logger.error(f"Failed to generate valid network mapping in {max_attempts} attempts - exiting.") exit(-1) - # Invalid IP address boundary - self.logger.info(f"New network mapping:{mapping_nets}") - - # genereate mapping for ips: - for net,ips in self._networks.items(): - ip_list = list(netaddr.IPNetwork(str(mapping_nets[net])))[1:] - # remove broadcast and network ip from the list - random.shuffle(ip_list) - for i,ip in enumerate(ips): - mapping_ips[ip] = IP(str(ip_list[i])) - # Always add keywords 'random' and 'all_local' 'all_attackers' to the mapping - mapping_ips['random'] = 'random' - mapping_ips['all_local'] = 'all_local' - mapping_ips['all_attackers'] = 'all_attackers' - - self.logger.info(f"Mapping IPs done:{mapping_ips}") + + self.logger.info(f"New network mapping: {mapping_nets}") + + # 4. MAP IPS (Preserves distance/offset within subnet) + for net in sorted_networks: + if net not in mapping_nets: continue + + orig_net_obj = netaddr.IPNetwork(str(net)) + new_net_obj = netaddr.IPNetwork(str(mapping_nets[net])) + + # Prepare fallback pool (deterministic shuffle) just in case an offset fails + # We exclude .0 and .255 explicitly from the list + fallback_pool = list(new_net_obj)[1:-1] + rng.shuffle(fallback_pool) + + # Sort hosts for deterministic processing order + hosts = self._networks[net] + sorted_hosts = sorted(hosts, key=lambda x: repr(x)) + + for host in sorted_hosts: + try: + old_host_ip = netaddr.IPAddress(str(host)) + + # Calculate Offset: (Host IP) - (Network Address) + # e.g. 192.168.1.55 - 192.168.1.0 = 55 + offset = old_host_ip - orig_net_obj.network + + # Apply Offset to New Network + # e.g. 10.0.0.0 + 55 = 10.0.0.55 + new_host_ip = new_net_obj.network + offset + + # Verify validity: + # 1. Must be inside the new subnet (cidr check) + # 2. Must not be the Network Address (.0) or Broadcast (.255) + if (new_host_ip in new_net_obj and + new_host_ip != new_net_obj.network and + new_host_ip != new_net_obj.broadcast): + + mapping_ips[host] = IP(str(new_host_ip)) + + # Optimization: If this IP happens to be in our fallback pool, + # remove it so fallback logic doesn't re-assign it later. + # (Checking efficient sets is faster, but list remove is safe here for small subnets) + if new_host_ip in fallback_pool: + fallback_pool.remove(new_host_ip) + else: + raise ValueError("Offset calculated invalid IP") + + except (ValueError, TypeError, netaddr.AddrFormatError): + # Fallback Strategy: Assign next available random IP from the pool + # This handles edge cases or weird topology mismatches gracefully + if fallback_pool: + safe_ip = fallback_pool.pop(0) # Take first available from shuffled pool + mapping_ips[host] = IP(str(safe_ip)) + self.logger.warning(f"Offset failed for {host}, assigned fallback {safe_ip}") + else: + self.logger.error(f"Subnet exhausted for {net}") + + # Static mappings + mapping_ips['random'] = 'random' + mapping_ips['all_local'] = 'all_local' + mapping_ips['all_attackers'] = 'all_attackers' + + self.logger.info(f"Mapping IPs done: {mapping_ips}") return mapping_nets, mapping_ips - def _get_services_from_host(self, host_ip:str, controlled_hosts:set)-> set: + def _get_services_from_host(self, host_ip:str, controlled_hosts:set)-> Set[Service]: """ Returns set of Service tuples from given hostIP """ @@ -681,7 +780,7 @@ def _get_services_from_host(self, host_ip:str, controlled_hosts:set)-> set: self.logger.debug("\tServices not found because target IP does not exists.") return found_services - def _get_networks_from_host(self, host_ip)->set: + def _get_networks_from_host(self, host_ip)->Set[Network]: """ Returns set of IPs the host has access to """ @@ -691,7 +790,7 @@ def _get_networks_from_host(self, host_ip)->set: networks.add(net) return networks - def _get_data_in_host(self, host_ip:str, controlled_hosts:set)->set: + def _get_data_in_host(self, host_ip:str, controlled_hosts:set)->Set[Data]: """ Returns set of Data tuples from given host IP Check if the host is in the list of controlled hosts @@ -775,15 +874,6 @@ def _record_false_positive(self, src_ip:IP, dst_ip:IP, agent_id:tuple)-> None: else: self.logger.debug(f"False positive for blocking {src_host} -> {dst_host} caused by the system configuration.") - def _state_parts_deep_copy(self, current:GameState)->tuple: - next_nets = copy.deepcopy(current.known_networks) - next_known_h = copy.deepcopy(current.known_hosts) - next_controlled_h = copy.deepcopy(current.controlled_hosts) - next_services = copy.deepcopy(current.known_services) - next_data = copy.deepcopy(current.known_data) - next_blocked = copy.deepcopy(current.known_blocks) - return next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked - def _firewall_check(self, src_ip:IP, dst_ip:IP)->bool: """Checks if firewall allows connection from 'src_ip to ''dst_ip'""" try: @@ -796,7 +886,7 @@ def _execute_scan_network_action(self, current_state:GameState, action:Action, a """ Executes the ScanNetwork action in the environment """ - next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) + next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = state_parts_deep_copy(current_state) self.logger.debug(f"\t\tScanning {action.parameters['target_network']}") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts: new_ips = set() @@ -819,7 +909,7 @@ def _execute_find_services_action(self, current_state:GameState, action:Action, """ Executes the FindServices action in the environment """ - next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) + next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = state_parts_deep_copy(current_state) self.logger.debug(f"\t\tSearching for services in {action.parameters['target_host']}") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts: if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']): @@ -846,7 +936,7 @@ def _execute_find_data_action(self, current:GameState, action:Action, agent_id:t """ Executes the FindData action in the environment """ - next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current) + next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = state_parts_deep_copy(current) self.logger.debug(f"\t\tSearching for data in {action.parameters['target_host']}") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current.controlled_hosts: if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']): @@ -877,7 +967,7 @@ def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action """ Executes the ExfiltrateData action in the environment """ - next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) + next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = state_parts_deep_copy(current_state) self.logger.info(f"\t\tAttempting to Exfiltrate {action.parameters['data']} from {action.parameters['source_host']} to {action.parameters['target_host']}") # Is the target host controlled? if action.parameters["target_host"] in current_state.controlled_hosts: @@ -924,7 +1014,7 @@ def _execute_exploit_service_action(self, current_state:GameState, action:Action """ Executes the ExploitService action in the environment """ - next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) + next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = state_parts_deep_copy(current_state) # We don't check if the target is a known_host because it can be a blind attempt to attack self.logger.info(f"\t\tAttempting to ExploitService in '{action.parameters['target_host']}':'{action.parameters['target_service']}'") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts: @@ -974,7 +1064,7 @@ def _execute_block_ip_action(self, current_state:GameState, action:Action, agent - Add the rule to the FW list - Update the state """ - next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) + next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = state_parts_deep_copy(current_state) # Is the src in the controlled hosts? if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts: # Is the target in the controlled hosts? @@ -1033,7 +1123,7 @@ def _execute_block_ip_action(self, current_state:GameState, action:Action, agent self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _get_all_local_ips(self)->set: + def _get_all_local_ips(self)->Set[IP]: local_ips = set() for net, ips in self._networks.items(): if netaddr.IPNetwork(str(net)).ip.is_private(): @@ -1042,7 +1132,14 @@ def _get_all_local_ips(self)->set: self.logger.info(f"\t\t\tLocal ips: {local_ips}") return local_ips - def update_log_file(self, known_data:set, action, target_host:IP): + def update_log_file(self, known_data:set, action, target_host:IP)->None: + """ + Updates the log file in the target host. + Args: + known_data (set): Set of known data. + action (Action): Action to be recorded. + target_host (IP): Target host. + """ hostaname = self._ip_to_hostname[target_host] self.logger.debug(f"Updating log file in host {hostaname}") try: @@ -1077,21 +1174,19 @@ async def reset_agent(self, agent_id, agent_role:AgentRole, agent_initial_view:d goal_state = self._create_goal_state_from_view(agent_win_condition_view) return game_state, goal_state - async def reset(self)->bool: + async def reset(self, seed:Optional[int]=None, topology_change:Optional[bool]=None)->bool: """ Function to reset the state of the game and prepare for a new episode """ # write all steps in the episode replay buffer in the file self.logger.info('--- Reseting NSG Environment to its initial state ---') - # change IPs if needed - # This is done ONLY if it is (i) enabled in the task config and (ii) all agents requested it - if self.config_manager.get_use_dynamic_ips(): - if all(self._randomize_topology_requests.values()): - self.logger.info("All agents requested reset with randomized topology.") - self._dynamic_ip_change() - else: - self.logger.info("Not all agents requested a topology randomization. Keeping the current one.") + if seed is not None: + self._set_random_seed(seed) + + if self.config_manager.get_use_dynamic_ips(): #topology change is allowed + if topology_change: # agents agree on topology change + self._dynamic_ip_change(seed=seed) # reset self._data to orignal state self._data = copy.deepcopy(self._data_original) # reset self._data_content to orignal state diff --git a/netsecgame/game_components.py b/netsecgame/game_components.py index 2dd9eddd..b0e0a661 100755 --- a/netsecgame/game_components.py +++ b/netsecgame/game_components.py @@ -406,6 +406,10 @@ def as_dict(self) -> Dict[str, Any]: params[k] = v elif isinstance(v, bool): # Handle boolean values params[k] = v + elif isinstance(v, int): # Handle integer values + params[k] = v + elif v is None: + params[k] = None else: params[k] = str(v) return {"action_type": str(self.action_type), "parameters": params} @@ -463,6 +467,13 @@ def from_dict(cls, data_dict: Dict[str, Any]) -> Action: params[k] = v else: params[k] = ast.literal_eval(v) + case "seed": + if isinstance(v, int): + params[k] = v + elif v is None or v == "None": + params[k] = None + else: + raise ValueError(f"Unsupported value in {k}: {v}") case _: raise ValueError(f"Unsupported value in {k}: {v}") return cls(action_type=action_type, parameters=params) @@ -920,12 +931,4 @@ class ProtocolConfig: BUFFER_SIZE (int): Buffer size for messages. """ END_OF_MESSAGE: bytes = b"EOF" - BUFFER_SIZE: int = 8192 - -if __name__ == "__main__": - role_str = AgentRole.Attacker.to_string() - role = AgentRole.from_string(role_str) - action = Action(ActionType.JoinGame, parameters={"agent_info": {"role": role, "name": "TestAgent"}}) - print(action) - print(action.to_json()) - print(action.from_json(action.to_json())) \ No newline at end of file + BUFFER_SIZE: int = 8192 \ No newline at end of file diff --git a/netsecgame/utils/utils.py b/netsecgame/utils/utils.py index e0190eee..f4cd3f48 100644 --- a/netsecgame/utils/utils.py +++ b/netsecgame/utils/utils.py @@ -7,7 +7,7 @@ import json import logging import os -from typing import Optional +from typing import Optional, Set, List, Dict, Any, Tuple # --- Third-Party Imports --- import jsonlines @@ -55,7 +55,7 @@ def get_str_hash(string, hash_func='sha256'): hash_algorithm.update(string.encode('utf-8')) return hash_algorithm.hexdigest() -def read_replay_buffer_from_csv(csvfile:str)->list: +def read_replay_buffer_from_csv(csvfile:str)->List[Tuple[GameState, Action, float, GameState, bool]]: """ Function to read steps from a CSV file and restore the objects in the replay buffer. @@ -104,7 +104,7 @@ def state_as_ordered_string(state:GameState)->str: ret += "}" return ret -def observation_as_dict(observation: Observation) -> dict: +def observation_as_dict(observation: Observation) -> Dict[str, Any]: """ Generates dict representation of a given Observation object. Acts as the single source of truth for the structure. @@ -130,7 +130,7 @@ def observation_to_str(observation: Observation) -> str: logging.getLogger(__name__).error(f"Error in encoding observation '{observation}' to JSON string: {e}") raise e -def observation_from_dict(data: dict) -> Observation: +def observation_from_dict(data: Dict[str, Any]) -> Observation: """ Reconstructs an Observation object from a dictionary representation. @@ -240,13 +240,13 @@ def read_trajectories_from_jsonl(filepath:str)->list: """ raise NotImplementedError("This function is not yet implemented.") -def generate_valid_actions(state: GameState, include_blocks=False)->list: +def generate_valid_actions(state: GameState, include_blocks=False)->Set[Action]: """Function that generates a list of all valid actions in a given GameState Args: state (GameState): The current game state. include_blocks (bool): Whether to include BlockIP actions. Defaults to False. Returns: - list: A list of valid Action objects. + set: A set of valid Action objects. """ valid_actions = set() def is_fw_blocked(state, src_ip, dst_ip)->bool: @@ -293,7 +293,7 @@ def is_fw_blocked(state, src_ip, dst_ip)->bool: if not is_fw_blocked(state, source_host,target_host): for blocked_ip in state.known_hosts: valid_actions.add(Action(ActionType.BlockIP, {"target_host":target_host, "source_host":source_host, "blocked_host":blocked_ip})) - return list(valid_actions) + return valid_actions if __name__ == "__main__": state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py new file mode 100644 index 00000000..3543bff7 --- /dev/null +++ b/tests/agents/test_base_agent.py @@ -0,0 +1,175 @@ +import pytest +import json +import socket +from unittest.mock import patch, MagicMock + +from netsecgame.agents.base_agent import BaseAgent +from netsecgame.game_components import Action, ActionType, GameStatus, Observation, GameState, AgentRole, ProtocolConfig + +class TestAgent(BaseAgent): + """A concrete implementation of BaseAgent for testing.""" + def __init__(self, host, port, role): + super().__init__(host, port, role) + +@pytest.fixture +def mock_socket(): + with patch('socket.socket') as mock_sock_class: + mock_sock_instance = MagicMock() + mock_sock_class.return_value = mock_sock_instance + yield mock_sock_instance + +@pytest.fixture +def agent(mock_socket): + return TestAgent('localhost', 5000, AgentRole.Attacker) + +def test_initialization_success(mock_socket): + agent = TestAgent('localhost', 5000, AgentRole.Attacker) + assert agent._connection_details == ('localhost', 5000) + assert agent.role == AgentRole.Attacker + assert agent.socket == mock_socket + mock_socket.connect.assert_called_once_with(('localhost', 5000)) + +def test_initialization_failure(): + with patch('socket.socket') as mock_sock_class: + mock_sock_instance = MagicMock() + mock_sock_instance.connect.side_effect = socket.error("Connection refused") + mock_sock_class.return_value = mock_sock_instance + + agent = TestAgent('localhost', 5000, AgentRole.Attacker) + assert getattr(agent, "sock", None) is None + +def test_terminate_connection(agent, mock_socket): + assert agent.socket is not None + agent.terminate_connection() + mock_socket.close.assert_called_once() + assert agent.socket is None + +def test_del_closes_connection(mock_socket): + agent = TestAgent('localhost', 5000, AgentRole.Attacker) + agent.__del__() + mock_socket.close.assert_called_once() + +@patch('netsecgame.agents.base_agent.GameState.from_dict') +def test_communicate_success(mock_from_dict, agent, mock_socket): + action = Action(ActionType.JoinGame, parameters={}) + + # Mock response from server + response_data = { + "status": "GameStatus.CREATED", + "observation": {"state": {}, "reward": 0, "end": False, "info": {}}, + "message": "Success" + } + encoded_response = json.dumps(response_data).encode() + ProtocolConfig.END_OF_MESSAGE + mock_socket.recv.side_effect = [encoded_response] + + status, observation, message = agent.communicate(action) + + # Verify sending + mock_socket.sendall.assert_called_once() + sent_data = mock_socket.sendall.call_args[0][0] + assert sent_data == action.to_json().encode() + + # Verify receiving and parsing + assert status == GameStatus.CREATED + assert observation["reward"] == 0 + assert observation["end"] is False + assert message == "Success" + +def test_communicate_invalid_action(agent): + with pytest.raises(ValueError): + agent.communicate("not_an_action") + +def test_communicate_incomplete_response(agent, mock_socket): + action = Action(ActionType.JoinGame, parameters={}) + # Response without EOF marker + mock_socket.recv.side_effect = [b"incomplete data", b""] + + with pytest.raises(ConnectionError, match="Unfinished connection."): + agent.communicate(action) + +@patch('netsecgame.agents.base_agent.GameState.from_dict') +def test_register_success(mock_from_dict, agent): + mock_state = MagicMock(spec=GameState) + mock_from_dict.return_value = mock_state + + observation_dict = { + "state": {}, + "reward": 0, + "end": False, + "info": {} + } + with patch.object(agent, 'communicate', return_value=(GameStatus.CREATED, observation_dict, "Registered")) as mock_communicate: + observation = agent.register() + + mock_communicate.assert_called_once() + action_sent = mock_communicate.call_args[0][0] + assert action_sent.action_type == ActionType.JoinGame + assert action_sent.parameters["agent_info"].name == "TestAgent" + assert action_sent.parameters["agent_info"].role == AgentRole.Attacker.value + + assert isinstance(observation, Observation) + assert observation.reward == 0 + assert observation.end is False + +def test_register_failure(agent): + with patch.object(agent, 'communicate', return_value=(GameStatus.BAD_REQUEST, {}, "Failed")): + observation = agent.register() + assert observation is None + +@patch('netsecgame.agents.base_agent.GameState.from_dict') +def test_make_step_success(mock_from_dict, agent): + mock_state = MagicMock(spec=GameState) + mock_from_dict.return_value = mock_state + + action = Action(ActionType.ScanNetwork, parameters={}) + observation_dict = { + "state": {}, + "reward": 10, + "end": True, + "info": {"msg": "found"} + } + with patch.object(agent, 'communicate', return_value=(GameStatus.OK, observation_dict, "Step ok")): + observation = agent.make_step(action) + + assert isinstance(observation, Observation) + assert observation.reward == 10 + assert observation.end is True + assert observation.info == {"msg": "found"} + +def test_make_step_failure(agent): + action = Action(ActionType.ScanNetwork, parameters={}) + with patch.object(agent, 'communicate', return_value=(GameStatus.BAD_REQUEST, {}, "Step failed")): + observation = agent.make_step(action) + assert observation is None + +@patch('netsecgame.agents.base_agent.GameState.from_dict') +def test_request_game_reset_success(mock_from_dict, agent): + mock_state = MagicMock(spec=GameState) + mock_from_dict.return_value = mock_state + + observation_dict = { + "state": {}, + "reward": 0, + "end": False, + "info": {} + } + with patch.object(agent, 'communicate', return_value=(GameStatus.RESET_DONE, observation_dict, "Reset ok")) as mock_communicate: + observation = agent.request_game_reset(request_trajectory=True, randomize_topology=False, seed=42) + + mock_communicate.assert_called_once() + action_sent = mock_communicate.call_args[0][0] + assert action_sent.action_type == ActionType.ResetGame + assert action_sent.parameters["request_trajectory"] is True + assert action_sent.parameters["randomize_topology"] is False + assert action_sent.parameters["seed"] == 42 + + assert isinstance(observation, Observation) + +def test_request_game_reset_failure(agent): + with patch.object(agent, 'communicate', return_value=(None, {}, "Reset failed")): + observation = agent.request_game_reset() + assert observation is None + +def test_request_game_reset_missing_seed(agent): + with pytest.raises(ValueError, match="Topology randomization without seed is not supported."): + agent.request_game_reset(randomize_topology=True) diff --git a/tests/game/test_agent_server.py b/tests/game/test_agent_server.py index 233e3d3d..130ee947 100644 --- a/tests/game/test_agent_server.py +++ b/tests/game/test_agent_server.py @@ -3,8 +3,8 @@ import pytest from unittest.mock import AsyncMock, MagicMock from contextlib import suppress -from netsecgame.game.coordinator import AgentServer from netsecgame.game_components import Action, ActionType, ProtocolConfig +from netsecgame.game.agent_server import AgentServer # ----------------------- # Fixtures diff --git a/tests/game/test_config_parser.py b/tests/game/test_config_parser.py new file mode 100644 index 00000000..cd77631d --- /dev/null +++ b/tests/game/test_config_parser.py @@ -0,0 +1,282 @@ +import pytest +from unittest.mock import patch, mock_open, MagicMock + +# 1. Define the mocks you need for THIS file +MOCK_MODULES = { + 'aiohttp': MagicMock(), + 'cyst': MagicMock(), + 'cyst.api': MagicMock(), + 'cyst.api.environment': MagicMock(), + 'cyst.api.environment.environment': MagicMock(), + 'faker': MagicMock() +} + +# 2. Use a fixture to safely inject and clean up the mocks +@pytest.fixture(scope="module", autouse=True) +def isolate_mocks(): + """ + Safely injects mocks into sys.modules only for the duration of this module. + Once the tests in this file finish, patch.dict automatically restores the original sys.modules. + """ + with patch.dict('sys.modules', MOCK_MODULES): + yield # The tests run here + +# 3. Standard imports +# Because you implemented the Lazy Registry earlier, importing ConfigParser +# here is safe and won't prematurely trigger real 'cyst' imports. +from netsecgame.game.config_parser import ConfigParser +from netsecgame.game.scenarios import SCENARIO_REGISTRY +from netsecgame.game_components import IP, Data, Network, Service + +# --- Mock Configurations --- + +VALID_CONFIG = { + "env": { + "actions": { + "test_action": {"prob_success": 0.5} + }, + "rewards": { + "step": -1, + "success": 100 + }, + "use_dynamic_addresses": True, + "save_trajectories": True, + "scenario": "scenario1", + "use_firewall": True, + "use_global_defender": True, + "required_players": 2, + }, + "coordinator": { + "agents": { + "Attacker": { + "max_steps": 50, + "goal": { + "description": "Compromise host 10.0.0.1", + "known_networks": ["10.0.0.0/24"], + "known_hosts": ["10.0.0.1", "random", "all_local"], + "controlled_hosts": ["10.0.0.2"], + "known_services": { + "10.0.0.1": [["ssh", "tcp", "22", True], "random"] + }, + "known_blocks": { + "10.0.0.1": ["10.0.0.2"], + "10.0.0.3": "all_attackers" + }, + "known_data": { + "10.0.0.1": [["Admin", "Password"], "random"] + } + }, + "start_position": { + "known_networks": [], + "known_hosts": [], + "controlled_hosts": ["random"], + "known_services": {}, + "known_data": {} + } + }, + "Defender": { + "max_steps": {}, # to trigger TypeError fallback + "goal": { + # empty goal for defaults + }, + "start_position": { + "known_networks": [], + "known_hosts": [], + "controlled_hosts": ["192.168.1.1"], + "known_services": {}, + "known_data": {} + } + } + } + }, + "random_entity": { + "random_seed": 42 + }, + "random_entity_str": { + "random_seed": "random" + } +} + +@pytest.fixture +def parser(): + return ConfigParser(config_dict=VALID_CONFIG) + +@pytest.fixture +def empty_parser(): + return ConfigParser(config_dict={"empty_dummy": True}) + +# --- Tests --- + +def test_initialization(): + # Test valid dict + cp = ConfigParser(config_dict={"key": "value"}) + assert cp.config == {"key": "value"} + + # Test missing both file and dict raises error via log, but creates object + with patch('logging.Logger.error') as mock_log: + cp = ConfigParser() + mock_log.assert_called_once_with("You must provide either the configuration file or a dictionary with the configuration!") + + # Test file reading + mock_yaml = "key: value\n" + with patch('builtins.open', mock_open(read_data=mock_yaml)): + cp = ConfigParser(task_config_file="dummy.yaml") + assert cp.config == {"key": "value"} + + # Test file reading error + with patch('builtins.open', mock_open()) as mocked_file, patch('logging.Logger.error') as mock_log: + mocked_file.side_effect = IOError("File not found") + cp = ConfigParser(task_config_file="dummy.yaml") + mock_log.assert_called_once() + assert not hasattr(cp, 'config') + +def test_read_env_action_data(parser, empty_parser): + assert parser.read_env_action_data("test_action") == 0.5 + assert parser.read_env_action_data("unknown_action") == 1 + assert empty_parser.read_env_action_data("test_action") == 1 + +def test_get_simple_values(parser, empty_parser): + # Firewall + assert parser.get_use_firewall() is True + assert empty_parser.get_use_firewall(default_value=False) is False + + # Dynamic Addresses + assert parser.get_use_dynamic_addresses() is True + assert empty_parser.get_use_dynamic_addresses(default_value=False) is False + + # Global Defender + assert parser.get_use_global_defender() is True + assert empty_parser.get_use_global_defender(default_value=False) is False + + # Required Num Players + assert parser.get_required_num_players() == 2 + assert empty_parser.get_required_num_players(default_value=1) == 1 + + # Store trajectories + assert parser.get_store_trajectories() is True + assert empty_parser.get_store_trajectories(default_value=False) is False + +def test_get_rewards(parser, empty_parser): + rewards = parser.get_rewards(["step", "success", "fail"], default_value=0) + assert rewards["step"] == -1 + assert rewards["success"] == 100 + assert rewards["fail"] == 0 # Default fallback + + empty_rewards = empty_parser.get_rewards(["step"], default_value=5) + assert empty_rewards["step"] == 5 + +def test_get_max_steps(parser): + assert parser.get_max_steps("Attacker") == 50 + assert parser.get_max_steps("Defender") is None # Triggered TypeError handling + assert parser.get_max_steps("Unknown") is None # Triggered KeyError handling + +def test_get_goal_description(parser): + assert parser.get_goal_description("Attacker") == "Compromise host 10.0.0.1" + assert parser.get_goal_description("Defender") == "" + assert parser.get_goal_description("Benign") == "" + + with pytest.raises(ValueError, match="Unsupported agent role"): + parser.get_goal_description("UnknownRole") + +def test_validate_goal_description(parser): + with patch('logging.Logger.warning') as mock_warn: + # 10.0.0.2 is in controlled_hosts but missing from the desc text + parser.validate_goal_description("Attacker", "Compromise host 10.0.0.1") + mock_warn.assert_called_once() + assert "Controlled Host: 10.0.0.2" in mock_warn.call_args[0][0] + +def test_read_agents_known_networks(parser): + networks = parser.read_agents_known_networks("Attacker", "goal") + assert len(networks) == 1 + net = list(networks)[0] + assert isinstance(net, Network) + assert net.ip == "10.0.0.0" + assert net.mask == 24 + +def test_read_agents_known_hosts(parser): + hosts = parser.read_agents_known_hosts("Attacker", "goal") + assert len(hosts) == 3 + assert IP("10.0.0.1") in hosts + assert "random" in hosts + assert "all_local" in hosts + +def test_read_agents_controlled_hosts(parser): + hosts = parser.read_agents_controlled_hosts("Attacker", "goal") + assert len(hosts) == 1 + assert IP("10.0.0.2") in hosts + +def test_read_agents_known_services(parser): + services = parser.read_agents_known_services("Attacker", "goal") + assert IP("10.0.0.1") in services + srv_list = services[IP("10.0.0.1")] + assert len(srv_list) == 2 + assert isinstance(srv_list[0], Service) + assert srv_list[0].name == "ssh" + assert srv_list[0].type == "tcp" + assert srv_list[0].version == "22" + assert srv_list[0].is_local is True + assert srv_list[1] == "random" + +def test_read_agents_known_blocks(parser): + blocks = parser.read_agents_known_blocks("Attacker", "goal") + assert IP("10.0.0.1") in blocks + assert list(blocks[IP("10.0.0.1")]) == [IP("10.0.0.2")] # it stores map iterator, resolve to list for assertion + assert blocks[IP("10.0.0.3")] == "all_attackers" + +def test_read_agents_known_data(parser): + data_dict = parser.read_agents_known_data("Attacker", "goal") + assert IP("10.0.0.1") in data_dict + data_set = data_dict[IP("10.0.0.1")] + + assert len(data_set) == 2 + assert "random" in data_set + + # Find the Data object + data_obj = next(d for d in data_set if isinstance(d, Data)) + assert data_obj.owner == "Admin" + assert data_obj.id == "Password" + +def test_get_start_position(parser): + pos = parser.get_start_position("Attacker") + assert "random" in pos["controlled_hosts"] + assert len(pos["known_networks"]) == 0 + + assert parser.get_start_position("Defender")["controlled_hosts"].pop() == IP("192.168.1.1") + + benign_pos = parser.get_start_position("Benign") + assert benign_pos["controlled_hosts"] == ["random", "random", "random"] + + with pytest.raises(ValueError): + parser.get_start_position("Unknown") + +def test_get_win_conditions(parser): + win = parser.get_win_conditions("Attacker") + assert "random" in win["known_hosts"] + + benign_win = parser.get_win_conditions("Benign") + assert len(benign_win["known_networks"]) == 0 + assert IP("1.1.1.1") in benign_win["known_data"] + + with pytest.raises(ValueError): + parser.get_win_conditions("Unknown") + +def test_get_scenario(parser): + test_scenario_name = "scenario1" + parser.config = {'env': {'scenario': test_scenario_name}} + result = parser.get_scenario() + + assert result is not None + assert result == SCENARIO_REGISTRY[test_scenario_name] + +def test_get_scenario_invalid(empty_parser): + empty_parser.config = {"env": {"scenario": "unsupported_scenario"}} + with pytest.raises(ValueError, match="Unsupported scenario"): + empty_parser.get_scenario() + +def test_get_seed(parser): + assert parser.get_seed("random_entity") == 42 + + # Assuming randint won't return exactly -1 unless we mock it, we just check it is an int + seed = parser.get_seed("random_entity_str") + assert isinstance(seed, int) + assert 0 <= seed <= 100 diff --git a/tests/game/test_configuration_manager.py b/tests/game/test_configuration_manager.py new file mode 100644 index 00000000..aef85d09 --- /dev/null +++ b/tests/game/test_configuration_manager.py @@ -0,0 +1,187 @@ +import pytest +import sys +from unittest.mock import patch, MagicMock, AsyncMock + +# Mock out dependencies that might not be installed in the test environment +sys.modules['aiohttp'] = MagicMock() +sys.modules['cyst'] = MagicMock() +sys.modules['cyst.api'] = MagicMock() +sys.modules['cyst.api.environment'] = MagicMock() +sys.modules['cyst.api.environment.environment'] = MagicMock() +sys.modules['faker'] = MagicMock() + +from netsecgame.game.configuration_manager import ConfigurationManager +from netsecgame.game_components import AgentRole + +@pytest.fixture +def manager_local(): + return ConfigurationManager(task_config_file="dummy.yaml") + +@pytest.fixture +def manager_remote(): + return ConfigurationManager(service_host="localhost", service_port=8080) + +@pytest.fixture +def manager_both(): + return ConfigurationManager(task_config_file="dummy.yaml", service_host="localhost", service_port=8080) + +@pytest.fixture +def manager_none(): + return ConfigurationManager() + +def test_initialization(): + cm = ConfigurationManager(task_config_file="dummy.yaml", service_host="localhost", service_port=8080) + assert cm._task_config_file == "dummy.yaml" + assert cm._service_host == "localhost" + assert cm._service_port == 8080 + assert cm._parser is None + assert cm._cyst_objects is None + +import asyncio + +def test_load_no_source(manager_none): + with pytest.raises(ValueError, match="Task configuration source not specified"): + asyncio.run(manager_none.load()) + +@patch('netsecgame.game.configuration_manager.ConfigParser') +def test_load_local(mock_config_parser, manager_local): + mock_parser_instance = MagicMock() + mock_config_parser.return_value = mock_parser_instance + mock_parser_instance.get_scenario.return_value = {"cyst": "objects"} + + asyncio.run(manager_local.load()) + + mock_config_parser.assert_called_once_with(task_config_file="dummy.yaml") + assert manager_local._parser == mock_parser_instance + assert manager_local._cyst_objects == {"cyst": "objects"} + assert manager_local._config_file_hash is not None + +@patch('netsecgame.game.configuration_manager.get_str_hash') +@patch('netsecgame.game.configuration_manager.Environment') +@patch('netsecgame.game.configuration_manager.ConfigParser') +def test_load_remote_success(mock_config_parser, mock_environment, mock_get_str_hash, manager_remote): + mock_env_instance = MagicMock() + mock_environment.create.return_value = mock_env_instance + mock_env_instance.configuration.general.load_configuration.return_value = {"cyst": "objects"} + + mock_parser_instance = MagicMock() + mock_config_parser.return_value = mock_parser_instance + + mock_get_str_hash.return_value = "mocked_hash" + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json.return_value = {"key": "value"} + mock_response.__aenter__.return_value = mock_response + + mock_session = MagicMock() + mock_session.get.return_value = mock_response + mock_session.__aenter__.return_value = mock_session + + with patch('netsecgame.game.configuration_manager.ClientSession', return_value=mock_session): + asyncio.run(manager_remote.load()) + + assert manager_remote._cyst_objects == {"cyst": "objects"} + mock_config_parser.assert_called_once_with(config_dict={"key": "value"}) + assert manager_remote._parser == mock_parser_instance + +@patch('netsecgame.game.configuration_manager.ConfigParser') +def test_load_remote_failure_with_fallback(mock_config_parser, manager_both): + mock_parser_instance = MagicMock() + mock_config_parser.return_value = mock_parser_instance + mock_parser_instance.get_scenario.return_value = {"cyst": "objects_local"} + + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.__aenter__.return_value = mock_response + + mock_session = MagicMock() + mock_session.get.return_value = mock_response + mock_session.__aenter__.return_value = mock_session + + with patch('netsecgame.game.configuration_manager.ClientSession', return_value=mock_session): + asyncio.run(manager_both.load()) + + # It should fall back to local configuration + mock_config_parser.assert_called_once_with(task_config_file="dummy.yaml") + assert manager_both._cyst_objects == {"cyst": "objects_local"} + +def test_load_remote_failure_no_fallback(manager_remote): + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.__aenter__.return_value = mock_response + + mock_session = MagicMock() + mock_session.get.return_value = mock_response + mock_session.__aenter__.return_value = mock_session + + with patch('netsecgame.game.configuration_manager.ClientSession', return_value=mock_session): + with pytest.raises(RuntimeError, match="Remote configuration fetch failed"): + asyncio.run(manager_remote.load()) + +def test_accessors_without_load(manager_local): + with pytest.raises(RuntimeError, match="Configuration not loaded."): + manager_local.get_starting_position("Attacker") + + with pytest.raises(RuntimeError, match="Configuration not loaded."): + manager_local.get_win_conditions("Attacker") + + with pytest.raises(RuntimeError, match="Configuration not loaded."): + manager_local.get_max_steps("Attacker") + + with pytest.raises(RuntimeError, match="Configuration not loaded."): + manager_local.get_use_dynamic_ips() + +@pytest.fixture +def loaded_manager(): + cm = ConfigurationManager(task_config_file="dummy.yaml") + cm._parser = MagicMock() + cm._cyst_objects = {"cyst": "data"} + cm._config_file_hash = "hash123" + return cm + +def test_get_cyst_objects(loaded_manager): + assert loaded_manager.get_cyst_objects() == {"cyst": "data"} + +def test_get_config_hash(loaded_manager): + assert loaded_manager.get_config_hash() == "hash123" + +def test_get_starting_position(loaded_manager): + loaded_manager._parser.get_start_position.return_value = {"pos": (0, 0)} + assert loaded_manager.get_starting_position("Attacker") == {"pos": (0, 0)} + loaded_manager._parser.get_start_position.assert_called_once_with(agent_role="Attacker") + +def test_get_use_firewall(loaded_manager): + loaded_manager._parser.get_use_firewall.return_value = True + assert loaded_manager.get_use_firewall() is True + loaded_manager._parser.get_use_firewall.assert_called_once() + +def test_get_required_num_players(loaded_manager): + loaded_manager._parser.get_required_num_players.return_value = 2 + assert loaded_manager.get_required_num_players() == 2 + loaded_manager._parser.get_required_num_players.assert_called_once() + +def test_get_all_starting_positions(loaded_manager): + def mock_get_start(agent_role): + if agent_role == AgentRole.Attacker: + return {"network": "10.0.0.0"} + raise KeyError + + loaded_manager._parser.get_start_position.side_effect = mock_get_start + + result = loaded_manager.get_all_starting_positions() + assert result[AgentRole.Attacker] == {"network": "10.0.0.0"} + assert result[AgentRole.Defender] == {} + assert result[AgentRole.Benign] == {} + +def test_get_all_max_steps(loaded_manager): + def mock_get_steps(agent_role): + if agent_role == AgentRole.Attacker: + return 100 + return None + + loaded_manager._parser.get_max_steps.side_effect = mock_get_steps + + result = loaded_manager.get_all_max_steps() + assert result[AgentRole.Attacker] == 100 + assert result[AgentRole.Defender] is None diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 069d782c..63cdd270 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -119,7 +119,7 @@ def test_get_logging_level(): def test_generate_valid_actions(sample_gamestate): actions = generate_valid_actions(sample_gamestate, include_blocks=True) - assert isinstance(actions, list) + assert isinstance(actions, set) assert len(actions) > 0 # Check for specific expected actions based on sample state # Controlled host is 10.0.0.1