Skip to content

Commit

Permalink
Merge pull request #241 from stratosphereips/cyst-integration
Browse files Browse the repository at this point in the history
Add blocking, Optional global defender and bug fixes
  • Loading branch information
ondrej-lukas authored Oct 24, 2024
2 parents d4222f7 + ab9dced commit 6d2bd21
Show file tree
Hide file tree
Showing 20 changed files with 1,231 additions and 700 deletions.
18 changes: 18 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.git
.github
.gitignore
.gitmodules
.pytest_cache
.ruff_cache
.vscode
docs/
figures/
mlruns/
tests/
trajectories/
NetSecGameAgents/
notebooks/
readme_images/
tests/
*trajectories*.json
README.md
29 changes: 29 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Use an official Python 3.12 runtime as a parent image
FROM python:3.12-slim

# Set the working directory in the container
ENV DESTINATION_DIR=/aidojo


# Install system dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
git \
build-essential \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip

COPY . ${DESTINATION_DIR}/

# Set the working directory in the container
WORKDIR ${DESTINATION_DIR}

# Install any necessary Python dependencies
# If a requirements.txt file is in the repository
RUN if [ -f requirements.txt ]; then pip install --no-cache-dir -r requirements.txt; fi

# change the server ip to 0.0.0.0
RUN sed -i 's/"host": "127.0.0.1"/"host": "0.0.0.0"/' coordinator.conf

# Run the Python script when the container launches
CMD ["python3", "coordinator.py"]
130 changes: 66 additions & 64 deletions README.md

Large diffs are not rendered by default.

85 changes: 67 additions & 18 deletions coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import json
import asyncio
from datetime import datetime
from env.network_security_game import NetworkSecurityEnvironment
from env.worlds.network_security_game import NetworkSecurityEnvironment
from env.worlds.network_security_game_real_world import NetworkSecurityEnvironmentRealWorld
from env.worlds.aidojo_world import AIDojoWorld
from env.game_components import Action, Observation, ActionType, GameStatus, GameState
from utils.utils import observation_as_dict, get_logging_level
from pathlib import Path
import os
import signal
from env.global_defender import stochastic_with_threshold

class AIDojo:
def __init__(self, host: str, port: int, net_sec_config: str, world_type) -> None:
Expand Down Expand Up @@ -90,7 +93,14 @@ def __init__(self, actions_queue, answers_queue, max_connections):
self.max_connections = max_connections
self.current_connections = 0
self.logger = logging.getLogger("AIDojo-Server")
self._stop = False

def close(self)->None:
self.logger.info(
"Stopping server"
)
self._stop = True

async def handle_new_agent(self, reader, writer):
async def send_data_to_agent(writer, data: str) -> None:
"""
Expand All @@ -113,7 +123,7 @@ async def send_data_to_agent(writer, data: str) -> None:
try:
addr = writer.get_extra_info("peername")
self.logger.info(f"New agent connected: {addr}")
while True:
while not self._stop:
data = await reader.read(500)
raw_message = data.decode().strip()
if len(raw_message):
Expand Down Expand Up @@ -151,23 +161,36 @@ async def send_data_to_agent(writer, data: str) -> None:
finally:
# Decrement the count of current connections
self.current_connections -= 1
writer.close()
return

async def __call__(self, reader, writer):
await self.handle_new_agent(reader, writer)


class Coordinator:
def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles, world_type="netsecenv"):
# communication channels for asyncio
self._actions_queue = actions_queue
self._answers_queue = answers_queue
self.ALLOWED_ROLES = allowed_roles
self.logger = logging.getLogger("AIDojo-Coordinator")
self._world = NetworkSecurityEnvironment(net_sec_config)
# world definition
match world_type:
case "netsecenv":
self._world = NetworkSecurityEnvironment(net_sec_config)
case "netsecenv-real-world":
self._world = NetworkSecurityEnvironmentRealWorld(net_sec_config)
case _:
self._world = AIDojoWorld(net_sec_config)
self.world_type = world_type



self._starting_positions_per_role = self._get_starting_position_per_role()
self._win_conditions_per_role = self._get_win_condition_per_role()
self._goal_description_per_role = self._get_goal_description_per_role()
self._steps_limit = self._world.task_config.get_max_steps()

self._use_global_defender = self._world.task_config.get_use_global_defender()
# player information
self.agents = {}
# step counter per agent_addr (int)
Expand All @@ -182,13 +205,15 @@ def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles,
# goal reach status per agent_addr (bool)
self._agent_goal_reached = {}
self._agent_episode_ends = {}
self._agent_detected = {}
# trajectories per agent_addr
self._agent_trajectories = {}

@property
def episode_end(self)->bool:
# Terminate episode if at least one player wins or reaches the timeout
return any(self._agent_episode_ends.values())
self.logger.debug(f"End evaluation: {self._agent_episode_ends.values()}")
return all(self._agent_episode_ends.values())

def convert_msg_dict_to_json(self, msg_dict)->str:
try:
Expand Down Expand Up @@ -237,7 +262,7 @@ async def run(self):
self.logger.info(f"Coordinator received from RESET request from agent {agent_addr}")
if all(self._reset_requests.values()):
# should we discard the queue here?
self.logger.info(f"All agents requested reset, action_q:{self._actions_queue.empty()}, answers_q{self._answers_queue.empty()}")
self.logger.info(f"All agents requested reset, action_q:{self._actions_queue.empty()}, answers_q:{self._answers_queue.empty()}")
self._world.reset()
self._get_goal_description_per_role()
self._get_win_condition_per_role()
Expand Down Expand Up @@ -265,7 +290,7 @@ async def run(self):
except asyncio.CancelledError:
self.logger.info("\tTerminating by CancelledError")
except Exception as e:
self.logger.error(f"Exception in main_coordinator(): {e}")
self.logger.error(f"Exception in Class coordinator(): {e}")
raise e

def _initialize_new_player(self, agent_addr:tuple, agent_name:str, agent_role:str) -> Observation:
Expand All @@ -280,8 +305,9 @@ def _initialize_new_player(self, agent_addr:tuple, agent_name:str, agent_role:st
self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role]
self._agent_states[agent_addr] = self._world.create_state_from_view(self._agent_starting_position[agent_addr])
self._agent_goal_reached[agent_addr] = self._goal_reached(agent_addr)
self._agent_detected[agent_addr] = self._check_detection(agent_addr, None)
self._agent_episode_ends[agent_addr] = False
if self._world.task_config.get_store_trajectories():
if self._world.task_config.get_store_trajectories() or self._use_global_defender:
self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}")
return Observation(self._agent_states[agent_addr], 0, False, {})
Expand Down Expand Up @@ -324,7 +350,7 @@ def _get_win_condition_per_role(self)-> dict:
win_conditions = {}
for agent_role in self.ALLOWED_ROLES:
try:
win_conditions[agent_role] = self._world.re_map_goal_dict(
win_conditions[agent_role] = self._world.update_goal_dict(
self._world.task_config.get_win_conditions(agent_role=agent_role)
)
except KeyError:
Expand Down Expand Up @@ -394,10 +420,11 @@ def _create_response_to_reset_game_action(self, agent_addr: tuple) -> dict:
f"Coordinator responding to RESET request from agent {agent_addr}"
)
# store trajectory in file if needed
self._store_trajectory_to_file(agent_addr)
if self._world.task_config.get_store_trajectories():
self._store_trajectory_to_file(agent_addr)
new_observation = Observation(self._agent_states[agent_addr], 0, self.episode_end, {})
# reset trajectory
self._reset_trajectory(agent_addr)
self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
output_message_dict = {
"to_agent": agent_addr,
"status": str(GameStatus.OK),
Expand Down Expand Up @@ -458,9 +485,11 @@ def _process_generic_action(self, agent_addr: tuple, action: Action) -> dict:

current_state = self._agent_states[agent_addr]
# Build new Observation for the agent
self._agent_states[agent_addr] = self._world.step(current_state, action, agent_addr, self.world_type)
self._agent_states[agent_addr] = self._world.step(current_state, action, agent_addr)
self._agent_goal_reached[agent_addr] = self._goal_reached(agent_addr)

self._agent_detected[agent_addr] = self._check_detection(agent_addr, action)

reward = self._world._rewards["step"]
obs_info = {}
end_reason = None
Expand All @@ -473,6 +502,11 @@ def _process_generic_action(self, agent_addr: tuple, action: Action) -> dict:
self._agent_episode_ends[agent_addr] = True
obs_info = {"end_reason": "max_steps"}
end_reason = "max_steps"
elif self._agent_detected[agent_addr]:
reward += self._world._rewards["detection"]
self._agent_episode_ends[agent_addr] = True
obs_info = {"end_reason": "max_steps"}

# record step in trajecory
self._add_step_to_trajectory(agent_addr, action, reward,self._agent_states[agent_addr], end_reason)
new_observation = Observation(self._agent_states[agent_addr], reward, self.episode_end, info=obs_info)
Expand Down Expand Up @@ -522,10 +556,12 @@ def _goal_reached(self, agent_addr:tuple)->bool:
self.logger.info(f"Goal check for {agent_addr}({self.agents[agent_addr][1]})")
agents_state = self._agent_states[agent_addr]
agent_role = self.agents[agent_addr][1]
win_condition = self._world.re_map_goal_dict(self._win_conditions_per_role[agent_role])
win_condition = self._world.update_goal_dict(self._win_conditions_per_role[agent_role])
goal_check = self._check_goal(agents_state, win_condition)
if goal_check:
self.logger.info("\tGoal reached!")
else:
self.logger.info("\tGoal not reached!")
return goal_check

def _check_goal(self, state:GameState, goal_conditions:dict)->bool:
Expand Down Expand Up @@ -556,11 +592,24 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
goal_reached["controlled_hosts"] = set(goal_conditions["controlled_hosts"]) <= set(state.controlled_hosts)
goal_reached["services"] = goal_dict_satistfied(goal_conditions["known_services"], state.known_services)
goal_reached["data"] = goal_dict_satistfied(goal_conditions["known_data"], state.known_data)
goal_reached["known_blocks"] = goal_dict_satistfied(goal_conditions["known_blocks"], state.known_blocks)
self.logger.debug(f"\t{goal_reached}")
return all(goal_reached.values())


__version__ = "v0.2.1"
def _check_detection(self, agent_addr:tuple, last_action:Action)->bool:
self.logger.info(f"Detection check for {agent_addr}({self.agents[agent_addr][1]})")
detection = False
if last_action:
if self._use_global_defender:
self.logger.warning("Global defender - ONLY use for backward compatibility!")
episode_actions = self._agent_trajectories[agent_addr]["actions"] if "actions" in self._agent_trajectories[agent_addr] else []
detection = stochastic_with_threshold(last_action, episode_actions)
if detection:
self.logger.info("\tDetected!")
else:
self.logger.info("\tNot detected!")
return detection
__version__ = "v0.2.2"


if __name__ == "__main__":
Expand Down Expand Up @@ -601,7 +650,7 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
action="store",
required=False,
type=str,
default="WARNING",
default="INFO",
)

args = parser.parse_args()
Expand Down Expand Up @@ -641,4 +690,4 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
# Create AI Dojo
ai_dojo = AIDojo(host, port, task_config_file, world_type)
# Run it!
ai_dojo.run()
ai_dojo.run()
1 change: 1 addition & 0 deletions docs/Components.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ GameState is an object that represents a view of the NetSecGame environment in a
- `known_services`: Dictionary of services that the agent is aware of.
The dictionary format: {`IP`: {`Service`}} where [IP](#ip) object is a key and the value is a set of [Service](#service) objects located in the `IP`.
- `known_data`: Dictionary of data instances that the agent is aware of. The dictionary format: {`IP`: {`Data`}} where [IP](#ip) object is a key and the value is a set of [Data](#data) objects located in the `IP`.
- `known_blocks`: Dictionary of firewall blocks the agent is aware of. It is a dictionary with format: {`target_IP`: {`blocked_IP`, `blocked_IP`}}. Where `target_IP` is the [IP](#ip) where the FW rule was applied (usually a router) and `blocked_IP` is the IP address that is blocked. For now the blocks happen in both input and output direction simultaneously.


## Actions
Expand Down
Loading

0 comments on commit 6d2bd21

Please sign in to comment.