Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 37 additions & 34 deletions src/explainer/lstsq.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/

import math from '../../src/utils/math-import';
import { tensor2d } from '@tensorflow/tfjs';
import { tensor2d, tidy } from '@tensorflow/tfjs';

/**
* Solves linear least squares problems for given input matrix `x`,
Expand Down Expand Up @@ -48,45 +48,48 @@ export const lstsq = (
wMat = math.matrix(math.diag(values));
}

// Matrix multiplication is too slow in math.js, we use ml-matrix instead
const xTensor = tensor2d(
x.toArray() as number[][],
x.size() as [number, number]
);
const wTensor = tensor2d(
wMat.toArray() as number[][],
wMat.size() as [number, number]
);
const yTensor = tensor2d(
y.toArray() as number[][],
y.size() as [number, number]
);
const result = tidy(() => {
// Matrix multiplication is too slow in math.js, we use ml-matrix instead
const xTensor = tensor2d(
x.toArray() as number[][],
x.size() as [number, number]
);
const wTensor = tensor2d(
wMat.toArray() as number[][],
wMat.size() as [number, number]
);
const yTensor = tensor2d(
y.toArray() as number[][],
y.size() as [number, number]
);

const left = xTensor.transpose().matMul(wTensor).matMul(xTensor);
const right = xTensor.transpose().matMul(wTensor).matMul(yTensor);
const left = xTensor.transpose().matMul(wTensor).matMul(xTensor);
const right = xTensor.transpose().matMul(wTensor).matMul(yTensor);

// Convert `left` back to math.js for inversion
const left2D = left.arraySync() as number[][];
const leftMat = math.matrix(left2D);
const leftDet = math.det(leftMat);
// Convert `left` back to math.js for inversion
const left2D = left.arraySync() as number[][];
const leftMat = math.matrix(left2D);
const leftDet = math.det(leftMat);

// Invertible matrix
let leftInverse: math.Matrix;
if (leftDet !== 0) {
leftInverse = math.inv(leftMat);
} else {
// Singular matrix => we take pseudo-inverse instead
console.warn('Matrix x is singular, use pseudo-inverse instead.');
leftInverse = math.pinv(leftMat);
}
// Invertible matrix
let leftInverse: math.Matrix;
if (leftDet !== 0) {
leftInverse = math.inv(leftMat);
} else {
// Singular matrix => we take pseudo-inverse instead
console.warn('Matrix x is singular, use pseudo-inverse instead.');
leftInverse = math.pinv(leftMat);
}

const leftInverseTensor = tensor2d(
leftInverse.toArray() as number[][],
leftInverse.size() as [number, number]
);
const result = leftInverseTensor.matMul(right);
const leftInverseTensor = tensor2d(
leftInverse.toArray() as number[][],
leftInverse.size() as [number, number]
);
return leftInverseTensor.matMul(right);
});

// Convert the result Matrix to math.Matrix
const result2D = result.arraySync() as number[][];
result.dispose();
return math.matrix(result2D);
};