Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions internal/stream/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package stream

import (
"context"
"sync"
)

// MessageHandler handles an incoming stream message on the server side.
// It is called in a new goroutine for each received message.
// The handler must eventually release the mutex (by calling mut.Unlock())
// to allow the next request to be processed. The finished channel is used
// to send response messages back to the client.
type MessageHandler func(ctx context.Context, mut *sync.Mutex, finished chan<- *Message, msg *Message)

// Server implements the Gorums gRPC service for handling node streams.
type Server struct {
handlers map[string]MessageHandler
buffer uint
connectCallback func(context.Context)
UnimplementedGorumsServer
}

// NewServer creates a new StreamServer with the given buffer size
// and optional connect callback.
func NewServer(buffer uint, connectCallback func(context.Context)) *Server {
return &Server{
handlers: make(map[string]MessageHandler),
buffer: buffer,
connectCallback: connectCallback,
}
}

// RegisterHandler registers a message handler for the specified method name.
func (s *Server) RegisterHandler(method string, handler MessageHandler) {
s.handlers[method] = handler
}

// NodeStream handles a connection to a single client. The stream is aborted if there
// is any error with sending or receiving.
func (s *Server) NodeStream(srv Gorums_NodeStreamServer) error {
var mut sync.Mutex
finished := make(chan *Message, s.buffer)
ctx := srv.Context()

if s.connectCallback != nil {
s.connectCallback(ctx)
}

go func() {
for {
select {
case <-ctx.Done():
return
case streamOut := <-finished:
if err := srv.Send(streamOut); err != nil {
return
}
}
}
}()

// Start with a locked mutex
mut.Lock()
defer mut.Unlock()

for {
streamIn, err := srv.Recv()
if err != nil {
return err
}
if handler, ok := s.handlers[streamIn.GetMethod()]; ok {
// We start the handler in a new goroutine in order to allow multiple
// handlers to run concurrently. However, to preserve request ordering,
// the handler must unlock the shared mutex when it has either finished,
// or when it is safe to start processing the next request.
go handler(ctx, &mut, finished, streamIn)
// Wait until the handler releases the mutex.
mut.Lock()
}
}
}
109 changes: 25 additions & 84 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,87 +20,6 @@ type (
Handler func(ServerCtx, *Message) (*Message, error)
)

type streamServer struct {
handlers map[string]Handler
opts *serverOptions
stream.UnimplementedGorumsServer
}

func newStreamServer(opts *serverOptions) *streamServer {
return &streamServer{
handlers: make(map[string]Handler),
opts: opts,
}
}

// NodeStream handles a connection to a single client. The stream is aborted if there
// is any error with sending or receiving.
func (s *streamServer) NodeStream(srv stream.Gorums_NodeStreamServer) error {
var mut sync.Mutex // used to achieve mutex between request handlers
finished := make(chan *stream.Message, s.opts.buffer)
ctx := srv.Context()

if s.opts.connectCallback != nil {
s.opts.connectCallback(ctx)
}

go func() {
for {
select {
case <-ctx.Done():
return
case streamOut := <-finished:
if err := srv.Send(streamOut); err != nil {
return
}
}
}
}()

// Start with a locked mutex
mut.Lock()
defer mut.Unlock()

for {
streamIn, err := srv.Recv()
if err != nil {
return err
}
if handler, ok := s.handlers[streamIn.GetMethod()]; ok {
// We start the handler in a new goroutine in order to allow multiple handlers to run concurrently.
// However, to preserve request ordering, the handler must unlock the shared mutex when it has either
// finished, or when it is safe to start processing the next request.
//
// This func() is the default interceptor; it is the first and last handler in the chain.
// It is responsible for releasing the mutex when the handler chain is done.
go func() {
srvCtx := newServerCtx(streamIn.AppendToIncomingContext(ctx), &mut, finished)
defer srvCtx.Release()

msg, err := stream.UnmarshalRequest(streamIn)
in := &Message{Msg: msg, Message: streamIn}
if err != nil {
_ = srvCtx.SendMessage(messageWithError(in, nil, err))
return
}

out, err := handler(srvCtx, in)
// If there is no response and no error, we do not send anything back to the client.
// This corresponds to a unidirectional message from client to server, where clients
// are not expected to receive a response.
if out == nil && err == nil {
return
}
_ = srvCtx.SendMessage(messageWithError(in, out, err))
// We ignore the error from SendMessage here; it means that the stream is closed.
// The for-loop above will exit on the next Recv call.
}()
// Wait until the handler releases the mutex.
mut.Lock()
}
}
}

type serverOptions struct {
buffer uint
grpcOpts []grpc.ServerOption
Expand Down Expand Up @@ -169,7 +88,7 @@ func chainInterceptors(final Handler, interceptors ...Interceptor) Handler {

// Server serves all ordering based RPCs using registered handlers.
type Server struct {
srv *streamServer
srv *stream.Server
grpcServer *grpc.Server
interceptors []Interceptor
}
Expand All @@ -181,7 +100,7 @@ func NewServer(opts ...ServerOption) *Server {
opt(&serverOpts)
}
s := &Server{
srv: newStreamServer(&serverOpts),
srv: stream.NewServer(serverOpts.buffer, serverOpts.connectCallback),
grpcServer: grpc.NewServer(serverOpts.grpcOpts...),
interceptors: serverOpts.interceptors,
}
Expand All @@ -193,7 +112,29 @@ func NewServer(opts ...ServerOption) *Server {
//
// This function should only be used by generated code.
func (s *Server) RegisterHandler(method string, handler Handler) {
s.srv.handlers[method] = chainInterceptors(handler, s.interceptors...)
wrapped := chainInterceptors(handler, s.interceptors...)
s.srv.RegisterHandler(method, func(ctx context.Context, mut *sync.Mutex, finished chan<- *stream.Message, streamIn *stream.Message) {
srvCtx := newServerCtx(streamIn.AppendToIncomingContext(ctx), mut, finished)
defer srvCtx.Release()

msg, err := stream.UnmarshalRequest(streamIn)
in := &Message{Msg: msg, Message: streamIn}
if err != nil {
_ = srvCtx.SendMessage(messageWithError(in, nil, err))
return
}

out, err := wrapped(srvCtx, in)
// If there is no response and no error, we do not send anything back to the client.
// This corresponds to a unidirectional message from client to server, where clients
// are not expected to receive a response.
if out == nil && err == nil {
return
}
_ = srvCtx.SendMessage(messageWithError(in, out, err))
// We ignore the error from SendMessage here; it means that the stream is closed.
// The for-loop in srv.NodeStream() will exit on the next Recv() call.
})
}

// Serve starts serving on the listener.
Expand Down