-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCustomOP.cpp
More file actions
358 lines (296 loc) · 11.9 KB
/
CustomOP.cpp
File metadata and controls
358 lines (296 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
#define NS_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "metal-cpp/SingleHeader/Metal.hpp"
#include "CustomOP.metal"
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <unordered_map>
// Metal Context class to manage Metal resources
class MetalContext {
private:
static std::unique_ptr<MetalContext> instance;
// Metal resources
MTL::Device* device;
MTL::CommandQueue* commandQueue;
std::unordered_map<std::string, MTL::ComputePipelineState*> pipelineCache;
MTL::Library* library;
// Private constructor
MetalContext() {
device = MTL::CreateSystemDefaultDevice();
if (!device) {
throw std::runtime_error("Unable to create Metal device");
}
commandQueue = device->newCommandQueue();
if (!commandQueue) {
device->release();
throw std::runtime_error("Unable to create command queue");
}
// Load Metal kernel library
NS::Error* error = nullptr;
library = device->newLibrary(NS::String::string(CUSTOM_KERNEL, NS::UTF8StringEncoding), nullptr, &error);
if (!library) {
std::string errorMsg = "Unable to create kernel library";
if (error) {
errorMsg += ": " + std::string(error->localizedDescription()->utf8String());
}
commandQueue->release();
device->release();
throw std::runtime_error(errorMsg);
}
}
public:
// Prevent copying
MetalContext(const MetalContext&) = delete;
MetalContext& operator=(const MetalContext&) = delete;
// Destructor
~MetalContext() {
// Release all cached pipelines
for (auto& pair : pipelineCache) {
if (pair.second) {
pair.second->release();
}
}
if (library) {
library->release();
}
if (commandQueue) {
commandQueue->release();
}
if (device) {
device->release();
}
}
// Get singleton instance
static MetalContext& getInstance() {
if (!instance) {
instance = std::unique_ptr<MetalContext>(new MetalContext());
}
return *instance;
}
// Get device
MTL::Device* getDevice() const {
return device;
}
// Get command queue
MTL::CommandQueue* getCommandQueue() const {
return commandQueue;
}
// Get or create compute pipeline state
MTL::ComputePipelineState* getPipelineState(const std::string& functionName) {
// If pipeline is already cached, return it
if (pipelineCache.find(functionName) != pipelineCache.end()) {
return pipelineCache[functionName];
}
// Otherwise create new pipeline
NS::Error* error = nullptr;
MTL::Function* function = library->newFunction(NS::String::string(functionName.c_str(), NS::UTF8StringEncoding));
if (!function) {
throw std::runtime_error("Unable to find function: " + functionName);
}
MTL::ComputePipelineState* pipelineState = device->newComputePipelineState(function, &error);
function->release();
if (!pipelineState) {
std::string errorMsg = "Unable to create compute pipeline state: " + functionName;
if (error) {
errorMsg += ": " + std::string(error->localizedDescription()->utf8String());
}
throw std::runtime_error(errorMsg);
}
// Cache and return
pipelineCache[functionName] = pipelineState;
return pipelineState;
}
// Create command buffer and encoder
std::pair<MTL::CommandBuffer*, MTL::ComputeCommandEncoder*> createCommandBufferAndEncoder() {
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
if (!commandBuffer) {
throw std::runtime_error("Unable to create command buffer");
}
MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
if (!encoder) {
commandBuffer->release();
throw std::runtime_error("Unable to create compute command encoder");
}
return {commandBuffer, encoder};
}
};
// Initialize singleton pointer
std::unique_ptr<MetalContext> MetalContext::instance = nullptr;
// Helper function: Calculate thread groups configuration
void calculateThreadGroups(MTL::ComputePipelineState* pipelineState, int64_t numElements,
uint& threadgroups, uint& threadsPerThreadgroup) {
threadsPerThreadgroup = pipelineState->maxTotalThreadsPerThreadgroup();
threadgroups = (static_cast<uint>(numElements) + threadsPerThreadgroup - 1) / threadsPerThreadgroup;
}
// Helper function: Check tensor constraints
void checkTensorConstraints(const torch::Tensor& tensor, bool checkMPS = true, bool checkFloat = true) {
if (checkMPS) {
TORCH_CHECK(tensor.device().is_mps(), "Tensor must be an MPS tensor");
}
if (checkFloat) {
TORCH_CHECK(tensor.dtype() == torch::kFloat, "Tensor must be of float type");
}
}
// Helper function: Create Metal buffer
MTL::Buffer* createMetalBuffer(MTL::Device* device, size_t size, MTL::ResourceOptions options = MTL::ResourceStorageModeShared) {
MTL::Buffer* buffer = device->newBuffer(size, options);
if (!buffer) {
throw std::runtime_error("Unable to create Metal buffer");
}
return buffer;
}
// Helper function: Copy tensor data to Metal buffer
void copyTensorToBuffer(const torch::Tensor& tensor, MTL::Buffer* buffer, int64_t numElements) {
torch::Tensor tensor_cpu = tensor.to(torch::kCPU);
float* bufferPtr = static_cast<float*>(buffer->contents());
std::memcpy(bufferPtr, tensor_cpu.data_ptr<float>(), numElements * sizeof(float));
}
// Helper function: Execute Metal computation
torch::Tensor executeMetalComputation(
const std::string& kernelName,
std::vector<MTL::Buffer*> buffers,
torch::IntArrayRef shape,
int64_t numElements,
const torch::Device& device) {
try {
// Get Metal context and pipeline state
MetalContext& context = MetalContext::getInstance();
MTL::ComputePipelineState* pipelineState = context.getPipelineState(kernelName);
// Create command buffer and encoder
auto [commandBuffer, encoder] = context.createCommandBufferAndEncoder();
// Set pipeline and buffers
encoder->setComputePipelineState(pipelineState);
for (size_t i = 0; i < buffers.size(); i++) {
encoder->setBuffer(buffers[i], 0, i);
}
// Calculate and set thread groups configuration
uint threadgroups, threadsPerThreadgroup;
calculateThreadGroups(pipelineState, numElements, threadgroups, threadsPerThreadgroup);
encoder->dispatchThreadgroups(MTL::Size(threadgroups, 1, 1), MTL::Size(threadsPerThreadgroup, 1, 1));
encoder->endEncoding();
// Run the command buffer
commandBuffer->commit();
commandBuffer->waitUntilCompleted();
// Copy result back into a new MPS tensor (assuming output is at index 2 for binary ops, index 0 for unary ops)
MTL::Buffer* outputBuffer = (kernelName == "custom_fill") ? buffers[0] : buffers[2];
torch::Tensor outputTensor = torch::from_blob(
outputBuffer->contents(), shape, torch::kFloat)
.clone()
.to(device);
// Release Metal resources
for (auto* buffer : buffers) {
buffer->release();
}
encoder->release();
commandBuffer->release();
return outputTensor;
} catch (const std::exception& e) {
// Clean up buffers on error
for (auto* buffer : buffers) {
buffer->release();
}
std::cerr << "Metal execution error: " << e.what() << std::endl;
throw; // Re-throw the exception
}
}
// Custom add operation implementation
torch::Tensor custom_add(torch::Tensor input1, torch::Tensor input2)
{
// Validate inputs
checkTensorConstraints(input1);
checkTensorConstraints(input2);
TORCH_CHECK(input1.sizes() == input2.sizes(), "Input tensors must have the same shape");
auto shape = input1.sizes();
int64_t numElements = input1.numel();
try {
// Get Metal context
MetalContext& context = MetalContext::getInstance();
MTL::Device* device = context.getDevice();
// Allocate buffers
MTL::Buffer* in1Buffer = createMetalBuffer(device, numElements * sizeof(float));
MTL::Buffer* in2Buffer = createMetalBuffer(device, numElements * sizeof(float));
MTL::Buffer* outBuffer = createMetalBuffer(device, numElements * sizeof(float));
MTL::Buffer* sizeBuffer = createMetalBuffer(device, sizeof(uint));
// Copy input data into the Metal buffers
copyTensorToBuffer(input1, in1Buffer, numElements);
copyTensorToBuffer(input2, in2Buffer, numElements);
// Set the size
uint* sizePtr = static_cast<uint*>(sizeBuffer->contents());
*sizePtr = static_cast<uint>(numElements);
// Execute computation
std::vector<MTL::Buffer*> buffers = {in1Buffer, in2Buffer, outBuffer, sizeBuffer};
return executeMetalComputation("custom_add", buffers, shape, numElements, input1.device());
} catch (const std::exception& e) {
std::cerr << "Metal execution error: " << e.what() << std::endl;
throw; // Re-throw the exception
}
}
// Custom fill operation implementation
torch::Tensor custom_fill(torch::Tensor input, float fill_val)
{
// Ensure the input tensor is on the MPS device and is float type
checkTensorConstraints(input);
auto shape = input.sizes();
int64_t numElements = input.numel();
TORCH_CHECK(numElements > 0, "Input tensor must have at least one element");
try {
// Get Metal context
MetalContext& context = MetalContext::getInstance();
MTL::Device* device = context.getDevice();
// Create Metal buffers
MTL::Buffer* outputBuffer = createMetalBuffer(device, numElements * sizeof(float));
MTL::Buffer* fillValBuffer = createMetalBuffer(device, sizeof(float));
MTL::Buffer* sizeBuffer = createMetalBuffer(device, sizeof(uint));
// Initialize the fill value
float* fillValPtr = static_cast<float*>(fillValBuffer->contents());
*fillValPtr = fill_val;
// Set the size buffer
uint* sizePtr = static_cast<uint*>(sizeBuffer->contents());
*sizePtr = static_cast<uint>(numElements);
// Execute computation
std::vector<MTL::Buffer*> buffers = {outputBuffer, fillValBuffer, sizeBuffer};
return executeMetalComputation("custom_fill", buffers, shape, numElements, input.device());
} catch (const std::exception& e) {
std::cerr << "Metal execution error: " << e.what() << std::endl;
throw; // Re-throw the exception
}
}
// Custom multiply operation implementation
torch::Tensor custom_multiply(torch::Tensor input1, torch::Tensor input2)
{
// Validate inputs
checkTensorConstraints(input1);
checkTensorConstraints(input2);
TORCH_CHECK(input1.sizes() == input2.sizes(), "Input tensors must have the same shape");
auto shape = input1.sizes();
int64_t numElements = input1.numel();
try {
// Get Metal context
MetalContext& context = MetalContext::getInstance();
MTL::Device* device = context.getDevice();
// Allocate buffers
MTL::Buffer* in1Buffer = createMetalBuffer(device, numElements * sizeof(float));
MTL::Buffer* in2Buffer = createMetalBuffer(device, numElements * sizeof(float));
MTL::Buffer* outBuffer = createMetalBuffer(device, numElements * sizeof(float));
MTL::Buffer* sizeBuffer = createMetalBuffer(device, sizeof(uint));
// Copy input data into the Metal buffers
copyTensorToBuffer(input1, in1Buffer, numElements);
copyTensorToBuffer(input2, in2Buffer, numElements);
// Set the size
uint* sizePtr = static_cast<uint*>(sizeBuffer->contents());
*sizePtr = static_cast<uint>(numElements);
// Execute computation
std::vector<MTL::Buffer*> buffers = {in1Buffer, in2Buffer, outBuffer, sizeBuffer};
return executeMetalComputation("custom_multiply", buffers, shape, numElements, input1.device());
} catch (const std::exception& e) {
std::cerr << "Metal execution error: " << e.what() << std::endl;
throw; // Re-throw the exception
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("custom_fill", &custom_fill, "Custom fill function for MPS");
m.def("custom_add", &custom_add, "Custom add function for MPS");
m.def("custom_multiply", &custom_multiply, "Custom multiply function for MPS");
}