Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/from-python-serialization/SocketSerialization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export async function serializePythonObjectUsingSocketServer(
requestId,
objectAsString,
makeOptions(options),
socketServer.secretHex,
);
logDebug('Sending code to python: ', code);
logDebug('Sending request to python with reqId ', requestId);
Expand Down
4 changes: 3 additions & 1 deletion src/python-communication/BuildPythonCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ export function constructOpenSendAndCloseCode(
request_id: number,
expression: string,
options?: OpenSendAndCloseOptions,
secret?: string,
): EvalCodePython<Result<PythonObjectShape>> {
function makeOptionsString(options: OpenSendAndCloseOptions): string {
return `dict(${Object.entries(options)
Expand All @@ -248,8 +249,9 @@ export function constructOpenSendAndCloseCode(
.join(', ')})`;
}
const optionsStr = options ? makeOptionsString(options) : '{}';
const secretArg = secret ? `, secret="${secret}"` : '';
return convertExpressionIntoValueWrappedExpression(
`${OPEN_SEND_AND_CLOSE}(${port}, ${request_id}, ${expression}, ${optionsStr})`,
`${OPEN_SEND_AND_CLOSE}(${port}, ${request_id}, ${expression}, ${optionsStr}${secretArg})`,
);
}

Expand Down
44 changes: 42 additions & 2 deletions src/python-communication/socket-based/Server.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import type { MessageChunkHeader } from './protocol';
import { Buffer } from 'node:buffer';
import * as crypto from 'node:crypto';
import * as net from 'node:net';
import { Service } from 'typedi';
import { logDebug, logInfo, logTrace } from '../../Logging';
import { MessageChunks } from './MessageChunks';
import { HEADER_LENGTH, MAX_MESSAGE_SIZE, splitHeaderContentRest } from './protocol';
import { AUTH_SECRET_LENGTH, HEADER_LENGTH, MAX_MESSAGE_SIZE, splitHeaderContentRest } from './protocol';
import { RequestsManager } from './RequestsManager';

const EMPTY_BUFFER = Buffer.alloc(0);
Expand All @@ -19,6 +20,7 @@ export class SocketServer {

private outgoingRequestsManager: RequestsManager = new RequestsManager();
private chunksByMessageId: Map<number, MessageChunks> = new Map();
private readonly secret: Buffer = crypto.randomBytes(AUTH_SECRET_LENGTH);

constructor() {
const options: net.ServerOpts = {
Expand Down Expand Up @@ -65,6 +67,10 @@ export class SocketServer {
return this.port;
}

get secretHex(): string {
return this.secret.toString('hex');
}

onClientConnected(socket: net.Socket): void {
const outgoingRequestsManager = this.outgoingRequestsManager;
const handleMessage = (header: MessageChunkHeader, data: Buffer) => {
Expand Down Expand Up @@ -133,6 +139,33 @@ export class SocketServer {
}
}
};

// Authentication state per connection
let authenticated = false;
let authBuffer: Buffer = EMPTY_BUFFER;
const serverSecret = this.secret;

const handleAuth = (data: Buffer) => {
authBuffer = authBuffer.length > 0 ? Buffer.concat([authBuffer, data]) : data;
if (authBuffer.length < AUTH_SECRET_LENGTH) {
logTrace(`Auth: waiting for more bytes (${authBuffer.length}/${AUTH_SECRET_LENGTH})`);
return;
}
const token = authBuffer.subarray(0, AUTH_SECRET_LENGTH);
const rest = authBuffer.subarray(AUTH_SECRET_LENGTH);
authBuffer = EMPTY_BUFFER;
if (!crypto.timingSafeEqual(token, serverSecret)) {
logDebug('Socket auth failed: invalid secret');
socket.destroy();
return;
}
logTrace('Socket auth succeeded');
authenticated = true;
if (rest.length > 0) {
handleData(rest);
}
};

const makeSafe = (fn: (...args: any[]) => void) => {
return (...args: any[]) => {
try {
Expand All @@ -145,7 +178,14 @@ export class SocketServer {
};
};

socket.on('data', makeSafe(handleData));
socket.on('data', makeSafe((data: Buffer) => {
if (!authenticated) {
handleAuth(data);
}
else {
handleData(data);
}
}));
socket.on('close', () => {
logTrace('Client closed connection');
});
Expand Down
3 changes: 3 additions & 0 deletions src/python-communication/socket-based/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ export const HEADER_LENGTH = Object.entries(BytesPerKey).reduce(
/** Maximum allowed message size (256 MB). Prevents unbounded memory allocation from malformed or malicious headers. */
export const MAX_MESSAGE_SIZE = 256 * 1024 * 1024;

/** Length of the authentication secret in bytes. */
export const AUTH_SECRET_LENGTH = 32;

export enum Sender {
Server = 0x01,
Python = 0x02,
Expand Down
5 changes: 4 additions & 1 deletion src/python/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,13 @@ def torch_to_numpy(tensor):
return tensor.numpy()


def open_send_and_close(port, request_id, obj, options=None):
def open_send_and_close(port, request_id, obj, options=None, secret=None):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("localhost", port))

if secret is not None:
s.sendall(bytes.fromhex(secret))

try:
if _Internal.is_numpy_array(obj):
if _Internal.is_numpy_tensor(obj):
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/python/test_socket_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Tests for socket_client.py secret authentication parameter."""
import socket
import struct
import threading

# Inline the relevant function signature from socket_client.py
# We test that the secret parameter is sent as raw bytes before message data


def test_secret_hex_to_bytes():
"""bytes.fromhex converts hex secret to correct bytes."""
hex_secret = "aa" * 32 # 64 hex chars = 32 bytes
raw = bytes.fromhex(hex_secret)
assert len(raw) == 32
assert all(b == 0xAA for b in raw)


def test_secret_roundtrip():
"""Secret survives hex encode/decode roundtrip."""
import os
secret = os.urandom(32)
hex_str = secret.hex()
assert len(hex_str) == 64
recovered = bytes.fromhex(hex_str)
assert recovered == secret


def test_secret_sent_before_data():
"""When secret is provided, it is sent as the first 32 bytes on the socket."""
received_data = bytearray()
server_ready = threading.Event()

def server_thread(port_holder):
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.bind(("localhost", 0))
port_holder.append(srv.getsockname()[1])
srv.listen(1)
server_ready.set()
conn, _ = srv.accept()
while True:
chunk = conn.recv(4096)
if not chunk:
break
received_data.extend(chunk)
conn.close()
srv.close()

port_holder = []
t = threading.Thread(target=server_thread, args=(port_holder,))
t.daemon = True
t.start()
server_ready.wait(timeout=5)

port = port_holder[0]
secret_hex = "ab" * 32
expected_secret = bytes.fromhex(secret_hex)
test_payload = b"hello"

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("localhost", port))
# Simulate what socket_client.py does: send secret then payload
s.sendall(expected_secret)
s.sendall(test_payload)
s.close()

t.join(timeout=5)

assert received_data[:32] == expected_secret, "First 32 bytes should be the secret"
assert received_data[32:] == test_payload, "Remaining bytes should be the payload"


def test_no_secret_sends_no_prefix():
"""When secret is None, no prefix bytes are sent."""
received_data = bytearray()
server_ready = threading.Event()

def server_thread(port_holder):
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srv.bind(("localhost", 0))
port_holder.append(srv.getsockname()[1])
srv.listen(1)
server_ready.set()
conn, _ = srv.accept()
while True:
chunk = conn.recv(4096)
if not chunk:
break
received_data.extend(chunk)
conn.close()
srv.close()

port_holder = []
t = threading.Thread(target=server_thread, args=(port_holder,))
t.daemon = True
t.start()
server_ready.wait(timeout=5)

port = port_holder[0]
test_payload = b"hello"

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("localhost", port))
# No secret - just payload
s.sendall(test_payload)
s.close()

t.join(timeout=5)

assert received_data == test_payload, "All bytes should be payload when no secret"
13 changes: 10 additions & 3 deletions tests/unit/ts/max-msg-size.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Buffer } from 'node:buffer';
import * as net from 'node:net';
import { afterEach, describe, expect, it } from 'vitest';
import { afterEach, describe, expect, it, vi } from 'vitest';
import {
HEADER_LENGTH,
MAX_MESSAGE_SIZE,
Expand Down Expand Up @@ -258,8 +258,10 @@ describe('socketServer integration — MAX_MESSAGE_SIZE rejection', () => {
const client = new net.Socket();
const timer = setTimeout(() => reject(new Error('timed out waiting for socket close')), 5000);
client.connect(port, '127.0.0.1', () => {
// Send auth token first (required by socket server)
const secret = Buffer.from(serverInstance!.secretHex, 'hex');
const buf = buildBuffer({ messageLength: MAX_MESSAGE_SIZE + 1 });
client.write(buf);
client.write(Buffer.concat([secret, buf]));
});
client.once('close', () => {
clearTimeout(timer);
Expand Down Expand Up @@ -293,8 +295,10 @@ describe('socketServer integration — MAX_MESSAGE_SIZE rejection', () => {
const client = new net.Socket();
const timer = setTimeout(() => reject(new Error('timed out')), 5000);
client.connect(port, '127.0.0.1', () => {
// Send auth token first (required by socket server)
const secret = Buffer.from(serverInstance!.secretHex, 'hex');
const buf = buildBuffer({ messageLength: MAX_MESSAGE_SIZE + 1 });
client.write(buf);
client.write(Buffer.concat([secret, buf]));
});
client.once('close', () => {
clearTimeout(timer);
Expand Down Expand Up @@ -323,10 +327,13 @@ describe('socketServer integration — MAX_MESSAGE_SIZE rejection', () => {
const client = new net.Socket();
const timer = setTimeout(() => reject(new Error('timed out')), 5000);
client.connect(port, '127.0.0.1', () => {
// Send auth token first (required by socket server)
const secret = Buffer.from(serverInstance!.secretHex, 'hex');
// Send the full header split into two writes: first 4 bytes (messageLength),
// then the rest. After the second write the server has >= HEADER_LENGTH
// bytes and should detect the oversized messageLength and destroy the socket.
const buf = buildBuffer({ messageLength: MAX_MESSAGE_SIZE + 1 });
client.write(secret);
client.write(buf.subarray(0, 4), () => {
client.write(buf.subarray(4));
});
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/ts/response-timeout.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,16 @@ describe('socketServer.onResponse — integration (real server)', () => {
const clients: net.Socket[] = [];
const serverSideConnections = new Set<net.Socket>();

async function connectClient(port: number): Promise<net.Socket> {
async function connectClient(port: number, serverSecret: string): Promise<net.Socket> {
const client = net.createConnection({ port });
clients.push(client);
await new Promise<void>((resolve, reject) => {
client.once('connect', resolve);
client.once('connect', () => {
// Send auth token first (required by socket server)
const secret = Buffer.from(serverSecret, 'hex');
client.write(secret);
resolve();
});
client.once('error', reject);
});
return client;
Expand Down Expand Up @@ -207,7 +212,7 @@ describe('socketServer.onResponse — integration (real server)', () => {
server.onResponse(requestId, (_header, data) => resolve(data));
});

const client = await connectClient(server.portNumber);
const client = await connectClient(server.portNumber, server.secretHex);
client.write(chunk);

const received = await responsePromise;
Expand Down
Loading
Loading