Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/NatsDistributedCache/INatsCacheKeyEncoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
namespace CodeCargo.Nats.DistributedCache;

/// <summary>
/// Encodes raw strings so they satisfy the NATS KV key rules.
/// </summary>
public interface INatsCacheKeyEncoder
{
/// <summary>
/// Encodes a raw string into a KV-legal key
/// </summary>
/// <param name="raw">The raw string to encode</param>
/// <returns>A KV-legal key</returns>
string Encode(string raw);
}
12 changes: 3 additions & 9 deletions src/NatsDistributedCache/NatsCache.Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,15 @@ namespace CodeCargo.Nats.DistributedCache;

public partial class NatsCache
{
private void LogException(Exception exception) =>
_logger.LogError(EventIds.Exception, exception, "Exception in NatsDistributedCache");

private void LogConnected(string bucketName) =>
_logger.LogInformation(EventIds.Connected, "Connected to NATS KV bucket {bucketName}", bucketName);

private void LogUpdateFailed(string key) => _logger.LogDebug(
EventIds.UpdateFailed,
"Sliding expiration update failed for key {Key} due to optimistic concurrency control",
key);
private void LogException(Exception exception) =>
_logger.LogError(EventIds.Exception, exception, "Exception in NatsDistributedCache");

private static class EventIds
{
public static readonly EventId Connected = new(100, nameof(Connected));
public static readonly EventId UpdateFailed = new(101, nameof(UpdateFailed));
public static readonly EventId Exception = new(102, nameof(Exception));
public static readonly EventId Exception = new(101, nameof(Exception));
}
}
115 changes: 45 additions & 70 deletions src/NatsDistributedCache/NatsCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,17 @@ public partial class NatsCache : IBufferDistributedCache
new(CacheEntryJsonContext.Default);

private readonly string _bucketName;
private readonly INatsCacheKeyEncoder _keyEncoder;
private readonly string _keyPrefix;
private readonly ILogger _logger;
private readonly INatsConnection _natsConnection;
private Lazy<Task<INatsKVStore>> _lazyKvStore;

public NatsCache(
IOptions<NatsCacheOptions> optionsAccessor,
ILogger<NatsCache> logger,
INatsConnection natsConnection)
INatsConnection natsConnection,
ILogger<NatsCache>? logger = null,
INatsCacheKeyEncoder? keyEncoder = null)
{
var options = optionsAccessor.Value;
_bucketName = !string.IsNullOrWhiteSpace(options.BucketName)
Expand All @@ -61,13 +63,9 @@ public NatsCache(
? string.Empty
: options.CacheKeyPrefix.TrimEnd('.');
_lazyKvStore = CreateLazyKvStore();
_logger = logger;
_natsConnection = natsConnection;
}

public NatsCache(IOptions<NatsCacheOptions> optionsAccessor, INatsConnection natsConnection)
: this(optionsAccessor, NullLogger<NatsCache>.Instance, natsConnection)
{
_logger = logger ?? NullLogger<NatsCache>.Instance;
_keyEncoder = keyEncoder ?? new NatsCacheKeyEncoder();
}

/// <inheritdoc />
Expand All @@ -88,7 +86,8 @@ public async Task SetAsync(
try
{
// todo: remove cast after https://github.com/nats-io/nats.net/pull/852 is released
await ((NatsKVStore)kvStore).PutAsync(GetPrefixedKey(key), entry, ttl ?? TimeSpan.Zero, CacheEntrySerializer, token)
await ((NatsKVStore)kvStore)
.PutAsync(GetEncodedKey(key), entry, ttl ?? TimeSpan.Zero, CacheEntrySerializer, token)
.ConfigureAwait(false);
}
catch (Exception ex)
Expand Down Expand Up @@ -125,14 +124,14 @@ public async Task RemoveAsync(string key, CancellationToken token = default) =>

/// <inheritdoc />
public Task RefreshAsync(string key, CancellationToken token = default) =>
GetAndRefreshAsync(key, getData: false, retry: true, token: token);
GetAndRefreshAsync(key, token: token);

/// <inheritdoc />
public byte[]? Get(string key) => GetAsync(key).GetAwaiter().GetResult();

/// <inheritdoc />
public Task<byte[]?> GetAsync(string key, CancellationToken token = default) =>
GetAndRefreshAsync(key, getData: true, retry: true, token: token);
GetAndRefreshAsync(key, token: token);

/// <inheritdoc />
public bool TryGet(string key, IBufferWriter<byte> destination) =>
Expand Down Expand Up @@ -161,9 +160,6 @@ public async ValueTask<bool> TryGetAsync(
return false;
}

// This is the method used by hybrid caching to determine if it should use the distributed instance
internal virtual bool IsHybridCacheActive() => false;

private static TimeSpan? GetTtl(DistributedCacheEntryOptions options)
{
if (options.AbsoluteExpiration.HasValue && options.AbsoluteExpiration.Value <= DateTimeOffset.Now)
Expand Down Expand Up @@ -233,9 +229,10 @@ private static CacheEntry CreateCacheEntry(byte[] value, DistributedCacheEntryOp
return cacheEntry;
}

private string GetPrefixedKey(string key) => string.IsNullOrEmpty(_keyPrefix)
? key
: _keyPrefix + "." + key;
private string GetEncodedKey(string key) =>
string.IsNullOrEmpty(_keyPrefix)
? _keyEncoder.Encode(key)
: _keyEncoder.Encode($"{_keyPrefix}.{key}");

private Lazy<Task<INatsKVStore>> CreateLazyKvStore() =>
new(async () =>
Expand All @@ -259,18 +256,14 @@ private Lazy<Task<INatsKVStore>> CreateLazyKvStore() =>

private Task<INatsKVStore> GetKvStore() => _lazyKvStore.Value;

private async Task<byte[]?> GetAndRefreshAsync(
string key,
bool getData,
bool retry,
CancellationToken token = default)
private async Task<byte[]?> GetAndRefreshAsync(string key, CancellationToken token)
{
var encodedKey = GetEncodedKey(key);
var kvStore = await GetKvStore().ConfigureAwait(false);
var prefixedKey = GetPrefixedKey(key);
try
{
var natsResult = await kvStore
.TryGetEntryAsync(prefixedKey, serializer: CacheEntrySerializer, cancellationToken: token)
.TryGetEntryAsync(encodedKey, serializer: CacheEntrySerializer, cancellationToken: token)
.ConfigureAwait(false);
if (!natsResult.Success)
{
Expand All @@ -292,83 +285,65 @@ private Lazy<Task<INatsKVStore>> CreateLazyKvStore() =>
return null;
}

await UpdateEntryExpirationAsync(kvStore, prefixedKey, kvEntry, token).ConfigureAwait(false);
return getData ? kvEntry.Value.Data : null;
await UpdateEntryExpirationAsync(kvEntry).ConfigureAwait(false);
return kvEntry.Value.Data;
}
catch (NatsKVWrongLastRevisionException ex)
catch (NatsKVWrongLastRevisionException)
{
// Optimistic concurrency control failed, someone else updated it
LogUpdateFailed(key);
if (retry)
{
return await GetAndRefreshAsync(key, getData, retry: false, token).ConfigureAwait(false);
}

LogException(ex);
// Someone else updated it; that's fine, we'll get the latest version next time
return null;
}
catch (Exception ex)
{
LogException(ex);
throw;
}
}

private async Task UpdateEntryExpirationAsync(
INatsKVStore kvStore,
string key,
NatsKVEntry<CacheEntry> kvEntry,
CancellationToken token)
{
if (kvEntry.Value?.SlidingExpirationTicks == null)
// Local Functions
async Task UpdateEntryExpirationAsync(NatsKVEntry<CacheEntry> kvEntry)
{
return;
}
if (kvEntry.Value?.SlidingExpirationTicks == null)
{
return;
}

// If we have a sliding expiration, use it as the TTL
var ttl = TimeSpan.FromTicks(kvEntry.Value.SlidingExpirationTicks.Value);
// If we have a sliding expiration, use it as the TTL
var ttl = TimeSpan.FromTicks(kvEntry.Value.SlidingExpirationTicks.Value);

// If we also have an absolute expiration, make sure we don't exceed it
if (kvEntry.Value.AbsoluteExpiration != null)
{
var remainingTime = kvEntry.Value.AbsoluteExpiration.Value - DateTimeOffset.Now;

// Use the minimum of sliding window or remaining absolute time
if (remainingTime > TimeSpan.Zero && remainingTime < ttl)
// If we also have an absolute expiration, make sure we don't exceed it
if (kvEntry.Value.AbsoluteExpiration != null)
{
ttl = remainingTime;
var remainingTime = kvEntry.Value.AbsoluteExpiration.Value - DateTimeOffset.Now;

// Use the minimum of sliding window or remaining absolute time
if (remainingTime > TimeSpan.Zero && remainingTime < ttl)
{
ttl = remainingTime;
}
}
}

if (ttl > TimeSpan.Zero)
{
// Use optimistic concurrency control with the last revision
try
if (ttl > TimeSpan.Zero)
{
// Use optimistic concurrency control with the last revision
// todo: remove cast after https://github.com/nats-io/nats.net/pull/852 is released
await ((NatsKVStore)kvStore).UpdateAsync(
key,
encodedKey,
kvEntry.Value,
kvEntry.Revision,
ttl,
serializer: CacheEntrySerializer,
cancellationToken: token).ConfigureAwait(false);
}
catch (NatsKVWrongLastRevisionException)
{
// Someone else updated it; that's fine, we'll get the latest version next time
LogUpdateFailed(key.Replace(GetPrefixedKey(string.Empty), string.Empty));
}
}
}

private async Task RemoveAsync(
string key,
NatsKVDeleteOpts? natsKvDeleteOpts = null,
CancellationToken token = default)
string key,
NatsKVDeleteOpts? natsKvDeleteOpts = null,
CancellationToken token = default)
{
var kvStore = await GetKvStore().ConfigureAwait(false);
await kvStore.DeleteAsync(GetPrefixedKey(key), natsKvDeleteOpts, cancellationToken: token)
await kvStore.DeleteAsync(GetEncodedKey(key), natsKvDeleteOpts, cancellationToken: token)
.ConfigureAwait(false);
}
}
70 changes: 70 additions & 0 deletions src/NatsDistributedCache/NatsCacheKeyEncoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using System.Text.RegularExpressions;

namespace CodeCargo.Nats.DistributedCache;

/// <summary>
/// URL-encoding implementation that keeps already allowed keys
/// untouched and URL-encodes everything else. % characters in the
/// final encoded output are replaced with = characters to conform
/// to the NATS KV key rules.
/// </summary>
public sealed partial class NatsCacheKeyEncoder : INatsCacheKeyEncoder
{
/// <inheritdoc />
public string Encode(string raw)
{
if (string.IsNullOrEmpty(raw))
{
throw new ArgumentException("Key must not be null or empty.", nameof(raw));
}

if (ValidUnencodedKey(raw))
{
// already valid
return raw;
}

var encoded = Uri.EscapeDataString(raw);
encoded = encoded.Replace("~", "%7E");
if (encoded.StartsWith('.'))
{
encoded = "%2E" + encoded[1..];
}

if (encoded.EndsWith('.'))
{
encoded = encoded[..^1] + "%2E";
}

encoded = encoded.Replace('%', '=');
return encoded;
}

/// <inheritdoc />
public string Decode(string key)
{
if (string.IsNullOrEmpty(key))
{
throw new ArgumentException("Key must not be null or empty.", nameof(key));
}

if (!key.Contains('='))
{
// nothing to decode
return key;
}

var decoded = key.Replace('=', '%');
return Uri.UnescapeDataString(decoded);
}

private static bool ValidUnencodedKey(string rawKey) =>
!rawKey.StartsWith('.')
&& !rawKey.EndsWith('.')
&& ValidUnencodedKeyRegex().IsMatch(rawKey);

// Regex pattern to match valid NATS KV keys with = removed, since =
// is used instead of % to mark an encoded character sequence
[GeneratedRegex("^[-_.A-Za-z0-9]+$", RegexOptions.Compiled)]
private static partial Regex ValidUnencodedKeyRegex();
}
8 changes: 3 additions & 5 deletions src/NatsDistributedCache/NatsDistributedCacheExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ public static IServiceCollection AddNatsDistributedCache(
services.AddSingleton<IDistributedCache>(sp =>
{
var optionsAccessor = sp.GetRequiredService<IOptions<NatsCacheOptions>>();
var logger = sp.GetService<ILogger<NatsCache>>();

var natsConnection = connectionServiceKey == null
? sp.GetRequiredService<INatsConnection>()
: sp.GetRequiredKeyedService<INatsConnection>(connectionServiceKey);
var logger = sp.GetService<ILogger<NatsCache>>();
var keyEncoder = sp.GetService<INatsCacheKeyEncoder>();

return logger != null
? new NatsCache(optionsAccessor, logger, natsConnection)
: new NatsCache(optionsAccessor, natsConnection);
return new NatsCache(optionsAccessor, natsConnection, logger: logger, keyEncoder: keyEncoder);
});

return services;
Expand Down
Loading