Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ add_library(grilly_core_lib STATIC
cpp/src/ops/moe_forward.cpp
cpp/src/ops/vsa_lm_forward.cpp
cpp/src/ops/prefix_scan.cpp
cpp/src/ops/mingru.cpp
cpp/src/ops/bandit.cpp
cpp/src/ops/eggroll.cpp
cpp/src/shader_fusion.cpp
# ── Experimental ──
cpp/src/experimental/paged_latent_pool.cpp
Expand Down Expand Up @@ -312,6 +315,9 @@ pybind11_add_module(grilly_core
cpp/python/bindings_vsa_lm.cpp
cpp/python/bindings_grl.cpp
cpp/python/bindings_prefix_scan.cpp
cpp/python/bindings_mingru.cpp
cpp/python/bindings_bandit.cpp
cpp/python/bindings_eggroll.cpp
)
target_link_libraries(grilly_core PRIVATE grilly_core_lib)

Expand Down
157 changes: 157 additions & 0 deletions backend/_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2467,3 +2467,160 @@ def blockcode_similarity(query, codebook, num_blocks, block_size):
)
sims[entry_idx] = dot_sum / num_blocks
return sims


# ── VSA Core (Bitpacked / Full-Vector) ───────────────────────────────────


def vsa_bind(a, b):
"""GPU VSA cyclic binding (circular convolution)."""
dev = _get_device()
if dev is None:
return None
try:
return _core.vsa_bind(dev, _ensure_f32_contiguous(a), _ensure_f32_contiguous(b))
except Exception as e:
_record_fallback("vsa_bind", e)
return None


def vsa_bundle(a, b):
"""GPU VSA bundling (vector addition)."""
dev = _get_device()
if dev is None:
return None
try:
return _core.vsa_bundle(dev, _ensure_f32_contiguous(a), _ensure_f32_contiguous(b))
except Exception as e:
_record_fallback("vsa_bundle", e)
return None


def vsa_unbind(composite, key):
"""GPU VSA unbinding (circular correlation)."""
dev = _get_device()
if dev is None:
return None
try:
return _core.vsa_unbind(dev, _ensure_f32_contiguous(composite), _ensure_f32_contiguous(key))
except Exception as e:
_record_fallback("vsa_unbind", e)
return None


def vsa_bitpack(x):
"""GPU VSA bitpacking: convert float vector to uint32 bitset."""
dev = _get_device()
if dev is None:
return None
try:
return _core.vsa_bitpack(dev, _ensure_f32_contiguous(x))
except Exception as e:
_record_fallback("vsa_bitpack", e)
return None


def blake3_role(seed_string, dim):
"""Generate a pseudo-random VSA role vector using BLAKE3 hashing."""
if not _NATIVE or not hasattr(_core, "blake3_role"):
return None
try:
return _core.blake3_role(str(seed_string), int(dim))
except Exception as e:
_record_fallback("blake3_role", e)
return None


# ── Search & Retrieval ───────────────────────────────────────────────────


def hamming_search(query, codebook):
"""GPU Hamming search for nearest bitpacked vector."""
dev = _get_device()
if dev is None:
return None
try:
# Codebook should be uint32 (bitpacked)
return _core.hamming_search(dev, query, codebook)
except Exception as e:
_record_fallback("hamming_search", e)
return None


# ── CubeMind & Cognitive Ops ─────────────────────────────────────────────


def resonator_step(query, codebooks):
"""GPU Resonator Network step for VSA factorisation."""
dev = _get_device()
if dev is None:
return None
try:
# codebooks should be a list of ndarrays
cbs = [_ensure_f32_contiguous(cb) for cb in codebooks]
return _core.resonator(dev, _ensure_f32_contiguous(query), cbs)
except Exception as e:
_record_fallback("resonator_step", e)
return None


def semantic_assigner(x, semantic_codes):
"""GPU Semantic Assigner: maps continuous features to VSA semantic space."""
dev = _get_device()
if dev is None:
return None
try:
return _core.semantic_assigner(
dev, _ensure_f32_contiguous(x), _ensure_f32_contiguous(semantic_codes)
)
except Exception as e:
_record_fallback("semantic_assigner", e)
return None


# ── Optimized Transformer & Training ─────────────────────────────────────


def flash_attention2(q, k, v, mask=None, scale=None, use_rope=False):
"""GPU Flash Attention 2. Returns None on failure."""
dev = _get_device()
if dev is None:
return None
try:
q = _ensure_f32_contiguous(q)
k = _ensure_f32_contiguous(k)
v = _ensure_f32_contiguous(v)
mask = _ensure_f32_contiguous(mask)
# Use default scale if None
s = float(scale) if scale is not None else 1.0 / np.sqrt(q.shape[-1])
return _core.flash_attention2(dev, q, k, v, mask, s)
except Exception as e:
_record_fallback("flash_attention2", e)
return None


def adamw_update(
weights, grad, m, v, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01,
beta1_t=None, beta2_t=None, clear_grad=False
):
"""GPU AdamW optimizer update step."""
dev = _get_device()
if dev is None:
return None
try:
# Use provided bias correction factors or compute from t
b1_t = float(beta1_t) if beta1_t is not None else 1.0
b2_t = float(beta2_t) if beta2_t is not None else 1.0

return _core.adamw_update(
dev,
_ensure_f32_contiguous(weights),
_ensure_f32_contiguous(grad),
_ensure_f32_contiguous(m),
_ensure_f32_contiguous(v),
float(lr), float(beta1), float(beta2), float(eps),
float(weight_decay), b1_t, b2_t, bool(clear_grad)
)
except Exception as e:
_record_fallback("adamw_update", e)
return None
44 changes: 30 additions & 14 deletions backend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,19 @@ def attention_scores(self, queries, keys, num_heads, head_dim, scale=None):
"IIIIfI", batch_size, seq_len, num_heads, head_dim, scale, 0
)

# Dispatch
# Dispatch: x = seq_len, y = seq_len, z = batch_size * num_heads
workgroups_x = (seq_len + 15) // 16
workgroups_y = ((batch_size * num_heads * seq_len) + 15) // 16
workgroups_y = (seq_len + 15) // 16
workgroups_z = batch_size * num_heads

self.core._dispatch_compute(
pipeline,
pipeline_layout,
descriptor_set,
workgroups_x,
push_constants,
workgroups_y,
workgroups_z,
)

# Download results
Expand Down Expand Up @@ -958,7 +961,7 @@ def gqa_decode_attention(
self._upload_buffer(buf_v, v_flat)

pipeline, pipeline_layout, _ = self.pipelines.get_or_create_pipeline(
"gqa-attention", 5, push_constant_size=24
"gqa-attention", 5, push_constant_size=28
)

descriptor_set = self.pipelines.get_cached_descriptor_set(
Expand All @@ -972,17 +975,30 @@ def gqa_decode_attention(
],
)

push_constants = struct.pack(
"IIIIIf", batch_size, num_q_heads, num_kv_heads, head_dim, cache_len, scale
)

total_q = batch_size * num_q_heads
workgroups_x = (max(cache_len, head_dim) + 15) // 16
workgroups_y = (total_q + 15) // 16

self.core._dispatch_compute(
pipeline, pipeline_layout, descriptor_set, workgroups_x, push_constants, workgroups_y
)
total_q = batch_size * num_q_heads;
workgroups_scores_x = (cache_len + 15) // 16;
workgroups_scores_y = (total_q + 15) // 16;

workgroups_softmax_x = 1; # One thread per head
workgroups_softmax_y = (total_q + 15) // 16;

workgroups_out_x = (head_dim + 15) // 16;
workgroups_out_y = (total_q + 15) // 16;

with self.core.record_commands() as rec:
# Phase 0: Scores
push_0 = struct.pack("IIIIIfI", batch_size, num_q_heads, num_kv_heads, head_dim, cache_len, scale, 0)
rec.dispatch(pipeline, pipeline_layout, descriptor_set, (workgroups_scores_x, workgroups_scores_y, 1), push_0)
rec.barrier()

# Phase 1: Softmax
push_1 = struct.pack("IIIIIfI", batch_size, num_q_heads, num_kv_heads, head_dim, cache_len, scale, 1)
rec.dispatch(pipeline, pipeline_layout, descriptor_set, (workgroups_softmax_x, workgroups_softmax_y, 1), push_1)
rec.barrier()

# Phase 2: Output
push_2 = struct.pack("IIIIIfI", batch_size, num_q_heads, num_kv_heads, head_dim, cache_len, scale, 2)
rec.dispatch(pipeline, pipeline_layout, descriptor_set, (workgroups_out_x, workgroups_out_y, 1), push_2)

result = self._download_buffer(buf_out, out_bytes, np.float32)
self._release_buffers([buf_q, buf_k, buf_v, buf_out, buf_scores])
Expand Down
8 changes: 6 additions & 2 deletions backend/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,17 +795,21 @@ def activation_swiglu(self, input_data):
output_shape = original_shape[:-1] + (hidden_dim,)
return result.reshape(output_shape)

def activation_softmax(self, input_data, axis=-1):
def activation_softmax(self, input_data, axis=-1, **kwargs):
"""
Apply softmax activation: exp(x) / sum(exp(x))

Args:
input_data: Input array
axis: Axis along which to compute softmax (default: -1)
axis: Axis along which to compute softmax
kwargs: Accepts 'dim' as an alias for 'axis'

Returns:
Softmax probabilities
"""
if "dim" in kwargs:
axis = kwargs.pop("dim")

# Check if shader is available
if "activation-softmax" not in self.shaders:
# CPU fallback (numba if available)
Expand Down
29 changes: 29 additions & 0 deletions cpp/include/grilly/ops/bandit.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "grilly/command_batch.h"
#include "grilly/buffer_pool.h"
#include "grilly/pipeline_cache.h"

namespace grilly {
namespace ops {

struct BanditParams {
uint32_t nArms;
uint32_t nInstances;
uint32_t iters;
float delta;
};

/**
* Bandit Top-2 Solver & Stopping Criterion
*
* Computes TargetW (K, nInstances) and StopFlags (nInstances).
*/
void banditSolve(CommandBatch& batch, BufferPool& pool,
PipelineCache& cache,
const float* muHat, const float* N,
float* targetW, uint32_t* stopFlags,
const BanditParams& p);

} // namespace ops
} // namespace grilly
46 changes: 46 additions & 0 deletions cpp/include/grilly/ops/eggroll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "grilly/command_batch.h"
#include "grilly/buffer_pool.h"
#include "grilly/pipeline_cache.h"

namespace grilly {
namespace ops {

struct EggrollGenParams {
uint32_t dOut;
uint32_t dIn;
uint32_t nWorkers;
uint32_t seed;
float sigma;
};

struct EggrollUpdateParams {
uint32_t dOut;
uint32_t dIn;
uint32_t topK;
uint32_t nWorkers;
float meritIncrease;
float meritDecay;
};

/**
* Generate perturbations (U, V)
*/
void eggrollGenerate(CommandBatch& batch, BufferPool& pool,
PipelineCache& cache,
float* U, float* V,
const EggrollGenParams& p);

/**
* Apply fused update to weights and merit
*/
void eggrollUpdate(CommandBatch& batch, BufferPool& pool,
PipelineCache& cache,
float* W, float* Merit,
const float* U, const float* V,
const uint32_t* topIdx, const float* topWeights,
const EggrollUpdateParams& p);

} // namespace ops
} // namespace grilly
25 changes: 24 additions & 1 deletion cpp/include/grilly/ops/loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,30 @@ void crossEntropyBackward(CommandBatch& batch, BufferPool& pool,
PipelineCache& cache,
const float* logits, const uint32_t* targets,
float* gradLogits,
const CrossEntropyBackwardParams& p);
const CrossEntropyBackwardParams& p);

// ── MSE Loss ────────────────────────────────────────────────────────────

struct MSELossParams {
uint32_t n;
};

void mseLoss(CommandBatch& batch, BufferPool& pool,
PipelineCache& cache,
const float* preds, const float* targets,
float* losses, const MSELossParams& p);

// ── Cosine Similarity Loss ──────────────────────────────────────────────

struct CosineLossParams {
uint32_t batchSize;
uint32_t dim;
};

void cosineSimilarityLoss(CommandBatch& batch, BufferPool& pool,
PipelineCache& cache,
const float* preds, const float* targets,
float* losses, const CosineLossParams& p);

} // namespace ops
} // namespace grilly
Loading
Loading