Skip to content

Commit fc3275c

Browse files
authored
test: add alpha & beta tests for blas/base/ggemm
PR-URL: #12194 Reviewed-by: Athan Reines <kgryte@gmail.com>
1 parent 9330de7 commit fc3275c

4 files changed

Lines changed: 135 additions & 2 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"transA": "no-transpose",
3+
"transB": "no-transpose",
4+
"M": 2,
5+
"N": 4,
6+
"K": 3,
7+
"alpha": 2.0,
8+
"A": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ],
9+
"strideA1": 3,
10+
"strideA2": 1,
11+
"offsetA": 0,
12+
"B": [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
13+
"strideB1": 4,
14+
"strideB2": 1,
15+
"offsetB": 0,
16+
"beta": 3.0,
17+
"C": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 ],
18+
"strideC1": 4,
19+
"strideC2": 1,
20+
"offsetC": 0,
21+
"C_out": [ 15.0, 18.0, 21.0, 24.0, 45.0, 48.0, 51.0, 54.0 ]
22+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"order": "row-major",
3+
"transA": "no-transpose",
4+
"transB": "no-transpose",
5+
"M": 2,
6+
"N": 4,
7+
"K": 3,
8+
"alpha": 2.0,
9+
"A": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ],
10+
"lda": 3,
11+
"B": [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
12+
"ldb": 4,
13+
"beta": 3.0,
14+
"C": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 ],
15+
"ldc": 4,
16+
"C_out": [ 15.0, 18.0, 21.0, 24.0, 45.0, 48.0, 51.0, 54.0 ]
17+
}

lib/node_modules/@stdlib/blas/base/ggemm/test/test.main.js

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* limitations under the License.
1717
*/
1818

19-
/* eslint-disable max-len */
19+
/* eslint-disable max-len, stdlib/no-empty-lines-between-requires */
2020

2121
'use strict';
2222

@@ -41,6 +41,7 @@ var rntantb = require( './fixtures/row_major_nta_ntb.json' );
4141
var rtantb = require( './fixtures/row_major_ta_ntb.json' );
4242
var rntatb = require( './fixtures/row_major_nta_tb.json' );
4343
var rtatb = require( './fixtures/row_major_ta_tb.json' );
44+
var rntantbAlpha2Beta3 = require( './fixtures/row_major_nta_ntb_alpha2_beta3.json' );
4445

4546

4647
// TESTS //
@@ -1393,3 +1394,49 @@ tape( 'if `α` is `0` and `β` is neither `0` nor `1`, the function returns the
13931394

13941395
t.end();
13951396
});
1397+
1398+
tape( 'the function correctly applies both `α` and `β` scalars (row-major, no-transpose, no-transpose, α=2, β=3)', function test( t ) {
1399+
var expected;
1400+
var data;
1401+
var out;
1402+
var a;
1403+
var b;
1404+
var c;
1405+
1406+
data = rntantbAlpha2Beta3;
1407+
1408+
a = copy( data.A );
1409+
b = copy( data.B );
1410+
c = copy( data.C );
1411+
1412+
expected = data.C_out;
1413+
1414+
out = ggemm( data.order, data.transA, data.transB, data.M, data.N, data.K, data.alpha, a, data.lda, b, data.ldb, data.beta, c, data.ldc );
1415+
t.strictEqual( out, c, 'returns expected value' );
1416+
t.deepEqual( out, expected, 'returns expected value' );
1417+
t.end();
1418+
});
1419+
1420+
tape( 'the function correctly applies both `α` and `β` scalars (row-major, no-transpose, no-transpose, α=2, β=3) (accessors)', function test( t ) {
1421+
var expected;
1422+
var data;
1423+
var cbuf;
1424+
var out;
1425+
var a;
1426+
var b;
1427+
var c;
1428+
1429+
data = rntantbAlpha2Beta3;
1430+
1431+
a = toAccessorArray( copy( data.A ) );
1432+
b = toAccessorArray( copy( data.B ) );
1433+
cbuf = copy( data.C );
1434+
c = toAccessorArray( cbuf );
1435+
1436+
expected = data.C_out;
1437+
1438+
out = ggemm( data.order, data.transA, data.transB, data.M, data.N, data.K, data.alpha, a, data.lda, b, data.ldb, data.beta, c, data.ldc );
1439+
t.strictEqual( out, c, 'returns expected value' );
1440+
t.deepEqual( cbuf, expected, 'returns expected value' );
1441+
t.end();
1442+
});

lib/node_modules/@stdlib/blas/base/ggemm/test/test.ndarray.js

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* limitations under the License.
1717
*/
1818

19-
/* eslint-disable max-len */
19+
/* eslint-disable max-len, stdlib/no-empty-lines-between-requires */
2020

2121
'use strict';
2222

@@ -85,6 +85,7 @@ var rarbrcntantboa = require( './fixtures/ra_rb_rc_nta_ntb_oa.json' );
8585
var rarbrcntantbob = require( './fixtures/ra_rb_rc_nta_ntb_ob.json' );
8686
var rarbrcntantboc = require( './fixtures/ra_rb_rc_nta_ntb_oc.json' );
8787
var cap = require( './fixtures/ra_rb_rc_nta_ntb_complex_access_pattern.json' );
88+
var rarbrcntantbAlpha2Beta3 = require( './fixtures/ra_rb_rc_nta_ntb_alpha2_beta3.json' );
8889

8990

9091
// TESTS //
@@ -2945,3 +2946,49 @@ tape( 'the function supports computation over large arrays (column-major, column
29452946
t.deepEqual( out, expected, 'returns expected value' );
29462947
t.end();
29472948
});
2949+
2950+
tape( 'the function correctly applies both `α` and `β` scalars (row-major, row-major, row-major, no-transpose, no-transpose, α=2, β=3)', function test( t ) {
2951+
var expected;
2952+
var data;
2953+
var out;
2954+
var a;
2955+
var b;
2956+
var c;
2957+
2958+
data = rarbrcntantbAlpha2Beta3;
2959+
2960+
a = copy( data.A );
2961+
b = copy( data.B );
2962+
c = copy( data.C );
2963+
2964+
expected = data.C_out;
2965+
2966+
out = ggemm( data.transA, data.transB, data.M, data.N, data.K, data.alpha, a, data.strideA1, data.strideA2, data.offsetA, b, data.strideB1, data.strideB2, data.offsetB, data.beta, c, data.strideC1, data.strideC2, data.offsetC );
2967+
t.strictEqual( out, c, 'returns expected value' );
2968+
t.deepEqual( out, expected, 'returns expected value' );
2969+
t.end();
2970+
});
2971+
2972+
tape( 'the function correctly applies both `α` and `β` scalars (row-major, row-major, row-major, no-transpose, no-transpose, α=2, β=3) (accessors)', function test( t ) {
2973+
var expected;
2974+
var data;
2975+
var cbuf;
2976+
var out;
2977+
var a;
2978+
var b;
2979+
var c;
2980+
2981+
data = rarbrcntantbAlpha2Beta3;
2982+
2983+
a = toAccessorArray( copy( data.A ) );
2984+
b = toAccessorArray( copy( data.B ) );
2985+
cbuf = copy( data.C );
2986+
c = toAccessorArray( cbuf );
2987+
2988+
expected = data.C_out;
2989+
2990+
out = ggemm( data.transA, data.transB, data.M, data.N, data.K, data.alpha, a, data.strideA1, data.strideA2, data.offsetA, b, data.strideB1, data.strideB2, data.offsetB, data.beta, c, data.strideC1, data.strideC2, data.offsetC );
2991+
t.strictEqual( out, c, 'returns expected value' );
2992+
t.deepEqual( cbuf, expected, 'returns expected value' );
2993+
t.end();
2994+
});

0 commit comments

Comments
 (0)