diff --git a/src/from-python-serialization/SocketSerialization.ts b/src/from-python-serialization/SocketSerialization.ts index 73e139ac..351a5f05 100644 --- a/src/from-python-serialization/SocketSerialization.ts +++ b/src/from-python-serialization/SocketSerialization.ts @@ -101,6 +101,7 @@ export async function serializePythonObjectUsingSocketServer( requestId, objectAsString, makeOptions(options), + socketServer.secretHex, ); logDebug('Sending code to python: ', code); logDebug('Sending request to python with reqId ', requestId); diff --git a/src/python-communication/BuildPythonCode.ts b/src/python-communication/BuildPythonCode.ts index 92a452cc..1246bd47 100644 --- a/src/python-communication/BuildPythonCode.ts +++ b/src/python-communication/BuildPythonCode.ts @@ -240,6 +240,7 @@ export function constructOpenSendAndCloseCode( request_id: number, expression: string, options?: OpenSendAndCloseOptions, + secret?: string, ): EvalCodePython> { function makeOptionsString(options: OpenSendAndCloseOptions): string { return `dict(${Object.entries(options) @@ -248,8 +249,9 @@ export function constructOpenSendAndCloseCode( .join(', ')})`; } const optionsStr = options ? makeOptionsString(options) : '{}'; + const secretArg = secret ? `, secret="${secret}"` : ''; return convertExpressionIntoValueWrappedExpression( - `${OPEN_SEND_AND_CLOSE}(${port}, ${request_id}, ${expression}, ${optionsStr})`, + `${OPEN_SEND_AND_CLOSE}(${port}, ${request_id}, ${expression}, ${optionsStr}${secretArg})`, ); } diff --git a/src/python-communication/socket-based/Server.ts b/src/python-communication/socket-based/Server.ts index 5ee9ab43..9206ec0a 100644 --- a/src/python-communication/socket-based/Server.ts +++ b/src/python-communication/socket-based/Server.ts @@ -1,10 +1,11 @@ import type { MessageChunkHeader } from './protocol'; import { Buffer } from 'node:buffer'; +import * as crypto from 'node:crypto'; import * as net from 'node:net'; import { Service } from 'typedi'; import { logDebug, logInfo, logTrace } from '../../Logging'; import { MessageChunks } from './MessageChunks'; -import { HEADER_LENGTH, MAX_MESSAGE_SIZE, splitHeaderContentRest } from './protocol'; +import { AUTH_SECRET_LENGTH, HEADER_LENGTH, MAX_MESSAGE_SIZE, splitHeaderContentRest } from './protocol'; import { RequestsManager } from './RequestsManager'; const EMPTY_BUFFER = Buffer.alloc(0); @@ -19,6 +20,7 @@ export class SocketServer { private outgoingRequestsManager: RequestsManager = new RequestsManager(); private chunksByMessageId: Map = new Map(); + private readonly secret: Buffer = crypto.randomBytes(AUTH_SECRET_LENGTH); constructor() { const options: net.ServerOpts = { @@ -65,6 +67,10 @@ export class SocketServer { return this.port; } + get secretHex(): string { + return this.secret.toString('hex'); + } + onClientConnected(socket: net.Socket): void { const outgoingRequestsManager = this.outgoingRequestsManager; const handleMessage = (header: MessageChunkHeader, data: Buffer) => { @@ -133,6 +139,33 @@ export class SocketServer { } } }; + + // Authentication state per connection + let authenticated = false; + let authBuffer: Buffer = EMPTY_BUFFER; + const serverSecret = this.secret; + + const handleAuth = (data: Buffer) => { + authBuffer = authBuffer.length > 0 ? Buffer.concat([authBuffer, data]) : data; + if (authBuffer.length < AUTH_SECRET_LENGTH) { + logTrace(`Auth: waiting for more bytes (${authBuffer.length}/${AUTH_SECRET_LENGTH})`); + return; + } + const token = authBuffer.subarray(0, AUTH_SECRET_LENGTH); + const rest = authBuffer.subarray(AUTH_SECRET_LENGTH); + authBuffer = EMPTY_BUFFER; + if (!crypto.timingSafeEqual(token, serverSecret)) { + logDebug('Socket auth failed: invalid secret'); + socket.destroy(); + return; + } + logTrace('Socket auth succeeded'); + authenticated = true; + if (rest.length > 0) { + handleData(rest); + } + }; + const makeSafe = (fn: (...args: any[]) => void) => { return (...args: any[]) => { try { @@ -145,7 +178,14 @@ export class SocketServer { }; }; - socket.on('data', makeSafe(handleData)); + socket.on('data', makeSafe((data: Buffer) => { + if (!authenticated) { + handleAuth(data); + } + else { + handleData(data); + } + })); socket.on('close', () => { logTrace('Client closed connection'); }); diff --git a/src/python-communication/socket-based/protocol.ts b/src/python-communication/socket-based/protocol.ts index f96d2da3..7a0e99d1 100644 --- a/src/python-communication/socket-based/protocol.ts +++ b/src/python-communication/socket-based/protocol.ts @@ -58,6 +58,9 @@ export const HEADER_LENGTH = Object.entries(BytesPerKey).reduce( /** Maximum allowed message size (256 MB). Prevents unbounded memory allocation from malformed or malicious headers. */ export const MAX_MESSAGE_SIZE = 256 * 1024 * 1024; +/** Length of the authentication secret in bytes. */ +export const AUTH_SECRET_LENGTH = 32; + export enum Sender { Server = 0x01, Python = 0x02, diff --git a/src/python/socket_client.py b/src/python/socket_client.py index 2e547fb5..04910834 100644 --- a/src/python/socket_client.py +++ b/src/python/socket_client.py @@ -538,10 +538,13 @@ def torch_to_numpy(tensor): return tensor.numpy() -def open_send_and_close(port, request_id, obj, options=None): +def open_send_and_close(port, request_id, obj, options=None, secret=None): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect(("localhost", port)) + if secret is not None: + s.sendall(bytes.fromhex(secret)) + try: if _Internal.is_numpy_array(obj): if _Internal.is_numpy_tensor(obj): diff --git a/tests/unit/python/test_socket_auth.py b/tests/unit/python/test_socket_auth.py new file mode 100644 index 00000000..8ab68e7d --- /dev/null +++ b/tests/unit/python/test_socket_auth.py @@ -0,0 +1,111 @@ +"""Tests for socket_client.py secret authentication parameter.""" +import socket +import struct +import threading + +# Inline the relevant function signature from socket_client.py +# We test that the secret parameter is sent as raw bytes before message data + + +def test_secret_hex_to_bytes(): + """bytes.fromhex converts hex secret to correct bytes.""" + hex_secret = "aa" * 32 # 64 hex chars = 32 bytes + raw = bytes.fromhex(hex_secret) + assert len(raw) == 32 + assert all(b == 0xAA for b in raw) + + +def test_secret_roundtrip(): + """Secret survives hex encode/decode roundtrip.""" + import os + secret = os.urandom(32) + hex_str = secret.hex() + assert len(hex_str) == 64 + recovered = bytes.fromhex(hex_str) + assert recovered == secret + + +def test_secret_sent_before_data(): + """When secret is provided, it is sent as the first 32 bytes on the socket.""" + received_data = bytearray() + server_ready = threading.Event() + + def server_thread(port_holder): + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + srv.bind(("localhost", 0)) + port_holder.append(srv.getsockname()[1]) + srv.listen(1) + server_ready.set() + conn, _ = srv.accept() + while True: + chunk = conn.recv(4096) + if not chunk: + break + received_data.extend(chunk) + conn.close() + srv.close() + + port_holder = [] + t = threading.Thread(target=server_thread, args=(port_holder,)) + t.daemon = True + t.start() + server_ready.wait(timeout=5) + + port = port_holder[0] + secret_hex = "ab" * 32 + expected_secret = bytes.fromhex(secret_hex) + test_payload = b"hello" + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("localhost", port)) + # Simulate what socket_client.py does: send secret then payload + s.sendall(expected_secret) + s.sendall(test_payload) + s.close() + + t.join(timeout=5) + + assert received_data[:32] == expected_secret, "First 32 bytes should be the secret" + assert received_data[32:] == test_payload, "Remaining bytes should be the payload" + + +def test_no_secret_sends_no_prefix(): + """When secret is None, no prefix bytes are sent.""" + received_data = bytearray() + server_ready = threading.Event() + + def server_thread(port_holder): + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + srv.bind(("localhost", 0)) + port_holder.append(srv.getsockname()[1]) + srv.listen(1) + server_ready.set() + conn, _ = srv.accept() + while True: + chunk = conn.recv(4096) + if not chunk: + break + received_data.extend(chunk) + conn.close() + srv.close() + + port_holder = [] + t = threading.Thread(target=server_thread, args=(port_holder,)) + t.daemon = True + t.start() + server_ready.wait(timeout=5) + + port = port_holder[0] + test_payload = b"hello" + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("localhost", port)) + # No secret - just payload + s.sendall(test_payload) + s.close() + + t.join(timeout=5) + + assert received_data == test_payload, "All bytes should be payload when no secret" diff --git a/tests/unit/ts/max-msg-size.test.ts b/tests/unit/ts/max-msg-size.test.ts index 324df7f3..1499a4a6 100644 --- a/tests/unit/ts/max-msg-size.test.ts +++ b/tests/unit/ts/max-msg-size.test.ts @@ -1,6 +1,6 @@ import { Buffer } from 'node:buffer'; import * as net from 'node:net'; -import { afterEach, describe, expect, it } from 'vitest'; +import { afterEach, describe, expect, it, vi } from 'vitest'; import { HEADER_LENGTH, MAX_MESSAGE_SIZE, @@ -258,8 +258,10 @@ describe('socketServer integration — MAX_MESSAGE_SIZE rejection', () => { const client = new net.Socket(); const timer = setTimeout(() => reject(new Error('timed out waiting for socket close')), 5000); client.connect(port, '127.0.0.1', () => { + // Send auth token first (required by socket server) + const secret = Buffer.from(serverInstance!.secretHex, 'hex'); const buf = buildBuffer({ messageLength: MAX_MESSAGE_SIZE + 1 }); - client.write(buf); + client.write(Buffer.concat([secret, buf])); }); client.once('close', () => { clearTimeout(timer); @@ -293,8 +295,10 @@ describe('socketServer integration — MAX_MESSAGE_SIZE rejection', () => { const client = new net.Socket(); const timer = setTimeout(() => reject(new Error('timed out')), 5000); client.connect(port, '127.0.0.1', () => { + // Send auth token first (required by socket server) + const secret = Buffer.from(serverInstance!.secretHex, 'hex'); const buf = buildBuffer({ messageLength: MAX_MESSAGE_SIZE + 1 }); - client.write(buf); + client.write(Buffer.concat([secret, buf])); }); client.once('close', () => { clearTimeout(timer); @@ -323,10 +327,13 @@ describe('socketServer integration — MAX_MESSAGE_SIZE rejection', () => { const client = new net.Socket(); const timer = setTimeout(() => reject(new Error('timed out')), 5000); client.connect(port, '127.0.0.1', () => { + // Send auth token first (required by socket server) + const secret = Buffer.from(serverInstance!.secretHex, 'hex'); // Send the full header split into two writes: first 4 bytes (messageLength), // then the rest. After the second write the server has >= HEADER_LENGTH // bytes and should detect the oversized messageLength and destroy the socket. const buf = buildBuffer({ messageLength: MAX_MESSAGE_SIZE + 1 }); + client.write(secret); client.write(buf.subarray(0, 4), () => { client.write(buf.subarray(4)); }); diff --git a/tests/unit/ts/response-timeout.test.ts b/tests/unit/ts/response-timeout.test.ts index 4a5e6612..f09d7a2c 100644 --- a/tests/unit/ts/response-timeout.test.ts +++ b/tests/unit/ts/response-timeout.test.ts @@ -154,11 +154,16 @@ describe('socketServer.onResponse — integration (real server)', () => { const clients: net.Socket[] = []; const serverSideConnections = new Set(); - async function connectClient(port: number): Promise { + async function connectClient(port: number, serverSecret: string): Promise { const client = net.createConnection({ port }); clients.push(client); await new Promise((resolve, reject) => { - client.once('connect', resolve); + client.once('connect', () => { + // Send auth token first (required by socket server) + const secret = Buffer.from(serverSecret, 'hex'); + client.write(secret); + resolve(); + }); client.once('error', reject); }); return client; @@ -207,7 +212,7 @@ describe('socketServer.onResponse — integration (real server)', () => { server.onResponse(requestId, (_header, data) => resolve(data)); }); - const client = await connectClient(server.portNumber); + const client = await connectClient(server.portNumber, server.secretHex); client.write(chunk); const received = await responsePromise; diff --git a/tests/unit/ts/socket-auth.test.ts b/tests/unit/ts/socket-auth.test.ts new file mode 100644 index 00000000..4aa6be20 --- /dev/null +++ b/tests/unit/ts/socket-auth.test.ts @@ -0,0 +1,167 @@ +import { Buffer } from 'node:buffer'; +import * as crypto from 'node:crypto'; +import * as net from 'node:net'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('node:crypto', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + timingSafeEqual: vi.fn(actual.timingSafeEqual) as typeof actual.timingSafeEqual, + }; +}); + +vi.mock('typedi', () => ({ + default: { set: vi.fn(), get: vi.fn(), has: vi.fn() }, + Service: () => (c: unknown) => c, + Inject: () => () => {}, +})); + +vi.mock('vscode-extensions-json-generator/utils', () => ({ + configUtils: { + ConfigurationGetter: () => () => ({}), + }, +})); + +const { SocketServer } = await import('../../../src/python-communication/socket-based/Server'); +const { AUTH_SECRET_LENGTH } = await import('../../../src/python-communication/socket-based/protocol'); + +describe('socketServer — shared secret authentication (S4)', () => { + describe('secretHex property', () => { + it('is a non-empty string', () => { + const server = new SocketServer(); + expect(server.secretHex).toBeTruthy(); + expect(typeof server.secretHex).toBe('string'); + }); + + it('has length AUTH_SECRET_LENGTH * 2', () => { + const server = new SocketServer(); + expect(server.secretHex).toHaveLength(AUTH_SECRET_LENGTH * 2); + }); + + it('contains only valid lowercase hex characters', () => { + const server = new SocketServer(); + expect(server.secretHex).toMatch(/^[0-9a-f]+$/); + }); + + it('different SocketServer instances have different secrets', () => { + const a = new SocketServer(); + const b = new SocketServer(); + expect(a.secretHex).not.toBe(b.secretHex); + }); + + it('returns the same secretHex on repeated accesses', () => { + const s = new SocketServer(); + expect(s.secretHex).toBe(s.secretHex); + }); + }); + + describe('integration — connection auth', () => { + let server: InstanceType; + + beforeEach(async () => { + server = new SocketServer(); + await server.start(); + }); + + afterEach(() => { + server.server.close(); + }); + + it('rejects connection with wrong secret (socket is destroyed)', async () => { + const wrongSecret = Buffer.alloc(AUTH_SECRET_LENGTH, 0xAB); + const correctSecret = Buffer.from(server.secretHex, 'hex'); + // Confirm they differ + expect(wrongSecret.equals(correctSecret)).toBe(false); + + await new Promise((resolve, reject) => { + const client = net.connect(server.portNumber, 'localhost', () => { + client.write(wrongSecret); + }); + const timeout = setTimeout(() => { + client.destroy(); + reject(new Error('Timeout: socket was not closed after bad auth')); + }, 3000); + client.on('close', () => { + clearTimeout(timeout); + resolve(); + }); + client.on('error', () => { + clearTimeout(timeout); + resolve(); // destroyed = error or close, both count + }); + }); + }); + + it('accepts connection with correct secret', async () => { + const secret = Buffer.from(server.secretHex, 'hex'); + + await new Promise((resolve, reject) => { + const client = net.connect(server.portNumber, 'localhost', () => { + client.write(secret); + // Give the server a moment to process auth, then verify socket is still open + const timeoutId = setTimeout(() => { + if (!client.destroyed) { + client.destroy(); + resolve(); + } + else { + reject(new Error('Socket was destroyed after correct secret')); + } + }, 200); + }); + client.on('error', (err) => { + clearTimeout(timeoutId); + reject(err); + }); + }); + }); + + it('accepts connection when secret arrives in two fragments', async () => { + const secret = Buffer.from(server.secretHex, 'hex'); + const firstHalf = secret.subarray(0, 16); + const secondHalf = secret.subarray(16); + + await new Promise((resolve, reject) => { + let timeoutId: ReturnType; + const client = net.connect(server.portNumber, 'localhost', () => { + client.write(firstHalf); + setTimeout(() => client.write(secondHalf), 20); + timeoutId = setTimeout(() => { + if (!client.destroyed) { + client.destroy(); + resolve(); + } + else { + reject(new Error('Socket destroyed after fragmented correct secret')); + } + }, 300); + }); + client.on('error', (err) => { + clearTimeout(timeoutId); + reject(err); + }); + }); + }, 5000); + + it('uses crypto.timingSafeEqual for comparison', async () => { + vi.mocked(crypto.timingSafeEqual).mockClear(); + const secret = Buffer.from(server.secretHex, 'hex'); + + await new Promise((resolve, reject) => { + const client = net.connect(server.portNumber, 'localhost', () => { + client.write(secret, () => { + // Give server time to process + setTimeout(() => { + client.destroy(); + resolve(); + }, 100); + }); + }); + client.on('error', reject); + }); + + expect(vi.mocked(crypto.timingSafeEqual)).toHaveBeenCalled(); + }, 5000); + }); +}); diff --git a/ts-unit.xml b/ts-unit.xml new file mode 100644 index 00000000..b4ba99c4 --- /dev/null +++ b/ts-unit.xml @@ -0,0 +1,275 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +Error: Test timed out in 5000ms. +If this is a long-running test, pass a timeout value as the last argument or configure it globally with "testTimeout". + ❯ tests/unit/ts/max-msg-size.test.ts:252:3 + + +Error: Hook timed out in 10000ms. +If this is a long-running hook, pass a timeout value as the last argument or configure it globally with "hookTimeout". + ❯ tests/unit/ts/max-msg-size.test.ts:242:3 + + + + +Error: Test timed out in 5000ms. +If this is a long-running test, pass a timeout value as the last argument or configure it globally with "testTimeout". + ❯ tests/unit/ts/max-msg-size.test.ts:281:3 + + +Error: Hook timed out in 10000ms. +If this is a long-running hook, pass a timeout value as the last argument or configure it globally with "hookTimeout". + ❯ tests/unit/ts/max-msg-size.test.ts:242:3 + + + + +Error: Test timed out in 5000ms. +If this is a long-running test, pass a timeout value as the last argument or configure it globally with "testTimeout". + ❯ tests/unit/ts/max-msg-size.test.ts:317:3 + + +Error: Hook timed out in 10000ms. +If this is a long-running hook, pass a timeout value as the last argument or configure it globally with "hookTimeout". + ❯ tests/unit/ts/max-msg-size.test.ts:242:3 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +Error: Test timed out in 5000ms. +If this is a long-running test, pass a timeout value as the last argument or configure it globally with "testTimeout". + ❯ tests/unit/ts/response-timeout.test.ts:193:3 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +