diff --git a/CHANGELOG.md b/CHANGELOG.md index ebc8fa0..c1ef885 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ # 0.28.2 2026-02-20 -- Fix `Stream#get_string` +- Fix `Stream#read` # 0.28.1 2026-02-20 @@ -169,9 +169,9 @@ # 2025-06-03 Version 0.12 -- Add buffer, maxlen params to `Stream#get_line` -- Add buffer param to `Stream#get_string` -- Remove `Stream#resp_get_line`, `Stream#resp_get_string` methods +- Add buffer, maxlen params to `Stream#read_line` +- Add buffer param to `Stream#read` +- Remove `Stream#resp_read_line`, `Stream#resp_read` methods # 2025-06-02 Version 0.11.1 diff --git a/README.md b/README.md index c985750..9d1e93c 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ implementation that allows integration with the entire Ruby ecosystem. - Excellent performance characteristics for concurrent I/O-bound applications. - `Fiber::Scheduler` implementation to automatically integrate with the Ruby ecosystem in a transparent fashion. -- Read streams with automatic buffer management. +- [Connection](#connections) class with automatic buffer management for reading. - Optimized I/O for encrypted SSL connections. ## Design @@ -286,64 +286,70 @@ fiber = Fiber.schedule do end ``` -## Read Streams +## Connections -A UringMachine stream is used to efficiently read from a socket or other file -descriptor. Streams are ideal for implementing the read side of protocols, and -provide an API that is useful for both line-based protocols and binary -(frame-based) protocols. +`UringMachine::Connection` is a class designed for efficiently read from and +write to a socket or other file descriptor. Connections are ideal for +implementing the read side of protocols, and provide an API that is useful for +both line-based protocols and binary (frame-based) protocols. -A stream is associated with a UringMachine instance and a target file descriptor -(see also [stream modes](#stream-modes) below). Behind the scenes, streams take -advantage of io_uring's registered buffers feature, and more recently, the -introduction of [incremental buffer +A connection is associated with a UringMachine instance and a target file +descriptor (or SSL socket, see also [connection modes](#connection-modes) +below). Behind the scenes, connections take advantage of io_uring's registered +buffers feature, and more recently, the introduction of [incremental buffer consumption](https://github.com/axboe/liburing/wiki/What's-new-with-io_uring-in-6.11-and-6.12#incremental-provided-buffer-consumption). -When streams are used, UringMachine automatically manages the buffers it +When connections are used, UringMachine automatically manages the buffers it provides to the kernel, maximizing buffer reuse and minimizing allocations. UringMachine also responds to stress conditions (increased incoming traffic) by automatically provisioning additional buffers. -To create a stream for a given fd, use `UM#stream`: +To create a connection for a given fd, use `UM#connection`: ```ruby -stream = machine.stream(fd) +conn = machine.connection(fd) -# you can also provide a block that will be passed the stream instance: -machine.stream(fd) { |s| do_something_with(s) } +# you can also provide a block that will be passed the connection instance: +machine.connection(fd) { |c| do_something_with(c) } -# you can also instantiate a stream directly: -stream = UM::Stream.new(machine, fd) +# you can also instantiate a connection directly: +conn = UM::Connection.new(machine, fd) ``` -The following API is used to interact with the stream: +The following API is used to interact with the connection: ```ruby # Read until a newline character is encountered: -line = stream.get_line(0) +line = conn.read_line(0) # Read line with a maximum length of 13 bytes: -line = stream.get_line(13) +line = conn.read_line(13) # Read all data: -buf = stream.get_string(0) +buf = conn.read(0) # Read exactly 13 bytes: -buf = stream.get_string(13) +buf = conn.read(13) # Read up to 13 bytes: -buf = stream.get_string(-13) +buf = conn.read(-13) + +# Read continuously until EOF +conn.read_each { |data| ... } # Skip 3 bytes: -stream.skip(3) +conn.skip(3) + +# Write +conn.write('foo', 'bar', 'baz') ``` Here's an example of a how a basic HTTP request parser might be implemented -using a stream: +using a connection: ```ruby -def parse_http_request_headers(stream) - request_line = stream.get_line(0) +def parse_http_request_headers(conn) + request_line = conn.read_line(0) m = request_line.match(REQUEST_LINE_RE) return nil if !m @@ -354,7 +360,7 @@ def parse_http_request_headers(stream) } while true - line = stream.get_line(0) + line = conn.read_line(0) break if !line || line.empty? m = line.match(HEADER_RE) @@ -364,24 +370,26 @@ def parse_http_request_headers(stream) end ``` -### Stream modes +### Connection modes -Stream modes allow streams to be transport agnostic. Currently streams support -three modes: +Connection modes allow connections to be transport agnostic. Currently +connections support three modes: -- `:bp_read` - use the buffer pool, read data using multishot read +- `:fd` - use the buffer pool, read data using multishot read (this is the default mode). -- `:bp_recv` - use the buffer pool, read data using multishot recv. +- `:socket` - use the buffer pool, read data using multishot recv. - `:ssl` - read from an `SSLSocket` object. -The mode is specified as an additional argument to `Stream.new`: +The mode is specified as an additional argument to `Connection.new`: ```ruby -# stream using recv: -stream = machine.stream(fd, :bp_recv) +# using recv/send: +conn = machine.connection(fd, :socket) -# stream on an SSL socket: -stream = machine.stream(ssl, :ssl) +# SSL I/O: +conn = machine.connection(ssl, :ssl) +# or simply: +conn = machine.connection(ssl) ``` ## Performance diff --git a/TODO.md b/TODO.md index 0d7fa68..947862a 100644 --- a/TODO.md +++ b/TODO.md @@ -17,39 +17,44 @@ What if instead of `Stream` we had something called `Link`, which serves for both reading and writing: ```ruby -link = machine.link(fd) -while l = link.read_line - link.write(l, '\n') +conn = machine.connection(fd) +while l = conn.read_line + conn.write(l, '\n') end # or: -buf = link.read(42) +buf = conn.read(42) ``` RESP: ```ruby -link.resp_write(['foo', 'bar']) -reply = link.resp_read +conn.resp_write(['foo', 'bar']) +reply = conn.resp_read ``` HTTP: ```ruby -r = link.http_read_request -link.http_write_response({ ':status' => 200 }, 'foo') +r = conn.http_read_request +conn.http_write_response({ ':status' => 200 }, 'foo') # or: -link.http_write_request({ ':method' => 'GET', ':path' => '/foo' }, nil) +conn.http_write_request({ ':method' => 'GET', ':path' => '/foo' }, nil) ``` Plan of action: - Rename methods: - - rename `#get_line` to `#read_line` - - rename `#get_string` to `#read` - - rename `#get_to_delim` to `#read_to_delim` - - rename `#resp_decode` to `#resp_read` -- Rename `Stream` to `Link` + - [v] rename `#read_line` to `#read_line` + - [v] rename `#read` to `#read` + - [v] rename `#read_to_delim` to `#read_to_delim` + - [v] rename `#each` to `#read_each` + - [v] rename `#resp_decode` to `#resp_read` +- Rename modes: + - [v] :fd to :fd + - [v] :socket to :socket + - [v] auto detect SSL +- Rename `Stream` to `Connection` - Add methods: - `#write(*bufs)` - `#resp_write(obj)` diff --git a/benchmark/gets.rb b/benchmark/gets.rb index f1da956..9d2b913 100644 --- a/benchmark/gets.rb +++ b/benchmark/gets.rb @@ -34,16 +34,16 @@ def um_read end end -@fd_stream = @machine.open('/dev/random', UM::O_RDONLY) -@stream = UM::Stream.new(@machine, @fd_stream) -def um_stream_get_line - @stream.get_line(0) +@fd_connection = @machine.open('/dev/random', UM::O_RDONLY) +@conn = UM::Connection.new(@machine, @fd_connection) +def um_connection_read_line + @conn.read_line(0) end Benchmark.ips do |x| - x.report('IO#gets') { io_gets } - x.report('UM#read+buf') { um_read } - x.report('UM::Stream') { um_stream_get_line } + x.report('IO#gets') { io_gets } + x.report('UM#read+buf') { um_read } + x.report('UM::Connection') { um_connection_read_line } x.compare!(order: :baseline) end diff --git a/benchmark/gets_concurrent.rb b/benchmark/gets_concurrent.rb index 77a967d..3b38b0d 100644 --- a/benchmark/gets_concurrent.rb +++ b/benchmark/gets_concurrent.rb @@ -83,40 +83,40 @@ def um_read stop_server end -@total_stream = 0 -def um_stream_do +@total_connection = 0 +def um_connection_do # fd = @machine.open('/dev/random', UM::O_RDONLY) fd = @machine.socket(UM::AF_INET, UM::SOCK_STREAM, 0, 0) @machine.connect(fd, '127.0.0.1', 1234) - stream = UM::Stream.new(@machine, fd) - N.times { @total_stream += stream.get_line(0)&.bytesize || 0 } + conn = UM::Connection.new(@machine, fd) + N.times { @total_connection += conn.read_line(0)&.bytesize || 0 } rescue => e p e p e.backtrace ensure - stream.clear + conn.clear @machine.close(fd) end -def um_stream +def um_connection start_server ff = C.times.map { @machine.snooze - @machine.spin { um_stream_do } + @machine.spin { um_connection_do } } @machine.await(ff) - pp total: @total_stream + pp total: @total_connection ensure stop_server end p(C:, N:) -um_stream +um_connection pp @machine.metrics exit Benchmark.bm do it.report('Thread/IO#gets') { io_gets } it.report('Fiber/UM#read+buf') { um_read } - it.report('Fiber/UM::Stream') { um_stream } + it.report('Fiber/UM::Stream') { um_connection } end diff --git a/benchmark/http_parse.rb b/benchmark/http_parse.rb index 480ea6c..84b5dbb 100644 --- a/benchmark/http_parse.rb +++ b/benchmark/http_parse.rb @@ -65,7 +65,7 @@ def parse_http_parser RE_REQUEST_LINE = /^([a-z]+)\s+([^\s]+)\s+(http\/1\.1)/i RE_HEADER_LINE = /^([a-z0-9\-]+)\:\s+(.+)/i -def get_line(fd, sio, buffer) +def read_line(fd, sio, buffer) while true line = sio.gets(chomp: true) return line if line @@ -76,7 +76,7 @@ def get_line(fd, sio, buffer) end def get_request_line(fd, sio, buffer) - line = get_line(fd, sio, buffer) + line = read_line(fd, sio, buffer) m = line.match(RE_REQUEST_LINE) return nil if !m @@ -96,7 +96,7 @@ def parse_headers(fd) return nil if !headers while true - line = get_line(fd, sio, buffer) + line = read_line(fd, sio, buffer) break if line.empty? m = line.match(RE_HEADER_LINE) @@ -129,15 +129,15 @@ def parse_http_stringio ($machine.close(wfd) rescue nil) if wfd end -def stream_parse_headers(fd) - stream = UM::Stream.new($machine, fd) +def connection_parse_headers(fd) + conn = UM::Connection.new($machine, fd) buf = String.new(capacity: 65536) - headers = stream_get_request_line(stream, buf) + headers = connection_get_request_line(conn, buf) return nil if !headers while true - line = stream.get_line(0) + line = conn.read_line(0) break if line.empty? m = line.match(RE_HEADER_LINE) @@ -149,8 +149,8 @@ def stream_parse_headers(fd) headers end -def stream_get_request_line(stream, buf) - line = stream.get_line(0) +def connection_get_request_line(conn, buf) + line = conn.read_line(0) m = line.match(RE_REQUEST_LINE) return nil if !m @@ -162,12 +162,12 @@ def stream_get_request_line(stream, buf) } end -def parse_http_stream +def parse_http_connection rfd, wfd = UM.pipe queue = UM::Queue.new $machine.spin do - headers = stream_parse_headers(rfd) + headers = connection_parse_headers(rfd) $machine.push(queue, headers) rescue Exception => e p e @@ -188,7 +188,7 @@ def compare_allocs p( alloc_http_parser: alloc_count { x.times { parse_http_parser } }, alloc_stringio: alloc_count { x.times { parse_http_stringio } }, - alloc_stream: alloc_count { x.times { parse_http_stream } } + alloc_connection: alloc_count { x.times { parse_http_connection } } ) ensure GC.enable @@ -213,8 +213,8 @@ def benchmark x.config(:time => 5, :warmup => 3) x.report("http_parser") { parse_http_parser } - x.report("stringio") { parse_http_stringio } - x.report("stream") { parse_http_stream } + x.report("stringio") { parse_http_stringio } + x.report("connection") { parse_http_connection } x.compare! end diff --git a/benchmark/http_server_accept_queue.rb b/benchmark/http_server_accept_queue.rb index 48b87e1..881f17a 100644 --- a/benchmark/http_server_accept_queue.rb +++ b/benchmark/http_server_accept_queue.rb @@ -12,8 +12,8 @@ RE_REQUEST_LINE = /^([a-z]+)\s+([^\s]+)\s+(http\/[0-9\.]{1,3})/i RE_HEADER_LINE = /^([a-z0-9\-]+)\:\s+(.+)/i -def stream_get_request_line(stream, buf) - line = stream.get_line(buf, 0) +def connection_get_request_line(conn, buf) + line = conn.read_line(0) m = line&.match(RE_REQUEST_LINE) return nil if !m @@ -26,12 +26,12 @@ def stream_get_request_line(stream, buf) class InvalidHeadersError < StandardError; end -def get_headers(stream, buf) - headers = stream_get_request_line(stream, buf) +def get_headers(conn, buf) + headers = connection_get_request_line(conn, buf) return nil if !headers while true - line = stream.get_line(buf, 0) + line = conn.read_line(0) break if line.empty? m = line.match(RE_HEADER_LINE) @@ -51,17 +51,21 @@ def send_response(machine, fd) end def handle_connection(machine, fd) - stream = UM::Stream.new(machine, fd) + conn = UM::Connection.new(machine, fd) buf = String.new(capacity: 65536) while true - headers = get_headers(stream, buf) + headers = get_headers(conn, buf) break if !headers send_response(machine, fd) end rescue InvalidHeadersError, SystemCallError => e # ignore +rescue => e + p e + p e.backtrace + exit! ensure machine.close_async(fd) end diff --git a/benchmark/http_server_multi_accept.rb b/benchmark/http_server_multi_accept.rb index 2956617..69f549c 100644 --- a/benchmark/http_server_multi_accept.rb +++ b/benchmark/http_server_multi_accept.rb @@ -12,8 +12,8 @@ RE_REQUEST_LINE = /^([a-z]+)\s+([^\s]+)\s+(http\/[0-9\.]{1,3})/i RE_HEADER_LINE = /^([a-z0-9\-]+)\:\s+(.+)/i -def stream_get_request_line(stream, buf) - line = stream.get_line(buf, 0) +def connection_get_request_line(conn, buf) + line = conn.read_line(0) m = line&.match(RE_REQUEST_LINE) return nil if !m @@ -26,12 +26,12 @@ def stream_get_request_line(stream, buf) class InvalidHeadersError < StandardError; end -def get_headers(stream, buf) - headers = stream_get_request_line(stream, buf) +def get_headers(conn, buf) + headers = connection_get_request_line(conn, buf) return nil if !headers while true - line = stream.get_line(buf, 0) + line = conn.read_line(0) break if line.empty? m = line.match(RE_HEADER_LINE) @@ -51,11 +51,11 @@ def send_response(machine, fd) end def handle_connection(machine, fd) - stream = UM::Stream.new(machine, fd) + conn = UM::Connection.new(machine, fd) buf = String.new(capacity: 65536) while true - headers = get_headers(stream, buf) + headers = get_headers(conn, buf) break if !headers send_response(machine, fd) diff --git a/benchmark/http_server_multi_ractor.rb b/benchmark/http_server_multi_ractor.rb index d0c43ef..ad14593 100644 --- a/benchmark/http_server_multi_ractor.rb +++ b/benchmark/http_server_multi_ractor.rb @@ -12,8 +12,8 @@ RE_REQUEST_LINE = /^([a-z]+)\s+([^\s]+)\s+(http\/[0-9\.]{1,3})/i RE_HEADER_LINE = /^([a-z0-9\-]+)\:\s+(.+)/i -def stream_get_request_line(stream, buf) - line = stream.get_line(buf, 0) +def connection_get_request_line(conn, buf) + line = conn.read_line(0) m = line&.match(RE_REQUEST_LINE) return nil if !m @@ -26,12 +26,12 @@ def stream_get_request_line(stream, buf) class InvalidHeadersError < StandardError; end -def get_headers(stream, buf) - headers = stream_get_request_line(stream, buf) +def get_headers(conn, buf) + headers = connection_get_request_line(conn, buf) return nil if !headers while true - line = stream.get_line(buf, 0) + line = conn.read_line(0) break if line.empty? m = line.match(RE_HEADER_LINE) @@ -53,11 +53,11 @@ def send_response(machine, fd) def handle_connection(machine, fd) machine.setsockopt(fd, UM::IPPROTO_TCP, UM::TCP_NODELAY, true) - stream = UM::Stream.new(machine, fd) + conn = UM::Connection.new(machine, fd) buf = String.new(capacity: 65536) while true - headers = get_headers(stream, buf) + headers = get_headers(conn, buf) break if !headers send_response(machine, fd) diff --git a/benchmark/http_server_single_thread.rb b/benchmark/http_server_single_thread.rb index f766aec..2aa3ae5 100644 --- a/benchmark/http_server_single_thread.rb +++ b/benchmark/http_server_single_thread.rb @@ -12,8 +12,8 @@ RE_REQUEST_LINE = /^([a-z]+)\s+([^\s]+)\s+(http\/[0-9\.]{1,3})/i RE_HEADER_LINE = /^([a-z0-9\-]+)\:\s+(.+)/i -def stream_get_request_line(stream, buf) - line = stream.get_line(buf, 0) +def connection_get_request_line(conn, buf) + line = conn.read_line(0) m = line&.match(RE_REQUEST_LINE) return nil if !m @@ -26,12 +26,12 @@ def stream_get_request_line(stream, buf) class InvalidHeadersError < StandardError; end -def get_headers(stream, buf) - headers = stream_get_request_line(stream, buf) +def get_headers(conn, buf) + headers = connection_get_request_line(conn, buf) return nil if !headers while true - line = stream.get_line(buf, 0) + line = conn.read_line(0) break if line.empty? m = line.match(RE_HEADER_LINE) @@ -52,11 +52,11 @@ def send_response(machine, fd) def handle_connection(machine, fd) machine.setsockopt(fd, UM::IPPROTO_TCP, UM::TCP_NODELAY, true) - stream = UM::Stream.new(machine, fd) + conn = UM::Connection.new(machine, fd) buf = String.new(capacity: 65536) while true - headers = get_headers(stream, buf) + headers = get_headers(conn, buf) break if !headers send_response(machine, fd) diff --git a/benchmark/openssl.rb b/benchmark/openssl.rb index 4802457..0dfa9e9 100644 --- a/benchmark/openssl.rb +++ b/benchmark/openssl.rb @@ -76,12 +76,12 @@ @um.ssl_set_bio(@ssl_um) @ssl_um.connect -@ssl_stream = OpenSSL::SSL::SSLSocket.open("127.0.0.1", port) -@ssl_stream.sync_close = true -@um.ssl_set_bio(@ssl_stream) -@ssl_stream.connect +@ssl_conn = OpenSSL::SSL::SSLSocket.open("127.0.0.1", port) +@ssl_conn.sync_close = true +@um.ssl_set_bio(@ssl_conn) +@ssl_conn.connect -@stream = @um.stream(@ssl_stream, :ssl) +@conn = @um.connection(@ssl_conn, :ssl) @msg = 'abc' * 1000 @msg_newline = @msg + "\n" @@ -91,15 +91,15 @@ def do_io(ssl) ssl.gets end -def do_io_stream(ssl, um, stream) +def do_io_connection(ssl, um, conn) um.ssl_write(ssl, @msg_newline, 0) - stream.get_line(0) + conn.read_line(0) end Benchmark.ips do |x| x.report('stock') { do_io(@ssl_stock) } x.report('UM BIO') { do_io(@ssl_um) } - x.report('UM Stream') { do_io_stream(@ssl_stream, @um, @stream) } + x.report('UM Stream') { do_io_connection(@ssl_conn, @um, @conn) } x.compare!(order: :baseline) end diff --git a/docs/design/buffer_pool.md b/docs/design/buffer_pool.md index 745ee29..4931520 100644 --- a/docs/design/buffer_pool.md +++ b/docs/design/buffer_pool.md @@ -78,7 +78,7 @@ buffers, to using managed buffers from the buffer pool. ```ruby machine.stream_recv(fd) do |stream| loop do - line = stream.get_line(max: 60) + line = stream.read_line(max: 60) if (size = parse_size(line)) data = stream.read(size) process_data(data) diff --git a/examples/fiber_concurrency_io.rb b/examples/fiber_concurrency_io.rb index 2ae4ad8..eaadbd9 100644 --- a/examples/fiber_concurrency_io.rb +++ b/examples/fiber_concurrency_io.rb @@ -5,7 +5,7 @@ def fiber_switch while true next_fiber, value = @runqueue.shift return next_fiber.transfer value if next_fiber - + process_events end end diff --git a/examples/fiber_concurrency_naive.rb b/examples/fiber_concurrency_naive.rb index ef64b5d..8d091e5 100644 --- a/examples/fiber_concurrency_naive.rb +++ b/examples/fiber_concurrency_naive.rb @@ -15,7 +15,7 @@ @fiber2 = Fiber.new { last = Time.now - + # sleep while Time.now < last + 5 @fiber1.transfer diff --git a/examples/fiber_concurrency_runqueue.rb b/examples/fiber_concurrency_runqueue.rb index 2e86031..b6e237b 100644 --- a/examples/fiber_concurrency_runqueue.rb +++ b/examples/fiber_concurrency_runqueue.rb @@ -22,7 +22,7 @@ def fiber_switch @runqueue << Fiber.new { last = Time.now - + # sleep while Time.now < last + 10 fiber_switch diff --git a/examples/pg.rb b/examples/pg.rb index 543fa8c..34670f3 100644 --- a/examples/pg.rb +++ b/examples/pg.rb @@ -66,9 +66,9 @@ def gets(sep = $/, _limit = nil, _chomp: nil) puts 'Listening on port 1234' def handle_connection(fd) - stream = UM::Stream.new($machine, fd) + conn = UM::Connection.new($machine, fd) - while (l = stream.gets) + while (l = conn.gets) $machine.write(fd, "You said: #{l}") end rescue Exception => e diff --git a/examples/stream.rb b/examples/stream.rb index 543fa8c..1a0a2f8 100644 --- a/examples/stream.rb +++ b/examples/stream.rb @@ -66,9 +66,9 @@ def gets(sep = $/, _limit = nil, _chomp: nil) puts 'Listening on port 1234' def handle_connection(fd) - stream = UM::Stream.new($machine, fd) + conn = UM::Stream.new($machine, fd) - while (l = stream.gets) + while (l = conn.gets) $machine.write(fd, "You said: #{l}") end rescue Exception => e diff --git a/ext/um/um.c b/ext/um/um.c index 7640295..829a320 100644 --- a/ext/um/um.c +++ b/ext/um/um.c @@ -652,11 +652,11 @@ VALUE um_write(struct um *machine, int fd, VALUE buffer, size_t len, __u64 file_ return ret; } -size_t um_write_raw(struct um *machine, int fd, const char *buffer, size_t maxlen) { +size_t um_write_raw(struct um *machine, int fd, const char *buffer, size_t len) { struct um_op *op = um_op_acquire(machine); um_prep_op(machine, op, OP_WRITE, 2, 0); struct io_uring_sqe *sqe = um_get_sqe(machine, op); - io_uring_prep_write(sqe, fd, buffer, maxlen, 0); + io_uring_prep_write(sqe, fd, buffer, len, 0); int res = 0; VALUE ret = um_yield(machine); @@ -831,6 +831,23 @@ VALUE um_send(struct um *machine, int fd, VALUE buffer, size_t len, int flags) { return ret; } +size_t um_send_raw(struct um *machine, int fd, const char *buffer, size_t len, int flags) { + struct um_op *op = um_op_acquire(machine); + um_prep_op(machine, op, OP_SEND, 2, 0); + struct io_uring_sqe *sqe = um_get_sqe(machine, op); + io_uring_prep_send(sqe, fd, buffer, len, flags); + + int res = 0; + VALUE ret = um_yield(machine); + + if (likely(um_verify_op_completion(machine, op, true))) res = op->result.res; + um_op_release(machine, op); + + RAISE_IF_EXCEPTION(ret); + RB_GC_GUARD(ret); + return res; +} + // for some reason we don't get this define from liburing/io_uring.h #define IORING_SEND_VECTORIZED (1U << 5) @@ -1478,7 +1495,7 @@ int read_recv_each_multishot_process_result(struct op_ctx *ctx, struct um_op_res return false; *total += result->res; - VALUE buf = um_get_string_from_buffer_ring(ctx->machine, ctx->bgid, result->res, result->flags); + VALUE buf = um_read_from_buffer_ring(ctx->machine, ctx->bgid, result->res, result->flags); rb_yield(buf); RB_GC_GUARD(buf); diff --git a/ext/um/um.h b/ext/um/um.h index 5740ff5..1f9fb21 100644 --- a/ext/um/um.h +++ b/ext/um/um.h @@ -80,12 +80,12 @@ enum um_op_kind { OP_TIMEOUT_MULTISHOT, }; -enum um_stream_mode { - STREAM_BP_READ, - STREAM_BP_RECV, - STREAM_SSL, - STREAM_STRING, - STREAM_IO_BUFFER +enum um_connection_mode { + CONNECTION_FD, + CONNECTION_SOCKET, + CONNECTION_SSL, + CONNECTION_STRING, + CONNECTION_IO_BUFFER }; #define OP_F_CQE_SEEN (1U << 0) // CQE has been seen @@ -272,11 +272,11 @@ struct um_async_op { struct um_op *op; }; -struct um_stream { +struct um_connection { VALUE self; struct um *machine; - enum um_stream_mode mode; + enum um_connection_mode mode; union { int fd; VALUE target; @@ -304,7 +304,7 @@ extern VALUE eUMError; extern VALUE cMutex; extern VALUE cQueue; extern VALUE cAsyncOp; -extern VALUE eStreamRESPError; +extern VALUE eConnectionRESPError; struct um *um_get_machine(VALUE self); void um_setup(VALUE self, struct um *machine, uint size, uint sqpoll_timeout_msec, int sidecar_mode); @@ -352,7 +352,7 @@ int um_get_buffer_bytes_for_writing(VALUE buffer, const void **base, size_t *siz void * um_prepare_read_buffer(VALUE buffer, ssize_t len, ssize_t ofs); void um_update_read_buffer(VALUE buffer, ssize_t buffer_offset, __s32 result); int um_setup_buffer_ring(struct um *machine, unsigned size, unsigned count); -VALUE um_get_string_from_buffer_ring(struct um *machine, int bgid, __s32 result, __u32 flags); +VALUE um_read_from_buffer_ring(struct um *machine, int bgid, __s32 result, __u32 flags); void um_add_strings_to_buffer_ring(struct um *machine, int bgid, VALUE strings); struct iovec *um_alloc_iovecs_for_writing(int argc, VALUE *argv, size_t *total_len); void um_advance_iovecs_for_writing(struct iovec **ptr, int *len, size_t adv); @@ -378,7 +378,7 @@ VALUE um_read(struct um *machine, int fd, VALUE buffer, size_t maxlen, ssize_t b size_t um_read_raw(struct um *machine, int fd, char *buffer, size_t maxlen); VALUE um_read_each(struct um *machine, int fd, int bgid); VALUE um_write(struct um *machine, int fd, VALUE buffer, size_t len, __u64 file_offset); -size_t um_write_raw(struct um *machine, int fd, const char *buffer, size_t maxlen); +size_t um_write_raw(struct um *machine, int fd, const char *buffer, size_t len); VALUE um_writev(struct um *machine, int fd, int argc, VALUE *argv); VALUE um_write_async(struct um *machine, int fd, VALUE buffer, size_t len, __u64 file_offset); VALUE um_close(struct um *machine, int fd); @@ -403,6 +403,7 @@ VALUE um_accept_into_queue(struct um *machine, int fd, VALUE queue); VALUE um_socket(struct um *machine, int domain, int type, int protocol, uint flags); VALUE um_connect(struct um *machine, int fd, const struct sockaddr *addr, socklen_t addrlen); VALUE um_send(struct um *machine, int fd, VALUE buffer, size_t len, int flags); +size_t um_send_raw(struct um *machine, int fd, const char *buffer, size_t len, int flags); VALUE um_sendv(struct um *machine, int fd, int argc, VALUE *argv); VALUE um_send_bundle(struct um *machine, int fd, int bgid, VALUE strings); VALUE um_recv(struct um *machine, int fd, VALUE buffer, size_t maxlen, int flags); @@ -437,14 +438,16 @@ VALUE um_queue_pop(struct um *machine, struct um_queue *queue); VALUE um_queue_unshift(struct um *machine, struct um_queue *queue, VALUE value); VALUE um_queue_shift(struct um *machine, struct um_queue *queue); -void stream_teardown(struct um_stream *stream); -void stream_clear(struct um_stream *stream); -VALUE stream_get_line(struct um_stream *stream, VALUE out_buffer, size_t maxlen); -VALUE stream_get_string(struct um_stream *stream, VALUE out_buffer, ssize_t len, size_t inc, int safe_inc); -VALUE stream_get_to_delim(struct um_stream *stream, VALUE out_buffer, VALUE delim, ssize_t maxlen); -void stream_skip(struct um_stream *stream, size_t inc, int safe_inc); -void stream_each(struct um_stream *stream); -VALUE resp_decode(struct um_stream *stream, VALUE out_buffer); +void connection_teardown(struct um_connection *conn); +void connection_clear(struct um_connection *conn); +VALUE connection_read_line(struct um_connection *conn, VALUE out_buffer, size_t maxlen); +VALUE connection_read(struct um_connection *conn, VALUE out_buffer, ssize_t len, size_t inc, int safe_inc); +VALUE connection_read_to_delim(struct um_connection *conn, VALUE out_buffer, VALUE delim, ssize_t maxlen); +void connection_skip(struct um_connection *conn, size_t inc, int safe_inc); +void connection_read_each(struct um_connection *conn); +size_t connection_write_raw(struct um_connection *conn, const char *buffer, size_t len); +VALUE connection_writev(struct um_connection *conn, int argc, VALUE *argv); +VALUE resp_read(struct um_connection *conn, VALUE out_buffer); void resp_encode(struct um_write_buffer *buf, VALUE obj); void resp_encode_cmd(struct um_write_buffer *buf, int argc, VALUE *argv); @@ -464,6 +467,8 @@ void um_ssl_set_bio(struct um *machine, VALUE ssl_obj); int um_ssl_read(struct um *machine, VALUE ssl, VALUE buf, size_t maxlen); int um_ssl_read_raw(struct um *machine, VALUE ssl_obj, char *ptr, size_t maxlen); int um_ssl_write(struct um *machine, VALUE ssl, VALUE buf, size_t len); +int um_ssl_write_raw(struct um *machine, VALUE ssl, const char *buffer, size_t len); +int um_ssl_writev(struct um *machine, VALUE ssl, int argc, VALUE *argv); void bp_setup(struct um *machine); void bp_teardown(struct um *machine); diff --git a/ext/um/um_stream.c b/ext/um/um_connection.c similarity index 51% rename from ext/um/um_stream.c rename to ext/um/um_connection.c index 8e1da42..16aca8c 100644 --- a/ext/um/um_stream.c +++ b/ext/um/um_connection.c @@ -2,88 +2,88 @@ #include #include "um.h" -inline void stream_add_segment(struct um_stream *stream, struct um_segment *segment) { +inline void connection_add_segment(struct um_connection *conn, struct um_segment *segment) { segment->next = NULL; - if (stream->tail) { - stream->tail->next = segment; - stream->tail = segment; + if (conn->tail) { + conn->tail->next = segment; + conn->tail = segment; } else - stream->head = stream->tail = segment; - stream->pending_bytes += segment->len; + conn->head = conn->tail = segment; + conn->pending_bytes += segment->len; } -inline int stream_process_op_result(struct um_stream *stream, struct um_op_result *result) { +inline int connection_process_op_result(struct um_connection *conn, struct um_op_result *result) { if (likely(result->res > 0)) { if (likely(result->segment)) { - stream_add_segment(stream, result->segment); + connection_add_segment(conn, result->segment); result->segment = NULL; } } else - stream->eof = 1; + conn->eof = 1; return result->res; } -#define STREAM_OP_FLAGS (OP_F_MULTISHOT | OP_F_BUFFER_POOL) +#define CONNECTION_OP_FLAGS (OP_F_MULTISHOT | OP_F_BUFFER_POOL) -void stream_multishot_op_start(struct um_stream *stream) { - if (!stream->op) - stream->op = um_op_acquire(stream->machine); +void connection_multishot_op_start(struct um_connection *conn) { + if (!conn->op) + conn->op = um_op_acquire(conn->machine); struct io_uring_sqe *sqe; - bp_ensure_commit_level(stream->machine); + bp_ensure_commit_level(conn->machine); - switch (stream->mode) { - case STREAM_BP_READ: - um_prep_op(stream->machine, stream->op, OP_READ_MULTISHOT, 2, STREAM_OP_FLAGS); - sqe = um_get_sqe(stream->machine, stream->op); - io_uring_prep_read_multishot(sqe, stream->fd, 0, -1, BP_BGID); + switch (conn->mode) { + case CONNECTION_FD: + um_prep_op(conn->machine, conn->op, OP_READ_MULTISHOT, 2, CONNECTION_OP_FLAGS); + sqe = um_get_sqe(conn->machine, conn->op); + io_uring_prep_read_multishot(sqe, conn->fd, 0, -1, BP_BGID); break; - case STREAM_BP_RECV: - um_prep_op(stream->machine, stream->op, OP_RECV_MULTISHOT, 2, STREAM_OP_FLAGS); - sqe = um_get_sqe(stream->machine, stream->op); - io_uring_prep_recv_multishot(sqe, stream->fd, NULL, 0, 0); + case CONNECTION_SOCKET: + um_prep_op(conn->machine, conn->op, OP_RECV_MULTISHOT, 2, CONNECTION_OP_FLAGS); + sqe = um_get_sqe(conn->machine, conn->op); + io_uring_prep_recv_multishot(sqe, conn->fd, NULL, 0, 0); sqe->buf_group = BP_BGID; sqe->flags |= IOSQE_BUFFER_SELECT; break; default: um_raise_internal_error("Invalid multishot op"); } - stream->op->bp_commit_level = stream->machine->bp_commit_level; + conn->op->bp_commit_level = conn->machine->bp_commit_level; } -void stream_multishot_op_stop(struct um_stream *stream) { - assert(!stream->op); +void connection_multishot_op_stop(struct um_connection *conn) { + assert(!conn->op); - if (!(stream->op->flags & OP_F_CQE_DONE)) { - stream->op->flags |= OP_F_ASYNC; - um_cancel_op(stream->machine, stream->op); + if (!(conn->op->flags & OP_F_CQE_DONE)) { + conn->op->flags |= OP_F_ASYNC; + um_cancel_op(conn->machine, conn->op); } else - um_op_release(stream->machine, stream->op); - stream->op = NULL; + um_op_release(conn->machine, conn->op); + conn->op = NULL; } -void um_stream_cleanup(struct um_stream *stream) { - if (stream->op) stream_multishot_op_stop(stream); +void um_connection_cleanup(struct um_connection *conn) { + if (conn->op) connection_multishot_op_stop(conn); - while (stream->head) { - struct um_segment *next = stream->head->next; - um_segment_checkin(stream->machine, stream->head); - stream->head = next; + while (conn->head) { + struct um_segment *next = conn->head->next; + um_segment_checkin(conn->machine, conn->head); + conn->head = next; } - stream->pending_bytes = 0; + conn->pending_bytes = 0; } // returns true if case of ENOBUFS error, sets more to true if more data forthcoming -inline int stream_process_segments( - struct um_stream *stream, size_t *total_bytes, int *more) { +inline int connection_process_segments( + struct um_connection *conn, size_t *total_bytes, int *more) { *more = 0; - struct um_op_result *result = &stream->op->result; - stream->op->flags &= ~OP_F_CQE_SEEN; + struct um_op_result *result = &conn->op->result; + conn->op->flags &= ~OP_F_CQE_SEEN; while (result) { if (unlikely(result->res == -ENOBUFS)) { *more = 0; @@ -97,142 +97,142 @@ inline int stream_process_segments( *more = (result->flags & IORING_CQE_F_MORE); *total_bytes += result->res; - stream_process_op_result(stream, result); + connection_process_op_result(conn, result); result = result->next; } return false; } -void stream_clear(struct um_stream *stream) { - if (stream->op && stream->machine->ring_initialized) { - if (OP_CQE_SEEN_P(stream->op)) { +void connection_clear(struct um_connection *conn) { + if (conn->op && conn->machine->ring_initialized) { + if (OP_CQE_SEEN_P(conn->op)) { size_t total_bytes = 0; int more = false; - stream_process_segments(stream, &total_bytes, &more); - um_op_multishot_results_clear(stream->machine, stream->op); + connection_process_segments(conn, &total_bytes, &more); + um_op_multishot_results_clear(conn->machine, conn->op); } - if (OP_CQE_DONE_P(stream->op)) - um_op_release(stream->machine, stream->op); + if (OP_CQE_DONE_P(conn->op)) + um_op_release(conn->machine, conn->op); else - um_cancel_op_and_discard_cqe(stream->machine, stream->op); + um_cancel_op_and_discard_cqe(conn->machine, conn->op); - stream->op = NULL; + conn->op = NULL; } - while (stream->head) { - struct um_segment *next = stream->head->next; - um_segment_checkin(stream->machine, stream->head); - stream->head = next; + while (conn->head) { + struct um_segment *next = conn->head->next; + um_segment_checkin(conn->machine, conn->head); + conn->head = next; } - stream->pending_bytes = 0; + conn->pending_bytes = 0; - if (stream->working_buffer) { - bp_buffer_checkin(stream->machine, stream->working_buffer); - stream->working_buffer = NULL; + if (conn->working_buffer) { + bp_buffer_checkin(conn->machine, conn->working_buffer); + conn->working_buffer = NULL; } } -inline void stream_await_segments(struct um_stream *stream) { - if (unlikely(!stream->op)) stream_multishot_op_start(stream); +inline void connection_await_segments(struct um_connection *conn) { + if (unlikely(!conn->op)) connection_multishot_op_start(conn); - if (!OP_CQE_SEEN_P(stream->op)) { - stream->op->flags &= ~OP_F_ASYNC; - VALUE ret = um_yield(stream->machine); - stream->op->flags |= OP_F_ASYNC; - if (!OP_CQE_SEEN_P(stream->op)) RAISE_IF_EXCEPTION(ret); + if (!OP_CQE_SEEN_P(conn->op)) { + conn->op->flags &= ~OP_F_ASYNC; + VALUE ret = um_yield(conn->machine); + conn->op->flags |= OP_F_ASYNC; + if (!OP_CQE_SEEN_P(conn->op)) RAISE_IF_EXCEPTION(ret); RB_GC_GUARD(ret); } } -int stream_get_more_segments_bp(struct um_stream *stream) { +int connection_get_more_segments_bp(struct um_connection *conn) { size_t total_bytes = 0; int more = false; int enobufs = false; while (1) { - if (unlikely(stream->eof)) return 0; + if (unlikely(conn->eof)) return 0; - stream_await_segments(stream); - enobufs = stream_process_segments(stream, &total_bytes, &more); - um_op_multishot_results_clear(stream->machine, stream->op); + connection_await_segments(conn); + enobufs = connection_process_segments(conn, &total_bytes, &more); + um_op_multishot_results_clear(conn->machine, conn->op); if (unlikely(enobufs)) { - int should_restart = stream->pending_bytes < (stream->machine->bp_buffer_size * 4); - // int same_threshold = stream->op->bp_commit_level == stream->machine->bp_commit_level; + int should_restart = conn->pending_bytes < (conn->machine->bp_buffer_size * 4); + // int same_threshold = conn->op->bp_commit_level == conn->machine->bp_commit_level; // fprintf(stderr, "%p enobufs total: %ld pending: %ld threshold: %ld bc: %d (same: %d, restart: %d)\n", - // stream, - // total_bytes, stream->pending_bytes, stream->machine->bp_commit_level, - // stream->machine->bp_buffer_count, + // conn, + // total_bytes, conn->pending_bytes, conn->machine->bp_commit_level, + // conn->machine->bp_buffer_count, // same_threshold, should_restart // ); - // If multiple stream ops are happening at the same time, they'll all get - // ENOBUFS! We track the commit threshold in the op in order to prevent - // running bp_handle_enobufs() more than once. + // If multiple connection ops are happening at the same time, they'll all + // get ENOBUFS! We track the commit threshold in the op in order to + // prevent running bp_handle_enobufs() more than once. if (should_restart) { - if (stream->op->bp_commit_level == stream->machine->bp_commit_level) - bp_handle_enobufs(stream->machine); + if (conn->op->bp_commit_level == conn->machine->bp_commit_level) + bp_handle_enobufs(conn->machine); - um_op_release(stream->machine, stream->op); - stream->op = NULL; - // stream_multishot_op_start(stream); + um_op_release(conn->machine, conn->op); + conn->op = NULL; + // connection_multishot_op_start(conn); } else { - um_op_release(stream->machine, stream->op); - stream->op = NULL; + um_op_release(conn->machine, conn->op); + conn->op = NULL; } if (total_bytes) return total_bytes; } else { if (more) - stream->op->flags &= ~OP_F_CQE_SEEN; - if (total_bytes || stream->eof) return total_bytes; + conn->op->flags &= ~OP_F_CQE_SEEN; + if (total_bytes || conn->eof) return total_bytes; } } } -int stream_get_more_segments_ssl(struct um_stream *stream) { - if (!stream->working_buffer) - stream->working_buffer = bp_buffer_checkout(stream->machine); +int connection_get_more_segments_ssl(struct um_connection *conn) { + if (!conn->working_buffer) + conn->working_buffer = bp_buffer_checkout(conn->machine); - char *ptr = stream->working_buffer->buf + stream->working_buffer->pos; - size_t maxlen = stream->working_buffer->len - stream->working_buffer->pos; - int res = um_ssl_read_raw(stream->machine, stream->target, ptr, maxlen); + char *ptr = conn->working_buffer->buf + conn->working_buffer->pos; + size_t maxlen = conn->working_buffer->len - conn->working_buffer->pos; + int res = um_ssl_read_raw(conn->machine, conn->target, ptr, maxlen); if (res == 0) return 0; if (res < 0) rb_raise(eUMError, "Failed to read segment"); - struct um_segment *segment = bp_buffer_consume(stream->machine, stream->working_buffer, res); + struct um_segment *segment = bp_buffer_consume(conn->machine, conn->working_buffer, res); if ((size_t)res == maxlen) { - bp_buffer_checkin(stream->machine, stream->working_buffer); - stream->working_buffer = NULL; + bp_buffer_checkin(conn->machine, conn->working_buffer); + conn->working_buffer = NULL; } - stream_add_segment(stream, segment); + connection_add_segment(conn, segment); return 1; } -int stream_get_more_segments(struct um_stream *stream) { - switch (stream->mode) { - case STREAM_BP_READ: - case STREAM_BP_RECV: - return stream_get_more_segments_bp(stream); - case STREAM_SSL: - return stream_get_more_segments_ssl(stream); +int connection_get_more_segments(struct um_connection *conn) { + switch (conn->mode) { + case CONNECTION_FD: + case CONNECTION_SOCKET: + return connection_get_more_segments_bp(conn); + case CONNECTION_SSL: + return connection_get_more_segments_ssl(conn); default: - rb_raise(eUMError, "Invalid stream mode"); + rb_raise(eUMError, "Invalid connection mode"); } } //////////////////////////////////////////////////////////////////////////////// -inline void stream_shift_head(struct um_stream *stream) { - struct um_segment *consumed = stream->head; - stream->head = consumed->next; - if (!stream->head) stream->tail = NULL; - um_segment_checkin(stream->machine, consumed); - stream->pos = 0; +inline void connection_shift_head(struct um_connection *conn) { + struct um_segment *consumed = conn->head; + conn->head = consumed->next; + if (!conn->head) conn->tail = NULL; + um_segment_checkin(conn->machine, consumed); + conn->pos = 0; } inline VALUE make_segment_io_buffer(struct um_segment *segment, size_t pos) { @@ -242,32 +242,32 @@ inline VALUE make_segment_io_buffer(struct um_segment *segment, size_t pos) { ); } -inline void stream_skip(struct um_stream *stream, size_t inc, int safe_inc) { - if (unlikely(stream->eof && !stream->head)) return; - if (safe_inc && !stream->tail && !stream_get_more_segments(stream)) return; +inline void connection_skip(struct um_connection *conn, size_t inc, int safe_inc) { + if (unlikely(conn->eof && !conn->head)) return; + if (safe_inc && !conn->tail && !connection_get_more_segments(conn)) return; while (inc) { - size_t segment_len = stream->head->len - stream->pos; + size_t segment_len = conn->head->len - conn->pos; size_t inc_len = (segment_len <= inc) ? segment_len : inc; inc -= inc_len; - stream->pos += inc_len; - stream->consumed_bytes += inc_len; - stream->pending_bytes -= inc_len; - if (stream->pos == stream->head->len) { - stream_shift_head(stream); - if (inc && safe_inc && !stream->head) { - if (!stream_get_more_segments(stream)) break; + conn->pos += inc_len; + conn->consumed_bytes += inc_len; + conn->pending_bytes -= inc_len; + if (conn->pos == conn->head->len) { + connection_shift_head(conn); + if (inc && safe_inc && !conn->head) { + if (!connection_get_more_segments(conn)) break; } } } } -inline void stream_each(struct um_stream *stream) { - if (unlikely(stream->eof && !stream->head)) return; - if (!stream->tail && !stream_get_more_segments(stream)) return; +inline void connection_read_each(struct um_connection *conn) { + if (unlikely(conn->eof && !conn->head)) return; + if (!conn->tail && !connection_get_more_segments(conn)) return; - struct um_segment *current = stream->head; - size_t pos = stream->pos; + struct um_segment *current = conn->head; + size_t pos = conn->pos; VALUE buffer = Qnil; while (true) { @@ -275,34 +275,34 @@ inline void stream_each(struct um_stream *stream) { buffer = make_segment_io_buffer(current, pos); rb_yield(buffer); rb_io_buffer_free_locked(buffer); - stream_shift_head(stream); + connection_shift_head(conn); if (!next) { - if (!stream_get_more_segments(stream)) return; + if (!connection_get_more_segments(conn)) return; } - current = stream->head; + current = conn->head; pos = 0; } RB_GC_GUARD(buffer); } -inline void stream_copy(struct um_stream *stream, char *dest, size_t len) { +inline void connection_copy(struct um_connection *conn, char *dest, size_t len) { while (len) { - char *segment_ptr = stream->head->ptr + stream->pos; - size_t segment_len = stream->head->len - stream->pos; + char *segment_ptr = conn->head->ptr + conn->pos; + size_t segment_len = conn->head->len - conn->pos; size_t cpy_len = (segment_len <= len) ? segment_len : len; memcpy(dest, segment_ptr, cpy_len); len -= cpy_len; - stream->pos += cpy_len; - stream->consumed_bytes += cpy_len; - stream->pending_bytes -= cpy_len; + conn->pos += cpy_len; + conn->consumed_bytes += cpy_len; + conn->pending_bytes -= cpy_len; dest += cpy_len; - if (stream->pos == stream->head->len) stream_shift_head(stream); + if (conn->pos == conn->head->len) connection_shift_head(conn); } } -VALUE stream_consume_string(struct um_stream *stream, VALUE out_buffer, size_t len, size_t inc, int safe_inc) { +VALUE connection_consume_string(struct um_connection *conn, VALUE out_buffer, size_t len, size_t inc, int safe_inc) { VALUE str = Qnil; if (!NIL_P(out_buffer)) { str = out_buffer; @@ -316,8 +316,8 @@ VALUE stream_consume_string(struct um_stream *stream, VALUE out_buffer, size_t l str = rb_str_new(NULL, len); char *dest = RSTRING_PTR(str); - stream_copy(stream, dest, len); - stream_skip(stream, inc, safe_inc); + connection_copy(conn, dest, len); + connection_skip(conn, inc, safe_inc); return str; RB_GC_GUARD(str); } @@ -326,16 +326,16 @@ inline int trailing_cr_p(char *ptr, size_t len) { return ptr[len - 1] == '\r'; } -VALUE stream_get_line(struct um_stream *stream, VALUE out_buffer, size_t maxlen) { - if (unlikely(stream->eof && !stream->head)) return Qnil; - if (!stream->tail && !stream_get_more_segments(stream)) return Qnil; +VALUE connection_read_line(struct um_connection *conn, VALUE out_buffer, size_t maxlen) { + if (unlikely(conn->eof && !conn->head)) return Qnil; + if (!conn->tail && !connection_get_more_segments(conn)) return Qnil; struct um_segment *last = NULL; - struct um_segment *current = stream->head; + struct um_segment *current = conn->head; size_t remaining_len = maxlen; size_t total_len = 0; size_t inc = 1; - size_t pos = stream->pos; + size_t pos = conn->pos; while (true) { size_t segment_len = current->len - pos; @@ -359,7 +359,7 @@ VALUE stream_get_line(struct um_stream *stream, VALUE out_buffer, size_t maxlen) } } - return stream_consume_string(stream, out_buffer, total_len, inc, false); + return connection_consume_string(conn, out_buffer, total_len, inc, false); } else { // not found, early return if segment len exceeds maxlen @@ -370,7 +370,7 @@ VALUE stream_get_line(struct um_stream *stream, VALUE out_buffer, size_t maxlen) } if (!current->next) { - if (!stream_get_more_segments(stream)) { + if (!connection_get_more_segments(conn)) { return Qnil; } } @@ -381,15 +381,15 @@ VALUE stream_get_line(struct um_stream *stream, VALUE out_buffer, size_t maxlen) } } -VALUE stream_get_string(struct um_stream *stream, VALUE out_buffer, ssize_t len, size_t inc, int safe_inc) { - if (unlikely(stream->eof && !stream->head)) return Qnil; - if (!stream->tail && !stream_get_more_segments(stream)) return Qnil; +VALUE connection_read(struct um_connection *conn, VALUE out_buffer, ssize_t len, size_t inc, int safe_inc) { + if (unlikely(conn->eof && !conn->head)) return Qnil; + if (!conn->tail && !connection_get_more_segments(conn)) return Qnil; - struct um_segment *current = stream->head; + struct um_segment *current = conn->head; size_t abs_len = labs(len); size_t remaining_len = abs_len; size_t total_len = 0; - size_t pos = stream->pos; + size_t pos = conn->pos; while (true) { size_t segment_len = current->len - pos; @@ -400,14 +400,14 @@ VALUE stream_get_string(struct um_stream *stream, VALUE out_buffer, ssize_t len, if (abs_len) { remaining_len -= segment_len; if (!remaining_len) - return stream_consume_string(stream, out_buffer, total_len, inc, safe_inc); + return connection_consume_string(conn, out_buffer, total_len, inc, safe_inc); } if (!current->next) { if (len <= 0) - return stream_consume_string(stream, out_buffer, total_len, inc, safe_inc); + return connection_consume_string(conn, out_buffer, total_len, inc, safe_inc); - if (!stream_get_more_segments(stream)) + if (!connection_get_more_segments(conn)) return Qnil; } current = current->next; @@ -425,17 +425,17 @@ static inline char delim_to_char(VALUE delim) { return *RSTRING_PTR(delim); } -VALUE stream_get_to_delim(struct um_stream *stream, VALUE out_buffer, VALUE delim, ssize_t maxlen) { +VALUE connection_read_to_delim(struct um_connection *conn, VALUE out_buffer, VALUE delim, ssize_t maxlen) { char delim_char = delim_to_char(delim); - if (unlikely(stream->eof && !stream->head)) return Qnil; - if (unlikely(!stream->tail) && !stream_get_more_segments(stream)) return Qnil; + if (unlikely(conn->eof && !conn->head)) return Qnil; + if (unlikely(!conn->tail) && !connection_get_more_segments(conn)) return Qnil; - struct um_segment *current = stream->head; + struct um_segment *current = conn->head; size_t abs_maxlen = labs(maxlen); size_t remaining_len = abs_maxlen; size_t total_len = 0; - size_t pos = stream->pos; + size_t pos = conn->pos; while (true) { size_t segment_len = current->len - pos; @@ -447,7 +447,7 @@ VALUE stream_get_to_delim(struct um_stream *stream, VALUE out_buffer, VALUE deli if (delim_ptr) { size_t len = delim_ptr - start; total_len += len; - return stream_consume_string(stream, out_buffer, total_len, 1, false); + return connection_consume_string(conn, out_buffer, total_len, 1, false); } else { // delimiter not found @@ -455,26 +455,51 @@ VALUE stream_get_to_delim(struct um_stream *stream, VALUE out_buffer, VALUE deli remaining_len -= search_len; if (abs_maxlen && total_len >= abs_maxlen) - return (maxlen > 0) ? Qnil : stream_consume_string(stream, out_buffer, abs_maxlen, 1, false); + return (maxlen > 0) ? Qnil : connection_consume_string(conn, out_buffer, abs_maxlen, 1, false); } - if (!current->next && !stream_get_more_segments(stream)) return Qnil; + if (!current->next && !connection_get_more_segments(conn)) return Qnil; current = current->next; pos = 0; } } +size_t connection_write_raw(struct um_connection *conn, const char *buffer, size_t len) { + switch (conn->mode) { + case CONNECTION_FD: + return um_write_raw(conn->machine, conn->fd, buffer, len); + case CONNECTION_SOCKET: + return um_send_raw(conn->machine, conn->fd, buffer, len, 0); + case CONNECTION_SSL: + return um_ssl_write_raw(conn->machine, conn->target, buffer, len); + default: + rb_raise(eUMError, "Invalid connection mode"); + } +} + +VALUE connection_writev(struct um_connection *conn, int argc, VALUE *argv) { + switch (conn->mode) { + case CONNECTION_FD: + return um_writev(conn->machine, conn->fd, argc, argv); + case CONNECTION_SOCKET: + return um_sendv(conn->machine, conn->fd, argc, argv); + case CONNECTION_SSL: + return ULONG2NUM(um_ssl_writev(conn->machine, conn->target, argc, argv)); + default: + rb_raise(eUMError, "Invalid connection mode"); + } +} //////////////////////////////////////////////////////////////////////////////// -VALUE resp_get_line(struct um_stream *stream, VALUE out_buffer) { - if (unlikely(stream->eof && !stream->head)) return Qnil; - if (!stream->tail && !stream_get_more_segments(stream)) return Qnil; +VALUE resp_read_line(struct um_connection *conn, VALUE out_buffer) { + if (unlikely(conn->eof && !conn->head)) return Qnil; + if (!conn->tail && !connection_get_more_segments(conn)) return Qnil; - struct um_segment *current = stream->head; + struct um_segment *current = conn->head; size_t total_len = 0; - size_t pos = stream->pos; + size_t pos = conn->pos; while (true) { size_t segment_len = current->len - pos; @@ -483,20 +508,20 @@ VALUE resp_get_line(struct um_stream *stream, VALUE out_buffer) { if (lf_ptr) { size_t len = lf_ptr - start; total_len += len; - return stream_consume_string(stream, out_buffer, total_len, 2, true); + return connection_consume_string(conn, out_buffer, total_len, 2, true); } else total_len += segment_len; if (!current->next) - if (!stream_get_more_segments(stream)) return Qnil; + if (!connection_get_more_segments(conn)) return Qnil; current = current->next; } } -inline VALUE resp_get_string(struct um_stream *stream, ulong len, VALUE out_buffer) { - return stream_get_string(stream, out_buffer, len, 2, true); +inline VALUE resp_read_string(struct um_connection *conn, ulong len, VALUE out_buffer) { + return connection_read(conn, out_buffer, len, 2, true); } inline ulong resp_parse_length_field(const char *ptr, int len) { @@ -506,12 +531,12 @@ inline ulong resp_parse_length_field(const char *ptr, int len) { return acc; } -VALUE resp_decode_hash(struct um_stream *stream, VALUE out_buffer, ulong len) { +VALUE resp_decode_hash(struct um_connection *conn, VALUE out_buffer, ulong len) { VALUE hash = rb_hash_new(); for (ulong i = 0; i < len; i++) { - VALUE key = resp_decode(stream, out_buffer); - VALUE value = resp_decode(stream, out_buffer); + VALUE key = resp_read(conn, out_buffer); + VALUE value = resp_read(conn, out_buffer); rb_hash_aset(hash, key, value); RB_GC_GUARD(key); RB_GC_GUARD(value); @@ -521,11 +546,11 @@ VALUE resp_decode_hash(struct um_stream *stream, VALUE out_buffer, ulong len) { return hash; } -VALUE resp_decode_array(struct um_stream *stream, VALUE out_buffer, ulong len) { +VALUE resp_decode_array(struct um_connection *conn, VALUE out_buffer, ulong len) { VALUE array = rb_ary_new2(len); for (ulong i = 0; i < len; i++) { - VALUE value = resp_decode(stream, out_buffer); + VALUE value = resp_read(conn, out_buffer); rb_ary_push(array, value); RB_GC_GUARD(value); } @@ -538,12 +563,12 @@ static inline VALUE resp_decode_simple_string(char *ptr, ulong len) { return rb_str_new(ptr + 1, len - 1); } -static inline VALUE resp_decode_string(struct um_stream *stream, VALUE out_buffer, ulong len) { - return resp_get_string(stream, len, out_buffer); +static inline VALUE resp_decode_string(struct um_connection *conn, VALUE out_buffer, ulong len) { + return resp_read_string(conn, len, out_buffer); } -static inline VALUE resp_decode_string_with_encoding(struct um_stream *stream, VALUE out_buffer, ulong len) { - VALUE with_enc = resp_get_string(stream, len, out_buffer); +static inline VALUE resp_decode_string_with_encoding(struct um_connection *conn, VALUE out_buffer, ulong len) { + VALUE with_enc = resp_read_string(conn, len, out_buffer); char *ptr = RSTRING_PTR(with_enc); len = RSTRING_LEN(with_enc); if ((len < 4) || (ptr[3] != ':')) return Qnil; @@ -566,23 +591,23 @@ static inline VALUE resp_decode_simple_error(char *ptr, ulong len) { if (!ID_new) ID_new = rb_intern("new"); VALUE msg = rb_str_new(ptr + 1, len - 1); - VALUE err = rb_funcall(eStreamRESPError, ID_new, 1, msg); + VALUE err = rb_funcall(eConnectionRESPError, ID_new, 1, msg); RB_GC_GUARD(msg); return err; } -static inline VALUE resp_decode_error(struct um_stream *stream, VALUE out_buffer, ulong len) { +static inline VALUE resp_decode_error(struct um_connection *conn, VALUE out_buffer, ulong len) { static ID ID_new = 0; if (!ID_new) ID_new = rb_intern("new"); - VALUE msg = resp_decode_string(stream, out_buffer, len); - VALUE err = rb_funcall(eStreamRESPError, ID_new, 1, msg); + VALUE msg = resp_decode_string(conn, out_buffer, len); + VALUE err = rb_funcall(eConnectionRESPError, ID_new, 1, msg); RB_GC_GUARD(msg); return err; } -VALUE resp_decode(struct um_stream *stream, VALUE out_buffer) { - VALUE msg = resp_get_line(stream, out_buffer); +VALUE resp_read(struct um_connection *conn, VALUE out_buffer) { + VALUE msg = resp_read_line(conn, out_buffer); if (msg == Qnil) return Qnil; char *ptr = RSTRING_PTR(msg); @@ -594,22 +619,22 @@ VALUE resp_decode(struct um_stream *stream, VALUE out_buffer) { case '%': // hash case '|': // attributes hash data_len = resp_parse_length_field(ptr, len); - return resp_decode_hash(stream, out_buffer, data_len); + return resp_decode_hash(conn, out_buffer, data_len); case '*': // array case '~': // set case '>': // pub/sub push data_len = resp_parse_length_field(ptr, len); - return resp_decode_array(stream, out_buffer, data_len); + return resp_decode_array(conn, out_buffer, data_len); case '+': // simple string return resp_decode_simple_string(ptr, len); case '$': // string data_len = resp_parse_length_field(ptr, len); - return resp_decode_string(stream, out_buffer, data_len); + return resp_decode_string(conn, out_buffer, data_len); case '=': // string with encoding data_len = resp_parse_length_field(ptr, len); - return resp_decode_string_with_encoding(stream, out_buffer, data_len); + return resp_decode_string_with_encoding(conn, out_buffer, data_len); case '_': // null return Qnil; @@ -627,7 +652,7 @@ VALUE resp_decode(struct um_stream *stream, VALUE out_buffer) { return resp_decode_simple_error(ptr, len); case '!': // error data_len = resp_parse_length_field(ptr, len); - return resp_decode_error(stream, out_buffer, data_len); + return resp_decode_error(conn, out_buffer, data_len); default: um_raise_internal_error("Invalid character encountered"); } diff --git a/ext/um/um_connection_class.c b/ext/um/um_connection_class.c new file mode 100644 index 0000000..221a16c --- /dev/null +++ b/ext/um/um_connection_class.c @@ -0,0 +1,394 @@ +#include "um.h" + +VALUE cConnection; +VALUE eConnectionRESPError; + +VALUE SYM_fd; +VALUE SYM_socket; +VALUE SYM_ssl; + +inline int connection_has_target_obj_p(struct um_connection *conn) { + switch (conn->mode) { + case CONNECTION_SSL: + case CONNECTION_STRING: + case CONNECTION_IO_BUFFER: + return true; + default: + return false; + } +} + +inline void connection_mark_segments(struct um_connection *conn) { + struct um_segment *curr = conn->head; + while (curr) { + // rb_gc_mark_movable(curr->obj); + curr = curr->next; + } +} + +inline void connection_compact_segments(struct um_connection *conn) { + struct um_segment *curr = conn->head; + while (curr) { + // curr->obj = rb_gc_location(curr->obj); + curr = curr->next; + } +} + +static void Connection_mark(void *ptr) { + struct um_connection *conn = ptr; + rb_gc_mark_movable(conn->self); + rb_gc_mark_movable(conn->machine->self); + + if (connection_has_target_obj_p(conn)) { + rb_gc_mark_movable(conn->target); + connection_mark_segments(conn); + } +} + +static void Connection_compact(void *ptr) { + struct um_connection *conn = ptr; + conn->self = rb_gc_location(conn->self); + + if (connection_has_target_obj_p(conn)) { + conn->target = rb_gc_location(conn->target); + connection_compact_segments(conn); + } +} + +static void Connection_free(void *ptr) { + struct um_connection *conn = ptr; + connection_clear(conn); +} + +static const rb_data_type_t Connection_type = { + .wrap_struct_name = "UringMachine::Connection", + .function = { + .dmark = Connection_mark, + .dfree = Connection_free, + .dsize = NULL, + .dcompact = Connection_compact + }, + .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED | RUBY_TYPED_EMBEDDABLE +}; + +static VALUE Connection_allocate(VALUE klass) { + struct um_connection *conn; + VALUE self = TypedData_Make_Struct(klass, struct um_connection, &Connection_type, conn); + return self; +} + +static inline struct um_connection *um_get_connection(VALUE self) { + struct um_connection *conn; + TypedData_Get_Struct(self, struct um_connection, &Connection_type, conn); + return conn; +} + +static inline void connection_set_target(struct um_connection *conn, VALUE target, enum um_connection_mode mode) { + conn->mode = mode; + switch (mode) { + case CONNECTION_FD: + case CONNECTION_SOCKET: + conn->fd = NUM2INT(target); + return; + case CONNECTION_SSL: + conn->target = target; + um_ssl_set_bio(conn->machine, target); + return; + default: + rb_raise(eUMError, "Invalid connection mode"); + } +} + +static inline void connection_setup(struct um_connection *conn, VALUE target, VALUE mode) { + conn->working_buffer = NULL; + if (NIL_P(mode)) { + if (TYPE(target) == T_DATA) + connection_set_target(conn, target, CONNECTION_SSL); + else + connection_set_target(conn, target, CONNECTION_FD); + } + else if (mode == SYM_fd) + connection_set_target(conn, target, CONNECTION_FD); + else if (mode == SYM_socket) + connection_set_target(conn, target, CONNECTION_SOCKET); + else if (mode == SYM_ssl) + connection_set_target(conn, target, CONNECTION_SSL); + else + rb_raise(eUMError, "Invalid connection mode"); +} + +/* call-seq: + * UM::Stream.new(machine, fd, mode = nil) -> conn + * machine.connection(fd, mode = nil) -> conn + * machine.connection(fd, mode = nil) { |conn| ... } + * + * Initializes a new connection with the given UringMachine instance, target and + * optional mode. The target maybe a file descriptor, or an instance of + * OpenSSL::SSL::SSLSocket. In case of an SSL socket, the mode should be :ssl. + * + * @param machine [UringMachine] UringMachine instance + * @param target [integer, OpenSSL::SSL::SSLSocket] connection target: file descriptor or SSL socket + * @param mode [Symbol] optional connection mode: :fd, :socket, :ssl + * @return [void] + */ +VALUE Connection_initialize(int argc, VALUE *argv, VALUE self) { + VALUE machine; + VALUE target; + VALUE mode; + rb_scan_args(argc, argv, "21", &machine, &target, &mode); + + struct um_connection *conn = um_get_connection(self); + memset(conn, 0, sizeof(struct um_connection)); + + RB_OBJ_WRITE(self, &conn->self, self); + conn->machine = um_get_machine(machine); + connection_setup(conn, target, mode); + + return self; +} + +/* call-seq: + * conn.mode -> mode + * + * Returns the connection mode. + * + * @return [Symbol] connection mode + */ +VALUE Connection_mode(VALUE self) { + struct um_connection *conn = um_get_connection(self); + switch (conn->mode) { + case CONNECTION_FD: return SYM_fd; + case CONNECTION_SOCKET: return SYM_socket; + case CONNECTION_SSL: return SYM_ssl; + default: return Qnil; + } + return Qnil; +} + +/* call-seq: + * conn.read_line(limit) -> str + * + * Reads from the string until a newline character is encountered. Returns the + * line without the newline delimiter. If limit is 0, the line length is not + * limited. If no newline delimiter is found before EOF, returns nil. + * + * @param limit [integer] maximum line length (0 means no limit) + * @return [String, nil] read data or nil + */ +VALUE Connection_read_line(VALUE self, VALUE limit) { + struct um_connection *conn = um_get_connection(self); + return connection_read_line(conn, Qnil, NUM2ULONG(limit)); +} + +/* call-seq: + * conn.read(len) -> str + * + * Reads len bytes from the conn. If len is 0, reads all available bytes. If + * len is negative, reads up to -len available bytes. If len is positive and eof + * is encountered before len bytes are read, returns nil. + * + * @param len [integer] number of bytes to read + * @return [String, nil] read data or nil + */ +VALUE Connection_read(VALUE self, VALUE len) { + struct um_connection *conn = um_get_connection(self); + return connection_read(conn, Qnil, NUM2LONG(len), 0, false); +} + +/* call-seq: + * conn.read_to_delim(delim, limit) -> str + * + * Reads from the string until a the given delimiter is encountered. Returns the + * line without the delimiter. If limit is 0, the length is not limited. If a + * delimiter is not found before EOF and limit is 0 or greater, returns nil. + * + * If no delimiter is found before EOF and limit is negative, returns the + * buffered data up to EOF or until the absolute-value length limit is reached. + * + * The `delim` parameter must be a single byte string. + * + * @param delim [String] delimiter (single byte) @param limit [integer] maximum + * line length (0 means no limit) @return [String, nil] read data or nil + */ +VALUE Connection_read_to_delim(VALUE self, VALUE delim, VALUE limit) { + struct um_connection *conn = um_get_connection(self); + return connection_read_to_delim(conn, Qnil, delim, NUM2LONG(limit)); +} + +/* call-seq: + * conn.skip(len) -> len + * + * Skips len bytes in the conn. + * + * @param len [integer] number of bytes to skip + * @return [Integer] len + */ +VALUE Connection_skip(VALUE self, VALUE len) { + struct um_connection *conn = um_get_connection(self); + connection_skip(conn, NUM2LONG(len), true); + return len; +} + +/* call-seq: + * conn.read_each { |data| } -> conn + * + * Reads from the target, passing each chunk to the given block. + * + * @return [UringMachine::Connection] conn + */ +VALUE Connection_read_each(VALUE self) { + struct um_connection *conn = um_get_connection(self); + connection_read_each(conn); + return self; +} + +/* call-seq: + * conn.write(*bufs) -> len + * + * Writes to the connection, ensuring that all data has been written before + * returning the total number of bytes written. + * + * @param bufs [Array] data to write + * @return [Integer] total bytes written + */ +VALUE Connection_write(int argc, VALUE *argv, VALUE self) { + struct um_connection *conn = um_get_connection(self); + return connection_writev(conn, argc, argv); +} + +/* call-seq: + * conn.resp_read -> obj + * + * Decodes an object from a RESP (Redis protocol) message. + * + * @return [any] decoded object + */ +VALUE Connection_resp_read(VALUE self) { + struct um_connection *conn = um_get_connection(self); + VALUE out_buffer = rb_utf8_str_new_literal(""); + VALUE obj = resp_read(conn, out_buffer); + RB_GC_GUARD(out_buffer); + return obj; +} + +/* call-seq: + * conn.resp_write(obj) -> conn + * + * Writes the given object using RESP (Redis protocol) to the connection target. + * Returns the number of bytes written. + * + * @param obj [any] object to write + * @return [Integer] total bytes written + */ +VALUE Connection_resp_write(VALUE self, VALUE obj) { + struct um_connection *conn = um_get_connection(self); + + VALUE str = rb_str_new(NULL, 0); + struct um_write_buffer buf; + write_buffer_init(&buf, str); + rb_str_modify(str); + resp_encode(&buf, obj); + write_buffer_update_len(&buf); + + size_t len = connection_write_raw(conn, buf.ptr, buf.len); + RB_GC_GUARD(str); + return ULONG2NUM(len); +} + +/* call-seq: + * conn.resp_encode(obj) -> string + * + * Encodes an object into a RESP (Redis protocol) message. + * + * @param str [String] string buffer + * @param obj [any] object to be encoded + * @return [String] str + */ +VALUE Connection_resp_encode(VALUE self, VALUE str, VALUE obj) { + struct um_write_buffer buf; + write_buffer_init(&buf, str); + rb_str_modify(str); + resp_encode(&buf, obj); + write_buffer_update_len(&buf); + return str; +} + +/* call-seq: + * conn.eof? -> bool + * + * Returns true if connection has reached EOF. + * + * @return [bool] EOF reached + */ +VALUE Connection_eof_p(VALUE self) { + struct um_connection *conn = um_get_connection(self); + return conn->eof ? Qtrue : Qfalse; +} + +/* call-seq: + * conn.consumed -> int + * + * Returns the total number of bytes consumed from the conn. + * + * @return [Integer] total bytes consumed + */ +VALUE Connection_consumed(VALUE self) { + struct um_connection *conn = um_get_connection(self); + return LONG2NUM(conn->consumed_bytes); +} + +/* call-seq: + * conn.pending -> int + * + * Returns the number of bytes available for reading. + * + * @return [Integer] bytes available + */ +VALUE Connection_pending(VALUE self) { + struct um_connection *conn = um_get_connection(self); + return LONG2NUM(conn->pending_bytes); +} + +/* call-seq: + * conn.clear -> conn + * + * Clears all available bytes and stops any ongoing read operation. + * + * @return [UM::Stream] self + */ +VALUE Connection_clear(VALUE self) { + struct um_connection *conn = um_get_connection(self); + connection_clear(conn); + return self; +} + +void Init_Stream(void) { + cConnection = rb_define_class_under(cUM, "Connection", rb_cObject); + rb_define_alloc_func(cConnection, Connection_allocate); + + rb_define_method(cConnection, "initialize", Connection_initialize, -1); + rb_define_method(cConnection, "mode", Connection_mode, 0); + + rb_define_method(cConnection, "read_line", Connection_read_line, 1); + rb_define_method(cConnection, "read", Connection_read, 1); + rb_define_method(cConnection, "read_to_delim", Connection_read_to_delim, 2); + rb_define_method(cConnection, "skip", Connection_skip, 1); + rb_define_method(cConnection, "read_each", Connection_read_each, 0); + + rb_define_method(cConnection, "write", Connection_write, -1); + + rb_define_method(cConnection, "resp_read", Connection_resp_read, 0); + rb_define_method(cConnection, "resp_write", Connection_resp_write, 1); + rb_define_singleton_method(cConnection, "resp_encode", Connection_resp_encode, 2); + + rb_define_method(cConnection, "eof?", Connection_eof_p, 0); + rb_define_method(cConnection, "consumed", Connection_consumed, 0); + rb_define_method(cConnection, "pending", Connection_pending, 0); + rb_define_method(cConnection, "clear", Connection_clear, 0); + + eConnectionRESPError = rb_define_class_under(cConnection, "RESPError", rb_eStandardError); + + SYM_fd = ID2SYM(rb_intern("fd")); + SYM_socket = ID2SYM(rb_intern("socket")); + SYM_ssl = ID2SYM(rb_intern("ssl")); +} diff --git a/ext/um/um_ssl.c b/ext/um/um_ssl.c index 5a07cab..ae3980a 100644 --- a/ext/um/um_ssl.c +++ b/ext/um/um_ssl.c @@ -105,3 +105,38 @@ int um_ssl_write(struct um *machine, VALUE ssl_obj, VALUE buf, size_t len) { return ret; } + +int um_ssl_write_raw(struct um *machine, VALUE ssl_obj, const char *buffer, size_t len) { + SSL *ssl = RTYPEDDATA_GET_DATA(ssl_obj); + if (unlikely(!len)) return INT2NUM(0); + + int ret = SSL_write(ssl, buffer, (int)len); + if (ret <= 0) rb_raise(eUMError, "Failed to write"); + + return ret; +} + +int um_ssl_write_all(struct um *machine, VALUE ssl_obj, VALUE buf) { + SSL *ssl = RTYPEDDATA_GET_DATA(ssl_obj); + const char *base; + size_t size; + um_get_buffer_bytes_for_writing(buf, (const void **)&base, &size, true); + + size_t left = size; + while (left) { + int ret = SSL_write(ssl, base, (int)left); + if (ret <= 0) rb_raise(eUMError, "Failed to write"); + + left -= ret; + base += ret; + } + + return size; +} + +int um_ssl_writev(struct um *machine, VALUE ssl, int argc, VALUE *argv) { + size_t total = 0; + for (int i = 0; i < argc; i++) + total += um_ssl_write_all(machine, ssl, argv[i]); + return total; +} diff --git a/ext/um/um_stream_class.c b/ext/um/um_stream_class.c deleted file mode 100644 index fd096c2..0000000 --- a/ext/um/um_stream_class.c +++ /dev/null @@ -1,338 +0,0 @@ -#include "um.h" - -VALUE cStream; -VALUE eStreamRESPError; - -VALUE SYM_bp_read; -VALUE SYM_bp_recv; -VALUE SYM_ssl; - -inline int stream_has_target_obj_p(struct um_stream *stream) { - switch (stream->mode) { - case STREAM_SSL: - case STREAM_STRING: - case STREAM_IO_BUFFER: - return true; - default: - return false; - } -} - -inline void stream_mark_segments(struct um_stream *stream) { - struct um_segment *curr = stream->head; - while (curr) { - // rb_gc_mark_movable(curr->obj); - curr = curr->next; - } -} - -inline void stream_compact_segments(struct um_stream *stream) { - struct um_segment *curr = stream->head; - while (curr) { - // curr->obj = rb_gc_location(curr->obj); - curr = curr->next; - } -} - -static void Stream_mark(void *ptr) { - struct um_stream *stream = ptr; - rb_gc_mark_movable(stream->self); - rb_gc_mark_movable(stream->machine->self); - - if (stream_has_target_obj_p(stream)) { - rb_gc_mark_movable(stream->target); - stream_mark_segments(stream); - } -} - -static void Stream_compact(void *ptr) { - struct um_stream *stream = ptr; - stream->self = rb_gc_location(stream->self); - - if (stream_has_target_obj_p(stream)) { - stream->target = rb_gc_location(stream->target); - stream_compact_segments(stream); - } -} - -static void Stream_free(void *ptr) { - struct um_stream *stream = ptr; - stream_clear(stream); -} - -static const rb_data_type_t Stream_type = { - .wrap_struct_name = "UringMachine::Stream", - .function = { - .dmark = Stream_mark, - .dfree = Stream_free, - .dsize = NULL, - .dcompact = Stream_compact - }, - .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED | RUBY_TYPED_EMBEDDABLE -}; - -static VALUE Stream_allocate(VALUE klass) { - struct um_stream *stream; - VALUE self = TypedData_Make_Struct(klass, struct um_stream, &Stream_type, stream); - return self; -} - -static inline struct um_stream *um_get_stream(VALUE self) { - struct um_stream *stream; - TypedData_Get_Struct(self, struct um_stream, &Stream_type, stream); - return stream; -} - -static inline void stream_setup(struct um_stream *stream, VALUE target, VALUE mode) { - stream->working_buffer = NULL; - if (mode == SYM_bp_read || mode == Qnil) { - stream->mode = STREAM_BP_READ; - stream->fd = NUM2INT(target); - } - else if (mode == SYM_bp_recv) { - stream->mode = STREAM_BP_RECV; - stream->fd = NUM2INT(target); - } - else if (mode == SYM_ssl) { - stream->mode = STREAM_SSL; - stream->target = target; - um_ssl_set_bio(stream->machine, target); - } - else - rb_raise(eUMError, "Invalid stream mode"); -} - -/* call-seq: - * UM::Stream.new(machine, fd, mode = nil) -> stream - * machine.stream(fd, mode = nil) -> stream - * machine.stream(fd, mode = nil) { |stream| ... } - * - * Initializes a new stream with the given UringMachine instance, target and - * optional mode. The target maybe a file descriptor, or an instance of - * OpenSSL::SSL::SSLSocket. In case of an SSL socket, the mode should be :ssl. - * - * @param machine [UringMachine] UringMachine instance - * @param target [integer, OpenSSL::SSL::SSLSocket] stream target: file descriptor or SSL socket - * @param mode [Symbol] optional stream mode: :bp_read, :bp_recv, :ssl - * @return [void] - */ -VALUE Stream_initialize(int argc, VALUE *argv, VALUE self) { - VALUE machine; - VALUE target; - VALUE mode; - rb_scan_args(argc, argv, "21", &machine, &target, &mode); - - struct um_stream *stream = um_get_stream(self); - memset(stream, 0, sizeof(struct um_stream)); - - RB_OBJ_WRITE(self, &stream->self, self); - stream->machine = um_get_machine(machine); - stream_setup(stream, target, mode); - - return self; -} - -/* call-seq: - * stream.mode -> mode - * - * Returns the stream mode. - * - * @return [Symbol] stream mode - */ -VALUE Stream_mode(VALUE self) { - struct um_stream *stream = um_get_stream(self); - switch (stream->mode) { - case STREAM_BP_READ: return SYM_bp_read; - case STREAM_BP_RECV: return SYM_bp_recv; - case STREAM_SSL: return SYM_ssl; - default: return Qnil; - } - return Qnil; -} - -/* call-seq: - * stream.get_line(limit) -> str - * - * Reads from the string until a newline character is encountered. Returns the - * line without the newline delimiter. If limit is 0, the line length is not - * limited. If no newline delimiter is found before EOF, returns nil. - * - * @param limit [integer] maximum line length (0 means no limit) - * @return [String, nil] read data or nil - */ -VALUE Stream_get_line(VALUE self, VALUE limit) { - struct um_stream *stream = um_get_stream(self); - return stream_get_line(stream, Qnil, NUM2ULONG(limit)); -} - -/* call-seq: - * stream.get_string(len) -> str - * - * Reads len bytes from the stream. If len is 0, reads all available bytes. If - * len is negative, reads up to -len available bytes. If len is positive and eof - * is encountered before len bytes are read, returns nil. - * - * @param len [integer] number of bytes to read - * @return [String, nil] read data or nil - */ -VALUE Stream_get_string(VALUE self, VALUE len) { - struct um_stream *stream = um_get_stream(self); - return stream_get_string(stream, Qnil, NUM2LONG(len), 0, false); -} - -/* call-seq: - * stream.get_to_delim(delim, limit) -> str - * - * Reads from the string until a the given delimiter is encountered. Returns the - * line without the delimiter. If limit is 0, the length is not limited. If a - * delimiter is not found before EOF and limit is 0 or greater, returns nil. - * - * If no delimiter is found before EOF and limit is negative, returns the - * buffered data up to EOF or until the absolute-value length limit is reached. - * - * The `delim` parameter must be a single byte string. - * - * @param delim [String] delimiter (single byte) @param limit [integer] maximum - * line length (0 means no limit) @return [String, nil] read data or nil - */ -VALUE Stream_get_to_delim(VALUE self, VALUE delim, VALUE limit) { - struct um_stream *stream = um_get_stream(self); - return stream_get_to_delim(stream, Qnil, delim, NUM2LONG(limit)); -} - -/* call-seq: - * stream.skip(len) -> len - * - * Skips len bytes in the stream. - * - * @param len [integer] number of bytes to skip - * @return [Integer] len - */ -VALUE Stream_skip(VALUE self, VALUE len) { - struct um_stream *stream = um_get_stream(self); - stream_skip(stream, NUM2LONG(len), true); - return len; -} - -/* call-seq: - * stream.each { |data| } -> stream - * - * Reads from the target, passing each chunk to the given block. - * - * @return [UringMachine::Stream] stream - */ -VALUE Stream_each(VALUE self) { - struct um_stream *stream = um_get_stream(self); - stream_each(stream); - return self; -} - -/* call-seq: - * stream.resp_decode -> obj - * - * Decodes an object from a RESP (Redis protocol) message. - * - * @return [any] decoded object - */ -VALUE Stream_resp_decode(VALUE self) { - struct um_stream *stream = um_get_stream(self); - VALUE out_buffer = rb_utf8_str_new_literal(""); - VALUE obj = resp_decode(stream, out_buffer); - RB_GC_GUARD(out_buffer); - return obj; -} - -/* call-seq: - * stream.resp_encode(obj) -> string - * - * Encodes an object into a RESP (Redis protocol) message. - * - * @param str [String] string buffer - * @param obj [any] object to be encoded - * @return [String] str - */ -VALUE Stream_resp_encode(VALUE self, VALUE str, VALUE obj) { - struct um_write_buffer buf; - write_buffer_init(&buf, str); - rb_str_modify(str); - resp_encode(&buf, obj); - write_buffer_update_len(&buf); - return str; -} - -/* call-seq: - * stream.eof? -> bool - * - * Returns true if stream has reached EOF. - * - * @return [bool] EOF reached - */ -VALUE Stream_eof_p(VALUE self) { - struct um_stream *stream = um_get_stream(self); - return stream->eof ? Qtrue : Qfalse; -} - -/* call-seq: - * stream.consumed -> int - * - * Returns the total number of bytes consumed from the stream. - * - * @return [Integer] total bytes consumed - */ -VALUE Stream_consumed(VALUE self) { - struct um_stream *stream = um_get_stream(self); - return LONG2NUM(stream->consumed_bytes); -} - -/* call-seq: - * stream.pending -> int - * - * Returns the number of bytes available for reading. - * - * @return [Integer] bytes available - */ -VALUE Stream_pending(VALUE self) { - struct um_stream *stream = um_get_stream(self); - return LONG2NUM(stream->pending_bytes); -} - -/* call-seq: - * stream.clear -> stream - * - * Clears all available bytes and stops any ongoing read operation. - * - * @return [UM::Stream] self - */ -VALUE Stream_clear(VALUE self) { - struct um_stream *stream = um_get_stream(self); - stream_clear(stream); - return self; -} - -void Init_Stream(void) { - cStream = rb_define_class_under(cUM, "Stream", rb_cObject); - rb_define_alloc_func(cStream, Stream_allocate); - - rb_define_method(cStream, "initialize", Stream_initialize, -1); - rb_define_method(cStream, "mode", Stream_mode, 0); - - rb_define_method(cStream, "get_line", Stream_get_line, 1); - rb_define_method(cStream, "get_string", Stream_get_string, 1); - rb_define_method(cStream, "get_to_delim", Stream_get_to_delim, 2); - rb_define_method(cStream, "skip", Stream_skip, 1); - rb_define_method(cStream, "each", Stream_each, 0); - - rb_define_method(cStream, "resp_decode", Stream_resp_decode, 0); - rb_define_singleton_method(cStream, "resp_encode", Stream_resp_encode, 2); - - rb_define_method(cStream, "eof?", Stream_eof_p, 0); - rb_define_method(cStream, "consumed", Stream_consumed, 0); - rb_define_method(cStream, "pending", Stream_pending, 0); - rb_define_method(cStream, "clear", Stream_clear, 0); - - eStreamRESPError = rb_define_class_under(cStream, "RESPError", rb_eStandardError); - - SYM_bp_read = ID2SYM(rb_intern("bp_read")); - SYM_bp_recv = ID2SYM(rb_intern("bp_recv")); - SYM_ssl = ID2SYM(rb_intern("ssl")); -} diff --git a/ext/um/um_utils.c b/ext/um/um_utils.c index ffc6c9d..cd86b61 100644 --- a/ext/um/um_utils.c +++ b/ext/um/um_utils.c @@ -159,7 +159,7 @@ int um_setup_buffer_ring(struct um *machine, unsigned size, unsigned count) { return bg_id; } -inline VALUE um_get_string_from_buffer_ring(struct um *machine, int bgid, __s32 result, __u32 flags) { +inline VALUE um_read_from_buffer_ring(struct um *machine, int bgid, __s32 result, __u32 flags) { if (!result) return Qnil; unsigned buf_idx = flags >> IORING_CQE_BUFFER_SHIFT; diff --git a/grant-2025/journal.md b/grant-2025/journal.md index 312442c..19d6c5f 100644 --- a/grant-2025/journal.md +++ b/grant-2025/journal.md @@ -121,7 +121,7 @@ Ruby I/O layer. Some interesting warts in the Ruby `IO` implementation: ```ruby def io_write(io, buffer, length, offset) reset_nonblock(io) - @machine.write(io.fileno, buffer.get_string) + @machine.write(io.fileno, buffer.read) rescue Errno::EINTR retry end diff --git a/lib/uringmachine.rb b/lib/uringmachine.rb index 81c908b..ab4aafd 100644 --- a/lib/uringmachine.rb +++ b/lib/uringmachine.rb @@ -195,28 +195,28 @@ def file_watch(root, mask) end # call-seq: - # machine.stream(fd, mode = nil) -> stream - # machine.stream(fd, mode = nil) { |stream| } + # machine.connection(fd, mode = nil) -> conn + # machine.connection(fd, mode = nil) { |conn| } # - # Creates a stream for reading from the given target fd or other object. The - # mode indicates the type of target and how it is read from: + # Creates a connection for reading from the given target fd or other object. + # The mode indicates the type of target and how it is read from: # - # - :bp_read - read from the given fd using the buffer pool (default mode) - # - :bp_recv - receive from the given socket fd using the buffer pool + # - :fd - read from the given fd using the buffer pool (default mode) + # - :socket - receive from the given socket fd using the buffer pool # - :ssl - read from the given SSL connection # - # If a block is given, the block will be called with the stream object and the - # method will return the block's return value. + # If a block is given, the block will be called with the connection object and + # the method will return the block's return value. # # @param target [Integer, OpenSSL::SSL::SSLSocket] fd or ssl connection - # @param mode [Symbol, nil] stream mode - # @return [UringMachine::Stream] stream object - def stream(target, mode = nil) - stream = UM::Stream.new(self, target, mode) - return stream if !block_given? - - res = yield(stream) - stream.clear + # @param mode [Symbol, nil] connection mode + # @return [UringMachine::Stream] connection object + def connection(target, mode = nil) + conn = UM::Connection.new(self, target, mode) + return conn if !block_given? + + res = yield(conn) + conn.clear res end diff --git a/test/test_stream.rb b/test/test_connection.rb similarity index 56% rename from test/test_stream.rb rename to test/test_connection.rb index 45553a1..c5137db 100644 --- a/test/test_stream.rb +++ b/test/test_connection.rb @@ -5,24 +5,24 @@ require 'openssl' require 'localhost/authority' -class StreamBaseTest < UMBaseTest - attr_reader :stream +class ConnectionBaseTest < UMBaseTest + attr_reader :conn def setup super @rfd, @wfd = UM.pipe - @stream = UM::Stream.new(@machine, @rfd) + @conn = UM::Connection.new(@machine, @rfd) end def teardown - @stream = nil + @conn = nil machine.close(@rfd) rescue nil machine.close(@wfd) rescue nil super end end -class StreamTest < StreamBaseTest +class ConnectionTest < ConnectionBaseTest def buffer_metrics machine.metrics.fetch_values( :buffers_allocated, @@ -33,19 +33,19 @@ def buffer_metrics ) end - def test_stream_basic_usage + def test_connection_basic_usage assert_equal [0, 0, 0, 0, 0], buffer_metrics machine.write(@wfd, "foobar") machine.close(@wfd) - buf = stream.get_string(3) + buf = conn.read(3) assert_equal 'foo', buf - buf = stream.get_string(-6) + buf = conn.read(-6) assert_equal 'bar', buf - assert stream.eof? + assert conn.eof? - stream.clear + conn.clear # initial buffer size: 6BKV, initial buffers commited: 16 (256KB) # (plus an additional buffer commited after first usage) @@ -53,20 +53,20 @@ def test_stream_basic_usage assert_equal 0, machine.metrics[:ops_pending] end - def test_stream_clear + def test_connection_clear rfd, wfd = UM.pipe - stream = UM::Stream.new(machine, rfd) + conn = UM::Connection.new(machine, rfd) assert_equal [0, 0, 0, 0, 0], buffer_metrics machine.write(wfd, "foobar") - buf = stream.get_string(3) + buf = conn.read(3) assert_equal 'foo', buf assert_equal 1, machine.metrics[:ops_pending] assert_equal 255, machine.metrics[:segments_free] - stream.clear + conn.clear machine.snooze assert_equal 0, machine.metrics[:ops_pending] assert_equal 256, machine.metrics[:segments_free] @@ -77,9 +77,9 @@ def test_stream_clear machine.close(wfd) rescue nil end - def test_stream_big_read + def test_connection_big_read s1, s2 = UM.socketpair(UM::AF_UNIX, UM::SOCK_STREAM, 0) - stream = UM::Stream.new(machine, s2) + conn = UM::Connection.new(machine, s2) msg = '1234567' * 20000 @@ -89,16 +89,16 @@ def test_stream_big_read machine.shutdown(s1, UM::SHUT_WR) end - buf = stream.get_string(msg.bytesize) + buf = conn.read(msg.bytesize) assert_equal msg, buf ensure machine.terminate(f) machine.join(f) end - def test_stream_buffer_reuse + def test_connection_buffer_reuse s1, s2 = UM.socketpair(UM::AF_UNIX, UM::SOCK_STREAM, 0) - stream = UM::Stream.new(machine, s2) + conn = UM::Connection.new(machine, s2) msg = '1234567' * 20000 @@ -109,19 +109,19 @@ def test_stream_buffer_reuse machine.shutdown(s1, UM::SHUT_WR) end - buf = stream.get_string(msg.bytesize) + buf = conn.read(msg.bytesize) assert_equal msg, buf - buf = stream.get_string(msg.bytesize) + buf = conn.read(msg.bytesize) assert_equal msg, buf - buf = stream.get_string(msg.bytesize) + buf = conn.read(msg.bytesize) assert_equal msg, buf - buf = stream.get_string(msg.bytesize) + buf = conn.read(msg.bytesize) assert_equal msg, buf - stream.clear + conn.clear # numbers may vary with different kernel versions assert_in_range 24..32, machine.metrics[:buffers_allocated] assert_in_range 10..18, machine.metrics[:buffers_free] @@ -131,23 +131,23 @@ def test_stream_buffer_reuse machine.join(f) end - def test_stream_get_line + def test_connection_read_line machine.write(@wfd, "foo\nbar\r\nbaz") machine.close(@wfd) assert_equal [0, 0, 0, 0, 0], buffer_metrics - assert_equal 'foo', stream.get_line(0) + assert_equal 'foo', conn.read_line(0) assert_equal [16, 0, 255, 16384 * 16, 16384 * 16 - 12], buffer_metrics - assert_equal 'bar', stream.get_line(0) - assert_nil stream.get_line(0) - assert_equal "baz", stream.get_string(-6) + assert_equal 'bar', conn.read_line(0) + assert_nil conn.read_line(0) + assert_equal "baz", conn.read(-6) end - def test_stream_get_line_segmented + def test_connection_read_line_segmented machine.write(@wfd, "foo\n") - assert_equal 'foo', stream.get_line(0) + assert_equal 'foo', conn.read_line(0) machine.write(@wfd, "bar") machine.write(@wfd, "\r\n") @@ -156,19 +156,19 @@ def test_stream_get_line_segmented # three segments received assert_equal [16, 0, 253, 16384 * 16, 16384 * 16 - 13], buffer_metrics - assert_equal 'bar', stream.get_line(0) + assert_equal 'bar', conn.read_line(0) assert_equal [16, 0, 255, 16384 * 16, 16384 * 16 - 13], buffer_metrics - assert_equal 'baz', stream.get_line(0) + assert_equal 'baz', conn.read_line(0) assert_equal [16, 0, 256, 16384 * 16, 16384 * 16 - 13], buffer_metrics - assert_nil stream.get_line(0) + assert_nil conn.read_line(0) end - def test_stream_get_line_maxlen + def test_connection_read_line_maxlen machine.write(@wfd, "foobar\r\n") - assert_nil stream.get_line(3) - # verify that stream pos has not changed - assert_equal 'foobar', stream.get_line(0) + assert_nil conn.read_line(3) + # verify that connecvtion pos has not changed + assert_equal 'foobar', conn.read_line(0) machine.write(@wfd, "baz") machine.write(@wfd, "\n") @@ -176,83 +176,83 @@ def test_stream_get_line_maxlen machine.write(@wfd, "\n") machine.close(@wfd) - assert_nil stream.get_line(2) - assert_nil stream.get_line(3) - assert_equal 'baz', stream.get_line(4) + assert_nil conn.read_line(2) + assert_nil conn.read_line(3) + assert_equal 'baz', conn.read_line(4) - assert_nil stream.get_line(3) - assert_nil stream.get_line(4) - assert_equal 'bizz', stream.get_line(5) + assert_nil conn.read_line(3) + assert_nil conn.read_line(4) + assert_equal 'bizz', conn.read_line(5) - assert_nil stream.get_line(8) + assert_nil conn.read_line(8) assert_equal [16, 0, 256, 16384 * 16, 16384 * 16 - 17], buffer_metrics end - def test_stream_get_string + def test_connection_read machine.write(@wfd, "foobarbazblahzzz") machine.close(@wfd) - assert_equal 'foobar', stream.get_string(6) - assert_equal 'baz', stream.get_string(3) - assert_equal 'blah', stream.get_string(4) - assert_nil stream.get_string(4) + assert_equal 'foobar', conn.read(6) + assert_equal 'baz', conn.read(3) + assert_equal 'blah', conn.read(4) + assert_nil conn.read(4) end - def test_stream_get_string_zero_len + def test_connection_read_zero_len machine.write(@wfd, "foobar") - assert_equal 'foobar', stream.get_string(0) + assert_equal 'foobar', conn.read(0) machine.write(@wfd, "bazblah") machine.close(@wfd) - assert_equal 'bazblah', stream.get_string(0) - assert_nil stream.get_string(0) + assert_equal 'bazblah', conn.read(0) + assert_nil conn.read(0) end - def test_stream_get_string_negative_len + def test_connection_read_negative_len machine.write(@wfd, "foobar") - assert_equal 'foo', stream.get_string(-3) - assert_equal 'bar', stream.get_string(-6) + assert_equal 'foo', conn.read(-3) + assert_equal 'bar', conn.read(-6) machine.write(@wfd, "bazblah") machine.close(@wfd) - assert_equal 'bazblah', stream.get_string(-12) - assert_nil stream.get_string(-3) + assert_equal 'bazblah', conn.read(-12) + assert_nil conn.read(-3) end - def test_stream_get_to_delim + def test_connection_read_to_delim machine.write(@wfd, "abc,def,ghi") machine.close(@wfd) - assert_nil stream.get_to_delim('!', 0) # not there - assert_nil stream.get_to_delim(',', 2) # too long - assert_equal 'abc', stream.get_to_delim(',', 0) - assert_equal 'def', stream.get_to_delim(',', 0) - assert_nil stream.get_to_delim(',', 0) - assert_equal 'ghi', stream.get_to_delim(',', -3) + assert_nil conn.read_to_delim('!', 0) # not there + assert_nil conn.read_to_delim(',', 2) # too long + assert_equal 'abc', conn.read_to_delim(',', 0) + assert_equal 'def', conn.read_to_delim(',', 0) + assert_nil conn.read_to_delim(',', 0) + assert_equal 'ghi', conn.read_to_delim(',', -3) end - def test_stream_get_to_delim_invalid_delim + def test_connection_read_to_delim_invalid_delim machine.write(@wfd, "abc,def,ghi") - assert_raises(ArgumentError) { stream.get_to_delim(:foo, 0) } - assert_raises(UM::Error) { stream.get_to_delim('', 0) } - assert_raises(UM::Error) { stream.get_to_delim('ab', 0) } - assert_raises(UM::Error) { stream.get_to_delim('🙂', 0) } + assert_raises(ArgumentError) { conn.read_to_delim(:foo, 0) } + assert_raises(UM::Error) { conn.read_to_delim('', 0) } + assert_raises(UM::Error) { conn.read_to_delim('ab', 0) } + assert_raises(UM::Error) { conn.read_to_delim('🙂', 0) } end - def test_stream_skip + def test_connection_skip machine.write(@wfd, "foobarbaz") - stream.skip(2) - assert_equal 'obar', stream.get_string(4) + conn.skip(2) + assert_equal 'obar', conn.read(4) - stream.skip(1) - assert_equal 'az', stream.get_string(0) + conn.skip(1) + assert_equal 'az', conn.read(0) end - def test_stream_big_data + def test_connection_big_data data = SecureRandom.random_bytes(300_000) fiber = machine.spin { machine.writev(@wfd, data) @@ -261,7 +261,7 @@ def test_stream_big_data received = [] loop { - msg = stream.get_string(-60_000) + msg = conn.read(-60_000) break if !msg received << msg @@ -273,11 +273,11 @@ def test_stream_big_data assert_equal data, received.join end - def test_stream_each + def test_connection_read_each bufs = [] f = machine.spin do bufs << :ready - stream.each { + conn.read_each { assert_kind_of IO::Buffer, it bufs << it.get_string } @@ -307,87 +307,190 @@ def test_stream_each end end -class StreamRespTest < StreamBaseTest - def test_stream_resp_decode +class ConnectionWriteTest < UMBaseTest + attr_reader :conn + + def setup + super + @s1, @s2 = UM.socketpair(UM::AF_UNIX, UM::SOCK_STREAM, 0) + @conn = UM::Connection.new(@machine, @s1) + end + + def teardown + @conn = nil + machine.close(@s1) rescue nil + machine.close(@s2) rescue nil + super + end + + def test_connection_write_single_buf + assert_equal 3, conn.write('foo') + + buf = +'' + machine.read(@s2, buf, 100) + assert_equal 'foo', buf + end + + def test_connection_write_multi_buf + assert_equal 6, conn.write('foo', 'bar') + + buf = +'' + machine.read(@s2, buf, 100) + assert_equal 'foobar', buf + end + + def test_connection_write_socket_mode + conn = machine.connection(@s2, :socket) + + assert_equal 6, conn.write('foo', 'bar') + + buf = +'' + machine.read(@s1, buf, 100) + assert_equal 'foobar', buf + end + + def test_connection_write_ssl_mode + ssl1 = OpenSSL::SSL::SSLSocket.new(IO.for_fd(@s1), Localhost::Authority.fetch.server_context) + ssl1.sync_close = true + ssl2 = OpenSSL::SSL::SSLSocket.new(IO.for_fd(@s2), OpenSSL::SSL::SSLContext.new) + ssl2.sync_close = true + + machine.ssl_set_bio(ssl1) + machine.ssl_set_bio(ssl2) + + f = machine.spin { ssl1.accept rescue nil } + + ssl2.connect + refute_equal 0, @machine.metrics[:total_ops] + + conn1 = machine.connection(ssl1) + conn2 = machine.connection(ssl2) + + assert_equal 10, conn1.write('foobar', "\n", 'baz') + + assert_equal "foobar\nbaz", conn2.read(10) + ensure + ssl1.close rescue nil + ss2.close rescue nil + if f + machine.terminate(f) + machine.join(f) + end + end +end + +class ConnectionRespTest < ConnectionBaseTest + def test_connection_resp_read machine.write(@wfd, "+foo bar\r\n") - assert_equal "foo bar", stream.resp_decode + assert_equal "foo bar", conn.resp_read machine.write(@wfd, "+baz\r\n") - assert_equal "baz", stream.resp_decode + assert_equal "baz", conn.resp_read machine.write(@wfd, "-foobar\r\n") - o = stream.resp_decode - assert_kind_of UM::Stream::RESPError, o + o = conn.resp_read + assert_kind_of UM::Connection::RESPError, o assert_equal "foobar", o.message machine.write(@wfd, "!3\r\nbaz\r\n") - o = stream.resp_decode - assert_kind_of UM::Stream::RESPError, o + o = conn.resp_read + assert_kind_of UM::Connection::RESPError, o assert_equal "baz", o.message machine.write(@wfd, ":123\r\n") - assert_equal 123, stream.resp_decode + assert_equal 123, conn.resp_read machine.write(@wfd, ":-123\r\n") - assert_equal(-123, stream.resp_decode) + assert_equal(-123, conn.resp_read) machine.write(@wfd, ",123.321\r\n") - assert_equal 123.321, stream.resp_decode + assert_equal 123.321, conn.resp_read machine.write(@wfd, "_\r\n") - assert_nil stream.resp_decode + assert_nil conn.resp_read machine.write(@wfd, "#t\r\n") - assert_equal true, stream.resp_decode + assert_equal true, conn.resp_read machine.write(@wfd, "#f\r\n") - assert_equal false, stream.resp_decode + assert_equal false, conn.resp_read machine.write(@wfd, "$6\r\nfoobar\r\n") - assert_equal "foobar", stream.resp_decode + assert_equal "foobar", conn.resp_read machine.write(@wfd, "$3\r\nbaz\r\n") - assert_equal "baz", stream.resp_decode + assert_equal "baz", conn.resp_read machine.write(@wfd, "=10\r\ntxt:foobar\r\n") - assert_equal "foobar", stream.resp_decode + assert_equal "foobar", conn.resp_read machine.write(@wfd, "*3\r\n+foo\r\n:42\r\n$3\r\nbar\r\n") - assert_equal ['foo', 42, 'bar'], stream.resp_decode + assert_equal ['foo', 42, 'bar'], conn.resp_read machine.write(@wfd, "~3\r\n+foo\r\n:42\r\n$3\r\nbar\r\n") - assert_equal ['foo', 42, 'bar'], stream.resp_decode + assert_equal ['foo', 42, 'bar'], conn.resp_read machine.write(@wfd, ">3\r\n+foo\r\n:42\r\n$3\r\nbar\r\n") - assert_equal ['foo', 42, 'bar'], stream.resp_decode + assert_equal ['foo', 42, 'bar'], conn.resp_read machine.write(@wfd, "%2\r\n+a\r\n:42\r\n+b\r\n:43\r\n") - assert_equal({ 'a' => 42, 'b' => 43 }, stream.resp_decode) + assert_equal({ 'a' => 42, 'b' => 43 }, conn.resp_read) machine.write(@wfd, "|2\r\n+a\r\n:42\r\n+b\r\n:43\r\n") - assert_equal({ 'a' => 42, 'b' => 43 }, stream.resp_decode) + assert_equal({ 'a' => 42, 'b' => 43 }, conn.resp_read) machine.write(@wfd, "%2\r\n+a\r\n:42\r\n+b\r\n*3\r\n+foo\r\n+bar\r\n+baz\r\n") - assert_equal({ 'a' => 42, 'b' => ['foo', 'bar', 'baz'] }, stream.resp_decode) + assert_equal({ 'a' => 42, 'b' => ['foo', 'bar', 'baz'] }, conn.resp_read) end - def test_stream_resp_decode_segmented + def test_connection_resp_read_segmented machine.write(@wfd, "\n") - assert_equal "", stream.get_line(0) + assert_equal "", conn.read_line(0) machine.write(@wfd, "+foo") machine.write(@wfd, " ") machine.write(@wfd, "bar\r") machine.write(@wfd, "\n") - assert_equal "foo bar", stream.resp_decode + assert_equal "foo bar", conn.resp_read machine.write(@wfd, "$6\r") machine.write(@wfd, "\nbazbug") machine.write(@wfd, "\r\n") - assert_equal "bazbug", stream.resp_decode + assert_equal "bazbug", conn.resp_read end - def test_stream_resp_encode - s = UM::Stream + def test_connection_resp_write + writer = machine.connection(@wfd) + + writer.resp_write(nil); + assert_equal "_\r\n", conn.read(-100) + + writer.resp_write(true); + assert_equal "#t\r\n", conn.read(-100) + + writer.resp_write(false); + assert_equal "#f\r\n", conn.read(-100) + + writer.resp_write(42); + assert_equal ":42\r\n", conn.read(-100) + + writer.resp_write(42.1) + assert_equal ",42.1\r\n", conn.read(-100) + + writer.resp_write('foobar') + assert_equal "$6\r\nfoobar\r\n", conn.read(-100) + + writer.resp_write('פובאר') + assert_equal (+"$10\r\nפובאר\r\n").force_encoding('ASCII-8BIT'), conn.read(-100) + + writer.resp_write(['foo', 'bar']) + assert_equal "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n", conn.read(-100) + + writer.resp_write({ 'foo' => 'bar', 'baz' => 42 }) + assert_equal "%2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$3\r\nbaz\r\n:42\r\n", conn.read(-100) + end + + def test_connection_resp_encode + s = UM::Connection assert_equal "_\r\n", s.resp_encode(+'', nil) assert_equal "#t\r\n", s.resp_encode(+'', true) assert_equal "#f\r\n", s.resp_encode(+'', false) @@ -404,7 +507,7 @@ def test_stream_resp_encode end end -class StreamStressTest < UMBaseTest +class ConnectionStressTest < UMBaseTest def setup super @@ -422,8 +525,8 @@ def setup def start_connection_fiber(fd) machine.spin do - stream = UM::Stream.new(machine, fd) - while (msg = stream.get_line(0)) + conn = UM::Connection.new(machine, fd) + while (msg = conn.read_line(0)) @received << msg end machine.sendv(fd, @response_headers, @response_body) @@ -434,7 +537,7 @@ def start_connection_fiber(fd) end end - def test_stream_server_big_lines + def test_connection_server_big_lines server_fibers = [] server_fibers << machine.spin do machine.accept_each(@listen_fd) { |fd| @@ -477,7 +580,7 @@ def test_stream_server_big_lines assert_equal msg * client_count, @received.map { it + "\n" }.join end - def test_stream_server_http + def test_connection_server_http server_fibers = [] server_fibers << machine.spin do machine.accept_each(@listen_fd) { |fd| @@ -517,50 +620,50 @@ def test_stream_server_http end end -class StreamDevRandomTest < UMBaseTest - def test_stream_dev_random_get_line +class ConnectionDevRandomTest < UMBaseTest + def test_connection_dev_random_read_line fd = machine.open('/dev/random', UM::O_RDONLY) - stream = UM::Stream.new(machine, fd) + conn = UM::Connection.new(machine, fd) n = 100000 lines = [] n.times { - lines << stream.get_line(0) + lines << conn.read_line(0) } assert_equal n, lines.size ensure - stream.clear rescue nil + conn.clear rescue nil machine.close(fd) rescue nil end - def get_line_do(n, acc) + def read_line_do(n, acc) fd = @machine.open('/dev/random', UM::O_RDONLY) - stream = UM::Stream.new(@machine, fd) - n.times { acc << stream.get_line(0) } + conn = UM::Connection.new(@machine, fd) + n.times { acc << conn.read_line(0) } end - def test_stream_dev_random_get_line_concurrent + def test_connection_dev_random_read_line_concurrent acc = [] c = 1 n = 100000 ff = c.times.map { - machine.spin { get_line_do(n, acc) } + machine.spin { read_line_do(n, acc) } } machine.await(ff) assert_equal c * n, acc.size end - def test_stream_dev_random_get_string + def test_connection_dev_random_read fd = machine.open('/dev/random', UM::O_RDONLY) - stream = UM::Stream.new(machine, fd) + conn = UM::Connection.new(machine, fd) n = 256 size = 65536 * 8 count = 0 # lines = [] n.times { - l = stream.get_string(size) + l = conn.read(size) refute_nil l assert_equal size, l.bytesize @@ -569,49 +672,62 @@ def test_stream_dev_random_get_string assert_equal n, count ensure - stream.clear rescue nil + conn.clear rescue nil end end -class StreamModeTest < UMBaseTest - def test_stream_default_mode +class ConnectionModeTest < UMBaseTest + def test_connection_default_mode r, w = UM.pipe - stream = UM::Stream.new(machine, r) - assert_equal :bp_read, stream.mode + conn = UM::Connection.new(machine, r) + assert_equal :fd, conn.mode ensure machine.close(r) rescue nil machine.close(w) rescue nil end - def test_stream_recv_mode_non_socket + def test_connection_default_mode_ssl + authority = Localhost::Authority.fetch + @server_ctx = authority.server_context + sock1, sock2 = UNIXSocket.pair + + s1 = OpenSSL::SSL::SSLSocket.new(sock1, @server_ctx) + conn = UM::Connection.new(machine, s1) + assert_equal :ssl, conn.mode + ensure + sock1&.close rescue nil + sock2&.close rescue nil + end + + def test_connection_socket_mode_non_socket r, w = UM.pipe machine.write(w, 'foobar') machine.close(w) - stream = UM::Stream.new(machine, r, :bp_recv) - assert_equal :bp_recv, stream.mode - # assert :bp_recv, stream.mode - assert_raises(Errno::ENOTSOCK) { stream.get_string(0) } + conn = UM::Connection.new(machine, r, :socket) + assert_equal :socket, conn.mode + # assert :socket, conn.mode + assert_raises(Errno::ENOTSOCK) { conn.read(0) } ensure machine.close(r) rescue nil machine.close(w) rescue nil end - def test_stream_recv_mode_socket + def test_connection_socket_mode_socket r, w = UM.socketpair(UM::AF_UNIX, UM::SOCK_STREAM, 0) machine.write(w, 'foobar') machine.close(w) - stream = UM::Stream.new(machine, r, :bp_recv) - assert_equal :bp_recv, stream.mode - buf = stream.get_string(0) + conn = UM::Connection.new(machine, r, :socket) + assert_equal :socket, conn.mode + buf = conn.read(0) assert_equal 'foobar', buf ensure machine.close(r) rescue nil machine.close(w) rescue nil end - def test_stream_ssl_mode + def test_connection_ssl_mode authority = Localhost::Authority.fetch @server_ctx = authority.server_context sock1, sock2 = UNIXSocket.pair @@ -635,18 +751,18 @@ def test_stream_ssl_mode assert_equal 10, @machine.ssl_write(s1, buf, buf.bytesize) buf = +'' - stream = UM::Stream.new(machine, s2, :ssl) - assert_equal "foobar", stream.get_line(0) + conn = UM::Connection.new(machine, s2, :ssl) + assert_equal "foobar", conn.read_line(0) buf = "buh" @machine.ssl_write(s1, buf, buf.bytesize) - assert_equal "baz", stream.get_string(0) - assert_equal "buh", stream.get_string(0) + assert_equal "baz", conn.read(0) + assert_equal "buh", conn.read(0) s1.close - assert_nil stream.get_string(0) + assert_nil conn.read(0) rescue => e p e p e.backtrace @@ -658,35 +774,35 @@ def test_stream_ssl_mode end end -class StreamByteCountsTest < StreamBaseTest - def test_stream_byte_counts +class ConnectionByteCountsTest < ConnectionBaseTest + def test_connection_byte_counts machine.write(@wfd, "foobar") - assert_equal 0, stream.consumed - assert_equal 0, stream.pending + assert_equal 0, conn.consumed + assert_equal 0, conn.pending - buf = stream.get_string(2) + buf = conn.read(2) assert_equal 'fo', buf - assert_equal 2, stream.consumed - assert_equal 4, stream.pending + assert_equal 2, conn.consumed + assert_equal 4, conn.pending - buf = stream.get_string(3) + buf = conn.read(3) assert_equal 'oba', buf - assert_equal 5, stream.consumed - assert_equal 1, stream.pending + assert_equal 5, conn.consumed + assert_equal 1, conn.pending machine.write(@wfd, "abc\ndef") machine.snooze - assert_equal 5, stream.consumed - assert_equal 1, stream.pending + assert_equal 5, conn.consumed + assert_equal 1, conn.pending - buf = stream.get_line(0) + buf = conn.read_line(0) assert_equal 'rabc', buf - assert_equal 10, stream.consumed - assert_equal 3, stream.pending + assert_equal 10, conn.consumed + assert_equal 3, conn.pending - stream.clear - assert_equal 10, stream.consumed - assert_equal 0, stream.pending + conn.clear + assert_equal 10, conn.consumed + assert_equal 0, conn.pending end end diff --git a/test/test_um.rb b/test/test_um.rb index 95256c4..ee94cef 100644 --- a/test/test_um.rb +++ b/test/test_um.rb @@ -3506,53 +3506,53 @@ def test_pr_set_child_subreaper end end -class StreamMethodTest < UMBaseTest +class ConnectionMethodTest < UMBaseTest def setup super @rfd, @wfd = UM.pipe end def teardown - @stream = nil + @conn = nil machine.close(@rfd) rescue nil machine.close(@wfd) rescue nil super end - def test_stream_method + def test_connection_method machine.write(@wfd, "foobar") machine.close(@wfd) - stream = machine.stream(@rfd) - assert_kind_of UM::Stream, stream + conn = machine.connection(@rfd) + assert_kind_of UM::Connection, conn - buf = stream.get_string(3) + buf = conn.read(3) assert_equal 'foo', buf - buf = stream.get_string(-6) + buf = conn.read(-6) assert_equal 'bar', buf - assert stream.eof? + assert conn.eof? - stream.clear + conn.clear end - def test_stream_method_with_block + def test_connection_method_with_block machine.write(@wfd, "foobar") machine.close(@wfd) bufs = [] - stream_obj = nil - res = machine.stream(@rfd) do |s| - stream_obj = s + conn_obj = nil + res = machine.connection(@rfd) do |s| + conn_obj = s - bufs << s.get_string(3) - bufs << s.get_string(-6) + bufs << s.read(3) + bufs << s.read(-6) :foo end - assert_kind_of UM::Stream, stream_obj - assert stream_obj.eof? + assert_kind_of UM::Connection, conn_obj + assert conn_obj.eof? assert_equal ['foo', 'bar'], bufs assert_equal :foo, res end @@ -3688,7 +3688,7 @@ def test_tee assert_equal 6, len1 assert_equal 'foobar', result2 - assert_equal 6, len2 + assert_equal 6, len2 ensure machine.terminate(f) machine.join(f)