diff --git a/src/NatsDistributedCache/INatsCacheKeyEncoder.cs b/src/NatsDistributedCache/INatsCacheKeyEncoder.cs new file mode 100644 index 0000000..a192014 --- /dev/null +++ b/src/NatsDistributedCache/INatsCacheKeyEncoder.cs @@ -0,0 +1,14 @@ +namespace CodeCargo.Nats.DistributedCache; + +/// +/// Encodes raw strings so they satisfy the NATS KV key rules. +/// +public interface INatsCacheKeyEncoder +{ + /// + /// Encodes a raw string into a KV-legal key + /// + /// The raw string to encode + /// A KV-legal key + string Encode(string raw); +} diff --git a/src/NatsDistributedCache/NatsCache.Log.cs b/src/NatsDistributedCache/NatsCache.Log.cs index da23903..0ffe146 100644 --- a/src/NatsDistributedCache/NatsCache.Log.cs +++ b/src/NatsDistributedCache/NatsCache.Log.cs @@ -4,21 +4,15 @@ namespace CodeCargo.Nats.DistributedCache; public partial class NatsCache { - private void LogException(Exception exception) => - _logger.LogError(EventIds.Exception, exception, "Exception in NatsDistributedCache"); - private void LogConnected(string bucketName) => _logger.LogInformation(EventIds.Connected, "Connected to NATS KV bucket {bucketName}", bucketName); - private void LogUpdateFailed(string key) => _logger.LogDebug( - EventIds.UpdateFailed, - "Sliding expiration update failed for key {Key} due to optimistic concurrency control", - key); + private void LogException(Exception exception) => + _logger.LogError(EventIds.Exception, exception, "Exception in NatsDistributedCache"); private static class EventIds { public static readonly EventId Connected = new(100, nameof(Connected)); - public static readonly EventId UpdateFailed = new(101, nameof(UpdateFailed)); - public static readonly EventId Exception = new(102, nameof(Exception)); + public static readonly EventId Exception = new(101, nameof(Exception)); } } diff --git a/src/NatsDistributedCache/NatsCache.cs b/src/NatsDistributedCache/NatsCache.cs index f0b54dc..d737b34 100644 --- a/src/NatsDistributedCache/NatsCache.cs +++ b/src/NatsDistributedCache/NatsCache.cs @@ -43,6 +43,7 @@ public partial class NatsCache : IBufferDistributedCache new(CacheEntryJsonContext.Default); private readonly string _bucketName; + private readonly INatsCacheKeyEncoder _keyEncoder; private readonly string _keyPrefix; private readonly ILogger _logger; private readonly INatsConnection _natsConnection; @@ -50,8 +51,9 @@ public partial class NatsCache : IBufferDistributedCache public NatsCache( IOptions optionsAccessor, - ILogger logger, - INatsConnection natsConnection) + INatsConnection natsConnection, + ILogger? logger = null, + INatsCacheKeyEncoder? keyEncoder = null) { var options = optionsAccessor.Value; _bucketName = !string.IsNullOrWhiteSpace(options.BucketName) @@ -61,13 +63,9 @@ public NatsCache( ? string.Empty : options.CacheKeyPrefix.TrimEnd('.'); _lazyKvStore = CreateLazyKvStore(); - _logger = logger; _natsConnection = natsConnection; - } - - public NatsCache(IOptions optionsAccessor, INatsConnection natsConnection) - : this(optionsAccessor, NullLogger.Instance, natsConnection) - { + _logger = logger ?? NullLogger.Instance; + _keyEncoder = keyEncoder ?? new NatsCacheKeyEncoder(); } /// @@ -88,7 +86,8 @@ public async Task SetAsync( try { // todo: remove cast after https://github.com/nats-io/nats.net/pull/852 is released - await ((NatsKVStore)kvStore).PutAsync(GetPrefixedKey(key), entry, ttl ?? TimeSpan.Zero, CacheEntrySerializer, token) + await ((NatsKVStore)kvStore) + .PutAsync(GetEncodedKey(key), entry, ttl ?? TimeSpan.Zero, CacheEntrySerializer, token) .ConfigureAwait(false); } catch (Exception ex) @@ -125,14 +124,14 @@ public async Task RemoveAsync(string key, CancellationToken token = default) => /// public Task RefreshAsync(string key, CancellationToken token = default) => - GetAndRefreshAsync(key, getData: false, retry: true, token: token); + GetAndRefreshAsync(key, token: token); /// public byte[]? Get(string key) => GetAsync(key).GetAwaiter().GetResult(); /// public Task GetAsync(string key, CancellationToken token = default) => - GetAndRefreshAsync(key, getData: true, retry: true, token: token); + GetAndRefreshAsync(key, token: token); /// public bool TryGet(string key, IBufferWriter destination) => @@ -161,9 +160,6 @@ public async ValueTask TryGetAsync( return false; } - // This is the method used by hybrid caching to determine if it should use the distributed instance - internal virtual bool IsHybridCacheActive() => false; - private static TimeSpan? GetTtl(DistributedCacheEntryOptions options) { if (options.AbsoluteExpiration.HasValue && options.AbsoluteExpiration.Value <= DateTimeOffset.Now) @@ -233,9 +229,10 @@ private static CacheEntry CreateCacheEntry(byte[] value, DistributedCacheEntryOp return cacheEntry; } - private string GetPrefixedKey(string key) => string.IsNullOrEmpty(_keyPrefix) - ? key - : _keyPrefix + "." + key; + private string GetEncodedKey(string key) => + string.IsNullOrEmpty(_keyPrefix) + ? _keyEncoder.Encode(key) + : _keyEncoder.Encode($"{_keyPrefix}.{key}"); private Lazy> CreateLazyKvStore() => new(async () => @@ -259,18 +256,14 @@ private Lazy> CreateLazyKvStore() => private Task GetKvStore() => _lazyKvStore.Value; - private async Task GetAndRefreshAsync( - string key, - bool getData, - bool retry, - CancellationToken token = default) + private async Task GetAndRefreshAsync(string key, CancellationToken token) { + var encodedKey = GetEncodedKey(key); var kvStore = await GetKvStore().ConfigureAwait(false); - var prefixedKey = GetPrefixedKey(key); try { var natsResult = await kvStore - .TryGetEntryAsync(prefixedKey, serializer: CacheEntrySerializer, cancellationToken: token) + .TryGetEntryAsync(encodedKey, serializer: CacheEntrySerializer, cancellationToken: token) .ConfigureAwait(false); if (!natsResult.Success) { @@ -292,19 +285,12 @@ private Lazy> CreateLazyKvStore() => return null; } - await UpdateEntryExpirationAsync(kvStore, prefixedKey, kvEntry, token).ConfigureAwait(false); - return getData ? kvEntry.Value.Data : null; + await UpdateEntryExpirationAsync(kvEntry).ConfigureAwait(false); + return kvEntry.Value.Data; } - catch (NatsKVWrongLastRevisionException ex) + catch (NatsKVWrongLastRevisionException) { - // Optimistic concurrency control failed, someone else updated it - LogUpdateFailed(key); - if (retry) - { - return await GetAndRefreshAsync(key, getData, retry: false, token).ConfigureAwait(false); - } - - LogException(ex); + // Someone else updated it; that's fine, we'll get the latest version next time return null; } catch (Exception ex) @@ -312,63 +298,52 @@ private Lazy> CreateLazyKvStore() => LogException(ex); throw; } - } - private async Task UpdateEntryExpirationAsync( - INatsKVStore kvStore, - string key, - NatsKVEntry kvEntry, - CancellationToken token) - { - if (kvEntry.Value?.SlidingExpirationTicks == null) + // Local Functions + async Task UpdateEntryExpirationAsync(NatsKVEntry kvEntry) { - return; - } + if (kvEntry.Value?.SlidingExpirationTicks == null) + { + return; + } - // If we have a sliding expiration, use it as the TTL - var ttl = TimeSpan.FromTicks(kvEntry.Value.SlidingExpirationTicks.Value); + // If we have a sliding expiration, use it as the TTL + var ttl = TimeSpan.FromTicks(kvEntry.Value.SlidingExpirationTicks.Value); - // If we also have an absolute expiration, make sure we don't exceed it - if (kvEntry.Value.AbsoluteExpiration != null) - { - var remainingTime = kvEntry.Value.AbsoluteExpiration.Value - DateTimeOffset.Now; - - // Use the minimum of sliding window or remaining absolute time - if (remainingTime > TimeSpan.Zero && remainingTime < ttl) + // If we also have an absolute expiration, make sure we don't exceed it + if (kvEntry.Value.AbsoluteExpiration != null) { - ttl = remainingTime; + var remainingTime = kvEntry.Value.AbsoluteExpiration.Value - DateTimeOffset.Now; + + // Use the minimum of sliding window or remaining absolute time + if (remainingTime > TimeSpan.Zero && remainingTime < ttl) + { + ttl = remainingTime; + } } - } - if (ttl > TimeSpan.Zero) - { - // Use optimistic concurrency control with the last revision - try + if (ttl > TimeSpan.Zero) { + // Use optimistic concurrency control with the last revision // todo: remove cast after https://github.com/nats-io/nats.net/pull/852 is released await ((NatsKVStore)kvStore).UpdateAsync( - key, + encodedKey, kvEntry.Value, kvEntry.Revision, ttl, serializer: CacheEntrySerializer, cancellationToken: token).ConfigureAwait(false); } - catch (NatsKVWrongLastRevisionException) - { - // Someone else updated it; that's fine, we'll get the latest version next time - LogUpdateFailed(key.Replace(GetPrefixedKey(string.Empty), string.Empty)); - } } } private async Task RemoveAsync( - string key, - NatsKVDeleteOpts? natsKvDeleteOpts = null, - CancellationToken token = default) + string key, + NatsKVDeleteOpts? natsKvDeleteOpts = null, + CancellationToken token = default) { var kvStore = await GetKvStore().ConfigureAwait(false); - await kvStore.DeleteAsync(GetPrefixedKey(key), natsKvDeleteOpts, cancellationToken: token) + await kvStore.DeleteAsync(GetEncodedKey(key), natsKvDeleteOpts, cancellationToken: token) .ConfigureAwait(false); } } diff --git a/src/NatsDistributedCache/NatsCacheKeyEncoder.cs b/src/NatsDistributedCache/NatsCacheKeyEncoder.cs new file mode 100644 index 0000000..1032233 --- /dev/null +++ b/src/NatsDistributedCache/NatsCacheKeyEncoder.cs @@ -0,0 +1,70 @@ +using System.Text.RegularExpressions; + +namespace CodeCargo.Nats.DistributedCache; + +/// +/// URL-encoding implementation that keeps already allowed keys +/// untouched and URL-encodes everything else. % characters in the +/// final encoded output are replaced with = characters to conform +/// to the NATS KV key rules. +/// +public sealed partial class NatsCacheKeyEncoder : INatsCacheKeyEncoder +{ + /// + public string Encode(string raw) + { + if (string.IsNullOrEmpty(raw)) + { + throw new ArgumentException("Key must not be null or empty.", nameof(raw)); + } + + if (ValidUnencodedKey(raw)) + { + // already valid + return raw; + } + + var encoded = Uri.EscapeDataString(raw); + encoded = encoded.Replace("~", "%7E"); + if (encoded.StartsWith('.')) + { + encoded = "%2E" + encoded[1..]; + } + + if (encoded.EndsWith('.')) + { + encoded = encoded[..^1] + "%2E"; + } + + encoded = encoded.Replace('%', '='); + return encoded; + } + + /// + public string Decode(string key) + { + if (string.IsNullOrEmpty(key)) + { + throw new ArgumentException("Key must not be null or empty.", nameof(key)); + } + + if (!key.Contains('=')) + { + // nothing to decode + return key; + } + + var decoded = key.Replace('=', '%'); + return Uri.UnescapeDataString(decoded); + } + + private static bool ValidUnencodedKey(string rawKey) => + !rawKey.StartsWith('.') + && !rawKey.EndsWith('.') + && ValidUnencodedKeyRegex().IsMatch(rawKey); + + // Regex pattern to match valid NATS KV keys with = removed, since = + // is used instead of % to mark an encoded character sequence + [GeneratedRegex("^[-_.A-Za-z0-9]+$", RegexOptions.Compiled)] + private static partial Regex ValidUnencodedKeyRegex(); +} diff --git a/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs b/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs index e72f60b..4aca911 100644 --- a/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs +++ b/src/NatsDistributedCache/NatsDistributedCacheExtensions.cs @@ -30,15 +30,13 @@ public static IServiceCollection AddNatsDistributedCache( services.AddSingleton(sp => { var optionsAccessor = sp.GetRequiredService>(); - var logger = sp.GetService>(); - var natsConnection = connectionServiceKey == null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(connectionServiceKey); + var logger = sp.GetService>(); + var keyEncoder = sp.GetService(); - return logger != null - ? new NatsCache(optionsAccessor, logger, natsConnection) - : new NatsCache(optionsAccessor, natsConnection); + return new NatsCache(optionsAccessor, natsConnection, logger: logger, keyEncoder: keyEncoder); }); return services; diff --git a/test/UnitTests/KeyEncoder/NatsCacheKeyEncoderTest.cs b/test/UnitTests/KeyEncoder/NatsCacheKeyEncoderTest.cs new file mode 100644 index 0000000..625562b --- /dev/null +++ b/test/UnitTests/KeyEncoder/NatsCacheKeyEncoderTest.cs @@ -0,0 +1,50 @@ +using System.Text.RegularExpressions; + +namespace CodeCargo.Nats.DistributedCache.UnitTests.KeyEncoder; + +public partial class NatsCacheKeyEncoderTest +{ + private readonly NatsCacheKeyEncoder _encoder = new(); + + [Theory] + [InlineData("orders.pending", "orders.pending")] // simple, already-legal ASCII + [InlineData(".leading", "=2Eleading")] // leading dot + [InlineData("trailing.", "trailing=2E")] // trailing dot + [InlineData(".both.", "=2Eboth=2E")] // leading + trailing dots + [InlineData("naïve.café", "na=C3=AFve.caf=C3=A9")] // Unicode with accented chars + [InlineData("spaces and #/+/!", "spaces=20and=20=23=2F=2B=2F=21")] // spaces + disallowed ASCII + [InlineData("emoji😀key", "emoji=F0=9F=98=80key")] // emoji inside key + [InlineData("prod.release-v1.*", "prod.release-v1.=2A")] // wildcard asterisk (disallowed ASCII) + [InlineData("~tilde~", "=7Etilde=7E")] // tilde is encoded + [InlineData("=equal=", "=3Dequal=3D")] // equal is encoded + [InlineData("....", "=2E..=2E")] // only leading + trailing dots encoded + public void EncodeDecode_RoundTrips_ValidKeys(string rawKey, string encodedKey) + { + // encode + var encoded = _encoder.Encode(rawKey); + + // must match the KV character whitelist + Assert.True(ValidEncodedKey(encoded), "Encoded key must contain only allowed characters"); + + // must not start / end with a dot + Assert.False(encoded.StartsWith('.'), "Encoded key must not start with '.'"); + Assert.False(encoded.EndsWith('.'), "Encoded key must not end with '.'"); + + // check encoded + Assert.Equal(encodedKey, encoded); + + // decode + var decoded = _encoder.Decode(encoded); + + // check decoded + Assert.Equal(rawKey, decoded); + } + + private static bool ValidEncodedKey(string rawKey) => + !rawKey.StartsWith('.') + && !rawKey.EndsWith('.') + && ValidEncodedKeyRegex().IsMatch(rawKey); + + [GeneratedRegex("^[-_=.A-Za-z0-9]+$", RegexOptions.Compiled)] + private static partial Regex ValidEncodedKeyRegex(); +}