Skip to content
This repository has been archived by the owner on Oct 22, 2024. It is now read-only.

Commit

Permalink
Merge pull request #189 from BruceDai/fix_matmul_outputshape
Browse files Browse the repository at this point in the history
Fixed computed output shape of matmul op.
  • Loading branch information
BruceDai authored Nov 24, 2022
2 parents 7f42089 + b46181d commit c011e73
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/nn/ops/binary.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ export abstract class Binary extends SingleOutputOperation {
const rankA = a.rank;
const rankB = b.rank;
if (rankA === 1 && rankB === 1) {
outputShape = [];
} else if (rankA === 2 && rankB === 1) {
outputShape = [a.shape[0], 1];
} else if (rankA === 1 && rankB === 2) {
outputShape = [1, b.shape[1]];
outputShape = []; // scalar
} else if (rankA >= 2 && rankB === 1) {
outputShape = a.shape.slice();
outputShape[rankA - 1] = 1;
} else if (rankA === 1 && rankB >= 2) {
outputShape = b.shape.slice();
outputShape[rankB - 2] = 1;
} else if (rankA >= 2 && rankB >= 2) {
outputShape = utils.getBroadcastShape(a.shape.slice(0, -2),
b.shape.slice(0, -2));
Expand Down

0 comments on commit c011e73

Please sign in to comment.