Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/pool2d.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
'use strict';

import {Tensor} from './lib/tensor.js';
import {Tensor, Scalar} from './lib/tensor.js';
import {transpose} from './transpose.js';
import {l2Reducer, meanReducer, maxReducer} from './reduce.js';
import {pow} from './binary.js';
import {sqrt} from './unary.js';
import {maxReducer, meanReducer, sumReducer} from './reduce.js';
import {validatePool2dParams} from './lib/validate-input.js';

/**
Expand Down Expand Up @@ -120,5 +122,7 @@ export function maxPool2d(input, options = {}) {
* @return {Tensor}
*/
export function l2Pool2d(input, options = {}) {
return pool2d(input, l2Reducer, options);
const squaredInput = pow(input, new Scalar(2));
const pooledInput = pool2d(squaredInput, sumReducer, options);
return sqrt(pooledInput);
}
19 changes: 6 additions & 13 deletions src/reduce.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ export function meanReducer(previousValue, currentValue, currentIndex, array) {
}
}

/* The sum reducer */
export function sumReducer(previousValue, currentValue) {
return previousValue + currentValue;
}

/**
* Compute the average value of all the input values along the axes.
* @param {Tensor} input
Expand Down Expand Up @@ -140,8 +145,7 @@ export function reduceProduct(input, options = {}) {
* @return {Tensor}
*/
export function reduceSum(input, options = {}) {
return reduce(input,
(previousValue, currentValue) => previousValue + currentValue, options);
return reduce(input, sumReducer, options);
}

/**
Expand All @@ -164,17 +168,6 @@ export function reduceL1(input, options = {}) {
return reduceSum(abs(input), options);
}

/* The l2 reducer */
export function l2Reducer(previousValue, currentValue, currentIndex, array) {
if (currentIndex == 1) {
const sumOfSquares = previousValue * previousValue + currentValue * currentValue;
return sumOfSquares;
} else {
const sumOfSquares = previousValue + currentValue * currentValue;
return (currentIndex === array.length - 1) ? Math.sqrt(sumOfSquares) :sumOfSquares;
}
}

/**
* Compute the L2 norm of all the input values along the axes.
* @param {Tensor} input
Expand Down
18 changes: 18 additions & 0 deletions test/l2pool2d_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,22 @@ describe('test pool2d', function() {
];
utils.checkValue(y, expected);
});

it('l2Pool2d pads roundingType=ceil', function() {
const x = new Tensor([1, 1, 5, 5], [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
]);
const windowDimensions = [3, 3];
const padding = [1, 0, 0, 1];
const strides = [2, 2];
const y = l2Pool2d(x, {windowDimensions, padding, strides, roundingType: 'ceil'});
utils.checkShape(y, [1, 1, 3, 3]);
const expected = [
12.767145334803704, 17.175564037317667, 11.180339887498949,
38.1051177665153, 43.81780460041329, 26.92582403567252,
48.19751030914356, 53.056573579529235, 32.01562118716424,
];
utils.checkValue(y, expected);
});
});