Skip to content
This repository has been archived by the owner on Oct 22, 2024. It is now read-only.

Commit

Permalink
Fixed computed output shape of matmul op.
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Nov 23, 2022
1 parent 7f42089 commit fe6aff2
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/nn/ops/binary.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ export abstract class Binary extends SingleOutputOperation {
const rankA = a.rank;
const rankB = b.rank;
if (rankA === 1 && rankB === 1) {
outputShape = [];
} else if (rankA === 2 && rankB === 1) {
outputShape = [a.shape[0], 1];
} else if (rankA === 1 && rankB === 2) {
outputShape = [1, b.shape[1]];
outputShape = []; // scalar
} else if (rankA >= 2 && rankB === 1) {
outputShape = [...a.shape.slice(0, -1), 1];
} else if (rankA === 1 && rankB >= 2) {
outputShape = [...b.shape.slice(0, -2), 1, b.shape[1]];
} else if (rankA >= 2 && rankB >= 2) {
outputShape = utils.getBroadcastShape(a.shape.slice(0, -2),
b.shape.slice(0, -2));
Expand Down

0 comments on commit fe6aff2

Please sign in to comment.