@@ -3545,14 +3545,140 @@ static Value builtin_shape(Interpreter* interp, Value* args, int argc, Expr** ar
35453545// CONV: N-D discrete convolution (two-argument backward-compatible form)
35463546// Usage: CONV(TNS: x, TNS: kernel) -> TNS (same shape as x)
35473547static Value builtin_conv (Interpreter * interp , Value * args , int argc , Expr * * arg_nodes , Env * env , int line , int col ) {
3548- (void )arg_nodes ; (void )env ; ( void ) argc ;
3548+ (void )arg_nodes ; (void )env ;
35493549 if (args [0 ].type != VAL_TNS || args [1 ].type != VAL_TNS ) {
35503550 RUNTIME_ERROR (interp , "CONV expects (TNS, TNS)" , line , col );
35513551 }
35523552 Tensor * x = args [0 ].as .tns ;
35533553 Tensor * k = args [1 ].as .tns ;
35543554
3555- // kernel must have same rank
3555+ // Extended 2-D multi-output form triggered when more than two arguments provided
3556+ if (argc > 2 ) {
3557+ if (x -> ndim != 3 ) {
3558+ RUNTIME_ERROR (interp , "CONV extended form requires input rank 3" , line , col );
3559+ }
3560+ if (k -> ndim != 4 ) {
3561+ RUNTIME_ERROR (interp , "CONV extended form requires kernel rank 4" , line , col );
3562+ }
3563+
3564+ size_t in_w = x -> shape [0 ];
3565+ size_t in_h = x -> shape [1 ];
3566+ size_t in_c = x -> shape [2 ];
3567+ size_t kw = k -> shape [0 ];
3568+ size_t kh = k -> shape [1 ];
3569+ size_t k_in_c = k -> shape [2 ];
3570+ size_t out_c = k -> shape [3 ];
3571+
3572+ if (k_in_c != in_c ) {
3573+ RUNTIME_ERROR (interp , "CONV kernel input channels must match x channels" , line , col );
3574+ }
3575+
3576+ // Element types must be numeric
3577+ if (!((x -> elem_type == TYPE_INT || x -> elem_type == TYPE_FLT ) && (k -> elem_type == TYPE_INT || k -> elem_type == TYPE_FLT ))) {
3578+ RUNTIME_ERROR (interp , "CONV only supports INT or FLT element types" , line , col );
3579+ }
3580+
3581+ // Parse optional args: stride_w, stride_h, pad_w, pad_h, bias
3582+ int64_t stride_w = 1 , stride_h = 1 , pad_w = 0 , pad_h = 0 ;
3583+ if (argc > 2 && args [2 ].type != VAL_NULL ) { EXPECT_INT (args [2 ], "CONV" , interp , line , col ); stride_w = args [2 ].as .i ; }
3584+ if (argc > 3 && args [3 ].type != VAL_NULL ) { EXPECT_INT (args [3 ], "CONV" , interp , line , col ); stride_h = args [3 ].as .i ; }
3585+ if (argc > 4 && args [4 ].type != VAL_NULL ) { EXPECT_INT (args [4 ], "CONV" , interp , line , col ); pad_w = args [4 ].as .i ; }
3586+ if (argc > 5 && args [5 ].type != VAL_NULL ) { EXPECT_INT (args [5 ], "CONV" , interp , line , col ); pad_h = args [5 ].as .i ; }
3587+
3588+ if (stride_w <= 0 || stride_h <= 0 || pad_w < 0 || pad_h < 0 ) {
3589+ RUNTIME_ERROR (interp , "CONV invalid stride/pad" , line , col );
3590+ }
3591+
3592+ bool bias_present = false;
3593+ Tensor * bias_t = NULL ;
3594+ if (argc > 6 && args [6 ].type != VAL_NULL ) {
3595+ if (args [6 ].type != VAL_TNS ) {
3596+ RUNTIME_ERROR (interp , "CONV bias must be TNS" , line , col );
3597+ }
3598+ bias_present = true;
3599+ bias_t = args [6 ].as .tns ;
3600+ if ((bias_t -> ndim != 1 && bias_t -> length != 0 ) || (bias_t -> length != 0 && bias_t -> shape [0 ] != out_c )) {
3601+ RUNTIME_ERROR (interp , "CONV bias size mismatch" , line , col );
3602+ }
3603+ }
3604+
3605+ // Output typing
3606+ DeclType out_decl = (x -> elem_type == TYPE_INT && k -> elem_type == TYPE_INT ) ? TYPE_INT : TYPE_FLT ;
3607+
3608+ // Compute output shape
3609+ int64_t out_w_i = ((int64_t )in_w + 2 * pad_w - (int64_t )kw ) / stride_w + 1 ;
3610+ int64_t out_h_i = ((int64_t )in_h + 2 * pad_h - (int64_t )kh ) / stride_h + 1 ;
3611+ if (out_w_i <= 0 || out_h_i <= 0 ) {
3612+ size_t out_shape_zero [3 ] = {0 , 0 , out_c };
3613+ return value_tns_new (out_decl , 3 , out_shape_zero );
3614+ }
3615+ size_t out_w = (size_t )out_w_i ;
3616+ size_t out_h = (size_t )out_h_i ;
3617+
3618+ size_t out_shape [3 ]; out_shape [0 ] = out_w ; out_shape [1 ] = out_h ; out_shape [2 ] = out_c ;
3619+ Value out = value_tns_new (out_decl , 3 , out_shape );
3620+ Tensor * ot = out .as .tns ;
3621+
3622+ // Perform convolution: output indices order [w,h,oc]
3623+ for (size_t ow = 0 ; ow < out_w ; ow ++ ) {
3624+ for (size_t oh = 0 ; oh < out_h ; oh ++ ) {
3625+ for (size_t oc = 0 ; oc < out_c ; oc ++ ) {
3626+ if (out_decl == TYPE_INT ) {
3627+ int64_t acc = 0 ;
3628+ for (size_t kx = 0 ; kx < kw ; kx ++ ) {
3629+ for (size_t ky = 0 ; ky < kh ; ky ++ ) {
3630+ for (size_t ic = 0 ; ic < in_c ; ic ++ ) {
3631+ int64_t in_x = (int64_t )ow * stride_w + (int64_t )kx - pad_w ;
3632+ int64_t in_y = (int64_t )oh * stride_h + (int64_t )ky - pad_h ;
3633+ if (in_x < 0 || in_y < 0 || (size_t )in_x >= in_w || (size_t )in_y >= in_h ) continue ; // zero pad
3634+ size_t in_off = (size_t )in_x * x -> strides [0 ] + (size_t )in_y * x -> strides [1 ] + ic * x -> strides [2 ];
3635+ size_t k_off = kx * k -> strides [0 ] + ky * k -> strides [1 ] + ic * k -> strides [2 ] + oc * k -> strides [3 ];
3636+ Value vx = x -> data [in_off ];
3637+ Value vk = k -> data [k_off ];
3638+ if (vx .type != VAL_INT || vk .type != VAL_INT ) { value_free (out ); RUNTIME_ERROR (interp , "CONV integer-mode requires INT elements" , line , col ); }
3639+ acc += vx .as .i * vk .as .i ;
3640+ }
3641+ }
3642+ }
3643+ if (bias_present && bias_t -> length > 0 ) {
3644+ Value bv = bias_t -> data [oc ];
3645+ if (bv .type == VAL_INT ) acc += bv .as .i ;
3646+ else if (bv .type == VAL_FLT ) acc += (int64_t )bv .as .f ;
3647+ else { value_free (out ); RUNTIME_ERROR (interp , "CONV bias must be numeric" , line , col ); }
3648+ }
3649+ ot -> data [ow * ot -> strides [0 ] + oh * ot -> strides [1 ] + oc * ot -> strides [2 ]] = value_int (acc );
3650+ } else {
3651+ double acc = 0.0 ;
3652+ for (size_t kx = 0 ; kx < kw ; kx ++ ) {
3653+ for (size_t ky = 0 ; ky < kh ; ky ++ ) {
3654+ for (size_t ic = 0 ; ic < in_c ; ic ++ ) {
3655+ int64_t in_x = (int64_t )ow * stride_w + (int64_t )kx - pad_w ;
3656+ int64_t in_y = (int64_t )oh * stride_h + (int64_t )ky - pad_h ;
3657+ if (in_x < 0 || in_y < 0 || (size_t )in_x >= in_w || (size_t )in_y >= in_h ) continue ;
3658+ size_t in_off = (size_t )in_x * x -> strides [0 ] + (size_t )in_y * x -> strides [1 ] + ic * x -> strides [2 ];
3659+ size_t k_off = kx * k -> strides [0 ] + ky * k -> strides [1 ] + ic * k -> strides [2 ] + oc * k -> strides [3 ];
3660+ Value vx = x -> data [in_off ];
3661+ Value vk = k -> data [k_off ];
3662+ double aval = (vx .type == VAL_FLT ) ? vx .as .f : (double )vx .as .i ;
3663+ double kval = (vk .type == VAL_FLT ) ? vk .as .f : (double )vk .as .i ;
3664+ acc += aval * kval ;
3665+ }
3666+ }
3667+ }
3668+ if (bias_present && bias_t -> length > 0 ) {
3669+ Value bv = bias_t -> data [oc ];
3670+ double bval = (bv .type == VAL_FLT ) ? bv .as .f : (double )bv .as .i ;
3671+ acc += bval ;
3672+ }
3673+ ot -> data [ow * ot -> strides [0 ] + oh * ot -> strides [1 ] + oc * ot -> strides [2 ]] = value_flt (acc );
3674+ }
3675+ }
3676+ }
3677+ }
3678+ return out ;
3679+ }
3680+
3681+ // Legacy two-argument N-D convolution (backward-compatible)
35563682 if (x -> ndim != k -> ndim ) {
35573683 RUNTIME_ERROR (interp , "CONV kernel must have same rank as input" , line , col );
35583684 }
@@ -8166,6 +8292,7 @@ static const char* builtin_params_match[] = {"value", "template", "typing", "rec
81668292static const char * builtin_params_readfile [] = {"path" , "coding" };
81678293static const char * builtin_params_writefile [] = {"data" , "path" , "coding" };
81688294static const char * builtin_params_pause [] = {"thr" , "seconds" };
8295+ static const char * builtin_params_conv [] = {"x" , "kernel" , "stride_w" , "stride_h" , "pad_w" , "pad_h" , "bias" };
81698296
81708297static BuiltinFunction builtins_table [] = {
81718298 // Arithmetic
@@ -8204,7 +8331,7 @@ static BuiltinFunction builtins_table[] = {
82048331 {"TINT" , 1 , 1 , builtin_tint },
82058332 {"TFLT" , 1 , 1 , builtin_tflt },
82068333 {"TSTR" , 1 , 1 , builtin_tstr },
8207- {"CONV" , 2 , 2 , builtin_conv },
8334+ {"CONV" , 2 , 7 , builtin_conv , builtin_params_conv , 7 },
82088335 {"FILL" , 2 , 2 , builtin_fill },
82098336 {"TADD" , 2 , 2 , builtin_tadd },
82108337 {"TSUB" , 2 , 2 , builtin_tsub },
0 commit comments