Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cancel API to Measurement Plug-In Client #870

Merged
merged 16 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ class ${class_name}:

grpc_channel_pool: An optional gRPC channel pool.
"""
self._initialization_lock = threading.Lock()
self._initialization_lock = threading.RLock()
self._service_class = ${service_class | repr}
self._grpc_channel_pool = grpc_channel_pool
self._discovery_client = discovery_client
self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: Optional[
Generator[v2_measurement_service_pb2.MeasureResponse, None, None]
] = None
self._configuration_metadata = ${configuration_metadata}
self._output_metadata = ${output_metadata}
if grpc_channel is not None:
Expand Down Expand Up @@ -142,6 +145,7 @@ class ${class_name}:
configuration_parameters=serialized_configuration
)

% if output_metadata:
def _deserialize_response(
self, response: v2_measurement_service_pb2.MeasureResponse
) -> Outputs:
Expand All @@ -157,6 +161,7 @@ class ${class_name}:
result[k - 1] = v
return Outputs._make(result)

% endif
def measure(
self,
${configuration_parameters_with_type_and_default_values}
Expand All @@ -166,12 +171,16 @@ class ${class_name}:
Returns:
Measurement outputs.
"""
parameter_values = [${measure_api_parameters}]
request = self._create_measure_request(parameter_values)

for response in self._get_stub().Measure(request):
result = self._deserialize_response(response)
stream_measure_response = self.stream_measure(
${measure_api_parameters}
)
for response in stream_measure_response:
% if output_metadata:
result = response
return result
% else:
pass
% endif

def stream_measure(
self,
Expand All @@ -183,7 +192,33 @@ class ${class_name}:
Stream of measurement outputs.
"""
parameter_values = [${measure_api_parameters}]
request = self._create_measure_request(parameter_values)

for response in self._get_stub().Measure(request):
yield self._deserialize_response(response)
with self._initialization_lock:
if self._measure_response is not None:
raise RuntimeError(
"A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance."
)
request = self._create_measure_request(parameter_values)
self._measure_response = self._get_stub().Measure(request)

try:
for response in self._measure_response:
% if output_metadata:
yield self._deserialize_response(response)
% else:
yield
% endif
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.CANCELLED:
_logger.debug("The measurement is canceled.")
raise
finally:
with self._initialization_lock:
self._measure_response = None

def cancel(self) -> bool:
Jotheeswaran-Nandagopal marked this conversation as resolved.
Show resolved Hide resolved
"""Cancels the active measure call."""
Jotheeswaran-Nandagopal marked this conversation as resolved.
Show resolved Hide resolved
with self._initialization_lock:
if self._measure_response:
return self._measure_response.cancel()
else:
return False
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test___measurement_plugin_client___measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
expected_output = output_type(
float_out=0.05999999865889549,
double_array_out=[0.1, 0.2, 0.3],
Expand All @@ -40,7 +40,7 @@ def test___measurement_plugin_client___stream_measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
expected_output = output_type(
float_out=0.05999999865889549,
double_array_out=[0.1, 0.2, 0.3],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import concurrent.futures
import importlib.util
import pathlib
from types import ModuleType
from typing import Generator

import grpc
import pytest
from ni_measurement_plugin_sdk_service.measurement.service import MeasurementService

Expand All @@ -15,7 +17,7 @@ def test___measurement_plugin_client___measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
expected_output = output_type(
name="<Name>",
index=9,
Expand All @@ -32,7 +34,7 @@ def test___measurement_plugin_client___stream_measure___returns_output(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
output_type = getattr(measurement_plugin_client_module, "Output")
output_type = getattr(measurement_plugin_client_module, "Outputs")
measurement_plugin_client = test_measurement_client_type()

response_iterator = measurement_plugin_client.stream_measure()
Expand All @@ -49,6 +51,64 @@ def test___measurement_plugin_client___stream_measure___returns_output(
assert responses == expected_output


def test___measurement_plugin_client___invoke_measure_from_two_threads___initiates_first_measure_and_rejects_second_measure(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

with pytest.raises(RuntimeError) as exc_info:
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future_measure_1 = executor.submit(measurement_plugin_client.measure)
future_measure_2 = executor.submit(measurement_plugin_client.measure)
future_measure_1.result()
future_measure_2.result()

expected_error_message = "A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance."
assert expected_error_message in exc_info.value.args[0]


def test___non_streaming_measurement_execution___cancel___cancels_measurement(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

with pytest.raises(grpc.RpcError) as exc_info:
with concurrent.futures.ThreadPoolExecutor() as executor:
measure = executor.submit(measurement_plugin_client.measure)
measurement_plugin_client.cancel()
measure.result()

assert exc_info.value.code() == grpc.StatusCode.CANCELLED


def test___streaming_measurement_execution___cancel___cancels_measurement(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

with pytest.raises(grpc.RpcError) as exc_info:
with concurrent.futures.ThreadPoolExecutor() as executor:
measure = executor.submit(lambda: list(measurement_plugin_client.stream_measure()))
measurement_plugin_client.cancel()
measure.result()

assert exc_info.value.code() == grpc.StatusCode.CANCELLED


def test___measurement_client___cancel_without_measure___returns_false(
measurement_plugin_client_module: ModuleType,
) -> None:
test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement")
measurement_plugin_client = test_measurement_client_type()

is_canceled = measurement_plugin_client.cancel()

assert not is_canceled


@pytest.fixture(scope="module")
def measurement_client_directory(
tmp_path_factory: pytest.TempPathFactory,
Expand Down
Jotheeswaran-Nandagopal marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Python Measurement Plug-In Client."""
"""Generated client API for the 'Non-Streaming Data Measurement (Py)' measurement plug-in."""

import logging
import threading
Expand Down Expand Up @@ -29,8 +29,8 @@
_V2_MEASUREMENT_SERVICE_INTERFACE = "ni.measurementlink.measurement.v2.MeasurementService"


class Output(NamedTuple):
"""Measurement result container."""
class Outputs(NamedTuple):
"""Outputs for the 'Non-Streaming Data Measurement (Py)' measurement plug-in."""

float_out: float
double_array_out: List[float]
Expand All @@ -46,7 +46,7 @@ class Output(NamedTuple):


class NonStreamingDataMeasurementClient:
"""Client to interact with the measurement plug-in."""
"""Client for the 'Non-Streaming Data Measurement (Py)' measurement plug-in."""

def __init__(
self,
Expand All @@ -64,11 +64,14 @@ def __init__(

grpc_channel_pool: An optional gRPC channel pool.
"""
self._initialization_lock = threading.Lock()
self._initialization_lock = threading.RLock()
self._service_class = "ni.tests.NonStreamingDataMeasurement_Python"
self._grpc_channel_pool = grpc_channel_pool
self._discovery_client = discovery_client
self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: Optional[
Generator[v2_measurement_service_pb2.MeasureResponse, None, None]
] = None
self._configuration_metadata = {
1: ParameterMetadata(
display_name="Float In",
Expand Down Expand Up @@ -366,7 +369,9 @@ def _create_measure_request(
configuration_parameters=serialized_configuration
)

def _deserialize_response(self, response: v2_measurement_service_pb2.MeasureResponse) -> Output:
def _deserialize_response(
self, response: v2_measurement_service_pb2.MeasureResponse
) -> Outputs:
if self._output_metadata:
result = [None] * max(self._output_metadata.keys())
else:
Expand All @@ -377,7 +382,7 @@ def _deserialize_response(self, response: v2_measurement_service_pb2.MeasureResp

for k, v in output_values.items():
result[k - 1] = v
return Output._make(result)
return Outputs._make(result)

def measure(
self,
Expand All @@ -391,13 +396,13 @@ def measure(
io_in: str = "resource",
io_array_in: List[str] = ["resource1", "resource2"],
integer_in: int = 10,
) -> Output:
"""Executes the Non-Streaming Data Measurement (Py).
) -> Outputs:
"""Perform a single measurement.

Returns:
Measurement output.
Measurement outputs.
"""
parameter_values = [
stream_measure_response = self.stream_measure(
float_in,
double_array_in,
bool_in,
Expand All @@ -408,11 +413,9 @@ def measure(
io_in,
io_array_in,
integer_in,
]
request = self._create_measure_request(parameter_values)

for response in self._get_stub().Measure(request):
result = self._deserialize_response(response)
)
for response in stream_measure_response:
result = response
return result

def stream_measure(
Expand All @@ -427,11 +430,11 @@ def stream_measure(
io_in: str = "resource",
io_array_in: List[str] = ["resource1", "resource2"],
integer_in: int = 10,
) -> Generator[Output, None, None]:
"""Executes the Non-Streaming Data Measurement (Py).
) -> Generator[Outputs, None, None]:
"""Perform a streaming measurement.

Returns:
Stream of measurement output.
Stream of measurement outputs.
"""
parameter_values = [
float_in,
Expand All @@ -445,7 +448,28 @@ def stream_measure(
io_array_in,
integer_in,
]
request = self._create_measure_request(parameter_values)
with self._initialization_lock:
if self._measure_response is not None:
raise RuntimeError(
"A measurement is currently in progress. To make concurrent measurement requests, please create a new client instance."
)
request = self._create_measure_request(parameter_values)
self._measure_response = self._get_stub().Measure(request)
try:
for response in self._measure_response:
yield self._deserialize_response(response)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.CANCELLED:
_logger.debug("The measurement is canceled.")
raise
finally:
with self._initialization_lock:
self._measure_response = None

for response in self._get_stub().Measure(request):
yield self._deserialize_response(response)
def cancel(self) -> bool:
"""Cancels the active measure call."""
with self._initialization_lock:
if self._measure_response:
return self._measure_response.cancel()
else:
return False