diff --git a/netsecgame/game/worlds/CYSTCoordinator.py b/netsecgame/game/worlds/CYSTCoordinator.py index 788eace2..721c21e2 100644 --- a/netsecgame/game/worlds/CYSTCoordinator.py +++ b/netsecgame/game/worlds/CYSTCoordinator.py @@ -8,7 +8,8 @@ import logging import argparse from pathlib import Path -from netsecgame.game_components import GameState, Action, ActionType, IP, Service ,Network +from typing import Tuple +from netsecgame.game_components import GameState, Action, ActionType, IP, Service ,Network, AgentRole from netsecgame.game.coordinator import GameCoordinator from cyst.api.configuration.network.node import NodeConfig @@ -36,7 +37,7 @@ def get_starting_position_from_cyst_config(cyst_objects): hosts.add(IP(str(interface.ip))) net_ip, net_mask = str(interface.net).split("/") networks.add(Network(net_ip,int(net_mask))) - starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks} + starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks} return starting_positions class CYSTCoordinator(GameCoordinator): @@ -52,7 +53,7 @@ def __init__(self, game_host:str, game_port:int, service_host:str, service_port: self._starting_positions = None self._availabe_cyst_agents = None - def get_cyst_id(self, agent_role:str): + def get_cyst_id(self, agent_role:AgentRole): """ Returns ID of the CYST agent based on the agent's role. """ @@ -62,12 +63,11 @@ def get_cyst_id(self, agent_role:str): cyst_id = None return cyst_id - async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + async def register_agent(self, agent_id:tuple, agent_role:AgentRole, agent_initial_view:dict, agent_win_condition_view:dict)->Tuple[GameState, GameState]: self.logger.debug(f"Registering agent {agent_id} in the world.") - agent_role = "Attacker" if not self._starting_positions: self._starting_positions = get_starting_position_from_cyst_config(self._cyst_objects) - self._availabe_cyst_agents = {"Attacker":set([k for k in self._starting_positions.keys()])} + self._availabe_cyst_agents = {AgentRole.Attacker:set([k for k in self._starting_positions.keys()])} async with self._agents_lock: cyst_id = self.get_cyst_id(agent_role) if cyst_id: @@ -76,9 +76,10 @@ async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_vie self._known_agent_roles[agent_id] = agent_role kh = self._starting_positions[cyst_id]["known_hosts"] kn = self._starting_positions[cyst_id]["known_networks"] - return GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn) + state = GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn) + return state, GameState() else: - return None + return None, None async def remove_agent(self, agent_id, agent_state:GameState)->bool: print(f"Removing agent {agent_id} from the CYST World") @@ -203,11 +204,12 @@ async def _execute_exfiltrate_data_action(self, agent_id:tuple, agent_state: Gam async def _execute_block_ip_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState: raise NotImplementedError - async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + async def reset_agent(self, agent_id:tuple, agent_role:AgentRole, agent_initial_view:dict, agent_win_condition_view:dict)->Tuple[GameState, GameState]: cyst_id = self._id_to_cystid[agent_id] kh = self._starting_positions[cyst_id]["known_hosts"] kn = self._starting_positions[cyst_id]["known_networks"] - return GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn) + state = GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn) + return state, GameState() async def reset(self)->bool: return True