Skip to content

Commit 076174f

Browse files
gh-122: Fix CONV.
1 parent 8a3bd91 commit 076174f

1 file changed

Lines changed: 130 additions & 3 deletions

File tree

src/builtins.c

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3545,14 +3545,140 @@ static Value builtin_shape(Interpreter* interp, Value* args, int argc, Expr** ar
35453545
// CONV: N-D discrete convolution (two-argument backward-compatible form)
35463546
// Usage: CONV(TNS: x, TNS: kernel) -> TNS (same shape as x)
35473547
static Value builtin_conv(Interpreter* interp, Value* args, int argc, Expr** arg_nodes, Env* env, int line, int col) {
3548-
(void)arg_nodes; (void)env; (void)argc;
3548+
(void)arg_nodes; (void)env;
35493549
if (args[0].type != VAL_TNS || args[1].type != VAL_TNS) {
35503550
RUNTIME_ERROR(interp, "CONV expects (TNS, TNS)", line, col);
35513551
}
35523552
Tensor* x = args[0].as.tns;
35533553
Tensor* k = args[1].as.tns;
35543554

3555-
// kernel must have same rank
3555+
// Extended 2-D multi-output form triggered when more than two arguments provided
3556+
if (argc > 2) {
3557+
if (x->ndim != 3) {
3558+
RUNTIME_ERROR(interp, "CONV extended form requires input rank 3", line, col);
3559+
}
3560+
if (k->ndim != 4) {
3561+
RUNTIME_ERROR(interp, "CONV extended form requires kernel rank 4", line, col);
3562+
}
3563+
3564+
size_t in_w = x->shape[0];
3565+
size_t in_h = x->shape[1];
3566+
size_t in_c = x->shape[2];
3567+
size_t kw = k->shape[0];
3568+
size_t kh = k->shape[1];
3569+
size_t k_in_c = k->shape[2];
3570+
size_t out_c = k->shape[3];
3571+
3572+
if (k_in_c != in_c) {
3573+
RUNTIME_ERROR(interp, "CONV kernel input channels must match x channels", line, col);
3574+
}
3575+
3576+
// Element types must be numeric
3577+
if (!((x->elem_type == TYPE_INT || x->elem_type == TYPE_FLT) && (k->elem_type == TYPE_INT || k->elem_type == TYPE_FLT))) {
3578+
RUNTIME_ERROR(interp, "CONV only supports INT or FLT element types", line, col);
3579+
}
3580+
3581+
// Parse optional args: stride_w, stride_h, pad_w, pad_h, bias
3582+
int64_t stride_w = 1, stride_h = 1, pad_w = 0, pad_h = 0;
3583+
if (argc > 2 && args[2].type != VAL_NULL) { EXPECT_INT(args[2], "CONV", interp, line, col); stride_w = args[2].as.i; }
3584+
if (argc > 3 && args[3].type != VAL_NULL) { EXPECT_INT(args[3], "CONV", interp, line, col); stride_h = args[3].as.i; }
3585+
if (argc > 4 && args[4].type != VAL_NULL) { EXPECT_INT(args[4], "CONV", interp, line, col); pad_w = args[4].as.i; }
3586+
if (argc > 5 && args[5].type != VAL_NULL) { EXPECT_INT(args[5], "CONV", interp, line, col); pad_h = args[5].as.i; }
3587+
3588+
if (stride_w <= 0 || stride_h <= 0 || pad_w < 0 || pad_h < 0) {
3589+
RUNTIME_ERROR(interp, "CONV invalid stride/pad", line, col);
3590+
}
3591+
3592+
bool bias_present = false;
3593+
Tensor* bias_t = NULL;
3594+
if (argc > 6 && args[6].type != VAL_NULL) {
3595+
if (args[6].type != VAL_TNS) {
3596+
RUNTIME_ERROR(interp, "CONV bias must be TNS", line, col);
3597+
}
3598+
bias_present = true;
3599+
bias_t = args[6].as.tns;
3600+
if ((bias_t->ndim != 1 && bias_t->length != 0) || (bias_t->length != 0 && bias_t->shape[0] != out_c)) {
3601+
RUNTIME_ERROR(interp, "CONV bias size mismatch", line, col);
3602+
}
3603+
}
3604+
3605+
// Output typing
3606+
DeclType out_decl = (x->elem_type == TYPE_INT && k->elem_type == TYPE_INT) ? TYPE_INT : TYPE_FLT;
3607+
3608+
// Compute output shape
3609+
int64_t out_w_i = ((int64_t)in_w + 2 * pad_w - (int64_t)kw) / stride_w + 1;
3610+
int64_t out_h_i = ((int64_t)in_h + 2 * pad_h - (int64_t)kh) / stride_h + 1;
3611+
if (out_w_i <= 0 || out_h_i <= 0) {
3612+
size_t out_shape_zero[3] = {0, 0, out_c};
3613+
return value_tns_new(out_decl, 3, out_shape_zero);
3614+
}
3615+
size_t out_w = (size_t)out_w_i;
3616+
size_t out_h = (size_t)out_h_i;
3617+
3618+
size_t out_shape[3]; out_shape[0] = out_w; out_shape[1] = out_h; out_shape[2] = out_c;
3619+
Value out = value_tns_new(out_decl, 3, out_shape);
3620+
Tensor* ot = out.as.tns;
3621+
3622+
// Perform convolution: output indices order [w,h,oc]
3623+
for (size_t ow = 0; ow < out_w; ow++) {
3624+
for (size_t oh = 0; oh < out_h; oh++) {
3625+
for (size_t oc = 0; oc < out_c; oc++) {
3626+
if (out_decl == TYPE_INT) {
3627+
int64_t acc = 0;
3628+
for (size_t kx = 0; kx < kw; kx++) {
3629+
for (size_t ky = 0; ky < kh; ky++) {
3630+
for (size_t ic = 0; ic < in_c; ic++) {
3631+
int64_t in_x = (int64_t)ow * stride_w + (int64_t)kx - pad_w;
3632+
int64_t in_y = (int64_t)oh * stride_h + (int64_t)ky - pad_h;
3633+
if (in_x < 0 || in_y < 0 || (size_t)in_x >= in_w || (size_t)in_y >= in_h) continue; // zero pad
3634+
size_t in_off = (size_t)in_x * x->strides[0] + (size_t)in_y * x->strides[1] + ic * x->strides[2];
3635+
size_t k_off = kx * k->strides[0] + ky * k->strides[1] + ic * k->strides[2] + oc * k->strides[3];
3636+
Value vx = x->data[in_off];
3637+
Value vk = k->data[k_off];
3638+
if (vx.type != VAL_INT || vk.type != VAL_INT) { value_free(out); RUNTIME_ERROR(interp, "CONV integer-mode requires INT elements", line, col); }
3639+
acc += vx.as.i * vk.as.i;
3640+
}
3641+
}
3642+
}
3643+
if (bias_present && bias_t->length > 0) {
3644+
Value bv = bias_t->data[oc];
3645+
if (bv.type == VAL_INT) acc += bv.as.i;
3646+
else if (bv.type == VAL_FLT) acc += (int64_t)bv.as.f;
3647+
else { value_free(out); RUNTIME_ERROR(interp, "CONV bias must be numeric", line, col); }
3648+
}
3649+
ot->data[ow * ot->strides[0] + oh * ot->strides[1] + oc * ot->strides[2]] = value_int(acc);
3650+
} else {
3651+
double acc = 0.0;
3652+
for (size_t kx = 0; kx < kw; kx++) {
3653+
for (size_t ky = 0; ky < kh; ky++) {
3654+
for (size_t ic = 0; ic < in_c; ic++) {
3655+
int64_t in_x = (int64_t)ow * stride_w + (int64_t)kx - pad_w;
3656+
int64_t in_y = (int64_t)oh * stride_h + (int64_t)ky - pad_h;
3657+
if (in_x < 0 || in_y < 0 || (size_t)in_x >= in_w || (size_t)in_y >= in_h) continue;
3658+
size_t in_off = (size_t)in_x * x->strides[0] + (size_t)in_y * x->strides[1] + ic * x->strides[2];
3659+
size_t k_off = kx * k->strides[0] + ky * k->strides[1] + ic * k->strides[2] + oc * k->strides[3];
3660+
Value vx = x->data[in_off];
3661+
Value vk = k->data[k_off];
3662+
double aval = (vx.type == VAL_FLT) ? vx.as.f : (double)vx.as.i;
3663+
double kval = (vk.type == VAL_FLT) ? vk.as.f : (double)vk.as.i;
3664+
acc += aval * kval;
3665+
}
3666+
}
3667+
}
3668+
if (bias_present && bias_t->length > 0) {
3669+
Value bv = bias_t->data[oc];
3670+
double bval = (bv.type == VAL_FLT) ? bv.as.f : (double)bv.as.i;
3671+
acc += bval;
3672+
}
3673+
ot->data[ow * ot->strides[0] + oh * ot->strides[1] + oc * ot->strides[2]] = value_flt(acc);
3674+
}
3675+
}
3676+
}
3677+
}
3678+
return out;
3679+
}
3680+
3681+
// Legacy two-argument N-D convolution (backward-compatible)
35563682
if (x->ndim != k->ndim) {
35573683
RUNTIME_ERROR(interp, "CONV kernel must have same rank as input", line, col);
35583684
}
@@ -8166,6 +8292,7 @@ static const char* builtin_params_match[] = {"value", "template", "typing", "rec
81668292
static const char* builtin_params_readfile[] = {"path", "coding"};
81678293
static const char* builtin_params_writefile[] = {"data", "path", "coding"};
81688294
static const char* builtin_params_pause[] = {"thr", "seconds"};
8295+
static const char* builtin_params_conv[] = {"x", "kernel", "stride_w", "stride_h", "pad_w", "pad_h", "bias"};
81698296

81708297
static BuiltinFunction builtins_table[] = {
81718298
// Arithmetic
@@ -8204,7 +8331,7 @@ static BuiltinFunction builtins_table[] = {
82048331
{"TINT", 1, 1, builtin_tint},
82058332
{"TFLT", 1, 1, builtin_tflt},
82068333
{"TSTR", 1, 1, builtin_tstr},
8207-
{"CONV", 2, 2, builtin_conv},
8334+
{"CONV", 2, 7, builtin_conv, builtin_params_conv, 7},
82088335
{"FILL", 2, 2, builtin_fill},
82098336
{"TADD", 2, 2, builtin_tadd},
82108337
{"TSUB", 2, 2, builtin_tsub},

0 commit comments

Comments
 (0)