diff --git a/packages/wobe/package.json b/packages/wobe/package.json index c7bd98b..fc188cd 100644 --- a/packages/wobe/package.json +++ b/packages/wobe/package.json @@ -1,6 +1,6 @@ { "name": "wobe", - "version": "1.1.15", + "version": "1.1.16", "description": "A fast, lightweight and simple web framework", "keywords": [ "bun", diff --git a/packages/wobe/src/Wobe.ts b/packages/wobe/src/Wobe.ts index 1f80029..66b90a5 100644 --- a/packages/wobe/src/Wobe.ts +++ b/packages/wobe/src/Wobe.ts @@ -58,6 +58,7 @@ export type Hook = 'beforeHandler' | 'afterHandler' | 'beforeAndAfterHandler' * @param backpressureLimit The limit of the backpressure * @param closeOnBackpressureLimit Close the WebSocket server if the backpressure limit is reached * @param beforeWebSocketUpgrade Array of handlers before the WebSocket server is upgraded + * @param getUpgradeData Function to extract data passed to the WebSocket on upgrade * @param onOpen Handler when the WebSocket server is opened * @param onMessage Handler when the WebSocket server receives a message * @param onClose Handler when the WebSocket server is closed @@ -71,6 +72,7 @@ export interface WobeWebSocket { backpressureLimit?: number closeOnBackpressureLimit?: boolean beforeWebSocketUpgrade?: Array> + getUpgradeData?: (context: Context) => unknown onOpen?(ws: ServerWebSocket): void onMessage?(ws: ServerWebSocket, message: string | Buffer): void onClose?(ws: ServerWebSocket, code: number, message: string | Buffer): void diff --git a/packages/wobe/src/adapters/bun/bun.ts b/packages/wobe/src/adapters/bun/bun.ts index 4c380ed..16c43be 100644 --- a/packages/wobe/src/adapters/bun/bun.ts +++ b/packages/wobe/src/adapters/bun/bun.ts @@ -124,7 +124,11 @@ export const BunAdapter = (): RuntimeAdapter => ({ for (const hookBeforeSocketUpgrade of webSocket.beforeWebSocketUpgrade || []) await hookBeforeSocketUpgrade(context) - if (server.upgrade(req)) return + const upgraded = server.upgrade(req, { + data: webSocket.getUpgradeData?.(context) as any, + }) + + if (upgraded) return } if (!context.handler) { @@ -146,5 +150,5 @@ export const BunAdapter = (): RuntimeAdapter => ({ } }, }), - stopServer: (server) => server.stop(), + stopServer: (server) => server.stop(true), }) diff --git a/packages/wobe/src/adapters/bun/websocket.test.ts b/packages/wobe/src/adapters/bun/websocket.test.ts index f1cc6af..6a595d9 100644 --- a/packages/wobe/src/adapters/bun/websocket.test.ts +++ b/packages/wobe/src/adapters/bun/websocket.test.ts @@ -1,3 +1,4 @@ +import type { ServerWebSocket } from 'bun' import { describe, expect, it, beforeAll, afterAll, mock, beforeEach } from 'bun:test' import { Wobe } from '../../Wobe' import getPort from 'get-port' @@ -13,6 +14,15 @@ const waitWebsocketClosed = (ws: WebSocket) => ws.onclose = resolve }) +const waitForMockCall = (mockFn: { mock: { calls: unknown[] } }) => + new Promise((resolve) => { + const check = () => { + if (mockFn.mock.calls.length > 0) resolve() + else setTimeout(check, 10) + } + check() + }) + describe.skipIf(process.env.NODE_TEST === 'true')('Bun - websocket', () => { const mockOnOpen = mock(() => {}) const mockOnMessage = mock(() => {}) @@ -142,6 +152,47 @@ describe.skipIf(process.env.NODE_TEST === 'true')('Bun - websocket', () => { wobe2.stop() }) + it('should pass getUpgradeData to ws.data on open and message', async () => { + const port2 = await getPort() + + const upgradeData = { userId: 'user-123', role: 'admin' } + const mockOnOpenWithData = mock((_ws: ServerWebSocket) => {}) + const mockOnMessageWithData = mock( + (_ws: ServerWebSocket, _message: string | Buffer) => {}, + ) + + const wobe2 = new Wobe() + + wobe2 + .useWebSocket({ + path: '/ws', + getUpgradeData: () => upgradeData, + onOpen: mockOnOpenWithData, + onMessage: mockOnMessageWithData, + }) + .listen(port2) + + const ws = new WebSocket(`ws://localhost:${port2}/ws`) + + await waitWebsocketOpened(ws) + + ws.send('Hello') + + await waitForMockCall(mockOnMessageWithData) + + expect(mockOnOpenWithData).toHaveBeenCalledTimes(1) + const [openWs] = mockOnOpenWithData.mock.calls[0]! + expect(openWs.data).toEqual(upgradeData) + + expect(mockOnMessageWithData).toHaveBeenCalledTimes(1) + const [messageWs] = mockOnMessageWithData.mock.calls[0]! + expect(messageWs.data).toEqual(upgradeData) + + ws.close() + + wobe2.stop() + }) + it('should not established the socket connection if one of the beforeSocketUpgrade failed', async () => { const port2 = await getPort()