Skip to content

Commit

Permalink
Add a lock for GRPC calls to prevent corruption and exceptions on gat…
Browse files Browse the repository at this point in the history
…eway restart.

Fixes #255

Signed-off-by: Gil Bregman <[email protected]>
  • Loading branch information
gbregman committed Oct 15, 2023
1 parent d24a890 commit 052374e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-container.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
strategy:
fail-fast: false
matrix:
test: ["cli", "state", "multi_gateway", "server"]
test: ["cli", "state", "multi_gateway", "server", "grpc"]
runs-on: ubuntu-latest
env:
HUGEPAGES: 512 # for multi gateway test, approx 256 per gateway instance
Expand Down
97 changes: 84 additions & 13 deletions control/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import random
import logging
import os
import threading

import spdk.rpc.bdev as rpc_bdev
import spdk.rpc.nvmf as rpc_nvmf
Expand All @@ -36,6 +37,29 @@ class GatewayService(pb2_grpc.GatewayServicer):
spdk_rpc_client: Client of SPDK RPC server
"""

class RPCGuard:
RPC_GUARD_LOCK_TIMEOUT = 300 # Default timeout waiting for the lock to be acquired, in seconds

def __init__(self, logger, timeout = None) -> None:
self.rpc_lock = threading.Lock()
self.lock_timeout = timeout if timeout != None else self.RPC_GUARD_LOCK_TIMEOUT
self.logger = logger

def __enter__(self):
rc = self.rpc_lock.acquire(True, self.lock_timeout)
if not rc:
self.logger.warning(f"Couldn't acquire lock after {self.lock_timeout} seconds, will try again")
rc = self.rpc_lock.acquire(True, self.lock_timeout)
if not rc:
self.logger.error(f"Failed to acquire lock for guarding RPC, will continue anyway")
return self

def __exit__(self, typ, value, traceback):
if self.rpc_lock.locked():
self.rpc_lock.release()
else:
self.logger.warning(f"Asked to release an unlocked RPC guard, ignore")

def __init__(self, config, gateway_state, spdk_rpc_client) -> None:
"""Constructor"""
self.logger = logging.getLogger(__name__)
Expand All @@ -44,6 +68,7 @@ def __init__(self, config, gateway_state, spdk_rpc_client) -> None:
self.logger.info(f"Using NVMeoF gateway version {ver}")
self.config = config
self.logger.info(f"Using configuration file {config.filepath}")
self.rpc_lock = GatewayService.RPCGuard(self.logger)
self.gateway_state = gateway_state
self.spdk_rpc_client = spdk_rpc_client
self.gateway_name = self.config.get("gateway", "name")
Expand Down Expand Up @@ -91,7 +116,7 @@ def _alloc_cluster(self) -> str:
)
return name

def create_bdev(self, request, context=None):
def create_bdev_safe(self, request, context=None):
"""Creates a bdev from an RBD image."""

if not request.uuid:
Expand Down Expand Up @@ -132,13 +157,18 @@ def create_bdev(self, request, context=None):

return pb2.bdev(bdev_name=bdev_name, status=True)

def delete_bdev(self, request, context=None):
def create_bdev(self, request, context=None):
with self.rpc_lock:
return self.create_bdev_safe(request, context)

def delete_bdev_safe(self, request, context=None):
"""Deletes a bdev."""

self.logger.info(f"Received request to delete bdev {request.bdev_name}")
use_excep = None
req_get_subsystems = pb2.get_subsystems_req()
ret = self.get_subsystems(req_get_subsystems, context)
# We already hold the lock, so call the safe version, do not try lock again
ret = self.get_subsystems_safe(req_get_subsystems, context)
subsystems = json.loads(ret.subsystems)
for subsystem in subsystems:
for namespace in subsystem['namespaces']:
Expand All @@ -149,7 +179,8 @@ def delete_bdev(self, request, context=None):
self.logger.info(f"Will remove namespace {namespace['nsid']} from {subsystem['nqn']} as it is using bdev {request.bdev_name}")
try:
req_rm_ns = pb2.remove_namespace_req(subsystem_nqn=subsystem['nqn'], nsid=namespace['nsid'])
ret = self.remove_namespace(req_rm_ns, context)
# We already hold the lock, so call the safe version, do not try lock again
ret = self.remove_namespace_safe(req_rm_ns, context)
self.logger.info(
f"Removed namespace {namespace['nsid']} from {subsystem['nqn']}: {ret.status}")
except Exception as ex:
Expand Down Expand Up @@ -191,7 +222,11 @@ def delete_bdev(self, request, context=None):

return pb2.req_status(status=ret)

def create_subsystem(self, request, context=None):
def delete_bdev(self, request, context=None):
with self.rpc_lock:
return self.delete_bdev_safe(request, context)

def create_subsystem_safe(self, request, context=None):
"""Creates a subsystem."""

self.logger.info(
Expand Down Expand Up @@ -233,7 +268,11 @@ def create_subsystem(self, request, context=None):

return pb2.req_status(status=ret)

def delete_subsystem(self, request, context=None):
def create_subsystem(self, request, context=None):
with self.rpc_lock:
return self.create_subsystem_safe(request, context)

def delete_subsystem_safe(self, request, context=None):
"""Deletes a subsystem."""

self.logger.info(
Expand Down Expand Up @@ -262,7 +301,11 @@ def delete_subsystem(self, request, context=None):

return pb2.req_status(status=ret)

def add_namespace(self, request, context=None):
def delete_subsystem(self, request, context=None):
with self.rpc_lock:
return self.delete_subsystem_safe(request, context)

def add_namespace_safe(self, request, context=None):
"""Adds a namespace to a subsystem."""

self.logger.info(f"Received request to add {request.bdev_name} to"
Expand Down Expand Up @@ -298,7 +341,11 @@ def add_namespace(self, request, context=None):

return pb2.nsid(nsid=nsid, status=True)

def remove_namespace(self, request, context=None):
def add_namespace(self, request, context=None):
with self.rpc_lock:
return self.add_namespace_safe(request, context)

def remove_namespace_safe(self, request, context=None):
"""Removes a namespace from a subsystem."""

self.logger.info(f"Received request to remove {request.nsid} from"
Expand Down Expand Up @@ -329,7 +376,11 @@ def remove_namespace(self, request, context=None):

return pb2.req_status(status=ret)

def add_host(self, request, context=None):
def remove_namespace(self, request, context=None):
with self.rpc_lock:
return self.remove_namespace_safe(request, context)

def add_host_safe(self, request, context=None):
"""Adds a host to a subsystem."""

try:
Expand Down Expand Up @@ -373,7 +424,11 @@ def add_host(self, request, context=None):

return pb2.req_status(status=ret)

def remove_host(self, request, context=None):
def add_host(self, request, context=None):
with self.rpc_lock:
return self.add_host_safe(request, context)

def remove_host_safe(self, request, context=None):
"""Removes a host from a subsystem."""

try:
Expand Down Expand Up @@ -415,7 +470,11 @@ def remove_host(self, request, context=None):

return pb2.req_status(status=ret)

def create_listener(self, request, context=None):
def remove_host(self, request, context=None):
with self.rpc_lock:
return self.remove_host_safe(request, context)

def create_listener_safe(self, request, context=None):
"""Creates a listener for a subsystem at a given IP/Port."""

ret = True
Expand Down Expand Up @@ -459,7 +518,11 @@ def create_listener(self, request, context=None):

return pb2.req_status(status=ret)

def delete_listener(self, request, context=None):
def create_listener(self, request, context=None):
with self.rpc_lock:
return self.create_listener_safe(request, context)

def delete_listener_safe(self, request, context=None):
"""Deletes a listener from a subsystem at a given IP/Port."""

ret = True
Expand Down Expand Up @@ -502,7 +565,11 @@ def delete_listener(self, request, context=None):

return pb2.req_status(status=ret)

def get_subsystems(self, request, context):
def delete_listener(self, request, context=None):
with self.rpc_lock:
return self.delete_listener_safe(request, context)

def get_subsystems_safe(self, request, context):
"""Gets subsystems."""

self.logger.info(f"Received request to get subsystems")
Expand All @@ -516,3 +583,7 @@ def get_subsystems(self, request, context):
return pb2.subsystems_info()

return pb2.subsystems_info(subsystems=json.dumps(ret))

def get_subsystems(self, request, context):
with self.rpc_lock:
return self.get_subsystems_safe(request, context)
49 changes: 49 additions & 0 deletions tests/test_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import time
from control.server import GatewayServer
from control.cli import main as cli
import logging
import warnings

# Set up a logger
logger = logging.getLogger(__name__)
image = "mytestdevimage"
pool = "rbd"
bdev_prefix = "Ceph0"
subsystem_prefix = "nqn.2016-06.io.spdk:cnode"
created_resource_count = 300
get_subsys_count = 100

def create_resource_by_index(i):
bdev = f"{bdev_prefix}_{i}"
cli(["create_bdev", "-i", image, "-p", pool, "-b", bdev])
subsystem = f"{subsystem_prefix}{i}"
cli(["create_subsystem", "-n", subsystem ])
cli(["add_namespace", "-n", subsystem, "-b", bdev])

@pytest.mark.filterwarnings("error::pytest.PytestUnhandledThreadExceptionWarning")
def test_create_get_subsys(caplog, config):
with GatewayServer(config) as gateway:
time.sleep(1)
gateway.serve()

for i in range(created_resource_count):
create_resource_by_index(i)
assert "Failed" not in caplog.text

gateway.server.stop(grace=1)

time.sleep(2)
caplog.clear()

# restart the gateway here
with GatewayServer(config) as gateway:
time.sleep(1)
gateway.serve()

for i in range(get_subsys_count):
cli(["get_subsystems"])
assert "Exception" not in caplog.text
time.sleep(1)

time.sleep(10)

0 comments on commit 052374e

Please sign in to comment.