Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion openspec/specs/api/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ const output = await add.forward([inputA, inputB]);
| Input | Shape | Description |
|-------|-------|-------------|
| inputA | any | First input tensor |
| inputB | any | Second input tensor (must match inputA shape exactly) |
| inputB | any | Second input tensor (must match inputA shape and layout exactly) |

### Output

Expand Down Expand Up @@ -669,6 +669,12 @@ Load a model definition.
**Parameters:**
- `modelDef`: Model definition with layers and weights

**Throws:**
- `Error`: If the model has no layers
- `Error`: If a layer type is unknown
- `Error`: If a layer input cannot be resolved from `input`, prior layers, or weights
- `Error`: If layer names are duplicated

#### tensorFromArray

```typescript
Expand Down
4 changes: 4 additions & 0 deletions openspec/specs/architecture/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ The system SHALL provide high-level inference orchestration.
- **WHEN** model is loaded
- **THEN** operators are mapped by type name for dynamic dispatch

#### Scenario: Model graph compilation
- **WHEN** model is loaded
- **THEN** layer names, operator types, and tensor references are validated before inference begins

#### Scenario: Intermediate cleanup
- **WHEN** inference completes
- **THEN** intermediate tensors are destroyed to free GPU memory
Expand Down
4 changes: 4 additions & 0 deletions openspec/specs/product/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ The system SHALL provide model loading and end-to-end inference.
- **WHEN** loading model definition with layers and weights
- **THEN** weights are allocated as GPU tensors

#### Scenario: Reject invalid graph definitions
- **WHEN** model definition contains duplicate layer names, unknown operators, or unresolved tensor references
- **THEN** loading fails before inference starts

#### Scenario: Run inference
- **WHEN** calling infer() with input tensor
- **THEN** output tensor is returned with correct shape
Expand Down
31 changes: 31 additions & 0 deletions openspec/specs/testing/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,27 @@ Feature: Dense Operator

---

## Feature: Add Operator

```gherkin
Feature: Add Operator
As a deep learning developer
I want to add residual tensors element-wise
So that I can express skip connections safely
```

### Scenario: Add basic execution

- **WHEN** I execute Add with two tensors of the same shape and layout
- **THEN** the output shape should equal the input shape

### Scenario: Add rejects layout mismatch

- **WHEN** I execute Add with tensors that share the same shape but use different layouts
- **THEN** it should throw an error "same layout"

---

## Feature: Flatten Operator

```gherkin
Expand Down Expand Up @@ -254,6 +275,16 @@ Feature: Inference Engine
- **WHEN** I load a model with layers and weights into initialized InferenceEngine
- **THEN** the weights should be allocated as GPU tensors

### Scenario: Reject invalid model graph at load time

- **WHEN** I load a model with duplicate layer names, unknown operator types, or missing tensor references
- **THEN** `loadModel()` should throw before inference starts

### Scenario: Preserve previous model on failed reload

- **WHEN** I load a valid model and then attempt to load an invalid replacement model
- **THEN** the previously loaded model should remain executable

### Scenario: Run inference

- **WHEN** I run inference on loaded model with correct input tensor shape
Expand Down
74 changes: 44 additions & 30 deletions src/core/GPUContext.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { WebGPUNotSupportedError, DeviceInitializationError, ShaderCompilationError } from './errors';
import { WebGPUNotSupportedError, DeviceInitializationError } from './errors';

export interface GPUContextConfig {
powerPreference?: 'low-power' | 'high-performance';
Expand All @@ -12,47 +12,51 @@ export class GPUContext {
private adapter: GPUAdapter | null = null;
private device: GPUDevice | null = null;
private _isInitialized = false;
private pendingResourceCleanup = new Set<Promise<void>>();

private trackCleanup(cleanup: Promise<void>): void {
this.pendingResourceCleanup.add(cleanup);
void cleanup.finally(() => {
this.pendingResourceCleanup.delete(cleanup);
});
}
private deferredBuffers = new Set<GPUBuffer>();

deferDestroy(buffer: GPUBuffer | null | undefined): void {
if (!buffer) return;

const cleanup = this.waitForSubmittedWork()
.then(() => {
buffer.destroy();
})
.catch(() => {
try {
buffer.destroy();
} catch {
// Ignore cleanup failures after device loss/destroy.
}
});
this.deferredBuffers.add(buffer);
}

this.trackCleanup(cleanup);
private destroyDeferredBuffers(): void {
for (const buffer of this.deferredBuffers) {
buffer.destroy();
}
this.deferredBuffers.clear();
}

async flushDeferredDestroys(): Promise<void> {
await Promise.allSettled([...this.pendingResourceCleanup]);
this.destroyDeferredBuffers();
}

async waitForSubmittedWork(): Promise<void> {
// Yield to the event loop to allow pending GPU work to complete.
// The deprecated onSubmittedWorkDone() was removed from the WebGPU spec.
// In practice, yielding briefly is sufficient for most testing scenarios.
return new Promise(resolve => setTimeout(resolve, 0));
const queue = this.getDevice().queue as GPUQueue & {
onSubmittedWorkDone?: () => Promise<void>;
};

if (typeof queue.onSubmittedWorkDone === 'function') {
await queue.onSubmittedWorkDone();
return;
}

await new Promise(resolve => setTimeout(resolve, 0));
}

async sync(): Promise<void> {
await this.waitForSubmittedWork();
let waitError: unknown;
try {
await this.waitForSubmittedWork();
} catch (error) {
waitError = error;
}

await this.flushDeferredDestroys();

if (waitError) {
throw waitError;
}
}

/**
Expand Down Expand Up @@ -213,10 +217,20 @@ export class GPUContext {
* Release all GPU resources.
*/
destroy(): void {
for (const cleanup of this.pendingResourceCleanup) {
void cleanup.catch(() => {});
const queue = this.device?.queue as (GPUQueue & {
onSubmittedWorkDone?: () => Promise<void>;
}) | undefined;

if (this.deferredBuffers.size > 0) {
if (typeof queue?.onSubmittedWorkDone === 'function') {
void queue.onSubmittedWorkDone().then(
() => this.destroyDeferredBuffers(),
() => this.destroyDeferredBuffers()
);
} else {
this.destroyDeferredBuffers();
}
}
this.pendingResourceCleanup.clear();

if (this.device) {
this.device.destroy();
Expand Down
129 changes: 78 additions & 51 deletions src/engine/InferenceEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@ import { DenseOperator } from '../operators/DenseOperator';
import { AddOperator } from '../operators/AddOperator';
import { BatchNorm2dOperator } from '../operators/BatchNorm2dOperator';
import { ModelDefinition } from './ModelLoader';
import { CompiledModel, ModelCompiler } from './ModelCompiler';

export class InferenceEngine {
private context: GPUContext;
private operators: Map<string, Operator>;
private weights: Map<string, Tensor> = new Map();
private modelDef: ModelDefinition | null = null;
private compiledModel: CompiledModel | null = null;
private readonly compiler: ModelCompiler;

constructor() {
this.context = new GPUContext();
this.operators = new Map();
this.compiler = new ModelCompiler();
}

async initialize(): Promise<void> {
Expand All @@ -39,20 +42,14 @@ export class InferenceEngine {
}

async loadModel(modelDef: ModelDefinition): Promise<void> {
this.modelDef = modelDef;
const compiledModel = this.compiler.compile(modelDef, this.operators.keys());
const nextWeights = this.materializeWeights(modelDef);

for (const tensor of this.weights.values()) {
tensor.destroy();
}
this.weights.clear();

for (const [name, weightDef] of Object.entries(modelDef.weights)) {
if (!weightDef.shape || weightDef.shape.length === 0) {
throw new Error(`Weight "${name}" is missing shape metadata`);
}
const tensor = Tensor.fromArray(this.context, weightDef.data, weightDef.shape);
this.weights.set(name, tensor);
}
this.weights = nextWeights;
this.compiledModel = compiledModel;
}

tensorFromArray(
Expand All @@ -63,70 +60,100 @@ export class InferenceEngine {
return Tensor.fromArray(this.context, data, shape, options);
}

async infer(input: Tensor): Promise<Tensor> {
if (!this.modelDef) {
throw new Error('Model not loaded');
}
if (!input.usesContext(this.context)) {
throw new Error('Input tensor must be created from the same GPUContext as the inference engine');
}
private materializeWeights(modelDef: ModelDefinition): Map<string, Tensor> {
const nextWeights = new Map<string, Tensor>();

const activations = new Map<string, Tensor>();
activations.set('input', input);

// Execute layers in order
for (const layer of this.modelDef.layers) {
const operator = this.operators.get(layer.type);
if (!operator) {
throw new Error(`Unknown operator type: ${layer.type}`);
}

// Get inputs
const inputs: Tensor[] = [];
for (const inputName of layer.inputs) {
const tensor = activations.get(inputName) ?? this.weights.get(inputName);
if (!tensor) {
throw new Error(`Missing input: ${inputName}`);
try {
for (const [name, weightDef] of Object.entries(modelDef.weights)) {
if (!weightDef.shape || weightDef.shape.length === 0) {
throw new Error(`Weight "${name}" is missing shape metadata`);
}
inputs.push(tensor);
const tensor = Tensor.fromArray(this.context, weightDef.data, weightDef.shape);
nextWeights.set(name, tensor);
}

// Execute
const output = await operator.forward(inputs, layer.params);
activations.set(layer.name, output);
}

// Return final output
const lastLayer = this.modelDef.layers[this.modelDef.layers.length - 1];
const finalOutput = activations.get(lastLayer.name);
if (!finalOutput) {
throw new Error(`Final output not found for layer: ${lastLayer.name}`);
return nextWeights;
} catch (error) {
for (const tensor of nextWeights.values()) {
tensor.destroy();
}
throw error;
}
}

// Ensure queued GPU work sees all intermediate activations before releasing them.
private async cleanupActivations(
activations: Map<string, Tensor>,
retainedBuffer: GPUBuffer | null
): Promise<void> {
await this.context.sync();

// Destroy intermediate activations to free GPU memory.
// If the final output is a view (e.g. flatten/reshape), keep any tensor sharing its buffer alive.
for (const [name, tensor] of activations.entries()) {
if (
name !== 'input' &&
name !== lastLayer.name &&
!this.weights.has(name) &&
tensor.buffer !== finalOutput.buffer
tensor.buffer !== retainedBuffer
) {
tensor.destroy();
}
}
}

async infer(input: Tensor): Promise<Tensor> {
if (!this.compiledModel) {
throw new Error('Model not loaded');
}
if (!input.usesContext(this.context)) {
throw new Error('Input tensor must be created from the same GPUContext as the inference engine');
}

return finalOutput;
const activations = new Map<string, Tensor>();
activations.set('input', input);
let retainedBuffer: GPUBuffer | null = null;

try {
// Execute layers in order
for (const layer of this.compiledModel.layers) {
const operator = this.operators.get(layer.type);
if (!operator) {
throw new Error(`Unknown operator type: ${layer.type}`);
}

// Get inputs
const inputs: Tensor[] = [];
for (const source of layer.inputs) {
const tensor = source.kind === 'weight'
? this.weights.get(source.name)
: activations.get(source.name);
if (!tensor) {
throw new Error(`Missing input: ${source.name}`);
}
inputs.push(tensor);
}

// Execute
const output = await operator.forward(inputs, layer.params);
activations.set(layer.name, output);
}
// Return final output
// Return final output
const finalOutput = activations.get(this.compiledModel.outputName);
if (!finalOutput) {
throw new Error(`Final output not found for layer: ${this.compiledModel.outputName}`);
}

retainedBuffer = finalOutput.buffer;
return finalOutput;
} finally {
await this.cleanupActivations(activations, retainedBuffer);
}
}

destroy(): void {
for (const tensor of this.weights.values()) {
tensor.destroy();
}
this.weights.clear();
this.compiledModel = null;

for (const operator of this.operators.values()) {
operator.destroy();
Expand Down
Loading
Loading