diff --git a/src/explainer/lstsq.ts b/src/explainer/lstsq.ts index 741ea2f..b2c1185 100644 --- a/src/explainer/lstsq.ts +++ b/src/explainer/lstsq.ts @@ -4,7 +4,7 @@ */ import math from '../../src/utils/math-import'; -import { tensor2d } from '@tensorflow/tfjs'; +import { tensor2d, tidy } from '@tensorflow/tfjs'; /** * Solves linear least squares problems for given input matrix `x`, @@ -48,45 +48,48 @@ export const lstsq = ( wMat = math.matrix(math.diag(values)); } - // Matrix multiplication is too slow in math.js, we use ml-matrix instead - const xTensor = tensor2d( - x.toArray() as number[][], - x.size() as [number, number] - ); - const wTensor = tensor2d( - wMat.toArray() as number[][], - wMat.size() as [number, number] - ); - const yTensor = tensor2d( - y.toArray() as number[][], - y.size() as [number, number] - ); + const result = tidy(() => { + // Matrix multiplication is too slow in math.js, we use ml-matrix instead + const xTensor = tensor2d( + x.toArray() as number[][], + x.size() as [number, number] + ); + const wTensor = tensor2d( + wMat.toArray() as number[][], + wMat.size() as [number, number] + ); + const yTensor = tensor2d( + y.toArray() as number[][], + y.size() as [number, number] + ); - const left = xTensor.transpose().matMul(wTensor).matMul(xTensor); - const right = xTensor.transpose().matMul(wTensor).matMul(yTensor); + const left = xTensor.transpose().matMul(wTensor).matMul(xTensor); + const right = xTensor.transpose().matMul(wTensor).matMul(yTensor); - // Convert `left` back to math.js for inversion - const left2D = left.arraySync() as number[][]; - const leftMat = math.matrix(left2D); - const leftDet = math.det(leftMat); + // Convert `left` back to math.js for inversion + const left2D = left.arraySync() as number[][]; + const leftMat = math.matrix(left2D); + const leftDet = math.det(leftMat); - // Invertible matrix - let leftInverse: math.Matrix; - if (leftDet !== 0) { - leftInverse = math.inv(leftMat); - } else { - // Singular matrix => we take pseudo-inverse instead - console.warn('Matrix x is singular, use pseudo-inverse instead.'); - leftInverse = math.pinv(leftMat); - } + // Invertible matrix + let leftInverse: math.Matrix; + if (leftDet !== 0) { + leftInverse = math.inv(leftMat); + } else { + // Singular matrix => we take pseudo-inverse instead + console.warn('Matrix x is singular, use pseudo-inverse instead.'); + leftInverse = math.pinv(leftMat); + } - const leftInverseTensor = tensor2d( - leftInverse.toArray() as number[][], - leftInverse.size() as [number, number] - ); - const result = leftInverseTensor.matMul(right); + const leftInverseTensor = tensor2d( + leftInverse.toArray() as number[][], + leftInverse.size() as [number, number] + ); + return leftInverseTensor.matMul(right); + }); // Convert the result Matrix to math.Matrix const result2D = result.arraySync() as number[][]; + result.dispose(); return math.matrix(result2D); };