diff --git a/src/Redis.OM/RedisCommands.cs b/src/Redis.OM/RedisCommands.cs index 96e9ea2a..2f10a2f4 100644 --- a/src/Redis.OM/RedisCommands.cs +++ b/src/Redis.OM/RedisCommands.cs @@ -774,8 +774,9 @@ public static async Task> HGetAllAsync(this IRed /// The key. /// The value. /// The storage type of the value. + /// The ttl for the key. /// The type of the value. - internal static void UnlinkAndSet(this IRedisConnection connection, string key, T value, StorageType storageType) + internal static void UnlinkAndSet(this IRedisConnection connection, string key, T value, StorageType storageType, TimeSpan? ttl) { _ = value ?? throw new ArgumentNullException(nameof(value)); if (storageType == StorageType.Json) @@ -791,6 +792,11 @@ internal static void UnlinkAndSet(this IRedisConnection connection, string ke { args.Add(pair.Key); args.Add(pair.Value); + if (ttl is not null) + { + args.Add("EXPIRE"); + args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture)); + } } connection.CreateAndEval(nameof(Scripts.UnlinkAndSetHash), new[] { key }, args.ToArray()); @@ -804,9 +810,10 @@ internal static void UnlinkAndSet(this IRedisConnection connection, string ke /// The key. /// The value. /// The storage type of the value. + /// The time to live for the key. /// The type of the value. /// A representing the asynchronous operation. - internal static async Task UnlinkAndSetAsync(this IRedisConnection connection, string key, T value, StorageType storageType) + internal static async Task UnlinkAndSetAsync(this IRedisConnection connection, string key, T value, StorageType storageType, TimeSpan? ttl) { _ = value ?? throw new ArgumentNullException(nameof(value)); if (storageType == StorageType.Json) @@ -822,6 +829,11 @@ internal static async Task UnlinkAndSetAsync(this IRedisConnection connection { args.Add(pair.Key); args.Add(pair.Value); + if (ttl is not null) + { + args.Add("EXPIRE"); + args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture)); + } } await connection.CreateAndEvalAsync(nameof(Scripts.UnlinkAndSetHash), new[] { key }, args.ToArray()); diff --git a/src/Redis.OM/Scripts.cs b/src/Redis.OM/Scripts.cs index 29813100..8bddf786 100644 --- a/src/Redis.OM/Scripts.cs +++ b/src/Redis.OM/Scripts.cs @@ -19,6 +19,8 @@ internal class Scripts if index>=0 then redis.call('JSON.ARRPOP', key, ARGV[i+1], index) end + elseif 'EXPIRE' == ARGV[i] then + redis.call('PEXPIRE', key, tonumber(ARGV[i+1])) else if 'DEL' == ARGV[i] then redis.call('JSON.DEL',key,ARGV[i+1]) @@ -38,6 +40,7 @@ internal class Scripts local num_fields_to_set = ARGV[1] local end_index = num_fields_to_set*2+1 local args = {} +local expire_time = -1 for i=2, end_index, 2 do args[i-1] = ARGV[i] args[i] = ARGV[i+1] @@ -49,9 +52,19 @@ internal class Scripts local second_op args = {} for i = end_index+1, num_args, 1 do - args[i-end_index] = ARGV[i] + if ARGV[i] == 'EXPIRE' then + expire_time = tonumber(ARGV[i+1]) + else + args[i-end_index] = ARGV[i] + end + end + + if table.getn(args) > 0 then + redis.call('HDEL',key,unpack(args)) end - redis.call('HDEL',key,unpack(args)) +end +if expire_time > -1 then + redis.call('PEXPIRE', key, expire_time) end "; @@ -69,11 +82,19 @@ local second_op local num_fields = ARGV[1] local end_index = num_fields * 2 + 1 local args = {} +local expire_time = -1 for i = 2, end_index, 2 do - args[i-1] = ARGV[i] - args[i] = ARGV[i+1] + if ARGV[i] == 'EXPIRE' then + expire_time = tonumber(ARGV[i+1]) + else + args[i-1] = ARGV[i] + args[i] = ARGV[i+1] + end end redis.call('HSET',KEYS[1],unpack(args)) +if expire_time > -1 then + redis.call('PEXPIRE', KEYS[1], expire_time) +end return 0 "; @@ -81,7 +102,14 @@ local second_op /// Unlinks a JSON object and sets the key again with a fresh new JSON object. /// internal const string UnlinkAndSendJson = @" -local expiry = tonumber(redis.call('PTTL', KEYS[1])) +local num_args = table.getn(ARGV) +local expiry = -1 +if num_args > 1 and 'EXPIRE' == ARGV[2] then + expiry = tonumber(ARGV[3]) +else + expiry = tonumber(redis.call('PTTL', KEYS[1])) +end + redis.call('UNLINK', KEYS[1]) redis.call('JSON.SET', KEYS[1], '.', ARGV[1]) if expiry > 0 then diff --git a/src/Redis.OM/Searching/IRedisCollection.cs b/src/Redis.OM/Searching/IRedisCollection.cs index 2fc123f9..d72388f1 100644 --- a/src/Redis.OM/Searching/IRedisCollection.cs +++ b/src/Redis.OM/Searching/IRedisCollection.cs @@ -158,6 +158,29 @@ public interface IRedisCollection : IOrderedQueryable, IAsyncEnumerable /// A representing the asynchronous operation. ValueTask UpdateAsync(IEnumerable items); + /// + /// Updates the provided item in Redis. Document must have a property marked with the . + /// + /// The item to update. + /// The updated ttl for the record. + void Update(T item, TimeSpan ttl); + + /// + /// Updates the provided item in Redis. Document must have a property marked with the . + /// + /// The item to update. + /// The updated ttl for the record. + /// A representing the asynchronous operation. + Task UpdateAsync(T item, TimeSpan ttl); + + /// + /// Updates the provided items in Redis. Document must have a property marked with the . + /// + /// The items to update. + /// The updated ttl for the record. + /// A representing the asynchronous operation. + ValueTask UpdateAsync(IEnumerable items, TimeSpan ttl); + /// /// Deletes the item from Redis. /// diff --git a/src/Redis.OM/Searching/RedisCollection.cs b/src/Redis.OM/Searching/RedisCollection.cs index c0a0bdc5..f9469872 100644 --- a/src/Redis.OM/Searching/RedisCollection.cs +++ b/src/Redis.OM/Searching/RedisCollection.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -144,69 +145,37 @@ public bool Any(Expression> expression) /// public void Update(T item) { - var key = item.GetKey(); - IList? diff; - var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff); - if (diffConstructed) - { - if (diff!.Any()) - { - var args = new List(); - var scriptName = diff!.First().Script; - foreach (var update in diff!) - { - args.AddRange(update.SerializeScriptArgs()); - } - - _connection.CreateAndEval(scriptName, new[] { key }, args.ToArray()); - } - } - else - { - _connection.UnlinkAndSet(key, item, StateManager.DocumentAttribute.StorageType); - } - - SaveToStateManager(key, item); + SendUpdate(item); } /// - public async Task UpdateAsync(T item) + public Task UpdateAsync(T item) { - var key = item.GetKey(); - IList? diff; - var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff); - if (diffConstructed) - { - if (diff!.Any()) - { - var args = new List(); - var scriptName = diff!.First().Script; - foreach (var update in diff!) - { - args.AddRange(update.SerializeScriptArgs()); - } + return SendUpdateAsync(item); + } - await _connection.CreateAndEvalAsync(scriptName, new[] { key }, args.ToArray()); - } - } - else - { - await _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType); - } + /// + public ValueTask UpdateAsync(IEnumerable items) + { + return SendUpdateAsync(items); + } - SaveToStateManager(key, item); + /// + public void Update(T item, TimeSpan ttl) + { + SendUpdate(item, ttl); } /// - public async ValueTask UpdateAsync(IEnumerable items) + public Task UpdateAsync(T item, TimeSpan ttl) { - var tasks = items.Select(UpdateAsyncNoSave); + return SendUpdateAsync(item, ttl); + } - await Task.WhenAll(tasks); - foreach (var kvp in tasks.Select(x => x.Result)) - { - SaveToStateManager(kvp.Key, kvp.Value); - } + /// + public ValueTask UpdateAsync(IEnumerable items, TimeSpan ttl) + { + return SendUpdateAsync(items, ttl); } /// @@ -774,7 +743,7 @@ private static MethodInfo GetMethodInfo(Func f, T1 unused) return _connection.GetAsync(key).AsTask(); } - private async Task> UpdateAsyncNoSave(T item) + private async Task> UpdateAsyncNoSave(T item, TimeSpan? ttl) { var key = item.GetKey(); IList? diff; @@ -790,12 +759,22 @@ private async Task> UpdateAsyncNoSave(T item) args.AddRange(update.SerializeScriptArgs()); } + if (ttl is not null) + { + args.Add("EXPIRE"); + args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture)); + } + await _connection.CreateAndEvalAsync(scriptName, new[] { key }, args.ToArray()); } + else if (ttl is not null) + { + await _connection.ExecuteAsync("PEXPIRE", key, ttl.Value.TotalMilliseconds); + } } else { - await _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType); + await _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType, ttl); } return new KeyValuePair(key, item); @@ -831,5 +810,97 @@ private void SaveToStateManager(string key, object value) } } } + + private void SendUpdate(T item, TimeSpan? ttl = null) + { + var key = item.GetKey(); + IList? diff; + var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff); + if (diffConstructed) + { + if (diff!.Any()) + { + var args = new List(); + var scriptName = diff!.First().Script; + foreach (var update in diff!) + { + args.AddRange(update.SerializeScriptArgs()); + } + + if (ttl is not null) + { + args.Add("EXPIRE"); + args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture)); + } + + _connection.CreateAndEval(scriptName, new[] { key }, args.ToArray()); + } + else if (ttl is not null) + { + _connection.Execute("PEXPIRE", key, ttl.Value.TotalMilliseconds); + } + } + else + { + _connection.UnlinkAndSet(key, item, StateManager.DocumentAttribute.StorageType, ttl); + } + + SaveToStateManager(key, item); + } + + private Task SendUpdateAsync(T item, TimeSpan? ttl = null) + { + var key = item.GetKey(); + IList? diff; + var diffConstructed = StateManager.TryDetectDifferencesSingle(key, item, out diff); + Task? task = null; + if (diffConstructed) + { + if (diff!.Any()) + { + var args = new List(); + var scriptName = diff!.First().Script; + foreach (var update in diff!) + { + args.AddRange(update.SerializeScriptArgs()); + } + + if (ttl is not null) + { + args.Add("EXPIRE"); + args.Add(ttl.Value.TotalMilliseconds.ToString(CultureInfo.InvariantCulture)); + } + + task = _connection.CreateAndEvalAsync(scriptName, new[] { key }, args.ToArray()); + } + else if (ttl is not null) + { + task = _connection.ExecuteAsync("PEXPIRE", key, ttl.Value.TotalMilliseconds); + } + } + else + { + task = _connection.UnlinkAndSetAsync(key, item, StateManager.DocumentAttribute.StorageType, ttl); + } + + SaveToStateManager(key, item); + if (task is null) + { + return Task.CompletedTask; + } + + return task; + } + + private async ValueTask SendUpdateAsync(IEnumerable items, TimeSpan? ttl = null) + { + var tasks = items.Select(x => UpdateAsyncNoSave(x, ttl)).ToArray(); + + await Task.WhenAll(tasks); + foreach (var kvp in tasks.Select(x => x.Result)) + { + SaveToStateManager(kvp.Key, kvp.Value); + } + } } } \ No newline at end of file diff --git a/test/Redis.OM.Unit.Tests/RediSearchTests/SearchFunctionalTests.cs b/test/Redis.OM.Unit.Tests/RediSearchTests/SearchFunctionalTests.cs index 3b720a14..9ca8f807 100644 --- a/test/Redis.OM.Unit.Tests/RediSearchTests/SearchFunctionalTests.cs +++ b/test/Redis.OM.Unit.Tests/RediSearchTests/SearchFunctionalTests.cs @@ -11,6 +11,7 @@ using Redis.OM.Modeling; using Redis.OM.Searching; using Redis.OM.Searching.Query; +using StackExchange.Redis; using Xunit; namespace Redis.OM.Unit.Tests.RediSearchTests @@ -349,6 +350,54 @@ public void TestUpdate() Assert.Equal(testP.Id, secondQueriedP.Id); } + [Fact] + public void TestUpdateWithTimeout() + { + var collection = new RedisCollection(_connection); + var testP = new Person { Name = "Steve", Age = 32 }; + TimeSpan ttl = TimeSpan.FromHours(1); + var key = collection.Insert(testP); + var queriedP = collection.FindById(key); + Assert.NotNull(queriedP); + queriedP.Age = 33; + collection.Update(queriedP, ttl); + var ttlOnKey = (long)_connection.Execute("PTTL", key); + var secondQueriedP = collection.FindById(key); + + Assert.NotNull(secondQueriedP); + Assert.InRange(ttlOnKey, ttl.TotalMilliseconds - 2000, ttl.TotalMilliseconds); + Assert.Equal(33, secondQueriedP.Age); + Assert.Equal(secondQueriedP.Id, queriedP.Id); + Assert.Equal(testP.Id, secondQueriedP.Id); + } + + [Fact] + public async Task TestUpdateWithTimeoutMulti() + { + var collection = new RedisCollection(_connection); + var testP = new Person { Name = "Steve", Age = 32 }; + var testP2 = new Person { Name = "Chris", Age = 37 }; + TimeSpan ttl = TimeSpan.FromMinutes(5); + var keys = (await collection.InsertAsync(new Person[] { testP, testP2 }, WhenKey.Always, ttl)).ToArray(); + var queriedP = await collection.FindByIdAsync(keys[0]); + var queriedP2 = await collection.FindByIdAsync(keys[1]); + Assert.NotNull(queriedP); + queriedP.Age = 33; + ttl = TimeSpan.FromHours(1); + await collection.UpdateAsync(new []{queriedP, queriedP2}, ttl); + var ttlOnKey1 = (long) await _connection.ExecuteAsync("PTTL", keys[0]); + var ttlOnKey2 = (long) await _connection.ExecuteAsync("PTTL", keys[1]); + var secondQueriedP = await collection.FindByIdAsync(keys[0]); + + Assert.NotNull(secondQueriedP); + Assert.InRange(ttlOnKey1, ttl.TotalMilliseconds - 2000, ttl.TotalMilliseconds); + Assert.InRange(ttlOnKey2, ttl.TotalMilliseconds - 2000, ttl.TotalMilliseconds); + Assert.Equal(33, secondQueriedP.Age); + Assert.Equal(secondQueriedP.Id, queriedP.Id); + Assert.Equal(testP.Id, secondQueriedP.Id); + } + + [Fact] public void TestUpdateNullCollection() { @@ -385,6 +434,29 @@ public async Task TestUpdateAsync() Assert.Equal(secondQueriedP.Id, queriedP.Id); Assert.Equal(testP.Id, secondQueriedP.Id); } + + [Fact] + public async Task TestUpdateWithTtlAsync() + { + var collection = new RedisCollection(_connection); + var testP = new Person { Name = "Steve", Age = 32 }; + var key = await collection.InsertAsync(testP); + var queriedP = await collection.FindByIdAsync(key); + Assert.NotNull(queriedP); + queriedP.Age = 33; + TimeSpan ttl = TimeSpan.FromHours(1); + await collection.UpdateAsync(queriedP, ttl); + + var ttlFromKey = (double) await _connection.ExecuteAsync("PTTL", key); + + var secondQueriedP = await collection.FindByIdAsync(key); + + Assert.InRange(ttlFromKey, ttl.TotalMilliseconds - 2000, ttl.TotalMilliseconds); + Assert.NotNull(secondQueriedP); + Assert.Equal(33, secondQueriedP.Age); + Assert.Equal(secondQueriedP.Id, queriedP.Id); + Assert.Equal(testP.Id, secondQueriedP.Id); + } [Fact] public async Task TestUpdateName() @@ -424,6 +496,28 @@ public async Task TestUpdateHashPerson() Assert.Equal(secondQueriedP.Id, queriedP.Id); Assert.Equal(testP.Id, secondQueriedP.Id); } + + [Fact] + public async Task TestUpdateHashWithTtl() + { + var collection = new RedisCollection(_connection); + var testP = new HashPerson { Name = "Steve", Age = 32 }; + var key = await collection.InsertAsync(testP); + var queriedP = await collection.FindByIdAsync(key); + Assert.NotNull(queriedP); + queriedP.Age = 33; + var ttl = TimeSpan.FromHours(1); + await collection.UpdateAsync(queriedP, ttl); + + var ttlFromKey = (double)await _connection.ExecuteAsync("PTTL", key); + var secondQueriedP = await collection.FindByIdAsync(key); + + Assert.InRange(ttlFromKey, ttl.TotalMilliseconds - 2000, ttl.TotalMilliseconds); + Assert.NotNull(secondQueriedP); + Assert.Equal(33, secondQueriedP.Age); + Assert.Equal(secondQueriedP.Id, queriedP.Id); + Assert.Equal(testP.Id, secondQueriedP.Id); + } [Fact] public async Task TestToListAsync()