|
| 1 | +package io.github.bahaa.webgpu.samples; |
| 2 | + |
| 3 | +import io.github.bahaa.webgpu.api.*; |
| 4 | +import io.github.bahaa.webgpu.api.model.*; |
| 5 | + |
| 6 | +import java.lang.foreign.MemorySegment; |
| 7 | +import java.util.List; |
| 8 | + |
| 9 | +public class ComputeBoids extends SampleBase { |
| 10 | + |
| 11 | + private static final int NUM_PARTICLES = 1500; |
| 12 | + |
| 13 | + private final SimParams simParams = new SimParams(); |
| 14 | + private final Buffer[] particleBuffers = new Buffer[2]; |
| 15 | + private final BindGroup[] particleBindGroups = new BindGroup[2]; |
| 16 | + private int time = 0; |
| 17 | + private Buffer simParamBuffer; |
| 18 | + private ComputePipeline computePipeline; |
| 19 | + private RenderPipeline renderPipeline; |
| 20 | + private Buffer spriteVertexBuffer; |
| 21 | + |
| 22 | + static void main(final String... args) { |
| 23 | + new ComputeBoids().run(args); |
| 24 | + } |
| 25 | + |
| 26 | + @Override |
| 27 | + protected void setup(final Device device, final Queue queue) { |
| 28 | + final var renderShaderModule = loadShader(device, "wgsl/compute-boids-render.wgsl"); |
| 29 | + |
| 30 | + this.renderPipeline = device.createRenderPipeline(RenderPipelineDescriptor.builder() |
| 31 | + .vertex(builder -> builder |
| 32 | + .module(renderShaderModule) |
| 33 | + .entryPoint("vert_main") |
| 34 | + // instanced particles buffer |
| 35 | + .addBuffer(VertexBufferLayout.builder() |
| 36 | + .arrayStride(4 * 4) |
| 37 | + .stepMode(VertexStepMode.INSTANCE) |
| 38 | + .addAttribute(VertexAttribute.builder() |
| 39 | + .shaderLocation(0) |
| 40 | + .offset(0) |
| 41 | + .format(VertexFormat.FLOAT32X2) |
| 42 | + .build()) |
| 43 | + .addAttribute(VertexAttribute.builder() |
| 44 | + .shaderLocation(1) |
| 45 | + .offset(2 * 4) |
| 46 | + .format(VertexFormat.FLOAT32X2) |
| 47 | + .build()) |
| 48 | + .build()) |
| 49 | + // vertex buffer |
| 50 | + .addBuffer(VertexBufferLayout.builder() |
| 51 | + .arrayStride(2 * 4) |
| 52 | + .stepMode(VertexStepMode.VERTEX) |
| 53 | + .addAttribute(VertexAttribute.builder() |
| 54 | + .shaderLocation(2) |
| 55 | + .offset(0) |
| 56 | + .format(VertexFormat.FLOAT32X2) |
| 57 | + .build()) |
| 58 | + .build()) |
| 59 | + ) |
| 60 | + .fragment(builder -> builder |
| 61 | + .module(renderShaderModule) |
| 62 | + .entryPoint("frag_main") |
| 63 | + .addTarget(ColorTargetState.builder() |
| 64 | + .format(getPreferredFormat()) |
| 65 | + .build()) |
| 66 | + ) |
| 67 | + .primitive(builder -> builder |
| 68 | + .topology(PrimitiveTopology.TRIANGLE_LIST)) |
| 69 | + .build()); |
| 70 | + |
| 71 | + this.computePipeline = device.createComputePipeline(ComputePipelineDescriptor.builder() |
| 72 | + .compute(builder -> builder |
| 73 | + .module(loadShader(device, "wgsl/compute-boids-compute.wgsl")) |
| 74 | + .entryPoint("main") |
| 75 | + ) |
| 76 | + .build()); |
| 77 | + |
| 78 | + final var vertexBufferData = new float[]{ |
| 79 | + -0.01f, -0.02f, 0.01f, |
| 80 | + -0.02f, 0.0f, 0.02f, |
| 81 | + }; |
| 82 | + |
| 83 | + this.spriteVertexBuffer = device.createBuffer(BufferDescriptor.builder() |
| 84 | + .size((long) vertexBufferData.length * Float.BYTES) |
| 85 | + .addUsage(BufferUsage.VERTEX) |
| 86 | + .mappedAtCreation(true) |
| 87 | + .build()); |
| 88 | + |
| 89 | + this.spriteVertexBuffer.getMappedRange().copyFrom(MemorySegment.ofArray(vertexBufferData)); |
| 90 | + this.spriteVertexBuffer.unmap(); |
| 91 | + |
| 92 | + this.simParamBuffer = device.createBuffer(BufferDescriptor.builder() |
| 93 | + .size(SimParams.BYTE_SIZE) |
| 94 | + .addUsage(BufferUsage.UNIFORM) |
| 95 | + .addUsage(BufferUsage.COPY_DST) |
| 96 | + .build()); |
| 97 | + updateSimParams(queue); |
| 98 | + |
| 99 | + final var initialParticleData = new float[NUM_PARTICLES * 4]; |
| 100 | + for (var i = 0; i < NUM_PARTICLES; ++i) { |
| 101 | + initialParticleData[4 * i] = 2 * (float) (Math.random() - 0.5); |
| 102 | + initialParticleData[4 * i + 1] = 2 * (float) (Math.random() - 0.5); |
| 103 | + initialParticleData[4 * i + 2] = 2 * (float) (Math.random() - 0.5) * 0.1f; |
| 104 | + initialParticleData[4 * i + 3] = 2 * (float) (Math.random() - 0.5) * 0.1f; |
| 105 | + } |
| 106 | + |
| 107 | + for (var i = 0; i < 2; i++) { |
| 108 | + this.particleBuffers[i] = device.createBuffer(BufferDescriptor.builder() |
| 109 | + .size((long) initialParticleData.length * Float.BYTES) |
| 110 | + .addUsage(BufferUsage.VERTEX) |
| 111 | + .addUsage(BufferUsage.STORAGE) |
| 112 | + .mappedAtCreation(true) |
| 113 | + .build()); |
| 114 | + this.particleBuffers[i].getMappedRange().copyFrom(MemorySegment.ofArray(initialParticleData)); |
| 115 | + this.particleBuffers[i].unmap(); |
| 116 | + } |
| 117 | + |
| 118 | + for (var i = 0; i < 2; ++i) { |
| 119 | + final var index = i; |
| 120 | + this.particleBindGroups[i] = device.createBindGroup(BindGroupDescriptor.builder() |
| 121 | + .layout(this.computePipeline.getBindGroupLayout(0)) |
| 122 | + .addEntry(builder -> builder |
| 123 | + .binding(0) |
| 124 | + .buffer(this.simParamBuffer) |
| 125 | + .size(this.simParamBuffer.size()) |
| 126 | + ) |
| 127 | + .addEntry(builder -> builder |
| 128 | + .binding(1) |
| 129 | + .buffer(this.particleBuffers[index]) |
| 130 | + .offset(0) |
| 131 | + .size((long) initialParticleData.length * Float.BYTES) |
| 132 | + ) |
| 133 | + .addEntry(builder -> builder |
| 134 | + .binding(2) |
| 135 | + .buffer(this.particleBuffers[(index + 1) % 2]) |
| 136 | + .offset(0) |
| 137 | + .size((long) initialParticleData.length * Float.BYTES) |
| 138 | + ) |
| 139 | + .build()); |
| 140 | + } |
| 141 | + } |
| 142 | + |
| 143 | + @Override |
| 144 | + protected void render(final Device device, final Queue queue, final Surface surface, final Texture texture) { |
| 145 | + final var targetView = texture.createView(); |
| 146 | + |
| 147 | + final var encoder = device.createCommandEncoder(CommandEncoderDescriptor.builder() |
| 148 | + .build()); |
| 149 | + |
| 150 | + final var computePass = encoder.beginComputePass(ComputePassDescriptor.create()); |
| 151 | + computePass.setPipeline(this.computePipeline); |
| 152 | + computePass.setBindGroup(0, this.particleBindGroups[this.time % 2]); |
| 153 | + computePass.dispatchWorkgroups(Math.ceilDiv(NUM_PARTICLES, 64)); |
| 154 | + computePass.end(); |
| 155 | + |
| 156 | + final var renderPass = encoder.beginRenderPass(RenderPassDescriptor.builder() |
| 157 | + .addColorAttachment(builder -> builder |
| 158 | + .view(targetView) |
| 159 | + .loadOp(LoadOp.CLEAR) |
| 160 | + .storeOp(StoreOp.STORE) |
| 161 | + .clearValue(Color.rgba(0, 0, 0, 1))) |
| 162 | + .build()); |
| 163 | + renderPass.setPipeline(this.renderPipeline); |
| 164 | + renderPass.setVertexBuffer(0, this.particleBuffers[(this.time + 1) % 2]); |
| 165 | + renderPass.setVertexBuffer(1, this.spriteVertexBuffer); |
| 166 | + renderPass.draw(3, NUM_PARTICLES, 0, 0); |
| 167 | + renderPass.end(); |
| 168 | + |
| 169 | + queue.submit(List.of(encoder.finish())); |
| 170 | + surface.present(); |
| 171 | + |
| 172 | + this.time++; |
| 173 | + } |
| 174 | + |
| 175 | + @Override |
| 176 | + protected String title() { |
| 177 | + return "WebGPU Compute Boids"; |
| 178 | + } |
| 179 | + |
| 180 | + private void updateSimParams(final Queue queue) { |
| 181 | + queue.writeBuffer(this.simParamBuffer, 0, new float[]{ |
| 182 | + this.simParams.deltaT, |
| 183 | + this.simParams.rule1Distance, |
| 184 | + this.simParams.rule2Distance, |
| 185 | + this.simParams.rule3Distance, |
| 186 | + this.simParams.rule1Scale, |
| 187 | + this.simParams.rule2Scale, |
| 188 | + this.simParams.rule3Scale, |
| 189 | + }); |
| 190 | + } |
| 191 | + |
| 192 | + private static class SimParams { |
| 193 | + public static final long BYTE_SIZE = 7 * Float.BYTES; |
| 194 | + |
| 195 | + float deltaT = 0.04f; |
| 196 | + float rule1Distance = 0.1f; |
| 197 | + float rule2Distance = 0.025f; |
| 198 | + float rule3Distance = 0.025f; |
| 199 | + float rule1Scale = 0.02f; |
| 200 | + float rule2Scale = 0.05f; |
| 201 | + float rule3Scale = 0.005f; |
| 202 | + } |
| 203 | +} |
0 commit comments