Skip to content

Commit

Permalink
fix: add default sites
Browse files Browse the repository at this point in the history
  • Loading branch information
Jotheeswaran-Nandagopal committed Sep 23, 2024
1 parent 0a2d00e commit c8247ca
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ${class_name}:
if grpc_channel is not None:
self._stub = v2_measurement_service_pb2_grpc.MeasurementServiceStub(grpc_channel)
self._create_file_descriptor()
self._pin_map_context: Optional[PinMapContext] = None
self._pin_map_context: PinMapContext = PinMapContext(pin_map_id="", sites=[0])

@property
def pin_map_context(self) -> PinMapContext:
Expand All @@ -136,8 +136,7 @@ class ${class_name}:
@property
def sites(self) -> List[int]:
"""The sites where the measurement must be executed."""
if self._pin_map_context is not None:
return self._pin_map_context.sites
return self._pin_map_context.sites

@sites.setter
def sites(self, val: List[int]) -> None:
Expand Down Expand Up @@ -206,7 +205,7 @@ class ${class_name}:
)
return v2_measurement_service_pb2.MeasureRequest(
configuration_parameters=serialized_configuration,
pin_map_context=self._pin_map_context._to_grpc() if self._pin_map_context else None,
pin_map_context=self._pin_map_context._to_grpc(),
)

% if output_metadata:
Expand Down Expand Up @@ -298,10 +297,7 @@ class ${class_name}:
pin_map_path: Absolute path of the pin map file.
"""
pin_map_id = self._get_pin_map_client().update_pin_map(pin_map_path)
if self._pin_map_context is None:
self._pin_map_context = PinMapContext(pin_map_id=pin_map_id, sites=[0])
else:
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)

% if "from pathlib import Path" in built_in_import_modules:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __init__(
if grpc_channel is not None:
self._stub = v2_measurement_service_pb2_grpc.MeasurementServiceStub(grpc_channel)
self._create_file_descriptor()
self._pin_map_context: Optional[PinMapContext] = None
self._pin_map_context: PinMapContext = PinMapContext(pin_map_id="", sites=[0])

@property
def pin_map_context(self) -> PinMapContext:
Expand All @@ -438,8 +438,7 @@ def pin_map_context(self, val: PinMapContext) -> None:
@property
def sites(self) -> List[int]:
"""The sites where the measurement must be executed."""
if self._pin_map_context is not None:
return self._pin_map_context.sites
return self._pin_map_context.sites

@sites.setter
def sites(self, val: List[int]) -> None:
Expand Down Expand Up @@ -510,7 +509,7 @@ def _create_measure_request(
)
return v2_measurement_service_pb2.MeasureRequest(
configuration_parameters=serialized_configuration,
pin_map_context=self._pin_map_context._to_grpc() if self._pin_map_context else None,
pin_map_context=self._pin_map_context._to_grpc(),
)

def _deserialize_response(
Expand Down Expand Up @@ -667,10 +666,7 @@ def register_pin_map(self, pin_map_path: Path) -> None:
pin_map_path: Absolute path of the pin map file.
"""
pin_map_id = self._get_pin_map_client().update_pin_map(pin_map_path)
if self._pin_map_context is None:
self._pin_map_context = PinMapContext(pin_map_id=pin_map_id, sites=[0])
else:
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)


def _convert_paths_to_strings(parameter_values: Iterable[Any]) -> List[Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
if grpc_channel is not None:
self._stub = v2_measurement_service_pb2_grpc.MeasurementServiceStub(grpc_channel)
self._create_file_descriptor()
self._pin_map_context: Optional[PinMapContext] = None
self._pin_map_context: PinMapContext = PinMapContext(pin_map_id="", sites=[0])

@property
def pin_map_context(self) -> PinMapContext:
Expand All @@ -91,8 +91,7 @@ def pin_map_context(self, val: PinMapContext) -> None:
@property
def sites(self) -> List[int]:
"""The sites where the measurement must be executed."""
if self._pin_map_context is not None:
return self._pin_map_context.sites
return self._pin_map_context.sites

@sites.setter
def sites(self, val: List[int]) -> None:
Expand Down Expand Up @@ -163,7 +162,7 @@ def _create_measure_request(
)
return v2_measurement_service_pb2.MeasureRequest(
configuration_parameters=serialized_configuration,
pin_map_context=self._pin_map_context._to_grpc() if self._pin_map_context else None,
pin_map_context=self._pin_map_context._to_grpc(),
)

def measure(self, integer_in: int = 10) -> None:
Expand Down Expand Up @@ -216,7 +215,4 @@ def register_pin_map(self, pin_map_path: Path) -> None:
pin_map_path: Absolute path of the pin map file.
"""
pin_map_id = self._get_pin_map_client().update_pin_map(pin_map_path)
if self._pin_map_context is None:
self._pin_map_context = PinMapContext(pin_map_id=pin_map_id, sites=[0])
else:
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)

0 comments on commit c8247ca

Please sign in to comment.