Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay] Fix
CombineParallelDense
slicing axis (apache#13597)
The current implementation of `CombineParallelDense` is hardcoded to slice along the last axis after the combined dense. I hit an error using this pass on the stable diffusion UNet, since it has a combined group where the dense is followed by `expand_dims` which changes the slicing axis (see https://github.com/masahi/torchscript-to-tvm/blob/master/stable-diffusion/compile.py for repro) ``` %76 = concatenate(%74) /* ty=Tensor[(20160, 1280), float32] */; %79 = concatenate(%77) /* ty=Tensor[(20160), float32] */; %78 = nn.dense(%75, %76, units=20160) /* ty=Tensor[(2, 20160), float32] */; %80 = nn.bias_add(%78, %79, axis=-1) /* ty=Tensor[(2, 20160), float32] */; %81 = expand_dims(%80, axis=2) /* ty=Tensor[(2, 20160, 1), float32] */; %82 = expand_dims(%81, axis=3) /* ty=Tensor[(2, 20160, 1, 1), float32] */; ``` The correct way to generate `strided_slice`: ``` %84 = strided_slice(%82, begin=[0, 0, 0, 0], end=[-1, 320, -1, -1], strides=[1, 1, 1, 1], slice_mode="size", axes=None) /* ty=Tensor[(2, 320, 1, 1), float32] */; ``` As I documented in the code, this fix is probably not 100% fail-proof. I think this is a difficult problem, since it requires tracking how the original output-channel axis of the combined dense moves across shape-changing operations like `reshape /transpose / split`. But this is at least "more correct" than the current implementation, so I'm submitting this fix as is for now. With this fix, `CombineParallelDense` works successfully on the stable diffusion UNet, and it reduces the number of `nn.dense` from 184 to 100.
- Loading branch information