Skip to content

Commit

Permalink
Add Cancel API to Measurement Plug-In Client (#870)
Browse files Browse the repository at this point in the history
* feat : add cancel API to measurement plugin client
  • Loading branch information
Jotheeswaran-Nandagopal authored Sep 13, 2024
1 parent 6d75114 commit e7251a3
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ class ${class_name}:

grpc_channel_pool: An optional gRPC channel pool.
"""
self._initialization_lock = threading.Lock()
self._initialization_lock = threading.RLock()
self._service_class = ${service_class | repr}
self._grpc_channel_pool = grpc_channel_pool
self._discovery_client = discovery_client
self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: Optional[
Generator[v2_measurement_service_pb2.MeasureResponse, None, None]
] = None
self._configuration_metadata = ${configuration_metadata}
self._output_metadata = ${output_metadata}
if grpc_channel is not None:
Expand Down Expand Up @@ -142,6 +145,7 @@ class ${class_name}:
configuration_parameters=serialized_configuration
)

% if output_metadata:
def _deserialize_response(
self, response: v2_measurement_service_pb2.MeasureResponse
) -> Outputs:
Expand All @@ -157,6 +161,7 @@ class ${class_name}:
result[k - 1] = v
return Outputs._make(result)

% endif
def measure(
self,
${configuration_parameters_with_type_and_default_values}
Expand All @@ -166,12 +171,16 @@ class ${class_name}:
Returns:
Measurement outputs.
"""
parameter_values = [${measure_api_parameters}]
request = self._create_measure_request(parameter_values)

for response in self._get_stub().Measure(request):
result = self._deserialize_response(response)
stream_measure_response = self.stream_measure(
${measure_api_parameters}
)
for response in stream_measure_response:
% if output_metadata:
result = response
return result
% else:
pass
% endif

def stream_measure(
self,
Expand All @@ -183,7 +192,33 @@ class ${class_name}:
Stream of measurement outputs.
"""
parameter_values = [${measure_api_parameters}]
request = self._create_measure_request(parameter_values)

for response in self._get_stub().Measure(request):
yield self._deserialize_response(response)
with self._initialization_lock:
if self._measure_response is not None:
raise RuntimeError(
"A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance."
)
request = self._create_measure_request(parameter_values)
self._measure_response = self._get_stub().Measure(request)

try:
for response in self._measure_response:
% if output_metadata:
yield self._deserialize_response(response)
% else:
yield
% endif
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.CANCELLED:
_logger.debug("The measurement is canceled.")
raise
finally:
with self._initialization_lock:
self._measure_response = None

def cancel(self) -> bool:
"""Cancels the active measurement call."""
with self._initialization_lock:
if self._measure_response:
return self._measure_response.cancel()
else:
return False
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test___measurement_plugin_client___measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
expected_output = output_type(
float_out=0.05999999865889549,
double_array_out=[0.1, 0.2, 0.3],
Expand All @@ -40,7 +40,7 @@ def test___measurement_plugin_client___stream_measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
expected_output = output_type(
float_out=0.05999999865889549,
double_array_out=[0.1, 0.2, 0.3],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import concurrent.futures
import importlib.util
import pathlib
from types import ModuleType
from typing import Generator

import grpc
import pytest
from ni_measurement_plugin_sdk_service.measurement.service import MeasurementService

Expand All @@ -15,7 +17,7 @@ def test___measurement_plugin_client___measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
expected_output = output_type(
name="<Name>",
index=9,
Expand All @@ -32,7 +34,7 @@ def test___measurement_plugin_client___stream_measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
measurement_plugin_client = test_measurement_client_type()

response_iterator = measurement_plugin_client.stream_measure()
Expand All @@ -49,6 +51,64 @@ def test___measurement_plugin_client___stream_measure___returns_output(
assert responses == expected_output


def test___measurement_plugin_client___invoke_measure_from_two_threads___initiates_first_measure_and_rejects_second_measure(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

with pytest.raises(RuntimeError) as exc_info:
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future_measure_1 = executor.submit(measurement_plugin_client.measure)
future_measure_2 = executor.submit(measurement_plugin_client.measure)
future_measure_1.result()
future_measure_2.result()

expected_error_message = "A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance."
assert expected_error_message in exc_info.value.args[0]


def test___non_streaming_measurement_execution___cancel___cancels_measurement(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

with pytest.raises(grpc.RpcError) as exc_info:
with concurrent.futures.ThreadPoolExecutor() as executor:
measure = executor.submit(measurement_plugin_client.measure)
measurement_plugin_client.cancel()
measure.result()

assert exc_info.value.code() == grpc.StatusCode.CANCELLED


def test___streaming_measurement_execution___cancel___cancels_measurement(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

with pytest.raises(grpc.RpcError) as exc_info:
with concurrent.futures.ThreadPoolExecutor() as executor:
measure = executor.submit(lambda: list(measurement_plugin_client.stream_measure()))
measurement_plugin_client.cancel()
measure.result()

assert exc_info.value.code() == grpc.StatusCode.CANCELLED


def test___measurement_client___cancel_without_measure___returns_false(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

is_canceled = measurement_plugin_client.cancel()

assert not is_canceled


@pytest.fixture(scope="module")
def measurement_client_directory(
tmp_path_factory: pytest.TempPathFactory,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Python Measurement Plug-In Client."""
"""Generated client API for the 'Non-Streaming Data Measurement (Py)' measurement plug-in."""

import logging
import threading
Expand Down Expand Up @@ -29,8 +29,8 @@
_V2_MEASUREMENT_SERVICE_INTERFACE = "ni.measurementlink.measurement.v2.MeasurementService"


class Output(NamedTuple):
"""Measurement result container."""
class Outputs(NamedTuple):
"""Outputs for the 'Non-Streaming Data Measurement (Py)' measurement plug-in."""

float_out: float
double_array_out: List[float]
Expand All @@ -46,7 +46,7 @@ class Output(NamedTuple):


class NonStreamingDataMeasurementClient:
"""Client to interact with the measurement plug-in."""
"""Client for the 'Non-Streaming Data Measurement (Py)' measurement plug-in."""

def __init__(
self,
Expand All @@ -64,11 +64,14 @@ def __init__(
grpc_channel_pool: An optional gRPC channel pool.
"""
self._initialization_lock = threading.Lock()
self._initialization_lock = threading.RLock()
self._service_class = "ni.tests.NonStreamingDataMeasurement_Python"
self._grpc_channel_pool = grpc_channel_pool
self._discovery_client = discovery_client
self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: Optional[
Generator[v2_measurement_service_pb2.MeasureResponse, None, None]
] = None
self._configuration_metadata = {
1: ParameterMetadata(
display_name="Float In",
Expand Down Expand Up @@ -366,7 +369,9 @@ def _create_measure_request(
configuration_parameters=serialized_configuration
)

def _deserialize_response(self, response: v2_measurement_service_pb2.MeasureResponse) -> Output:
def _deserialize_response(
self, response: v2_measurement_service_pb2.MeasureResponse
) -> Outputs:
if self._output_metadata:
result = [None] * max(self._output_metadata.keys())
else:
Expand All @@ -377,7 +382,7 @@ def _deserialize_response(self, response: v2_measurement_service_pb2.MeasureResp

for k, v in output_values.items():
result[k - 1] = v
return Output._make(result)
return Outputs._make(result)

def measure(
self,
Expand All @@ -391,13 +396,13 @@ def measure(
io_in: str = "resource",
io_array_in: List[str] = ["resource1", "resource2"],
integer_in: int = 10,
) -> Output:
"""Executes the Non-Streaming Data Measurement (Py).
) -> Outputs:
"""Perform a single measurement.
Returns:
Measurement output.
Measurement outputs.
"""
parameter_values = [
stream_measure_response = self.stream_measure(
float_in,
double_array_in,
bool_in,
Expand All @@ -408,11 +413,9 @@ def measure(
io_in,
io_array_in,
integer_in,
]
request = self._create_measure_request(parameter_values)

for response in self._get_stub().Measure(request):
result = self._deserialize_response(response)
)
for response in stream_measure_response:
result = response
return result

def stream_measure(
Expand All @@ -427,11 +430,11 @@ def stream_measure(
io_in: str = "resource",
io_array_in: List[str] = ["resource1", "resource2"],
integer_in: int = 10,
) -> Generator[Output, None, None]:
"""Executes the Non-Streaming Data Measurement (Py).
) -> Generator[Outputs, None, None]:
"""Perform a streaming measurement.
Returns:
Stream of measurement output.
Stream of measurement outputs.
"""
parameter_values = [
float_in,
Expand All @@ -445,7 +448,28 @@ def stream_measure(
io_array_in,
integer_in,
]
request = self._create_measure_request(parameter_values)
with self._initialization_lock:
if self._measure_response is not None:
raise RuntimeError(
"A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance."
)
request = self._create_measure_request(parameter_values)
self._measure_response = self._get_stub().Measure(request)
try:
for response in self._measure_response:
yield self._deserialize_response(response)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.CANCELLED:
_logger.debug("The measurement is canceled.")
raise
finally:
with self._initialization_lock:
self._measure_response = None

for response in self._get_stub().Measure(request):
yield self._deserialize_response(response)
def cancel(self) -> bool:
"""Cancels the active measurement call."""
with self._initialization_lock:
if self._measure_response:
return self._measure_response.cancel()
else:
return False

0 comments on commit e7251a3

Please sign in to comment.