Skip to content

Commit f75bc33

Browse files
committed
Add compute boids sample
1 parent b6fded3 commit f75bc33

File tree

3 files changed

+317
-0
lines changed

3 files changed

+317
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
package io.github.bahaa.webgpu.samples;
2+
3+
import io.github.bahaa.webgpu.api.*;
4+
import io.github.bahaa.webgpu.api.model.*;
5+
6+
import java.lang.foreign.MemorySegment;
7+
import java.util.List;
8+
9+
public class ComputeBoids extends SampleBase {
10+
11+
private static final int NUM_PARTICLES = 1500;
12+
13+
private final SimParams simParams = new SimParams();
14+
private final Buffer[] particleBuffers = new Buffer[2];
15+
private final BindGroup[] particleBindGroups = new BindGroup[2];
16+
private int time = 0;
17+
private Buffer simParamBuffer;
18+
private ComputePipeline computePipeline;
19+
private RenderPipeline renderPipeline;
20+
private Buffer spriteVertexBuffer;
21+
22+
static void main(final String... args) {
23+
new ComputeBoids().run(args);
24+
}
25+
26+
@Override
27+
protected void setup(final Device device, final Queue queue) {
28+
final var renderShaderModule = loadShader(device, "wgsl/compute-boids-render.wgsl");
29+
30+
this.renderPipeline = device.createRenderPipeline(RenderPipelineDescriptor.builder()
31+
.vertex(builder -> builder
32+
.module(renderShaderModule)
33+
.entryPoint("vert_main")
34+
// instanced particles buffer
35+
.addBuffer(VertexBufferLayout.builder()
36+
.arrayStride(4 * 4)
37+
.stepMode(VertexStepMode.INSTANCE)
38+
.addAttribute(VertexAttribute.builder()
39+
.shaderLocation(0)
40+
.offset(0)
41+
.format(VertexFormat.FLOAT32X2)
42+
.build())
43+
.addAttribute(VertexAttribute.builder()
44+
.shaderLocation(1)
45+
.offset(2 * 4)
46+
.format(VertexFormat.FLOAT32X2)
47+
.build())
48+
.build())
49+
// vertex buffer
50+
.addBuffer(VertexBufferLayout.builder()
51+
.arrayStride(2 * 4)
52+
.stepMode(VertexStepMode.VERTEX)
53+
.addAttribute(VertexAttribute.builder()
54+
.shaderLocation(2)
55+
.offset(0)
56+
.format(VertexFormat.FLOAT32X2)
57+
.build())
58+
.build())
59+
)
60+
.fragment(builder -> builder
61+
.module(renderShaderModule)
62+
.entryPoint("frag_main")
63+
.addTarget(ColorTargetState.builder()
64+
.format(getPreferredFormat())
65+
.build())
66+
)
67+
.primitive(builder -> builder
68+
.topology(PrimitiveTopology.TRIANGLE_LIST))
69+
.build());
70+
71+
this.computePipeline = device.createComputePipeline(ComputePipelineDescriptor.builder()
72+
.compute(builder -> builder
73+
.module(loadShader(device, "wgsl/compute-boids-compute.wgsl"))
74+
.entryPoint("main")
75+
)
76+
.build());
77+
78+
final var vertexBufferData = new float[]{
79+
-0.01f, -0.02f, 0.01f,
80+
-0.02f, 0.0f, 0.02f,
81+
};
82+
83+
this.spriteVertexBuffer = device.createBuffer(BufferDescriptor.builder()
84+
.size((long) vertexBufferData.length * Float.BYTES)
85+
.addUsage(BufferUsage.VERTEX)
86+
.mappedAtCreation(true)
87+
.build());
88+
89+
this.spriteVertexBuffer.getMappedRange().copyFrom(MemorySegment.ofArray(vertexBufferData));
90+
this.spriteVertexBuffer.unmap();
91+
92+
this.simParamBuffer = device.createBuffer(BufferDescriptor.builder()
93+
.size(SimParams.BYTE_SIZE)
94+
.addUsage(BufferUsage.UNIFORM)
95+
.addUsage(BufferUsage.COPY_DST)
96+
.build());
97+
updateSimParams(queue);
98+
99+
final var initialParticleData = new float[NUM_PARTICLES * 4];
100+
for (var i = 0; i < NUM_PARTICLES; ++i) {
101+
initialParticleData[4 * i] = 2 * (float) (Math.random() - 0.5);
102+
initialParticleData[4 * i + 1] = 2 * (float) (Math.random() - 0.5);
103+
initialParticleData[4 * i + 2] = 2 * (float) (Math.random() - 0.5) * 0.1f;
104+
initialParticleData[4 * i + 3] = 2 * (float) (Math.random() - 0.5) * 0.1f;
105+
}
106+
107+
for (var i = 0; i < 2; i++) {
108+
this.particleBuffers[i] = device.createBuffer(BufferDescriptor.builder()
109+
.size((long) initialParticleData.length * Float.BYTES)
110+
.addUsage(BufferUsage.VERTEX)
111+
.addUsage(BufferUsage.STORAGE)
112+
.mappedAtCreation(true)
113+
.build());
114+
this.particleBuffers[i].getMappedRange().copyFrom(MemorySegment.ofArray(initialParticleData));
115+
this.particleBuffers[i].unmap();
116+
}
117+
118+
for (var i = 0; i < 2; ++i) {
119+
final var index = i;
120+
this.particleBindGroups[i] = device.createBindGroup(BindGroupDescriptor.builder()
121+
.layout(this.computePipeline.getBindGroupLayout(0))
122+
.addEntry(builder -> builder
123+
.binding(0)
124+
.buffer(this.simParamBuffer)
125+
.size(this.simParamBuffer.size())
126+
)
127+
.addEntry(builder -> builder
128+
.binding(1)
129+
.buffer(this.particleBuffers[index])
130+
.offset(0)
131+
.size((long) initialParticleData.length * Float.BYTES)
132+
)
133+
.addEntry(builder -> builder
134+
.binding(2)
135+
.buffer(this.particleBuffers[(index + 1) % 2])
136+
.offset(0)
137+
.size((long) initialParticleData.length * Float.BYTES)
138+
)
139+
.build());
140+
}
141+
}
142+
143+
@Override
144+
protected void render(final Device device, final Queue queue, final Surface surface, final Texture texture) {
145+
final var targetView = texture.createView();
146+
147+
final var encoder = device.createCommandEncoder(CommandEncoderDescriptor.builder()
148+
.build());
149+
150+
final var computePass = encoder.beginComputePass(ComputePassDescriptor.create());
151+
computePass.setPipeline(this.computePipeline);
152+
computePass.setBindGroup(0, this.particleBindGroups[this.time % 2]);
153+
computePass.dispatchWorkgroups(Math.ceilDiv(NUM_PARTICLES, 64));
154+
computePass.end();
155+
156+
final var renderPass = encoder.beginRenderPass(RenderPassDescriptor.builder()
157+
.addColorAttachment(builder -> builder
158+
.view(targetView)
159+
.loadOp(LoadOp.CLEAR)
160+
.storeOp(StoreOp.STORE)
161+
.clearValue(Color.rgba(0, 0, 0, 1)))
162+
.build());
163+
renderPass.setPipeline(this.renderPipeline);
164+
renderPass.setVertexBuffer(0, this.particleBuffers[(this.time + 1) % 2]);
165+
renderPass.setVertexBuffer(1, this.spriteVertexBuffer);
166+
renderPass.draw(3, NUM_PARTICLES, 0, 0);
167+
renderPass.end();
168+
169+
queue.submit(List.of(encoder.finish()));
170+
surface.present();
171+
172+
this.time++;
173+
}
174+
175+
@Override
176+
protected String title() {
177+
return "WebGPU Compute Boids";
178+
}
179+
180+
private void updateSimParams(final Queue queue) {
181+
queue.writeBuffer(this.simParamBuffer, 0, new float[]{
182+
this.simParams.deltaT,
183+
this.simParams.rule1Distance,
184+
this.simParams.rule2Distance,
185+
this.simParams.rule3Distance,
186+
this.simParams.rule1Scale,
187+
this.simParams.rule2Scale,
188+
this.simParams.rule3Scale,
189+
});
190+
}
191+
192+
private static class SimParams {
193+
public static final long BYTE_SIZE = 7 * Float.BYTES;
194+
195+
float deltaT = 0.04f;
196+
float rule1Distance = 0.1f;
197+
float rule2Distance = 0.025f;
198+
float rule3Distance = 0.025f;
199+
float rule1Scale = 0.02f;
200+
float rule2Scale = 0.05f;
201+
float rule3Scale = 0.005f;
202+
}
203+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
struct Particle {
2+
pos : vec2f,
3+
vel : vec2f,
4+
}
5+
struct SimParams {
6+
deltaT : f32,
7+
rule1Distance : f32,
8+
rule2Distance : f32,
9+
rule3Distance : f32,
10+
rule1Scale : f32,
11+
rule2Scale : f32,
12+
rule3Scale : f32,
13+
}
14+
struct Particles {
15+
particles : array<Particle>,
16+
}
17+
@binding(0) @group(0) var<uniform> params : SimParams;
18+
@binding(1) @group(0) var<storage, read> particlesA : Particles;
19+
@binding(2) @group(0) var<storage, read_write> particlesB : Particles;
20+
21+
// https://github.com/austinEng/Project6-Vulkan-Flocking/blob/master/data/shaders/computeparticles/particle.comp
22+
@compute @workgroup_size(64)
23+
fn main(@builtin(global_invocation_id) GlobalInvocationID : vec3u) {
24+
var index = GlobalInvocationID.x;
25+
26+
var vPos = particlesA.particles[index].pos;
27+
var vVel = particlesA.particles[index].vel;
28+
var cMass = vec2(0.0);
29+
var cVel = vec2(0.0);
30+
var colVel = vec2(0.0);
31+
var cMassCount = 0u;
32+
var cVelCount = 0u;
33+
var pos : vec2f;
34+
var vel : vec2f;
35+
36+
for (var i = 0u; i < arrayLength(&particlesA.particles); i++) {
37+
if (i == index) {
38+
continue;
39+
}
40+
41+
pos = particlesA.particles[i].pos.xy;
42+
vel = particlesA.particles[i].vel.xy;
43+
if (distance(pos, vPos) < params.rule1Distance) {
44+
cMass += pos;
45+
cMassCount++;
46+
}
47+
if (distance(pos, vPos) < params.rule2Distance) {
48+
colVel -= pos - vPos;
49+
}
50+
if (distance(pos, vPos) < params.rule3Distance) {
51+
cVel += vel;
52+
cVelCount++;
53+
}
54+
}
55+
if (cMassCount > 0) {
56+
cMass = (cMass / vec2(f32(cMassCount))) - vPos;
57+
}
58+
if (cVelCount > 0) {
59+
cVel /= f32(cVelCount);
60+
}
61+
vVel += (cMass * params.rule1Scale) + (colVel * params.rule2Scale) + (cVel * params.rule3Scale);
62+
63+
// clamp velocity for a more pleasing simulation
64+
vVel = normalize(vVel) * clamp(length(vVel), 0.0, 0.1);
65+
// kinematic update
66+
vPos = vPos + (vVel * params.deltaT);
67+
// Wrap around boundary
68+
if (vPos.x < -1.0) {
69+
vPos.x = 1.0;
70+
}
71+
if (vPos.x > 1.0) {
72+
vPos.x = -1.0;
73+
}
74+
if (vPos.y < -1.0) {
75+
vPos.y = 1.0;
76+
}
77+
if (vPos.y > 1.0) {
78+
vPos.y = -1.0;
79+
}
80+
// Write back
81+
particlesB.particles[index].pos = vPos;
82+
particlesB.particles[index].vel = vVel;
83+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
struct VertexOutput {
2+
@builtin(position) position : vec4f,
3+
@location(4) color : vec4f,
4+
}
5+
6+
@vertex
7+
fn vert_main(
8+
@location(0) a_particlePos : vec2f,
9+
@location(1) a_particleVel : vec2f,
10+
@location(2) a_pos : vec2f
11+
) -> VertexOutput {
12+
let angle = -atan2(a_particleVel.x, a_particleVel.y);
13+
let pos = vec2(
14+
(a_pos.x * cos(angle)) - (a_pos.y * sin(angle)),
15+
(a_pos.x * sin(angle)) + (a_pos.y * cos(angle))
16+
);
17+
18+
var output : VertexOutput;
19+
output.position = vec4(pos + a_particlePos, 0.0, 1.0);
20+
output.color = vec4(
21+
1.0 - sin(angle + 1.0) - a_particleVel.y,
22+
pos.x * 100.0 - a_particleVel.y + 0.1,
23+
a_particleVel.x + cos(angle + 0.5),
24+
1.0);
25+
return output;
26+
}
27+
28+
@fragment
29+
fn frag_main(@location(4) color : vec4f) -> @location(0) vec4f {
30+
return color;
31+
}

0 commit comments

Comments
 (0)