Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/Protocol.zig
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub const Error = error{
NotFound,
ValueTooLarge,
Exists,
InvalidKey,
} || std.Io.Reader.Error || std.Io.Reader.DelimiterError || std.Io.Writer.Error;

/// Returns true if the error is a protocol-level error where the connection
Expand All @@ -38,6 +39,7 @@ pub fn isResumable(err: anyerror) bool {
error.NotFound,
error.Exists,
error.ValueTooLarge,
error.InvalidKey,
=> true,
else => false,
};
Expand Down Expand Up @@ -67,7 +69,15 @@ pub const SetMode = enum {
prepend,
};

fn validateKey(key: []const u8) Error!void {
if (key.len == 0 or key.len > 250) return error.InvalidKey;
for (key) |c| {
if (c <= ' ' or c == 0x7f) return error.InvalidKey;
}
}

pub fn get(self: Protocol, key: []const u8, buf: []u8, opts: GetOpts) Error!?Info {
try validateKey(key);
// Build command: mg <key> v f c [T<ttl>]
try self.writer.print("mg {s} v f c", .{key});
if (opts.ttl) |ttl| {
Expand Down Expand Up @@ -100,6 +110,7 @@ pub fn get(self: Protocol, key: []const u8, buf: []u8, opts: GetOpts) Error!?Inf
// --- Set ---

pub fn set(self: Protocol, key: []const u8, value: []const u8, opts: SetOpts, mode: SetMode) Error!void {
try validateKey(key);
// Build command: ms <key> <size> [flags]
try self.writer.print("ms {s} {d}", .{ key, value.len });

Expand Down Expand Up @@ -145,6 +156,7 @@ pub fn set(self: Protocol, key: []const u8, value: []const u8, opts: SetOpts, mo
// --- Delete ---

pub fn delete(self: Protocol, key: []const u8) Error!void {
try validateKey(key);
try self.writer.print("md {s}\r\n", .{key});
try self.writer.flush();

Expand All @@ -169,6 +181,7 @@ pub fn decr(self: Protocol, key: []const u8, delta: u64) Error!u64 {
}

fn arithmetic(self: Protocol, key: []const u8, delta: u64, is_decr: bool) Error!u64 {
try validateKey(key);
// ma <key> v D<delta> [MD]
try self.writer.print("ma {s} v D{d}", .{ key, delta });
if (is_decr) {
Expand Down Expand Up @@ -197,6 +210,7 @@ fn arithmetic(self: Protocol, key: []const u8, delta: u64, is_decr: bool) Error!
// --- Touch ---

pub fn touch(self: Protocol, key: []const u8, ttl: u32) Error!void {
try validateKey(key);
try self.writer.print("mg {s} T{d}\r\n", .{ key, ttl });
try self.writer.flush();

Expand Down Expand Up @@ -355,3 +369,18 @@ test "parseResponse EX" {
const resp = try parseResponse("EX");
try std.testing.expectEqual(Response.exists, resp);
}

test "validateKey" {
try validateKey("valid_key");
try validateKey("a");
try validateKey("x" ** 250);

try std.testing.expectError(error.InvalidKey, validateKey(""));
try std.testing.expectError(error.InvalidKey, validateKey("x" ** 251));
try std.testing.expectError(error.InvalidKey, validateKey("has space"));
try std.testing.expectError(error.InvalidKey, validateKey("has\ttab"));
try std.testing.expectError(error.InvalidKey, validateKey("has\nnewline"));
try std.testing.expectError(error.InvalidKey, validateKey("has\r\nreturn"));
try std.testing.expectError(error.InvalidKey, validateKey("has\x00null"));
try std.testing.expectError(error.InvalidKey, validateKey("has\x7fdelete"));
}