Skip to content

Commit

Permalink
Adding more logs
Browse files Browse the repository at this point in the history
  • Loading branch information
abuabraham-ttd committed Dec 11, 2024
1 parent 62cc490 commit 5de70be
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
26 changes: 16 additions & 10 deletions scripts/aws/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from confidential_compute import ConfidentialCompute, ConfidentialComputeConfig, SecretNotFoundException


class AWSConfidentialComputeConfig(ConfidentialComputeConfig):
enclave_memory_mb: int
enclave_cpu_count: int


class EC2(ConfidentialCompute):

def __init__(self):
super().__init__()
self.aws_metadata = "169.254.169.254"

def __get_aws_token(self) -> str:
"""Fetches a temporary AWS EC2 metadata token."""
try:
token_url = "http://169.254.169.254/latest/api/token"
token_url = f"http://{self.aws_metadata}/latest/api/token"
response = requests.put(
token_url, headers={"X-aws-ec2-metadata-token-ttl-seconds": "3600"}, timeout=2
)
Expand All @@ -43,7 +42,7 @@ def __get_aws_token(self) -> str:
def __get_current_region(self) -> str:
"""Fetches the current AWS region from EC2 instance metadata."""
token = self.__get_aws_token()
metadata_url = "http://169.254.169.254/latest/dynamic/instance-identity/document"
metadata_url = f"http://{self.aws_metadata}/latest/dynamic/instance-identity/document"
headers = {"X-aws-ec2-metadata-token": token}
try:
response = requests.get(metadata_url, headers=headers, timeout=2)
Expand All @@ -55,12 +54,14 @@ def __get_current_region(self) -> str:
def __validate_aws_specific_config(self, secret):
if "enclave_memory_mb" in secret or "enclave_cpu_count" in secret:
max_capacity = self.__get_max_capacity()

for key in ["enclave_memory_mb", "enclave_cpu_count"]:
if int(secret.get(key, 0)) > max_capacity.get(key):
raise ValueError(f"{key} value ({secret.get(key, 0)}) exceeds the maximum allowed ({max_capacity.get(key)}).")


def _get_secret(self, secret_identifier: str) -> AWSConfidentialComputeConfig:
"""Fetches a secret value from AWS Secrets Manager."""
"""Fetches a secret value from AWS Secrets Manager and adds defaults"""

def add_defaults(configs: Dict[str, any]) -> AWSConfidentialComputeConfig:
"""Adds default values to configuration if missing."""
Expand All @@ -71,6 +72,7 @@ def add_defaults(configs: Dict[str, any]) -> AWSConfidentialComputeConfig:
return configs

region = self.__get_current_region()
print(f"Running in {region}")
try:
client = boto3.client("secretsmanager", region_name=region)
except Exception as e:
Expand Down Expand Up @@ -124,7 +126,7 @@ def __run_socks_proxy(self) -> None:
def __get_secret_name_from_userdata(self) -> str:
"""Extracts the secret name from EC2 user data."""
token = self.__get_aws_token()
user_data_url = "http://169.254.169.254/latest/user-data"
user_data_url = f"http://{self.aws_metadata}/latest/user-data"
response = requests.get(user_data_url, headers={"X-aws-ec2-metadata-token": token})
user_data = response.text

Expand All @@ -137,17 +139,15 @@ def __get_secret_name_from_userdata(self) -> str:
return match.group(1) if match else default_name

def _setup_auxiliaries(self) -> None:
"""Sets up the necessary auxiliary services and configuration."""
self.configs = self._get_secret(self.__get_secret_name_from_userdata())
self.validate_configuration()
"""Sets up the vsock tunnel, socks proxy and flask server"""
log_level = 3 if self.configs["debug_mode"] else 1
self.__setup_vsockproxy(log_level)
self.__run_config_server()
self.__run_socks_proxy()
time.sleep(5) #TODO: Change to while loop if required.

def _validate_auxiliaries(self) -> None:
"""Validates auxiliary services."""
"""Validates connection to flask server direct and through socks proxy."""
proxy = "socks5://127.0.0.1:3306"
config_url = "http://127.0.0.1:27015/getConfig"
try:
Expand All @@ -161,9 +161,14 @@ def _validate_auxiliaries(self) -> None:
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(f"Cannot connect to config server via SOCKS proxy: {e}")
print("Connectivity check to config server passes")

def run_compute(self) -> None:
"""Main execution flow for confidential compute."""
secret_manager_key = self.__get_secret_name_from_userdata()
self.configs = self._get_secret(secret_manager_key)
print(f"Fetched configs from {secret_manager_key}")
self.validate_configuration()
self._setup_auxiliaries()
self._validate_auxiliaries()
command = [
Expand All @@ -175,6 +180,7 @@ def run_compute(self) -> None:
"--enclave-name", "uid2operator"
]
if self.configs["debug_mode"]:
print("Running in debug_mode")
command += ["--debug-mode", "--attach-console"]
self.run_command(command)

Expand Down
9 changes: 8 additions & 1 deletion scripts/confidential_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def validate_operator_key():
raise ValueError(
f"Operator key does not match the expected environment ({expected_env})."
)
return True
print("Validated operator key matches environment")
else:
print("Skipping operator key validation")

def validate_url(url_key, environment):
"""URL should include environment except in prod"""
Expand All @@ -48,6 +50,8 @@ def validate_url(url_key, environment):
raise ValueError(
f"{url_key} is invalid. Ensure {self.configs[url_key]} follows HTTPS, and doesn't have any path specified."
)
print(f"Validated {self.configs[url_key]} matches other config parameters")


def validate_connectivity() -> None:
""" Validates that the core and opt-out URLs are accessible."""
Expand All @@ -56,8 +60,10 @@ def validate_connectivity() -> None:
optout_url = self.configs["optout_base_url"]
core_ip = socket.gethostbyname(urlparse(core_url).netloc)
requests.get(core_url, timeout=5)
print(f"Validated connectivity to {core_url}")
optout_ip = socket.gethostbyname(urlparse(optout_url).netloc)
requests.get(optout_url, timeout=5)
print(f"Validated connectivity to {optout_url}")
except (requests.ConnectionError, requests.Timeout) as e:
raise Exception(
f"Failed to reach required URLs. Consider enabling {core_ip}, {optout_ip} in the egress firewall."
Expand All @@ -79,6 +85,7 @@ def validate_connectivity() -> None:
validate_url("optout_base_url", environment)
validate_operator_key()
validate_connectivity()
print("Completed static validation of confidential compute config values")


@abstractmethod
Expand Down

0 comments on commit 5de70be

Please sign in to comment.