Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

service: Add SessionManagementClient to MeasurementService and rework initialization code #386

Merged
merged 12 commits into from
Sep 20, 2023
Merged
28 changes: 19 additions & 9 deletions ni_measurementlink_service/_internal/discovery_client.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
""" Contains API to register and un-register measurement service with discovery service.
"""
"""Client for accessing the MeasurementLink discovery service."""
import json
import logging
import os
import pathlib
import subprocess
import sys
import threading
import time
import typing
from typing import Any, Dict, Optional

import grpc
from deprecation import deprecated

from ni_measurementlink_service._channelpool import GrpcChannelPool
from ni_measurementlink_service._internal.stubs.ni.measurementlink.discovery.v1 import (
discovery_service_pb2,
discovery_service_pb2_grpc,
)
from ni_measurementlink_service._loggers import ClientLogger
from ni_measurementlink_service.measurement.info import MeasurementInfo, ServiceInfo

if sys.platform == "win32":
Expand Down Expand Up @@ -58,13 +58,20 @@ class DiscoveryClient:
"""Client for accessing the MeasurementLink discovery service."""

def __init__(
self, stub: Optional[discovery_service_pb2_grpc.DiscoveryServiceStub] = None
self,
stub: Optional[discovery_service_pb2_grpc.DiscoveryServiceStub] = None,
*,
grpc_channel_pool: Optional[GrpcChannelPool] = None,
) -> None:
"""Initialize the discovery client.

Args:
stub: An optional discovery service gRPC stub for unit testing.

grpc_channel_pool: An optional gRPC channel pool (recommended).
"""
self._initialization_lock = threading.Lock()
self._grpc_channel_pool = grpc_channel_pool
self._stub = stub
self._registration_id = ""

Expand All @@ -88,11 +95,14 @@ def stub(self) -> discovery_service_pb2_grpc.DiscoveryServiceStub:

def _get_stub(self) -> discovery_service_pb2_grpc.DiscoveryServiceStub:
if self._stub is None:
address = _get_discovery_service_address()
channel = grpc.insecure_channel(address)
if ClientLogger.is_enabled():
channel = grpc.intercept_channel(channel, ClientLogger())
self._stub = discovery_service_pb2_grpc.DiscoveryServiceStub(channel)
with self._initialization_lock:
if self._grpc_channel_pool is None:
_logger.debug("Creating unshared GrpcChannelPool.")
self._grpc_channel_pool = GrpcChannelPool()
if self._stub is None:
address = _get_discovery_service_address()
channel = self._grpc_channel_pool.get_channel(address)
self._stub = discovery_service_pb2_grpc.DiscoveryServiceStub(channel)
return self._stub

@deprecated(deprecated_in="1.2.0-dev2", details="Use register_service instead.")
Expand Down
115 changes: 90 additions & 25 deletions ni_measurementlink_service/measurement/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import sys
import threading
from enum import Enum, EnumMeta
from os import path
from pathlib import Path
Expand All @@ -22,14 +23,18 @@
)

import grpc
from deprecation import deprecated
from google.protobuf.descriptor import EnumDescriptor

from ni_measurementlink_service import _datatypeinfo
from ni_measurementlink_service._channelpool import ( # re-export
GrpcChannelPool as GrpcChannelPool,
)
from ni_measurementlink_service._internal import grpc_servicer
from ni_measurementlink_service._internal.discovery_client import DiscoveryClient
from ni_measurementlink_service._internal.discovery_client import (
DiscoveryClient,
ServiceLocation,
)
from ni_measurementlink_service._internal.parameter import (
metadata as parameter_metadata,
)
Expand Down Expand Up @@ -108,12 +113,6 @@ class MeasurementService:
measure_function (Callable): Registered measurement function.

context (MeasurementContext): Accessor for context-local state.

discovery_client (DiscoveryClient): Client for accessing the MeasurementLink discovery
service.

channel_pool (GrpcChannelPool): Pool of gRPC channels used by the service.

"""

def __init__(
Expand Down Expand Up @@ -183,10 +182,54 @@ def convert_value_to_str(value: object) -> str:

self.configuration_parameter_list: List[Any] = []
self.output_parameter_list: List[Any] = []
self.grpc_service = GrpcService()
self.context: MeasurementContext = MeasurementContext()
self.channel_pool: GrpcChannelPool = GrpcChannelPool()
self.discovery_client: DiscoveryClient = DiscoveryClient()
self.context = MeasurementContext()
bkeryan marked this conversation as resolved.
Show resolved Hide resolved

self._initialization_lock = threading.RLock()
bkeryan marked this conversation as resolved.
Show resolved Hide resolved
self._channel_pool: Optional[GrpcChannelPool] = None
self._discovery_client: Optional[DiscoveryClient] = None
self._grpc_service: Optional[GrpcService] = None

@property
def channel_pool(self) -> GrpcChannelPool:
"""Pool of gRPC channels used by the service."""
if self._channel_pool is None:
with self._initialization_lock:
if self._channel_pool is None:
self._channel_pool = GrpcChannelPool()
return self._channel_pool

@property
def discovery_client(self) -> DiscoveryClient:
"""Client for accessing the MeasurementLink discovery service."""
if self._discovery_client is None:
with self._initialization_lock:
if self._discovery_client is None:
self._discovery_client = DiscoveryClient(grpc_channel_pool=self.channel_pool)
return self._discovery_client

@property
@deprecated(
deprecated_in="1.3.0-dev0",
details="This property should not be public and will be removed in a later release.",
)
def grpc_service(self) -> Optional[GrpcService]:
"""The gRPC service object. This is a private implementation detail."""
return self._grpc_service

@property
def service_location(self) -> ServiceLocation:
"""The location of the service on the network."""
with self._initialization_lock:
if self._grpc_service is None:
raise RuntimeError(
"Measurement service not running. Call host_service() before querying the service_location."
)

return ServiceLocation(
location="localhost",
insecure_port=self._grpc_service.port,
ssl_authenticated_port="",
)

def register_measurement(self, measurement_function: _F) -> _F:
"""Register a function as the measurement function for a measurement service.
Expand Down Expand Up @@ -328,7 +371,7 @@ def _output(func: _F) -> _F:
return _output

def host_service(self) -> MeasurementService:
"""Host the registered measurement method as gRPC measurement service.
"""Host the registered measurement method as a gRPC measurement service.

Returns
-------
Expand All @@ -340,16 +383,23 @@ def host_service(self) -> MeasurementService:
Exception: If register measurement methods not available.

"""
if self.measure_function is None:
raise Exception("Error, must register measurement method.")
self.grpc_service.start(
self.measurement_info,
self.service_info,
self.configuration_parameter_list,
self.output_parameter_list,
self.measure_function,
)
return self
with self._initialization_lock:
if self.measure_function is None:
raise RuntimeError(
"Measurement method not registered. Use the register_measurement decorator to register it."
)
if self._grpc_service is not None:
raise RuntimeError("Measurement service already running.")

self._grpc_service = GrpcService(self.discovery_client)
self._grpc_service.start(
self.measurement_info,
self.service_info,
self.configuration_parameter_list,
self.output_parameter_list,
self.measure_function,
)
return self

def _make_annotations_dict(
self,
Expand Down Expand Up @@ -396,9 +446,24 @@ def _is_protobuf_enum(self, enum_type: SupportedEnumType) -> TypeGuard[_EnumType
return isinstance(getattr(enum_type, "DESCRIPTOR", None), EnumDescriptor)

def close_service(self) -> None:
"""Close the Service after un-registering with discovery service and cleanups."""
self.grpc_service.stop()
self.channel_pool.close()
"""Stop the gRPC measurement service.

This method stops the gRPC server, unregisters with the discovery service, and cleans up
the cached discovery client and gRPC channel pool.

After calling close_service(), you may call host_service() again.

Exiting the measurement service's runtime context automatically calls close_service().
"""
with self._initialization_lock:
if self._grpc_service is not None:
self._grpc_service.stop()
if self._channel_pool is not None:
self._channel_pool.close()

self._grpc_service = None
self._channel_pool = None
self._discovery_client = None

def __enter__(self: Self) -> Self:
"""Enter the runtime context related to the measurement service."""
Expand Down
79 changes: 62 additions & 17 deletions ni_measurementlink_service/session_management.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Contains methods related to managing driver sessions."""
"""Client for accessing the MeasurementLink session management service."""
from __future__ import annotations

import abc
import logging
import sys
import threading
import warnings
from functools import cached_property
from types import TracebackType
Expand All @@ -21,6 +23,8 @@
import grpc
from deprecation import DeprecatedWarning

from ni_measurementlink_service._channelpool import GrpcChannelPool
from ni_measurementlink_service._internal.discovery_client import DiscoveryClient
from ni_measurementlink_service._internal.stubs import session_pb2
from ni_measurementlink_service._internal.stubs.ni.measurementlink import (
pin_map_context_pb2,
Expand All @@ -36,6 +40,8 @@
else:
from typing_extensions import Self

_logger = logging.getLogger(__name__)

GRPC_SERVICE_INTERFACE_NAME = "ni.measurementlink.sessionmanagement.v1.SessionManagementService"
GRPC_SERVICE_CLASS = "ni.measurementlink.sessionmanagement.v1.SessionManagementService"

Expand Down Expand Up @@ -249,13 +255,57 @@ def __getattr__(name: str) -> Any:


class SessionManagementClient(object):
"""Class that manages driver sessions."""
"""Client for accessing the MeasurementLink session management service."""

def __init__(self, *, grpc_channel: grpc.Channel) -> None:
"""Initialize session manangement client."""
self._client: session_management_service_pb2_grpc.SessionManagementServiceStub = (
session_management_service_pb2_grpc.SessionManagementServiceStub(grpc_channel)
)
def __init__(
self,
*,
discovery_client: Optional[DiscoveryClient] = None,
grpc_channel: Optional[grpc.Channel] = None,
grpc_channel_pool: Optional[GrpcChannelPool] = None,
) -> None:
"""Initialize session management client.

Args:
discovery_client: An optional discovery client (recommended).

grpc_channel: An optional session management gRPC channel.

grpc_channel_pool: An optional gRPC channel pool (recommended).
"""
self._initialization_lock = threading.Lock()
self._discovery_client = discovery_client
self._grpc_channel_pool = grpc_channel_pool
self._stub: Optional[
session_management_service_pb2_grpc.SessionManagementServiceStub
] = None

if grpc_channel is not None:
self._stub = session_management_service_pb2_grpc.SessionManagementServiceStub(
grpc_channel
)

def _get_stub(self) -> session_management_service_pb2_grpc.SessionManagementServiceStub:
if self._stub is None:
with self._initialization_lock:
if self._grpc_channel_pool is None:
_logger.debug("Creating unshared GrpcChannelPool.")
self._grpc_channel_pool = GrpcChannelPool()
if self._discovery_client is None:
_logger.debug("Creating unshared DiscoveryClient.")
self._discovery_client = DiscoveryClient(
grpc_channel_pool=self._grpc_channel_pool
)
if self._stub is None:
service_location = self._discovery_client.resolve_service(
provided_interface=GRPC_SERVICE_INTERFACE_NAME,
service_class=GRPC_SERVICE_CLASS,
)
channel = self._grpc_channel_pool.get_channel(service_location.insecure_address)
self._stub = session_management_service_pb2_grpc.SessionManagementServiceStub(
channel
)
return self._stub

def reserve_session(
self,
Expand Down Expand Up @@ -391,18 +441,15 @@ def _reserve_sessions(
timeout_in_ms = -1
request.timeout_in_milliseconds = timeout_in_ms

response: session_management_service_pb2.ReserveSessionsResponse = (
self._client.ReserveSessions(request)
)

response = self._get_stub().ReserveSessions(request)
return response.sessions

def _unreserve_sessions(
self, session_info: Iterable[session_management_service_pb2.SessionInformation]
) -> None:
"""Unreserves sessions so they can be accessed by other clients."""
request = session_management_service_pb2.UnreserveSessionsRequest(sessions=session_info)
self._client.UnreserveSessions(request)
self._get_stub().UnreserveSessions(request)

def register_sessions(self, session_info: Iterable[SessionInformation]) -> None:
"""Register the sessions with the Session Manager.
Expand Down Expand Up @@ -439,7 +486,7 @@ def register_sessions(self, session_info: Iterable[SessionInformation]) -> None:
for info in session_info
)
)
self._client.RegisterSessions(request)
self._get_stub().RegisterSessions(request)

def unregister_sessions(self, session_info: Iterable[SessionInformation]) -> None:
"""Unregisters the sessions from the Session Manager.
Expand Down Expand Up @@ -471,7 +518,7 @@ def unregister_sessions(self, session_info: Iterable[SessionInformation]) -> Non
for info in session_info
)
)
self._client.UnregisterSessions(request)
self._get_stub().UnregisterSessions(request)

def reserve_all_registered_sessions(
self, instrument_type_id: Optional[str] = None, timeout: Optional[float] = None
Expand Down Expand Up @@ -514,9 +561,7 @@ def reserve_all_registered_sessions(
timeout_in_ms = -1
request.timeout_in_milliseconds = timeout_in_ms

response: session_management_service_pb2.ReserveAllRegisteredSessionsResponse = (
self._client.ReserveAllRegisteredSessions(request)
)
response = self._get_stub().ReserveAllRegisteredSessions(request)
return MultiSessionReservation(session_manager=self, session_info=response.sessions)


Expand Down
Loading
Loading