Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/example/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"pod:install:ios": "pod install --project-directory=ios",
"pod:install:macos": "pod install --project-directory=macos",
"build:android": "cd android && ./gradlew assembleDebug --warning-mode all",
"build:ios": "react-native build-ios --scheme Example --mode Debug --extra-params \"-sdk iphonesimulator CC=clang CPLUSPLUS=clang++ LD=clang LDPLUSPLUS=clang++ GCC_OPTIMIZATION_LEVEL=0 GCC_PRECOMPILE_PREFIX_HEADER=YES ASSETCATALOG_COMPILER_OPTIMIZATION=time DEBUG_INFORMATION_FORMAT=dwarf COMPILER_INDEX_STORE_ENABLE=NO\"",
"build:ios": "react-native build-ios --scheme Example --mode Debug --extra-params \"-sdk iphonesimulator CC=clang CPLUSPLUS=clang++ LD=clang LDPLUSPLUS=clang++ GCC_OPTIMIZATION_LEVEL=0 GCC_PRECOMPILE_PREFIX_HEADER=YES ASSETCATALOG_COMPILER_OPTIMIZATION=time DEBUG_INFORMATION_FORMAT=dwarf COMPILER_INDEX_STORE_ENABLE=NO CLANG_CXX_LANGUAGE_STANDARD=c++20\"",
"build:macos": "react-native build-macos --scheme Example --mode Debug",
"mkdist": "node -e \"require('node:fs').mkdirSync('dist', { recursive: true, mode: 0o755 })\"",
"postinstall": "node -e \"if (process.platform !== 'darwin') { console.log('Skipping iOS pod install on non-macOS environment.'); process.exit(0); } const { execSync } = require('child_process'); execSync('yarn pod:install:ios', { stdio: 'inherit' });\""
Expand Down
12 changes: 8 additions & 4 deletions apps/example/src/Cube/TexturedCube.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ export const TexturedCube = () => {

// Fetch the image and upload it into a GPUTexture.
let cubeTexture: GPUTexture;
{
const response = await fetchAsset(require("../assets/Di-3d.png"));
const imageBitmap = await createImageBitmap(await response.blob());
try {
const asset = await fetchAsset(require("../assets/Di-3d.png"));
const arrayBuffer = await asset.arrayBuffer();
const imageBitmap = await createImageBitmap(arrayBuffer);
cubeTexture = device.createTexture({
size: [imageBitmap.width, imageBitmap.height, 1],
format: "rgba8unorm",
Expand All @@ -116,8 +117,11 @@ export const TexturedCube = () => {
device.queue.copyExternalImageToTexture(
{ source: imageBitmap },
{ texture: cubeTexture },
[imageBitmap.width, imageBitmap.height],
[imageBitmap.width, imageBitmap.height]
);
} catch (err) {
console.error("Failed to fetch asset", err);
throw err;
}

// Create a sampler with linear filtering for smooth interpolation.
Expand Down
3 changes: 2 additions & 1 deletion packages/webgpu/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.4.1)
project(RNWGPU)

set (CMAKE_VERBOSE_MAKEFILE ON)
set (CMAKE_CXX_STANDARD 17)
set (CMAKE_CXX_STANDARD 20)
set (CMAKE_CXX_STANDARD_REQUIRED True)

set (PACKAGE_NAME "react-native-wgpu")

Expand Down
2 changes: 1 addition & 1 deletion packages/webgpu/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ android {
buildConfigField "boolean", "IS_NEW_ARCHITECTURE_ENABLED", "true"
externalNativeBuild {
cmake {
cppFlags "-fexceptions", "-frtti", "-std=c++1y", "-DONANDROID"
cppFlags "-fexceptions", "-frtti", "-DONANDROID"
abiFilters (*reactNativeArchitectures())
arguments '-DANDROID_STL=c++_shared',
"-DNODE_MODULES_DIR=${nodeModules}",
Expand Down
243 changes: 104 additions & 139 deletions packages/webgpu/android/cpp/AndroidPlatformContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,43 @@ class AndroidPlatformContext : public PlatformContext {
private:
jobject _blobModule;

std::vector<uint8_t> resolveBlob(JNIEnv *env, const std::string &blobId,
double offset, double size) {
if (!_blobModule) {
throw std::runtime_error("BlobModule instance is null");
}

jclass blobModuleClass = env->GetObjectClass(_blobModule);
if (!blobModuleClass) {
throw std::runtime_error("Couldn't find BlobModule class");
}

jmethodID resolveMethod = env->GetMethodID(blobModuleClass, "resolve",
"(Ljava/lang/String;II)[B");
env->DeleteLocalRef(blobModuleClass);

if (!resolveMethod) {
throw std::runtime_error("Couldn't find resolve method in BlobModule");
}

jstring jBlobId = env->NewStringUTF(blobId.c_str());
jbyteArray blobData = (jbyteArray)env->CallObjectMethod(
_blobModule, resolveMethod, jBlobId, static_cast<jint>(offset),
static_cast<jint>(size));
env->DeleteLocalRef(jBlobId);

if (!blobData) {
throw std::runtime_error("Couldn't retrieve blob data");
}

jsize len = env->GetArrayLength(blobData);
std::vector<uint8_t> data(len);
env->GetByteArrayRegion(blobData, 0, len,
reinterpret_cast<jbyte *>(data.data()));
env->DeleteLocalRef(blobData);
return data;
}

public:
explicit AndroidPlatformContext(jobject blobModule)
: _blobModule(blobModule) {}
Expand Down Expand Up @@ -52,188 +89,116 @@ class AndroidPlatformContext : public PlatformContext {
throw std::runtime_error("Couldn't get JNI environment");
}

// Use the BlobModule instance from _blobModule
if (!_blobModule) {
throw std::runtime_error("BlobModule instance is null");
}
auto data = resolveBlob(env, blobId, offset, size);
return createImageBitmapFromData(data);
}

// Get the resolve method ID
jclass blobModuleClass = env->GetObjectClass(_blobModule);
if (!blobModuleClass) {
throw std::runtime_error("Couldn't find BlobModule class");
}
void createImageBitmapAsync(
std::string blobId, double offset, double size,
std::function<void(ImageData)> onSuccess,
std::function<void(std::string)> onError) override {
std::thread([this, blobId = std::move(blobId), offset, size,
onSuccess = std::move(onSuccess),
onError = std::move(onError)]() {
jni::Environment::ensureCurrentThreadIsAttached();
try {
JNIEnv *env = facebook::jni::Environment::current();
if (!env) {
throw std::runtime_error("Couldn't get JNI environment");
}
auto data = resolveBlob(env, blobId, offset, size);
auto result = createImageBitmapFromData(data);
onSuccess(std::move(result));
} catch (const std::exception &e) {
onError(e.what());
}
}).detach();
}

jmethodID resolveMethod = env->GetMethodID(blobModuleClass, "resolve",
"(Ljava/lang/String;II)[B");
if (!resolveMethod) {
throw std::runtime_error("Couldn't find resolve method in BlobModule");
}
ImageData createImageBitmapFromData(std::span<const uint8_t> data) override {
jni::Environment::ensureCurrentThreadIsAttached();

// Resolve the blob data
jstring jBlobId = env->NewStringUTF(blobId.c_str());
jbyteArray blobData = (jbyteArray)env->CallObjectMethod(
_blobModule, resolveMethod, jBlobId, static_cast<jint>(offset),
static_cast<jint>(size));
env->DeleteLocalRef(jBlobId);
JNIEnv *env = facebook::jni::Environment::current();
if (!env) {
throw std::runtime_error("Couldn't get JNI environment");
}

if (!blobData) {
throw std::runtime_error("Couldn't retrieve blob data");
// Create jbyteArray from the raw bytes
jbyteArray byteArray = env->NewByteArray(static_cast<jsize>(data.size()));
if (!byteArray) {
throw std::runtime_error("Couldn't allocate byte array");
}
env->SetByteArrayRegion(byteArray, 0, static_cast<jsize>(data.size()),
reinterpret_cast<const jbyte *>(data.data()));

// Create a Bitmap from the blob data
// Decode via BitmapFactory
jclass bitmapFactoryClass =
env->FindClass("android/graphics/BitmapFactory");
if (!bitmapFactoryClass) {
env->DeleteLocalRef(byteArray);
throw std::runtime_error("Couldn't find BitmapFactory class");
}
jmethodID decodeByteArrayMethod =
env->GetStaticMethodID(bitmapFactoryClass, "decodeByteArray",
"([BII)Landroid/graphics/Bitmap;");
jint blobLength = env->GetArrayLength(blobData);
if (!decodeByteArrayMethod) {
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmapFactoryClass);
throw std::runtime_error("Couldn't find decodeByteArray method");
}
jint length = static_cast<jint>(data.size());
jobject bitmap = env->CallStaticObjectMethod(
bitmapFactoryClass, decodeByteArrayMethod, blobData, 0, blobLength);
bitmapFactoryClass, decodeByteArrayMethod, byteArray, 0, length);
env->DeleteLocalRef(bitmapFactoryClass);

if (!bitmap) {
env->DeleteLocalRef(blobData);
env->DeleteLocalRef(byteArray);
throw std::runtime_error("Couldn't decode image");
}

// Get bitmap info
AndroidBitmapInfo bitmapInfo;
if (AndroidBitmap_getInfo(env, bitmap, &bitmapInfo) !=
ANDROID_BITMAP_RESULT_SUCCESS) {
env->DeleteLocalRef(blobData);
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmap);
throw std::runtime_error("Couldn't get bitmap info");
}

// Lock the bitmap pixels
void *bitmapPixels;
if (AndroidBitmap_lockPixels(env, bitmap, &bitmapPixels) !=
ANDROID_BITMAP_RESULT_SUCCESS) {
env->DeleteLocalRef(blobData);
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmap);
throw std::runtime_error("Couldn't lock bitmap pixels");
}

// Copy the bitmap data
std::vector<uint8_t> imageData(bitmapInfo.height * bitmapInfo.stride);
memcpy(imageData.data(), bitmapPixels, imageData.size());
ImageData result;
result.width = static_cast<int>(bitmapInfo.width);
result.height = static_cast<int>(bitmapInfo.height);
result.data.resize(bitmapInfo.height * bitmapInfo.stride);
memcpy(result.data.data(), bitmapPixels, result.data.size());

// Unlock the bitmap pixels
AndroidBitmap_unlockPixels(env, bitmap);

// Clean up JNI references
env->DeleteLocalRef(blobData);
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmap);

ImageData result;
result.width = static_cast<int>(bitmapInfo.width);
result.height = static_cast<int>(bitmapInfo.height);
result.data = imageData;
return result;
}

void createImageBitmapAsync(
std::string blobId, double offset, double size,
std::function<void(ImageData)> onSuccess,
void createImageBitmapFromDataAsync(
std::span<const uint8_t> data, std::function<void(ImageData)> onSuccess,
std::function<void(std::string)> onError) override {
// Capture blobModule for the background thread
jobject blobModule = _blobModule;

// Dispatch to a background thread
std::thread([blobModule, blobId = std::move(blobId), offset, size,
std::thread([this, ownedData = std::vector<uint8_t>(data.begin(), data.end()),
onSuccess = std::move(onSuccess),
onError = std::move(onError)]() {
onError = std::move(onError)]() mutable {
jni::Environment::ensureCurrentThreadIsAttached();

JNIEnv *env = facebook::jni::Environment::current();
if (!env) {
onError("Couldn't get JNI environment");
return;
}

if (!blobModule) {
onError("BlobModule instance is null");
return;
}

// Get the resolve method ID
jclass blobModuleClass = env->GetObjectClass(blobModule);
if (!blobModuleClass) {
onError("Couldn't find BlobModule class");
return;
}

jmethodID resolveMethod = env->GetMethodID(blobModuleClass, "resolve",
"(Ljava/lang/String;II)[B");
if (!resolveMethod) {
onError("Couldn't find resolve method in BlobModule");
return;
}

// Resolve the blob data
jstring jBlobId = env->NewStringUTF(blobId.c_str());
jbyteArray blobData = (jbyteArray)env->CallObjectMethod(
blobModule, resolveMethod, jBlobId, static_cast<jint>(offset),
static_cast<jint>(size));
env->DeleteLocalRef(jBlobId);

if (!blobData) {
onError("Couldn't retrieve blob data");
return;
try {
auto result = createImageBitmapFromData(ownedData);
onSuccess(std::move(result));
} catch (const std::exception &e) {
onError(e.what());
}

// Create a Bitmap from the blob data
jclass bitmapFactoryClass =
env->FindClass("android/graphics/BitmapFactory");
jmethodID decodeByteArrayMethod =
env->GetStaticMethodID(bitmapFactoryClass, "decodeByteArray",
"([BII)Landroid/graphics/Bitmap;");
jint blobLength = env->GetArrayLength(blobData);
jobject bitmap = env->CallStaticObjectMethod(
bitmapFactoryClass, decodeByteArrayMethod, blobData, 0, blobLength);

if (!bitmap) {
env->DeleteLocalRef(blobData);
onError("Couldn't decode image");
return;
}

// Get bitmap info
AndroidBitmapInfo bitmapInfo;
if (AndroidBitmap_getInfo(env, bitmap, &bitmapInfo) !=
ANDROID_BITMAP_RESULT_SUCCESS) {
env->DeleteLocalRef(blobData);
env->DeleteLocalRef(bitmap);
onError("Couldn't get bitmap info");
return;
}

// Lock the bitmap pixels
void *bitmapPixels;
if (AndroidBitmap_lockPixels(env, bitmap, &bitmapPixels) !=
ANDROID_BITMAP_RESULT_SUCCESS) {
env->DeleteLocalRef(blobData);
env->DeleteLocalRef(bitmap);
onError("Couldn't lock bitmap pixels");
return;
}

// Copy the bitmap data
std::vector<uint8_t> imageData(bitmapInfo.height * bitmapInfo.stride);
memcpy(imageData.data(), bitmapPixels, imageData.size());

// Unlock the bitmap pixels
AndroidBitmap_unlockPixels(env, bitmap);

// Clean up JNI references
env->DeleteLocalRef(blobData);
env->DeleteLocalRef(bitmap);

ImageData result;
result.width = static_cast<int>(bitmapInfo.width);
result.height = static_cast<int>(bitmapInfo.height);
result.data = std::move(imageData);

onSuccess(std::move(result));
}).detach();
}
};
Expand Down
6 changes: 6 additions & 0 deletions packages/webgpu/apple/ApplePlatformContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class ApplePlatformContext : public PlatformContext {
std::string blobId, double offset, double size,
std::function<void(ImageData)> onSuccess,
std::function<void(std::string)> onError) override;

ImageData createImageBitmapFromData(std::span<const uint8_t> data) override;

void createImageBitmapFromDataAsync(
std::span<const uint8_t> data, std::function<void(ImageData)> onSuccess,
std::function<void(std::string)> onError) override;
};

} // namespace rnwgpu
Loading