From f8b147d2bc534b6be6cf696a95ca35518091193e Mon Sep 17 00:00:00 2001 From: Gil Bregman Date: Tue, 10 Oct 2023 09:08:10 -0400 Subject: [PATCH] Add a lock for GRPC calls to prevent corruption and exceptions on gateway restart. Fixes #255 Signed-off-by: Gil Bregman --- control/grpc.py | 114 ++++++++++++++++++++++++++++++++++++++++------ control/server.py | 96 +++++++++++++++++++++++++------------- 2 files changed, 165 insertions(+), 45 deletions(-) diff --git a/control/grpc.py b/control/grpc.py index f85c39b8e..69608b738 100644 --- a/control/grpc.py +++ b/control/grpc.py @@ -36,7 +36,7 @@ class GatewayService(pb2_grpc.GatewayServicer): spdk_rpc_client: Client of SPDK RPC server """ - def __init__(self, config, gateway_state, spdk_rpc_client) -> None: + def __init__(self, config, gateway_state, spdk_rpc_client, rpc_lock) -> None: """Constructor""" self.logger = logging.getLogger(__name__) ver = os.getenv("NVMEOF_VERSION") @@ -44,6 +44,7 @@ def __init__(self, config, gateway_state, spdk_rpc_client) -> None: self.logger.info(f"Using NVMeoF gateway version {ver}") self.config = config self.logger.info(f"Using configuration file {config.filepath}") + self.rpc_lock = rpc_lock self.gateway_state = gateway_state self.spdk_rpc_client = spdk_rpc_client self.gateway_name = self.config.get("gateway", "name") @@ -91,7 +92,7 @@ def _alloc_cluster(self) -> str: ) return name - def create_bdev(self, request, context=None): + def create_bdev_safe(self, request, context=None): """Creates a bdev from an RBD image.""" if not request.uuid: @@ -132,13 +133,22 @@ def create_bdev(self, request, context=None): return pb2.bdev(bdev_name=bdev_name, status=True) - def delete_bdev(self, request, context=None): + def create_bdev(self, request, context=None): + if context: + with self.rpc_lock: + return self.create_bdev_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.create_bdev_safe(request, context) + + def delete_bdev_safe(self, request, context=None): """Deletes a bdev.""" self.logger.info(f"Received request to delete bdev {request.bdev_name}") use_excep = None req_get_subsystems = pb2.get_subsystems_req() - ret = self.get_subsystems(req_get_subsystems, context) + # We already hold the lock, so call the safe version, do not try lock again + ret = self.get_subsystems_safe(req_get_subsystems, context) subsystems = json.loads(ret.subsystems) for subsystem in subsystems: for namespace in subsystem['namespaces']: @@ -149,7 +159,7 @@ def delete_bdev(self, request, context=None): self.logger.info(f"Will remove namespace {namespace['nsid']} from {subsystem['nqn']} as it is using bdev {request.bdev_name}") try: req_rm_ns = pb2.remove_namespace_req(subsystem_nqn=subsystem['nqn'], nsid=namespace['nsid']) - ret = self.remove_namespace(req_rm_ns, context) + ret = self.remove_namespace_safe(req_rm_ns, context) self.logger.info( f"Removed namespace {namespace['nsid']} from {subsystem['nqn']}: {ret.status}") except Exception as ex: @@ -191,7 +201,15 @@ def delete_bdev(self, request, context=None): return pb2.req_status(status=ret) - def create_subsystem(self, request, context=None): + def delete_bdev(self, request, context=None): + if context: + with self.rpc_lock: + return self.delete_bdev_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.delete_bdev_safe(request, context) + + def create_subsystem_safe(self, request, context=None): """Creates a subsystem.""" self.logger.info( @@ -233,7 +251,15 @@ def create_subsystem(self, request, context=None): return pb2.req_status(status=ret) - def delete_subsystem(self, request, context=None): + def create_subsystem(self, request, context=None): + if context: + with self.rpc_lock: + return self.create_subsystem_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.create_subsystem_safe(request, context) + + def delete_subsystem_safe(self, request, context=None): """Deletes a subsystem.""" self.logger.info( @@ -262,7 +288,15 @@ def delete_subsystem(self, request, context=None): return pb2.req_status(status=ret) - def add_namespace(self, request, context=None): + def delete_subsystem(self, request, context=None): + if context: + with self.rpc_lock: + return self.delete_subsystem_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.delete_subsystem_safe(request, context) + + def add_namespace_safe(self, request, context=None): """Adds a namespace to a subsystem.""" self.logger.info(f"Received request to add {request.bdev_name} to" @@ -298,7 +332,15 @@ def add_namespace(self, request, context=None): return pb2.nsid(nsid=nsid, status=True) - def remove_namespace(self, request, context=None): + def add_namespace(self, request, context=None): + if context: + with self.rpc_lock: + return self.add_namespace_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.add_namespace_safe(request, context) + + def remove_namespace_safe(self, request, context=None): """Removes a namespace from a subsystem.""" self.logger.info(f"Received request to remove {request.nsid} from" @@ -329,7 +371,15 @@ def remove_namespace(self, request, context=None): return pb2.req_status(status=ret) - def add_host(self, request, context=None): + def remove_namespace(self, request, context=None): + if context: + with self.rpc_lock: + return self.remove_namespace_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.remove_namespace_safe(request, context) + + def add_host_safe(self, request, context=None): """Adds a host to a subsystem.""" try: @@ -373,7 +423,15 @@ def add_host(self, request, context=None): return pb2.req_status(status=ret) - def remove_host(self, request, context=None): + def add_host(self, request, context=None): + if context: + with self.rpc_lock: + return self.add_host_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.add_host_safe(request, context) + + def remove_host_safe(self, request, context=None): """Removes a host from a subsystem.""" try: @@ -415,7 +473,15 @@ def remove_host(self, request, context=None): return pb2.req_status(status=ret) - def create_listener(self, request, context=None): + def remove_host(self, request, context=None): + if context: + with self.rpc_lock: + return self.remove_host_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.remove_host_safe(request, context) + + def create_listener_safe(self, request, context=None): """Creates a listener for a subsystem at a given IP/Port.""" ret = True @@ -459,7 +525,15 @@ def create_listener(self, request, context=None): return pb2.req_status(status=ret) - def delete_listener(self, request, context=None): + def create_listener(self, request, context=None): + if context: + with self.rpc_lock: + return self.create_listener_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.create_listener_safe(request, context) + + def delete_listener_safe(self, request, context=None): """Deletes a listener from a subsystem at a given IP/Port.""" ret = True @@ -502,7 +576,15 @@ def delete_listener(self, request, context=None): return pb2.req_status(status=ret) - def get_subsystems(self, request, context): + def delete_listener(self, request, context=None): + if context: + with self.rpc_lock: + return self.delete_listener_safe(request, context) + else: + self.rpc_lock.raise_exception_if_not_locked() + return self.delete_listener_safe(request, context) + + def get_subsystems_safe(self, request, context): """Gets subsystems.""" self.logger.info(f"Received request to get subsystems") @@ -516,3 +598,7 @@ def get_subsystems(self, request, context): return pb2.subsystems_info() return pb2.subsystems_info(subsystems=json.dumps(ret)) + + def get_subsystems(self, request, context): + with self.rpc_lock: + return self.get_subsystems_safe(request, context) diff --git a/control/server.py b/control/server.py index 7057c28eb..d5cf318b4 100644 --- a/control/server.py +++ b/control/server.py @@ -16,6 +16,7 @@ import json import logging import signal +import threading from concurrent import futures from google.protobuf import json_format @@ -59,9 +60,37 @@ class GatewayServer: discovery_pid: Subprocess running Ceph nvmeof discovery service """ + class RPCGuard: + RPC_GUARD_LOCK_TIMEOUT = 300 + + def __init__(self, logger, timeout = None) -> None: + self.rpc_lock = threading.Lock() + self.lock_timeout = timeout + self.logger = logger + + def __enter__(self): + rc = self.rpc_lock.acquire(True, self.lock_timeout) + if not rc: + self.logger.warning(f"Couldn't acquire lock after {self.lock_timeout} seconds, will try again") + rc = self.rpc_lock.acquire(True, self.lock_timeout) + if not rc: + self.logger.error(f"Failed to acquire lock for guarding RPC, will continue anyway") + return self + + def __exit__(self, typ, value, traceback): + if self.rpc_lock.locked(): + self.rpc_lock.release() + else: + self.logger.warning(f"Asked to release an unlocked RPC guard, ignore") + + def raise_exception_if_not_locked(self): + if not self.rpc_lock.locked(): + raise Exception("RPC guard is not locked like it should be") + def __init__(self, config): self.logger = logging.getLogger(__name__) self.config = config + self.rpc_lock = GatewayServer.RPCGuard(self.logger, GatewayServer.RPCGuard.RPC_GUARD_LOCK_TIMEOUT) self.spdk_process = None self.gateway_rpc = None self.server = None @@ -113,7 +142,7 @@ def serve(self): gateway_state = GatewayStateHandler(self.config, local_state, omap_state, self.gateway_rpc_caller) self.gateway_rpc = GatewayService(self.config, gateway_state, - self.spdk_rpc_client) + self.spdk_rpc_client, self.rpc_lock) self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) pb2_grpc.add_GatewayServicer_to_server(self.gateway_rpc, self.server) @@ -330,43 +359,48 @@ def gateway_rpc_caller(self, requests, is_add_req): """Passes RPC requests to gateway service.""" for key, val in requests.items(): if key.startswith(GatewayState.BDEV_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.create_bdev_req()) - self.gateway_rpc.create_bdev(req) - else: - req = json_format.Parse(val, + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.create_bdev_req()) + self.gateway_rpc.create_bdev(req) + else: + req = json_format.Parse(val, pb2.delete_bdev_req(), ignore_unknown_fields=True) - self.gateway_rpc.delete_bdev(req) + self.gateway_rpc.delete_bdev(req) elif key.startswith(GatewayState.SUBSYSTEM_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.create_subsystem_req()) - self.gateway_rpc.create_subsystem(req) - else: - req = json_format.Parse(val, + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.create_subsystem_req()) + self.gateway_rpc.create_subsystem(req) + else: + req = json_format.Parse(val, pb2.delete_subsystem_req(), ignore_unknown_fields=True) - self.gateway_rpc.delete_subsystem(req) + self.gateway_rpc.delete_subsystem(req) elif key.startswith(GatewayState.NAMESPACE_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.add_namespace_req()) - self.gateway_rpc.add_namespace(req) - else: - req = json_format.Parse(val, + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.add_namespace_req()) + self.gateway_rpc.add_namespace(req) + else: + req = json_format.Parse(val, pb2.remove_namespace_req(), ignore_unknown_fields=True) - self.gateway_rpc.remove_namespace(req) + self.gateway_rpc.remove_namespace(req) elif key.startswith(GatewayState.HOST_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.add_host_req()) - self.gateway_rpc.add_host(req) - else: - req = json_format.Parse(val, pb2.remove_host_req()) - self.gateway_rpc.remove_host(req) + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.add_host_req()) + self.gateway_rpc.add_host(req) + else: + req = json_format.Parse(val, pb2.remove_host_req()) + self.gateway_rpc.remove_host(req) elif key.startswith(GatewayState.LISTENER_PREFIX): - if is_add_req: - req = json_format.Parse(val, pb2.create_listener_req()) - self.gateway_rpc.create_listener(req) - else: - req = json_format.Parse(val, pb2.delete_listener_req()) - self.gateway_rpc.delete_listener(req) + with self.rpc_lock: + if is_add_req: + req = json_format.Parse(val, pb2.create_listener_req()) + self.gateway_rpc.create_listener(req) + else: + req = json_format.Parse(val, pb2.delete_listener_req()) + self.gateway_rpc.delete_listener(req)