diff --git a/src/nn/ops/binary.ts b/src/nn/ops/binary.ts index 01b67584..c9fac31d 100644 --- a/src/nn/ops/binary.ts +++ b/src/nn/ops/binary.ts @@ -31,11 +31,13 @@ export abstract class Binary extends SingleOutputOperation { const rankA = a.rank; const rankB = b.rank; if (rankA === 1 && rankB === 1) { - outputShape = []; - } else if (rankA === 2 && rankB === 1) { - outputShape = [a.shape[0], 1]; - } else if (rankA === 1 && rankB === 2) { - outputShape = [1, b.shape[1]]; + outputShape = []; // scalar + } else if (rankA >= 2 && rankB === 1) { + outputShape = a.shape.slice(); + outputShape[rankA - 1] = 1; + } else if (rankA === 1 && rankB >= 2) { + outputShape = b.shape.slice(); + outputShape[rankB - 2] = 1; } else if (rankA >= 2 && rankB >= 2) { outputShape = utils.getBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));