Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading module from command line #679

Merged
merged 9 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion libs/host/Configuration/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ internal sealed class Options
[Option("extension-bin-paths", Separator = ',', Required = false, HelpText = "List of directories on server from which custom command binaries can be loaded by admin users")]
public IEnumerable<string> ExtensionBinPaths { get; set; }

[ModuleFilePathValidation(true, true, false)]
[Option("loadmodulecs", Separator = ',', Required = false, HelpText = "List of modules to be loaded")]
public IEnumerable<string> LoadModuleCS { get; set; }

[Option("extension-allow-unsigned", Required = false, HelpText = "Allow loading custom commands from digitally unsigned assemblies (not recommended)")]
public bool? ExtensionAllowUnsignedAssemblies { get; set; }

Expand Down Expand Up @@ -653,7 +657,8 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null)
ExtensionBinPaths = ExtensionBinPaths?.ToArray(),
ExtensionAllowUnsignedAssemblies = ExtensionAllowUnsignedAssemblies.GetValueOrDefault(),
IndexResizeFrequencySecs = IndexResizeFrequencySecs,
IndexResizeThreshold = IndexResizeThreshold
IndexResizeThreshold = IndexResizeThreshold,
LoadModuleCS = LoadModuleCS
};
}

Expand Down
35 changes: 35 additions & 0 deletions libs/host/Configuration/OptionsValidators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,41 @@ protected override ValidationResult IsValid(object value, ValidationContext vali
}
}

[AttributeUsage(AttributeTargets.Property)]
internal class ModuleFilePathValidationAttribute : FilePathValidationAttribute
{
internal ModuleFilePathValidationAttribute(bool fileMustExist, bool directoryMustExist, bool isRequired, string[] acceptedFileExtensions = null) : base(fileMustExist, directoryMustExist, isRequired, acceptedFileExtensions)
{
}

protected override ValidationResult IsValid(object value, ValidationContext validationContext)
{
if (TryInitialValidation<IEnumerable<string>>(value, validationContext, out var initValidationResult, out var filePaths))
return initValidationResult;

var errorSb = new StringBuilder();
var isValid = true;
foreach (var filePathArg in filePaths)
{
var filePath = filePathArg.Split(' ')[0];
var result = base.IsValid(filePath, validationContext);
if (result != null && result != ValidationResult.Success)
{
isValid = false;
errorSb.AppendLine(result.ErrorMessage);
}
}

if (!isValid)
{
var errorMessage = $"Error(s) validating one or more file paths:{Environment.NewLine}{errorSb}";
return new ValidationResult(errorMessage, [validationContext.MemberName]);
}

return ValidationResult.Success;
}
}

/// <summary>
/// Validation logic for a string representing an IP address (either IPv4 or IPv6)
/// </summary>
Expand Down
28 changes: 28 additions & 0 deletions libs/host/GarnetServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using Garnet.cluster;
using Garnet.common;
Expand Down Expand Up @@ -211,6 +213,32 @@ private void InitializeServer()
Store = new StoreApi(storeWrapper);

server.Register(WireFormat.ASCII, Provider);

LoadModules(customCommandManager);
}

private void LoadModules(CustomCommandManager customCommandManager)
{
if (opts.LoadModuleCS == null)
return;

foreach (var moduleCS in opts.LoadModuleCS)
{
var moduleCSData = moduleCS.Split(' ', StringSplitOptions.RemoveEmptyEntries);
if (moduleCSData.Length < 1)
continue;

var modulePath = moduleCSData[0];
var moduleArgs = moduleCSData.Length > 1 ? moduleCSData.Skip(1).ToArray() : [];
if (ModuleUtils.LoadAssemblies([modulePath], null, true, out var loadedAssemblies, out var errorMsg))
{
ModuleRegistrar.Instance.LoadModule(customCommandManager, loadedAssemblies.ToList()[0], moduleArgs, logger, out errorMsg);
}
else
{
logger?.LogError("Module {0} failed to load with error {1}", modulePath, Encoding.UTF8.GetString(errorMsg));
}
}
}

private void CreateMainStore(IClusterFactory clusterFactory, out string checkpointDir)
Expand Down
3 changes: 3 additions & 0 deletions libs/host/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,7 @@

/* Overflow bucket count over total index size in percentage to trigger index resize */
"IndexResizeThreshold": 50,

/* List of module paths to be loaded at startup */
"LoadModuleCS": null
}
8 changes: 4 additions & 4 deletions libs/server/Module/ModuleRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using System.Reflection;
using Microsoft.Extensions.Logging;

namespace Garnet.server.Module
namespace Garnet.server
{
/// <summary>
/// Abstract base class that all Garnet modules must inherit from.
Expand Down Expand Up @@ -171,11 +171,11 @@ public ModuleActionStatus RegisterProcedure(string name, CustomProcedure customS
}
}

internal sealed class ModuleRegistrar
public sealed class ModuleRegistrar
{
private static readonly Lazy<ModuleRegistrar> lazy = new Lazy<ModuleRegistrar>(() => new ModuleRegistrar());

internal static ModuleRegistrar Instance { get { return lazy.Value; } }
public static ModuleRegistrar Instance { get { return lazy.Value; } }

private ModuleRegistrar()
{
Expand All @@ -184,7 +184,7 @@ private ModuleRegistrar()

private readonly ConcurrentDictionary<string, ModuleLoadContext> modules;

internal bool LoadModule(CustomCommandManager customCommandManager, Assembly loadedAssembly, string[] moduleArgs, ILogger logger, out ReadOnlySpan<byte> errorMessage)
public bool LoadModule(CustomCommandManager customCommandManager, Assembly loadedAssembly, string[] moduleArgs, ILogger logger, out ReadOnlySpan<byte> errorMessage)
{
errorMessage = default;

Expand Down
82 changes: 82 additions & 0 deletions libs/server/Module/ModuleUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using Garnet.common;

namespace Garnet.server
{
public class ModuleUtils
{
public static bool LoadAssemblies(
IEnumerable<string> binaryPaths,
string[] allowedExtensionPaths,
bool allowUnsignedAssemblies,
out IEnumerable<Assembly> loadedAssemblies,
out ReadOnlySpan<byte> errorMessage)
{
loadedAssemblies = null;
errorMessage = default;

// Get all binary file paths from inputs binary paths
if (!FileUtils.TryGetFiles(binaryPaths, out var files, out _, [".dll", ".exe"], SearchOption.AllDirectories))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_GETTING_BINARY_FILES;
return false;
}

// Check that all binary files are contained in allowed binary paths
var binaryFiles = files.ToArray();
if (allowedExtensionPaths != null)
{
if (binaryFiles.Any(f =>
allowedExtensionPaths.All(p => !FileUtils.IsFileInDirectory(f, p))))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_BINARY_FILES_NOT_IN_ALLOWED_PATHS;
return false;
}
}

// If necessary, check that all assemblies are digitally signed
if (!allowUnsignedAssemblies)
{
foreach (var filePath in files)
{
using var fs = File.OpenRead(filePath);
using var peReader = new PEReader(fs);

var metadataReader = peReader.GetMetadataReader();
var assemblyPublicKeyHandle = metadataReader.GetAssemblyDefinition().PublicKey;

if (assemblyPublicKeyHandle.IsNil)
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_ASSEMBLY_NOT_SIGNED;
return false;
}

var publicKeyBytes = metadataReader.GetBlobBytes(assemblyPublicKeyHandle);
if (publicKeyBytes == null || publicKeyBytes.Length == 0)
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_ASSEMBLY_NOT_SIGNED;
badrishc marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
}
}

// Get all assemblies from binary files
if (!FileUtils.TryLoadAssemblies(binaryFiles, out loadedAssemblies, out _))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_LOADING_ASSEMBLIES;
return false;
}

return true;
}
}
}
54 changes: 4 additions & 50 deletions libs/server/Resp/AdminCommands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using Garnet.common;
using Garnet.server.Custom;
using Garnet.server.Module;

namespace Garnet.server
{
Expand Down Expand Up @@ -136,52 +134,6 @@ private bool NetworkMonitor()
return true;
}

private bool LoadAssemblies(IEnumerable<string> binaryPaths, out IEnumerable<Assembly> loadedAssemblies, out ReadOnlySpan<byte> errorMessage)
{
loadedAssemblies = null;
errorMessage = default;

// Get all binary file paths from inputs binary paths
if (!FileUtils.TryGetFiles(binaryPaths, out var files, out _, [".dll", ".exe"],
SearchOption.AllDirectories))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_GETTING_BINARY_FILES;
return false;
}

// Check that all binary files are contained in allowed binary paths
var binaryFiles = files.ToArray();
if (binaryFiles.Any(f =>
storeWrapper.serverOptions.ExtensionBinPaths.All(p => !FileUtils.IsFileInDirectory(f, p))))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_BINARY_FILES_NOT_IN_ALLOWED_PATHS;
return false;
}

// Get all assemblies from binary files
if (!FileUtils.TryLoadAssemblies(binaryFiles, out loadedAssemblies, out _))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_LOADING_ASSEMBLIES;
return false;
}

// If necessary, check that all assemblies are digitally signed
if (!storeWrapper.serverOptions.ExtensionAllowUnsignedAssemblies)
{
foreach (var loadedAssembly in loadedAssemblies)
{
var publicKey = loadedAssembly.GetName().GetPublicKey();
if (publicKey == null || publicKey.Length == 0)
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_ASSEMBLY_NOT_SIGNED;
return false;
}
}
}

return true;
}

/// <summary>
/// Register all custom commands / transactions
/// </summary>
Expand Down Expand Up @@ -231,7 +183,8 @@ private bool TryRegisterCustomCommands(
}
}

if (!LoadAssemblies(binaryPaths, out var loadedAssemblies, out errorMessage))
if (!ModuleUtils.LoadAssemblies(binaryPaths, storeWrapper.serverOptions.ExtensionBinPaths,
storeWrapper.serverOptions.ExtensionAllowUnsignedAssemblies, out var loadedAssemblies, out errorMessage))
return false;

foreach (var c in classNameToRegisterArgs.Keys)
Expand Down Expand Up @@ -488,7 +441,8 @@ private bool NetworkModuleLoad(CustomCommandManager customCommandManager)
for (var i = 0; i < moduleArgs.Length; i++)
moduleArgs[i] = parseState.GetArgSliceByRef(i + 1).ToString();

if (LoadAssemblies([modulePath], out var loadedAssemblies, out var errorMsg))
if (ModuleUtils.LoadAssemblies([modulePath], storeWrapper.serverOptions.ExtensionBinPaths,
storeWrapper.serverOptions.ExtensionAllowUnsignedAssemblies, out var loadedAssemblies, out var errorMsg))
{
Debug.Assert(loadedAssemblies != null);
var assembliesList = loadedAssemblies.ToList();
Expand Down
3 changes: 3 additions & 0 deletions libs/server/Servers/GarnetServerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license.

using System;
using System.Collections.Generic;
using System.IO;
using Garnet.server.Auth.Settings;
using Garnet.server.TLS;
Expand Down Expand Up @@ -368,6 +369,8 @@ public class GarnetServerOptions : ServerOptions
/// </summary>
public bool ExtensionAllowUnsignedAssemblies;

public IEnumerable<string> LoadModuleCS;

/// <summary>
/// Constructor
/// </summary>
Expand Down
1 change: 0 additions & 1 deletion playground/GarnetJSON/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license.

using Garnet.server;
using Garnet.server.Module;
using Microsoft.Extensions.Logging;

namespace GarnetJSON
Expand Down
1 change: 0 additions & 1 deletion playground/SampleModule/SampleModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using Garnet;
using Garnet.server;
using Garnet.server.Module;
using Microsoft.Extensions.Logging;

namespace SampleModule
Expand Down
4 changes: 3 additions & 1 deletion test/Garnet.test/GarnetServerConfigTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void ImportExportConfigLocal()
// No import path, include command line args, export to file
// Check values from command line override values from defaults.conf
static string GetFullExtensionBinPath(string testProjectName) => Path.GetFullPath(testProjectName, TestUtils.RootTestsProjectPath);
var args = new string[] { "--config-export-path", configPath, "-p", "4m", "-m", "128m", "-s", "2g", "--recover", "--port", "53", "--reviv-obj-bin-record-count", "2", "--reviv-fraction", "0.5", "--extension-bin-paths", $"{GetFullExtensionBinPath("Garnet.test")},{GetFullExtensionBinPath("Garnet.test.cluster")}" };
var args = new string[] { "--config-export-path", configPath, "-p", "4m", "-m", "128m", "-s", "2g", "--recover", "--port", "53", "--reviv-obj-bin-record-count", "2", "--reviv-fraction", "0.5", "--extension-bin-paths", $"{GetFullExtensionBinPath("Garnet.test")},{GetFullExtensionBinPath("Garnet.test.cluster")}", "--loadmodulecs", $"{Assembly.GetExecutingAssembly().Location}" };
parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out options, out invalidOptions);
ClassicAssert.IsTrue(parseSuccessful);
ClassicAssert.AreEqual(invalidOptions.Count, 0);
Expand All @@ -95,6 +95,8 @@ public void ImportExportConfigLocal()
ClassicAssert.IsTrue(options.Recover);
ClassicAssert.IsTrue(File.Exists(configPath));
ClassicAssert.AreEqual(2, options.ExtensionBinPaths.Count());
ClassicAssert.AreEqual(1, options.LoadModuleCS.Count());
ClassicAssert.AreEqual(Assembly.GetExecutingAssembly().Location, options.LoadModuleCS.First());

// Import from previous export command, no command line args
// Check values from import path override values from default.conf
Expand Down
Loading
Loading