diff --git a/src/NatsDistributedCache/INatsCacheKeyEncoder.cs b/src/NatsDistributedCache/INatsCacheKeyEncoder.cs
new file mode 100644
index 0000000..a192014
--- /dev/null
+++ b/src/NatsDistributedCache/INatsCacheKeyEncoder.cs
@@ -0,0 +1,14 @@
+namespace CodeCargo.Nats.DistributedCache;
+
+///
+/// Encodes raw strings so they satisfy the NATS KV key rules.
+///
+public interface INatsCacheKeyEncoder
+{
+ ///
+ /// Encodes a raw string into a KV-legal key
+ ///
+ /// The raw string to encode
+ /// A KV-legal key
+ string Encode(string raw);
+}
diff --git a/src/NatsDistributedCache/NatsCache.Log.cs b/src/NatsDistributedCache/NatsCache.Log.cs
index da23903..0ffe146 100644
--- a/src/NatsDistributedCache/NatsCache.Log.cs
+++ b/src/NatsDistributedCache/NatsCache.Log.cs
@@ -4,21 +4,15 @@ namespace CodeCargo.Nats.DistributedCache;
public partial class NatsCache
{
- private void LogException(Exception exception) =>
- _logger.LogError(EventIds.Exception, exception, "Exception in NatsDistributedCache");
-
private void LogConnected(string bucketName) =>
_logger.LogInformation(EventIds.Connected, "Connected to NATS KV bucket {bucketName}", bucketName);
- private void LogUpdateFailed(string key) => _logger.LogDebug(
- EventIds.UpdateFailed,
- "Sliding expiration update failed for key {Key} due to optimistic concurrency control",
- key);
+ private void LogException(Exception exception) =>
+ _logger.LogError(EventIds.Exception, exception, "Exception in NatsDistributedCache");
private static class EventIds
{
public static readonly EventId Connected = new(100, nameof(Connected));
- public static readonly EventId UpdateFailed = new(101, nameof(UpdateFailed));
- public static readonly EventId Exception = new(102, nameof(Exception));
+ public static readonly EventId Exception = new(101, nameof(Exception));
}
}
diff --git a/src/NatsDistributedCache/NatsCache.cs b/src/NatsDistributedCache/NatsCache.cs
index f0b54dc..d737b34 100644
--- a/src/NatsDistributedCache/NatsCache.cs
+++ b/src/NatsDistributedCache/NatsCache.cs
@@ -43,6 +43,7 @@ public partial class NatsCache : IBufferDistributedCache
new(CacheEntryJsonContext.Default);
private readonly string _bucketName;
+ private readonly INatsCacheKeyEncoder _keyEncoder;
private readonly string _keyPrefix;
private readonly ILogger _logger;
private readonly INatsConnection _natsConnection;
@@ -50,8 +51,9 @@ public partial class NatsCache : IBufferDistributedCache
public NatsCache(
IOptions optionsAccessor,
- ILogger logger,
- INatsConnection natsConnection)
+ INatsConnection natsConnection,
+ ILogger? logger = null,
+ INatsCacheKeyEncoder? keyEncoder = null)
{
var options = optionsAccessor.Value;
_bucketName = !string.IsNullOrWhiteSpace(options.BucketName)
@@ -61,13 +63,9 @@ public NatsCache(
? string.Empty
: options.CacheKeyPrefix.TrimEnd('.');
_lazyKvStore = CreateLazyKvStore();
- _logger = logger;
_natsConnection = natsConnection;
- }
-
- public NatsCache(IOptions optionsAccessor, INatsConnection natsConnection)
- : this(optionsAccessor, NullLogger.Instance, natsConnection)
- {
+ _logger = logger ?? NullLogger.Instance;
+ _keyEncoder = keyEncoder ?? new NatsCacheKeyEncoder();
}
///
@@ -88,7 +86,8 @@ public async Task SetAsync(
try
{
// todo: remove cast after https://github.com/nats-io/nats.net/pull/852 is released
- await ((NatsKVStore)kvStore).PutAsync(GetPrefixedKey(key), entry, ttl ?? TimeSpan.Zero, CacheEntrySerializer, token)
+ await ((NatsKVStore)kvStore)
+ .PutAsync(GetEncodedKey(key), entry, ttl ?? TimeSpan.Zero, CacheEntrySerializer, token)
.ConfigureAwait(false);
}
catch (Exception ex)
@@ -125,14 +124,14 @@ public async Task RemoveAsync(string key, CancellationToken token = default) =>
///
public Task RefreshAsync(string key, CancellationToken token = default) =>
- GetAndRefreshAsync(key, getData: false, retry: true, token: token);
+ GetAndRefreshAsync(key, token: token);
///
public byte[]? Get(string key) => GetAsync(key).GetAwaiter().GetResult();
///
public Task GetAsync(string key, CancellationToken token = default) =>
- GetAndRefreshAsync(key, getData: true, retry: true, token: token);
+ GetAndRefreshAsync(key, token: token);
///
public bool TryGet(string key, IBufferWriter destination) =>
@@ -161,9 +160,6 @@ public async ValueTask TryGetAsync(
return false;
}
- // This is the method used by hybrid caching to determine if it should use the distributed instance
- internal virtual bool IsHybridCacheActive() => false;
-
private static TimeSpan? GetTtl(DistributedCacheEntryOptions options)
{
if (options.AbsoluteExpiration.HasValue && options.AbsoluteExpiration.Value <= DateTimeOffset.Now)
@@ -233,9 +229,10 @@ private static CacheEntry CreateCacheEntry(byte[] value, DistributedCacheEntryOp
return cacheEntry;
}
- private string GetPrefixedKey(string key) => string.IsNullOrEmpty(_keyPrefix)
- ? key
- : _keyPrefix + "." + key;
+ private string GetEncodedKey(string key) =>
+ string.IsNullOrEmpty(_keyPrefix)
+ ? _keyEncoder.Encode(key)
+ : _keyEncoder.Encode($"{_keyPrefix}.{key}");
private Lazy> CreateLazyKvStore() =>
new(async () =>
@@ -259,18 +256,14 @@ private Lazy> CreateLazyKvStore() =>
private Task GetKvStore() => _lazyKvStore.Value;
- private async Task GetAndRefreshAsync(
- string key,
- bool getData,
- bool retry,
- CancellationToken token = default)
+ private async Task GetAndRefreshAsync(string key, CancellationToken token)
{
+ var encodedKey = GetEncodedKey(key);
var kvStore = await GetKvStore().ConfigureAwait(false);
- var prefixedKey = GetPrefixedKey(key);
try
{
var natsResult = await kvStore
- .TryGetEntryAsync(prefixedKey, serializer: CacheEntrySerializer, cancellationToken: token)
+ .TryGetEntryAsync(encodedKey, serializer: CacheEntrySerializer, cancellationToken: token)
.ConfigureAwait(false);
if (!natsResult.Success)
{
@@ -292,19 +285,12 @@ private Lazy> CreateLazyKvStore() =>
return null;
}
- await UpdateEntryExpirationAsync(kvStore, prefixedKey, kvEntry, token).ConfigureAwait(false);
- return getData ? kvEntry.Value.Data : null;
+ await UpdateEntryExpirationAsync(kvEntry).ConfigureAwait(false);
+ return kvEntry.Value.Data;
}
- catch (NatsKVWrongLastRevisionException ex)
+ catch (NatsKVWrongLastRevisionException)
{
- // Optimistic concurrency control failed, someone else updated it
- LogUpdateFailed(key);
- if (retry)
- {
- return await GetAndRefreshAsync(key, getData, retry: false, token).ConfigureAwait(false);
- }
-
- LogException(ex);
+ // Someone else updated it; that's fine, we'll get the latest version next time
return null;
}
catch (Exception ex)
@@ -312,63 +298,52 @@ private Lazy> CreateLazyKvStore() =>
LogException(ex);
throw;
}
- }
- private async Task UpdateEntryExpirationAsync(
- INatsKVStore kvStore,
- string key,
- NatsKVEntry kvEntry,
- CancellationToken token)
- {
- if (kvEntry.Value?.SlidingExpirationTicks == null)
+ // Local Functions
+ async Task UpdateEntryExpirationAsync(NatsKVEntry kvEntry)
{
- return;
- }
+ if (kvEntry.Value?.SlidingExpirationTicks == null)
+ {
+ return;
+ }
- // If we have a sliding expiration, use it as the TTL
- var ttl = TimeSpan.FromTicks(kvEntry.Value.SlidingExpirationTicks.Value);
+ // If we have a sliding expiration, use it as the TTL
+ var ttl = TimeSpan.FromTicks(kvEntry.Value.SlidingExpirationTicks.Value);
- // If we also have an absolute expiration, make sure we don't exceed it
- if (kvEntry.Value.AbsoluteExpiration != null)
- {
- var remainingTime = kvEntry.Value.AbsoluteExpiration.Value - DateTimeOffset.Now;
-
- // Use the minimum of sliding window or remaining absolute time
- if (remainingTime > TimeSpan.Zero && remainingTime < ttl)
+ // If we also have an absolute expiration, make sure we don't exceed it
+ if (kvEntry.Value.AbsoluteExpiration != null)
{
- ttl = remainingTime;
+ var remainingTime = kvEntry.Value.AbsoluteExpiration.Value - DateTimeOffset.Now;
+
+ // Use the minimum of sliding window or remaining absolute time
+ if (remainingTime > TimeSpan.Zero && remainingTime < ttl)
+ {
+ ttl = remainingTime;
+ }
}
- }
- if (ttl > TimeSpan.Zero)
- {
- // Use optimistic concurrency control with the last revision
- try
+ if (ttl > TimeSpan.Zero)
{
+ // Use optimistic concurrency control with the last revision
// todo: remove cast after https://github.com/nats-io/nats.net/pull/852 is released
await ((NatsKVStore)kvStore).UpdateAsync(
- key,
+ encodedKey,
kvEntry.Value,
kvEntry.Revision,
ttl,
serializer: CacheEntrySerializer,
cancellationToken: token).ConfigureAwait(false);
}
- catch (NatsKVWrongLastRevisionException)
- {
- // Someone else updated it; that's fine, we'll get the latest version next time
- LogUpdateFailed(key.Replace(GetPrefixedKey(string.Empty), string.Empty));
- }
}
}
private async Task RemoveAsync(
- string key,
- NatsKVDeleteOpts? natsKvDeleteOpts = null,
- CancellationToken token = default)
+ string key,
+ NatsKVDeleteOpts? natsKvDeleteOpts = null,
+ CancellationToken token = default)
{
var kvStore = await GetKvStore().ConfigureAwait(false);
- await kvStore.DeleteAsync(GetPrefixedKey(key), natsKvDeleteOpts, cancellationToken: token)
+ await kvStore.DeleteAsync(GetEncodedKey(key), natsKvDeleteOpts, cancellationToken: token)
.ConfigureAwait(false);
}
}
diff --git a/src/NatsDistributedCache/NatsCacheKeyEncoder.cs b/src/NatsDistributedCache/NatsCacheKeyEncoder.cs
new file mode 100644
index 0000000..1032233
--- /dev/null
+++ b/src/NatsDistributedCache/NatsCacheKeyEncoder.cs
@@ -0,0 +1,70 @@
+using System.Text.RegularExpressions;
+
+namespace CodeCargo.Nats.DistributedCache;
+
+///
+/// URL-encoding implementation that keeps already allowed keys
+/// untouched and URL-encodes everything else. % characters in the
+/// final encoded output are replaced with = characters to conform
+/// to the NATS KV key rules.
+///
+public sealed partial class NatsCacheKeyEncoder : INatsCacheKeyEncoder
+{
+ ///
+ public string Encode(string raw)
+ {
+ if (string.IsNullOrEmpty(raw))
+ {
+ throw new ArgumentException("Key must not be null or empty.", nameof(raw));
+ }
+
+ if (ValidUnencodedKey(raw))
+ {
+ // already valid
+ return raw;
+ }
+
+ var encoded = Uri.EscapeDataString(raw);
+ encoded = encoded.Replace("~", "%7E");
+ if (encoded.StartsWith('.'))
+ {
+ encoded = "%2E" + encoded[1..];
+ }
+
+ if (encoded.EndsWith('.'))
+ {
+ encoded = encoded[..^1] + "%2E";
+ }
+
+ encoded = encoded.Replace('%', '=');
+ return encoded;
+ }
+
+ ///
+ public string Decode(string key)
+ {
+ if (string.IsNullOrEmpty(key))
+ {
+ throw new ArgumentException("Key must not be null or empty.", nameof(key));
+ }
+
+ if (!key.Contains('='))
+ {
+ // nothing to decode
+ return key;
+ }
+
+ var decoded = key.Replace('=', '%');
+ return Uri.UnescapeDataString(decoded);
+ }
+
+ private static bool ValidUnencodedKey(string rawKey) =>
+ !rawKey.StartsWith('.')
+ && !rawKey.EndsWith('.')
+ && ValidUnencodedKeyRegex().IsMatch(rawKey);
+
+ // Regex pattern to match valid NATS KV keys with = removed, since =
+ // is used instead of % to mark an encoded character sequence
+ [GeneratedRegex("^[-_.A-Za-z0-9]+$", RegexOptions.Compiled)]
+ private static partial Regex ValidUnencodedKeyRegex();
+}
diff --git a/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs b/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs
index e72f60b..4aca911 100644
--- a/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs
+++ b/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs
@@ -30,15 +30,13 @@ public static IServiceCollection AddNatsDistributedCache(
services.AddSingleton(sp =>
{
var optionsAccessor = sp.GetRequiredService>();
- var logger = sp.GetService>();
-
var natsConnection = connectionServiceKey == null
? sp.GetRequiredService()
: sp.GetRequiredKeyedService(connectionServiceKey);
+ var logger = sp.GetService>();
+ var keyEncoder = sp.GetService();
- return logger != null
- ? new NatsCache(optionsAccessor, logger, natsConnection)
- : new NatsCache(optionsAccessor, natsConnection);
+ return new NatsCache(optionsAccessor, natsConnection, logger: logger, keyEncoder: keyEncoder);
});
return services;
diff --git a/test/UnitTests/KeyEncoder/NatsCacheKeyEncoderTest.cs b/test/UnitTests/KeyEncoder/NatsCacheKeyEncoderTest.cs
new file mode 100644
index 0000000..625562b
--- /dev/null
+++ b/test/UnitTests/KeyEncoder/NatsCacheKeyEncoderTest.cs
@@ -0,0 +1,50 @@
+using System.Text.RegularExpressions;
+
+namespace CodeCargo.Nats.DistributedCache.UnitTests.KeyEncoder;
+
+public partial class NatsCacheKeyEncoderTest
+{
+ private readonly NatsCacheKeyEncoder _encoder = new();
+
+ [Theory]
+ [InlineData("orders.pending", "orders.pending")] // simple, already-legal ASCII
+ [InlineData(".leading", "=2Eleading")] // leading dot
+ [InlineData("trailing.", "trailing=2E")] // trailing dot
+ [InlineData(".both.", "=2Eboth=2E")] // leading + trailing dots
+ [InlineData("naïve.café", "na=C3=AFve.caf=C3=A9")] // Unicode with accented chars
+ [InlineData("spaces and #/+/!", "spaces=20and=20=23=2F=2B=2F=21")] // spaces + disallowed ASCII
+ [InlineData("emoji😀key", "emoji=F0=9F=98=80key")] // emoji inside key
+ [InlineData("prod.release-v1.*", "prod.release-v1.=2A")] // wildcard asterisk (disallowed ASCII)
+ [InlineData("~tilde~", "=7Etilde=7E")] // tilde is encoded
+ [InlineData("=equal=", "=3Dequal=3D")] // equal is encoded
+ [InlineData("....", "=2E..=2E")] // only leading + trailing dots encoded
+ public void EncodeDecode_RoundTrips_ValidKeys(string rawKey, string encodedKey)
+ {
+ // encode
+ var encoded = _encoder.Encode(rawKey);
+
+ // must match the KV character whitelist
+ Assert.True(ValidEncodedKey(encoded), "Encoded key must contain only allowed characters");
+
+ // must not start / end with a dot
+ Assert.False(encoded.StartsWith('.'), "Encoded key must not start with '.'");
+ Assert.False(encoded.EndsWith('.'), "Encoded key must not end with '.'");
+
+ // check encoded
+ Assert.Equal(encodedKey, encoded);
+
+ // decode
+ var decoded = _encoder.Decode(encoded);
+
+ // check decoded
+ Assert.Equal(rawKey, decoded);
+ }
+
+ private static bool ValidEncodedKey(string rawKey) =>
+ !rawKey.StartsWith('.')
+ && !rawKey.EndsWith('.')
+ && ValidEncodedKeyRegex().IsMatch(rawKey);
+
+ [GeneratedRegex("^[-_=.A-Za-z0-9]+$", RegexOptions.Compiled)]
+ private static partial Regex ValidEncodedKeyRegex();
+}