diff --git a/src/Protocol.zig b/src/Protocol.zig index 4ffc6ec..83dedca 100644 --- a/src/Protocol.zig +++ b/src/Protocol.zig @@ -28,6 +28,7 @@ pub const Error = error{ NotFound, ValueTooLarge, Exists, + InvalidKey, } || std.Io.Reader.Error || std.Io.Reader.DelimiterError || std.Io.Writer.Error; /// Returns true if the error is a protocol-level error where the connection @@ -38,6 +39,7 @@ pub fn isResumable(err: anyerror) bool { error.NotFound, error.Exists, error.ValueTooLarge, + error.InvalidKey, => true, else => false, }; @@ -67,7 +69,15 @@ pub const SetMode = enum { prepend, }; +fn validateKey(key: []const u8) Error!void { + if (key.len == 0 or key.len > 250) return error.InvalidKey; + for (key) |c| { + if (c <= ' ' or c == 0x7f) return error.InvalidKey; + } +} + pub fn get(self: Protocol, key: []const u8, buf: []u8, opts: GetOpts) Error!?Info { + try validateKey(key); // Build command: mg v f c [T] try self.writer.print("mg {s} v f c", .{key}); if (opts.ttl) |ttl| { @@ -100,6 +110,7 @@ pub fn get(self: Protocol, key: []const u8, buf: []u8, opts: GetOpts) Error!?Inf // --- Set --- pub fn set(self: Protocol, key: []const u8, value: []const u8, opts: SetOpts, mode: SetMode) Error!void { + try validateKey(key); // Build command: ms [flags] try self.writer.print("ms {s} {d}", .{ key, value.len }); @@ -145,6 +156,7 @@ pub fn set(self: Protocol, key: []const u8, value: []const u8, opts: SetOpts, mo // --- Delete --- pub fn delete(self: Protocol, key: []const u8) Error!void { + try validateKey(key); try self.writer.print("md {s}\r\n", .{key}); try self.writer.flush(); @@ -169,6 +181,7 @@ pub fn decr(self: Protocol, key: []const u8, delta: u64) Error!u64 { } fn arithmetic(self: Protocol, key: []const u8, delta: u64, is_decr: bool) Error!u64 { + try validateKey(key); // ma v D [MD] try self.writer.print("ma {s} v D{d}", .{ key, delta }); if (is_decr) { @@ -197,6 +210,7 @@ fn arithmetic(self: Protocol, key: []const u8, delta: u64, is_decr: bool) Error! // --- Touch --- pub fn touch(self: Protocol, key: []const u8, ttl: u32) Error!void { + try validateKey(key); try self.writer.print("mg {s} T{d}\r\n", .{ key, ttl }); try self.writer.flush(); @@ -355,3 +369,18 @@ test "parseResponse EX" { const resp = try parseResponse("EX"); try std.testing.expectEqual(Response.exists, resp); } + +test "validateKey" { + try validateKey("valid_key"); + try validateKey("a"); + try validateKey("x" ** 250); + + try std.testing.expectError(error.InvalidKey, validateKey("")); + try std.testing.expectError(error.InvalidKey, validateKey("x" ** 251)); + try std.testing.expectError(error.InvalidKey, validateKey("has space")); + try std.testing.expectError(error.InvalidKey, validateKey("has\ttab")); + try std.testing.expectError(error.InvalidKey, validateKey("has\nnewline")); + try std.testing.expectError(error.InvalidKey, validateKey("has\r\nreturn")); + try std.testing.expectError(error.InvalidKey, validateKey("has\x00null")); + try std.testing.expectError(error.InvalidKey, validateKey("has\x7fdelete")); +}