diff --git a/src/route-options-requests-for-cors.test.ts b/src/route-options-requests-for-cors.test.ts new file mode 100644 index 0000000..3feee23 --- /dev/null +++ b/src/route-options-requests-for-cors.test.ts @@ -0,0 +1,94 @@ +import { beforeEach, describe, expect, it, jest } from '@jest/globals'; +import type { Express } from 'express'; +import type { Route } from './parse-stack-config.js'; +import { routeOptionsRequestsForCors } from './route-options-requests-for-cors.js'; + +// Mock print utility +jest.mock('./utils/print'); + +describe('routeOptionsRequestsForCors', () => { + let app: jest.Mocked; + let optionsHandlers: Record; + + beforeEach(() => { + optionsHandlers = {}; + app = { + options: jest.fn((path: string, handler: Function) => { + optionsHandlers[path] = handler; + }), + } as any; + jest.clearAllMocks(); + }); + + it('should set up CORS OPTIONS handlers for routes with corsEnabled', () => { + const routes: Route[] = [ + { publicPath: '/foo', httpMethod: 'GET', corsEnabled: true } as any, + { publicPath: '/foo', httpMethod: 'POST', corsEnabled: true } as any, + { publicPath: '/bar', httpMethod: 'PUT', corsEnabled: true } as any, + { publicPath: '/baz', httpMethod: 'DELETE', corsEnabled: false } as any, + ]; + + routeOptionsRequestsForCors(routes, app); + + expect(app.options).toHaveBeenCalledTimes(2); + expect(app.options).toHaveBeenCalledWith('/foo', expect.any(Function)); + expect(app.options).toHaveBeenCalledWith('/bar', expect.any(Function)); + expect(optionsHandlers['/foo']).toBeDefined(); + expect(optionsHandlers['/bar']).toBeDefined(); + expect(optionsHandlers['/baz']).toBeUndefined(); + }); + + it('should set correct headers and status in the handler', () => { + const routes: Route[] = [ + { publicPath: '/foo', httpMethod: 'GET', corsEnabled: true } as any, + { publicPath: '/foo', httpMethod: 'POST', corsEnabled: true } as any, + ]; + + routeOptionsRequestsForCors(routes, app); + + const res = { + header: jest.fn(), + sendStatus: jest.fn(), + }; + + if (!optionsHandlers['/foo']) { + throw new Error('Expected options handler for /foo to be defined'); + } + optionsHandlers['/foo']({}, res); + + expect(res.header).toHaveBeenCalledWith('Access-Control-Allow-Origin', '*'); + expect(res.header).toHaveBeenCalledWith('Access-Control-Allow-Methods', 'GET, POST'); + expect(res.header).toHaveBeenCalledWith( + 'Access-Control-Allow-Headers', + 'Content-Type, Authorization', + ); + expect(res.sendStatus).toHaveBeenCalledWith(204); + }); + + it('should not set up handler for routes without corsEnabled', () => { + const routes: Route[] = [ + { publicPath: '/foo', httpMethod: 'GET', corsEnabled: false } as any, + { publicPath: '/bar', httpMethod: 'POST', corsEnabled: false } as any, + ]; + + routeOptionsRequestsForCors(routes, app); + + expect(app.options).not.toHaveBeenCalled(); + expect(optionsHandlers['/foo']).toBeUndefined(); + expect(optionsHandlers['/bar']).toBeUndefined(); + }); + + it('should handle routes with missing httpMethod gracefully', () => { + const routes: Route[] = [ + { publicPath: '/foo', corsEnabled: true } as any, + { publicPath: '/bar', httpMethod: 'PUT', corsEnabled: true } as any, + ]; + + routeOptionsRequestsForCors(routes, app); + + expect(app.options).toHaveBeenCalledTimes(1); + expect(app.options).toHaveBeenCalledWith('/bar', expect.any(Function)); + expect(optionsHandlers['/foo']).toBeUndefined(); + expect(optionsHandlers['/bar']).toBeDefined(); + }); +}); diff --git a/src/route-options-requests-for-cors.ts b/src/route-options-requests-for-cors.ts new file mode 100644 index 0000000..c7709c9 --- /dev/null +++ b/src/route-options-requests-for-cors.ts @@ -0,0 +1,29 @@ +import type { Express } from 'express'; +import type { Route } from './parse-stack-config.js'; +import { print } from './utils/print.js'; + +export function routeOptionsRequestsForCors(routes: readonly Route[], app: Express) { + const corsEnabledMethodsByRoute = routes.reduce((corsEnabledMethodsByRoute, route) => { + if (route.corsEnabled && route.httpMethod) { + print.info(`CORS is enabled for route: ${route.publicPath} with method: ${route.httpMethod}`); + + const existingMethods = corsEnabledMethodsByRoute.get(route.publicPath); + if (existingMethods) { + existingMethods.push(route.httpMethod); + } else { + corsEnabledMethodsByRoute.set(route.publicPath, [route.httpMethod]); + } + } + return corsEnabledMethodsByRoute; + }, new Map()); + corsEnabledMethodsByRoute.forEach((methods, path) => { + print.info(`Setting up CORS for path: ${path} with methods: ${methods.join(`, `)}`); + + app.options(path, (_req, res) => { + res.header(`Access-Control-Allow-Origin`, `*`); + res.header(`Access-Control-Allow-Methods`, methods.join(`, `)); + res.header(`Access-Control-Allow-Headers`, `Content-Type, Authorization`); + res.sendStatus(204); + }); + }); +} diff --git a/src/start-command.ts b/src/start-command.ts index d91c2a0..b26ec02 100644 --- a/src/start-command.ts +++ b/src/start-command.ts @@ -17,6 +17,7 @@ import getPort from 'get-port'; import * as lambdaLocal from 'lambda-local'; import { mkdirp } from 'mkdirp'; import { dirname } from 'path'; +import { routeOptionsRequestsForCors } from './route-options-requests-for-cors.js'; const commandName = `start`; @@ -51,6 +52,8 @@ export const startCommand: CommandModule<{}, { readonly port: number }> = { const routes = sortRoutes(stackConfig.routes); + routeOptionsRequestsForCors(routes, app); + for (const route of routes) { if (route.type === `function`) { const { cacheTtlInSeconds = 300 } = route; diff --git a/src/utils/__mocks__/print.ts b/src/utils/__mocks__/print.ts new file mode 100644 index 0000000..af7d2a7 --- /dev/null +++ b/src/utils/__mocks__/print.ts @@ -0,0 +1,15 @@ +export function print(): void {} + +print.warning = (): void => {}; + +print.error = (): void => {}; + +print.confirmation = async (): Promise => { + return false; +}; + +print.listItem = (): void => {}; + +print.success = (): void => {}; + +print.info = (): void => {};