Skip to content

Commit

Permalink
Add a lock for GRPC calls to prevent corruption and exceptions on gat…
Browse files Browse the repository at this point in the history
…eway restart.

Fixes ceph#255

Signed-off-by: Gil Bregman <[email protected]>
  • Loading branch information
gbregman committed Oct 11, 2023
1 parent d24a890 commit f8b147d
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 45 deletions.
114 changes: 100 additions & 14 deletions control/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ class GatewayService(pb2_grpc.GatewayServicer):
spdk_rpc_client: Client of SPDK RPC server
"""

def __init__(self, config, gateway_state, spdk_rpc_client) -> None:
def __init__(self, config, gateway_state, spdk_rpc_client, rpc_lock) -> None:
"""Constructor"""
self.logger = logging.getLogger(__name__)
ver = os.getenv("NVMEOF_VERSION")
if ver:
self.logger.info(f"Using NVMeoF gateway version {ver}")
self.config = config
self.logger.info(f"Using configuration file {config.filepath}")
self.rpc_lock = rpc_lock
self.gateway_state = gateway_state
self.spdk_rpc_client = spdk_rpc_client
self.gateway_name = self.config.get("gateway", "name")
Expand Down Expand Up @@ -91,7 +92,7 @@ def _alloc_cluster(self) -> str:
)
return name

def create_bdev(self, request, context=None):
def create_bdev_safe(self, request, context=None):
"""Creates a bdev from an RBD image."""

if not request.uuid:
Expand Down Expand Up @@ -132,13 +133,22 @@ def create_bdev(self, request, context=None):

return pb2.bdev(bdev_name=bdev_name, status=True)

def delete_bdev(self, request, context=None):
def create_bdev(self, request, context=None):
if context:
with self.rpc_lock:
return self.create_bdev_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.create_bdev_safe(request, context)

def delete_bdev_safe(self, request, context=None):
"""Deletes a bdev."""

self.logger.info(f"Received request to delete bdev {request.bdev_name}")
use_excep = None
req_get_subsystems = pb2.get_subsystems_req()
ret = self.get_subsystems(req_get_subsystems, context)
# We already hold the lock, so call the safe version, do not try lock again
ret = self.get_subsystems_safe(req_get_subsystems, context)
subsystems = json.loads(ret.subsystems)
for subsystem in subsystems:
for namespace in subsystem['namespaces']:
Expand All @@ -149,7 +159,7 @@ def delete_bdev(self, request, context=None):
self.logger.info(f"Will remove namespace {namespace['nsid']} from {subsystem['nqn']} as it is using bdev {request.bdev_name}")
try:
req_rm_ns = pb2.remove_namespace_req(subsystem_nqn=subsystem['nqn'], nsid=namespace['nsid'])
ret = self.remove_namespace(req_rm_ns, context)
ret = self.remove_namespace_safe(req_rm_ns, context)
self.logger.info(
f"Removed namespace {namespace['nsid']} from {subsystem['nqn']}: {ret.status}")
except Exception as ex:
Expand Down Expand Up @@ -191,7 +201,15 @@ def delete_bdev(self, request, context=None):

return pb2.req_status(status=ret)

def create_subsystem(self, request, context=None):
def delete_bdev(self, request, context=None):
if context:
with self.rpc_lock:
return self.delete_bdev_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.delete_bdev_safe(request, context)

def create_subsystem_safe(self, request, context=None):
"""Creates a subsystem."""

self.logger.info(
Expand Down Expand Up @@ -233,7 +251,15 @@ def create_subsystem(self, request, context=None):

return pb2.req_status(status=ret)

def delete_subsystem(self, request, context=None):
def create_subsystem(self, request, context=None):
if context:
with self.rpc_lock:
return self.create_subsystem_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.create_subsystem_safe(request, context)

def delete_subsystem_safe(self, request, context=None):
"""Deletes a subsystem."""

self.logger.info(
Expand Down Expand Up @@ -262,7 +288,15 @@ def delete_subsystem(self, request, context=None):

return pb2.req_status(status=ret)

def add_namespace(self, request, context=None):
def delete_subsystem(self, request, context=None):
if context:
with self.rpc_lock:
return self.delete_subsystem_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.delete_subsystem_safe(request, context)

def add_namespace_safe(self, request, context=None):
"""Adds a namespace to a subsystem."""

self.logger.info(f"Received request to add {request.bdev_name} to"
Expand Down Expand Up @@ -298,7 +332,15 @@ def add_namespace(self, request, context=None):

return pb2.nsid(nsid=nsid, status=True)

def remove_namespace(self, request, context=None):
def add_namespace(self, request, context=None):
if context:
with self.rpc_lock:
return self.add_namespace_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.add_namespace_safe(request, context)

def remove_namespace_safe(self, request, context=None):
"""Removes a namespace from a subsystem."""

self.logger.info(f"Received request to remove {request.nsid} from"
Expand Down Expand Up @@ -329,7 +371,15 @@ def remove_namespace(self, request, context=None):

return pb2.req_status(status=ret)

def add_host(self, request, context=None):
def remove_namespace(self, request, context=None):
if context:
with self.rpc_lock:
return self.remove_namespace_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.remove_namespace_safe(request, context)

def add_host_safe(self, request, context=None):
"""Adds a host to a subsystem."""

try:
Expand Down Expand Up @@ -373,7 +423,15 @@ def add_host(self, request, context=None):

return pb2.req_status(status=ret)

def remove_host(self, request, context=None):
def add_host(self, request, context=None):
if context:
with self.rpc_lock:
return self.add_host_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.add_host_safe(request, context)

def remove_host_safe(self, request, context=None):
"""Removes a host from a subsystem."""

try:
Expand Down Expand Up @@ -415,7 +473,15 @@ def remove_host(self, request, context=None):

return pb2.req_status(status=ret)

def create_listener(self, request, context=None):
def remove_host(self, request, context=None):
if context:
with self.rpc_lock:
return self.remove_host_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.remove_host_safe(request, context)

def create_listener_safe(self, request, context=None):
"""Creates a listener for a subsystem at a given IP/Port."""

ret = True
Expand Down Expand Up @@ -459,7 +525,15 @@ def create_listener(self, request, context=None):

return pb2.req_status(status=ret)

def delete_listener(self, request, context=None):
def create_listener(self, request, context=None):
if context:
with self.rpc_lock:
return self.create_listener_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.create_listener_safe(request, context)

def delete_listener_safe(self, request, context=None):
"""Deletes a listener from a subsystem at a given IP/Port."""

ret = True
Expand Down Expand Up @@ -502,7 +576,15 @@ def delete_listener(self, request, context=None):

return pb2.req_status(status=ret)

def get_subsystems(self, request, context):
def delete_listener(self, request, context=None):
if context:
with self.rpc_lock:
return self.delete_listener_safe(request, context)
else:
self.rpc_lock.raise_exception_if_not_locked()
return self.delete_listener_safe(request, context)

def get_subsystems_safe(self, request, context):
"""Gets subsystems."""

self.logger.info(f"Received request to get subsystems")
Expand All @@ -516,3 +598,7 @@ def get_subsystems(self, request, context):
return pb2.subsystems_info()

return pb2.subsystems_info(subsystems=json.dumps(ret))

def get_subsystems(self, request, context):
with self.rpc_lock:
return self.get_subsystems_safe(request, context)
96 changes: 65 additions & 31 deletions control/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import logging
import signal
import threading
from concurrent import futures
from google.protobuf import json_format

Expand Down Expand Up @@ -59,9 +60,37 @@ class GatewayServer:
discovery_pid: Subprocess running Ceph nvmeof discovery service
"""

class RPCGuard:
RPC_GUARD_LOCK_TIMEOUT = 300

def __init__(self, logger, timeout = None) -> None:
self.rpc_lock = threading.Lock()
self.lock_timeout = timeout
self.logger = logger

def __enter__(self):
rc = self.rpc_lock.acquire(True, self.lock_timeout)
if not rc:
self.logger.warning(f"Couldn't acquire lock after {self.lock_timeout} seconds, will try again")
rc = self.rpc_lock.acquire(True, self.lock_timeout)
if not rc:
self.logger.error(f"Failed to acquire lock for guarding RPC, will continue anyway")
return self

def __exit__(self, typ, value, traceback):
if self.rpc_lock.locked():
self.rpc_lock.release()
else:
self.logger.warning(f"Asked to release an unlocked RPC guard, ignore")

def raise_exception_if_not_locked(self):
if not self.rpc_lock.locked():
raise Exception("RPC guard is not locked like it should be")

def __init__(self, config):
self.logger = logging.getLogger(__name__)
self.config = config
self.rpc_lock = GatewayServer.RPCGuard(self.logger, GatewayServer.RPCGuard.RPC_GUARD_LOCK_TIMEOUT)
self.spdk_process = None
self.gateway_rpc = None
self.server = None
Expand Down Expand Up @@ -113,7 +142,7 @@ def serve(self):
gateway_state = GatewayStateHandler(self.config, local_state,
omap_state, self.gateway_rpc_caller)
self.gateway_rpc = GatewayService(self.config, gateway_state,
self.spdk_rpc_client)
self.spdk_rpc_client, self.rpc_lock)
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
pb2_grpc.add_GatewayServicer_to_server(self.gateway_rpc, self.server)

Expand Down Expand Up @@ -330,43 +359,48 @@ def gateway_rpc_caller(self, requests, is_add_req):
"""Passes RPC requests to gateway service."""
for key, val in requests.items():
if key.startswith(GatewayState.BDEV_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.create_bdev_req())
self.gateway_rpc.create_bdev(req)
else:
req = json_format.Parse(val,
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.create_bdev_req())
self.gateway_rpc.create_bdev(req)
else:
req = json_format.Parse(val,
pb2.delete_bdev_req(),
ignore_unknown_fields=True)
self.gateway_rpc.delete_bdev(req)
self.gateway_rpc.delete_bdev(req)
elif key.startswith(GatewayState.SUBSYSTEM_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.create_subsystem_req())
self.gateway_rpc.create_subsystem(req)
else:
req = json_format.Parse(val,
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.create_subsystem_req())
self.gateway_rpc.create_subsystem(req)
else:
req = json_format.Parse(val,
pb2.delete_subsystem_req(),
ignore_unknown_fields=True)
self.gateway_rpc.delete_subsystem(req)
self.gateway_rpc.delete_subsystem(req)
elif key.startswith(GatewayState.NAMESPACE_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.add_namespace_req())
self.gateway_rpc.add_namespace(req)
else:
req = json_format.Parse(val,
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.add_namespace_req())
self.gateway_rpc.add_namespace(req)
else:
req = json_format.Parse(val,
pb2.remove_namespace_req(),
ignore_unknown_fields=True)
self.gateway_rpc.remove_namespace(req)
self.gateway_rpc.remove_namespace(req)
elif key.startswith(GatewayState.HOST_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.add_host_req())
self.gateway_rpc.add_host(req)
else:
req = json_format.Parse(val, pb2.remove_host_req())
self.gateway_rpc.remove_host(req)
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.add_host_req())
self.gateway_rpc.add_host(req)
else:
req = json_format.Parse(val, pb2.remove_host_req())
self.gateway_rpc.remove_host(req)
elif key.startswith(GatewayState.LISTENER_PREFIX):
if is_add_req:
req = json_format.Parse(val, pb2.create_listener_req())
self.gateway_rpc.create_listener(req)
else:
req = json_format.Parse(val, pb2.delete_listener_req())
self.gateway_rpc.delete_listener(req)
with self.rpc_lock:
if is_add_req:
req = json_format.Parse(val, pb2.create_listener_req())
self.gateway_rpc.create_listener(req)
else:
req = json_format.Parse(val, pb2.delete_listener_req())
self.gateway_rpc.delete_listener(req)

0 comments on commit f8b147d

Please sign in to comment.