diff --git a/internal/stream/server.go b/internal/stream/server.go new file mode 100644 index 00000000..82b69f7e --- /dev/null +++ b/internal/stream/server.go @@ -0,0 +1,81 @@ +package stream + +import ( + "context" + "sync" +) + +// MessageHandler handles an incoming stream message on the server side. +// It is called in a new goroutine for each received message. +// The handler must eventually release the mutex (by calling mut.Unlock()) +// to allow the next request to be processed. The finished channel is used +// to send response messages back to the client. +type MessageHandler func(ctx context.Context, mut *sync.Mutex, finished chan<- *Message, msg *Message) + +// Server implements the Gorums gRPC service for handling node streams. +type Server struct { + handlers map[string]MessageHandler + buffer uint + connectCallback func(context.Context) + UnimplementedGorumsServer +} + +// NewServer creates a new StreamServer with the given buffer size +// and optional connect callback. +func NewServer(buffer uint, connectCallback func(context.Context)) *Server { + return &Server{ + handlers: make(map[string]MessageHandler), + buffer: buffer, + connectCallback: connectCallback, + } +} + +// RegisterHandler registers a message handler for the specified method name. +func (s *Server) RegisterHandler(method string, handler MessageHandler) { + s.handlers[method] = handler +} + +// NodeStream handles a connection to a single client. The stream is aborted if there +// is any error with sending or receiving. +func (s *Server) NodeStream(srv Gorums_NodeStreamServer) error { + var mut sync.Mutex + finished := make(chan *Message, s.buffer) + ctx := srv.Context() + + if s.connectCallback != nil { + s.connectCallback(ctx) + } + + go func() { + for { + select { + case <-ctx.Done(): + return + case streamOut := <-finished: + if err := srv.Send(streamOut); err != nil { + return + } + } + } + }() + + // Start with a locked mutex + mut.Lock() + defer mut.Unlock() + + for { + streamIn, err := srv.Recv() + if err != nil { + return err + } + if handler, ok := s.handlers[streamIn.GetMethod()]; ok { + // We start the handler in a new goroutine in order to allow multiple + // handlers to run concurrently. However, to preserve request ordering, + // the handler must unlock the shared mutex when it has either finished, + // or when it is safe to start processing the next request. + go handler(ctx, &mut, finished, streamIn) + // Wait until the handler releases the mutex. + mut.Lock() + } + } +} diff --git a/server.go b/server.go index cba5bdfb..c0192bb4 100644 --- a/server.go +++ b/server.go @@ -20,87 +20,6 @@ type ( Handler func(ServerCtx, *Message) (*Message, error) ) -type streamServer struct { - handlers map[string]Handler - opts *serverOptions - stream.UnimplementedGorumsServer -} - -func newStreamServer(opts *serverOptions) *streamServer { - return &streamServer{ - handlers: make(map[string]Handler), - opts: opts, - } -} - -// NodeStream handles a connection to a single client. The stream is aborted if there -// is any error with sending or receiving. -func (s *streamServer) NodeStream(srv stream.Gorums_NodeStreamServer) error { - var mut sync.Mutex // used to achieve mutex between request handlers - finished := make(chan *stream.Message, s.opts.buffer) - ctx := srv.Context() - - if s.opts.connectCallback != nil { - s.opts.connectCallback(ctx) - } - - go func() { - for { - select { - case <-ctx.Done(): - return - case streamOut := <-finished: - if err := srv.Send(streamOut); err != nil { - return - } - } - } - }() - - // Start with a locked mutex - mut.Lock() - defer mut.Unlock() - - for { - streamIn, err := srv.Recv() - if err != nil { - return err - } - if handler, ok := s.handlers[streamIn.GetMethod()]; ok { - // We start the handler in a new goroutine in order to allow multiple handlers to run concurrently. - // However, to preserve request ordering, the handler must unlock the shared mutex when it has either - // finished, or when it is safe to start processing the next request. - // - // This func() is the default interceptor; it is the first and last handler in the chain. - // It is responsible for releasing the mutex when the handler chain is done. - go func() { - srvCtx := newServerCtx(streamIn.AppendToIncomingContext(ctx), &mut, finished) - defer srvCtx.Release() - - msg, err := stream.UnmarshalRequest(streamIn) - in := &Message{Msg: msg, Message: streamIn} - if err != nil { - _ = srvCtx.SendMessage(messageWithError(in, nil, err)) - return - } - - out, err := handler(srvCtx, in) - // If there is no response and no error, we do not send anything back to the client. - // This corresponds to a unidirectional message from client to server, where clients - // are not expected to receive a response. - if out == nil && err == nil { - return - } - _ = srvCtx.SendMessage(messageWithError(in, out, err)) - // We ignore the error from SendMessage here; it means that the stream is closed. - // The for-loop above will exit on the next Recv call. - }() - // Wait until the handler releases the mutex. - mut.Lock() - } - } -} - type serverOptions struct { buffer uint grpcOpts []grpc.ServerOption @@ -169,7 +88,7 @@ func chainInterceptors(final Handler, interceptors ...Interceptor) Handler { // Server serves all ordering based RPCs using registered handlers. type Server struct { - srv *streamServer + srv *stream.Server grpcServer *grpc.Server interceptors []Interceptor } @@ -181,7 +100,7 @@ func NewServer(opts ...ServerOption) *Server { opt(&serverOpts) } s := &Server{ - srv: newStreamServer(&serverOpts), + srv: stream.NewServer(serverOpts.buffer, serverOpts.connectCallback), grpcServer: grpc.NewServer(serverOpts.grpcOpts...), interceptors: serverOpts.interceptors, } @@ -193,7 +112,29 @@ func NewServer(opts ...ServerOption) *Server { // // This function should only be used by generated code. func (s *Server) RegisterHandler(method string, handler Handler) { - s.srv.handlers[method] = chainInterceptors(handler, s.interceptors...) + wrapped := chainInterceptors(handler, s.interceptors...) + s.srv.RegisterHandler(method, func(ctx context.Context, mut *sync.Mutex, finished chan<- *stream.Message, streamIn *stream.Message) { + srvCtx := newServerCtx(streamIn.AppendToIncomingContext(ctx), mut, finished) + defer srvCtx.Release() + + msg, err := stream.UnmarshalRequest(streamIn) + in := &Message{Msg: msg, Message: streamIn} + if err != nil { + _ = srvCtx.SendMessage(messageWithError(in, nil, err)) + return + } + + out, err := wrapped(srvCtx, in) + // If there is no response and no error, we do not send anything back to the client. + // This corresponds to a unidirectional message from client to server, where clients + // are not expected to receive a response. + if out == nil && err == nil { + return + } + _ = srvCtx.SendMessage(messageWithError(in, out, err)) + // We ignore the error from SendMessage here; it means that the stream is closed. + // The for-loop in srv.NodeStream() will exit on the next Recv() call. + }) } // Serve starts serving on the listener.