Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

topk regression in v1.5 #15703

Closed
leezu opened this issue Jul 30, 2019 · 18 comments · Fixed by #15948
Closed

topk regression in v1.5 #15703

leezu opened this issue Jul 30, 2019 · 18 comments · Fixed by #15948

Comments

@leezu
Copy link
Contributor

leezu commented Jul 30, 2019

Description

https://github.com/dmlc/gluon-nlp/blob/v0.7.1/scripts/word_embeddings/evaluate_pretrained.py stopped working with MXNet v1.5 due to out of memory errors. While the script runs fine on 16GB GPU memory (p3.2xlarge) with MXNet v1.4 it runs out of GPU memory even with 32GPU memory (p3dn.24xlarge) with MXNet v1.5.

Environment info (Required)

----------Python Info----------
Version      : 3.7.3
Compiler     : GCC 5.4.0 20160609
Build        : ('default', 'Jun 13 2019 13:24:27')
Arch         : ('64bit', 'ELF')
------------Pip Info-----------
Version      : 19.0.3
Directory    : /home/ubuntu/.pyenv/versions/3.7.3/lib/python3.7/site-packages/pip
----------MXNet Info-----------
Version      : 1.5.0
Directory    : /home/ubuntu/.local/lib/python3.7/site-packages/mxnet
Commit Hash   : 75a9e187d00a8b7ebc71412a02ed0e3ae489d91f
Library      : ['/home/ubuntu/.local/lib/python3.7/site-packages/mxnet/libmxnet.so']
Build features:
✔ CUDA
✔ CUDNN
✔ NCCL
✔ CUDA_RTC
✖ TENSORRT
✔ CPU_SSE
✔ CPU_SSE2
✔ CPU_SSE3
✔ CPU_SSE4_1
✔ CPU_SSE4_2
✖ CPU_SSE4A
✔ CPU_AVX
✖ CPU_AVX2
✖ OPENMP
✖ SSE
✔ F16C
✖ JEMALLOC
✖ BLAS_OPEN
✖ BLAS_ATLAS
✖ BLAS_MKL
✖ BLAS_APPLE
✔ LAPACK
✖ MKLDNN
✔ OPENCV
✖ CAFFE
✖ PROFILER
✔ DIST_KVSTORE
✖ CXX14
✖ INT64_TENSOR_SIZE
✔ SIGNAL_HANDLER
✖ DEBUG
----------System Info----------
Platform     : Linux-4.4.0-1088-aws-x86_64-with-debian-stretch-sid
system       : Linux
node         : ip-172-31-31-153
release      : 4.4.0-1088-aws
version      : #99-Ubuntu SMP Thu Jul 4 14:25:53 UTC 2019
----------Hardware Info----------
machine      : x86_64
processor    : x86_64
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                4
On-line CPU(s) list:   0-3
Thread(s) per core:    2
Core(s) per socket:    2
Socket(s):             1
NUMA node(s):          1
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 79
Model name:            Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping:              1
CPU MHz:               2699.535
CPU max MHz:           3000.0000
CPU min MHz:           1200.0000
BogoMIPS:              4600.08
Hypervisor vendor:     Xen
Virtualization type:   full
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              46080K
NUMA node0 CPU(s):     0-3
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single kaiser fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx xsaveopt
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0023 sec, LOAD: 0.4837 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.1444 sec, LOAD: 0.3984 sec.
Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.2908 sec, LOAD: 0.4350 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0080 sec, LOAD: 0.1284 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0034 sec, LOAD: 0.2064 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0067 sec, LOAD: 0.0290 sec.

Minimum reproducible example

import argparse
import mxnet as mx


class ExampleBlock(mx.gluon.HybridBlock):
    def __init__(self, idx_to_vec, **kwargs):
        super().__init__(**kwargs)

        self.k = 1
        self.eps = 1E-10

        self._vocab_size, self._embed_size = idx_to_vec.shape

        idx_to_vec = mx.nd.L2Normalization(idx_to_vec, eps=self.eps)
        with self.name_scope():
            self.weight = self.params.get_constant('weight', idx_to_vec)

    def hybrid_forward(self, F, words1, words2, words3, weight):  # pylint: disable=arguments-differ
        words123 = F.concat(words1, words2, words3, dim=0)
        embeddings_words123 = F.Embedding(words123, weight, input_dim=self._vocab_size,
                                          output_dim=self._embed_size)
        similarities = F.FullyConnected(embeddings_words123, weight, no_bias=True,
                                        num_hidden=self._vocab_size, flatten=False)
        # Map cosine similarities to [0, 1]
        similarities = (similarities + 1) / 2

        sim_w1w4, sim_w2w4, sim_w3w4 = F.split(similarities, num_outputs=3, axis=0)

        sim = (sim_w2w4 * sim_w3w4) / (sim_w1w4 + self.eps)

        for words in [words1, words2, words3]:
            sim = sim * F.one_hot(words, self.weight.shape[0], 0, 1)

        pred_idxs = F.topk(sim, k=self.k)
        return pred_idxs


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=1024)
    args = parser.parse_args()

    ctx = mx.gpu(0)
    idx_to_vec = mx.nd.zeros(shape=(111066, 300))
    block = ExampleBlock(idx_to_vec)
    block.initialize(ctx=ctx)
    block.hybridize()
    words = [mx.nd.zeros((args.batch_size, ), ctx=ctx) for i in range(3)]
    block(*words)
    mx.nd.waitall()

Alternatively

  • git clone https://github.com/dmlc/gluon-nlp
  • cd gluon-nlp/
  • git checkout v0.7.1
  • python3 ./scripts/word_embeddings/evaluate_pretrained.py --embedding-name fasttext --embedding-source wiki.simple --gpu 0 --similarity-datasets --eval-batch-size 1024
@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Performance

@vrakesh
Copy link
Contributor

vrakesh commented Jul 30, 2019

@leezu Thank you for reporting this
@mxnet-label-bot add [Performance]

@pengzhao-intel
Copy link
Contributor

@eric-haibin-lin @TaoLv

@TaoLv
Copy link
Member

TaoLv commented Aug 7, 2019

I did a traverse through the nightly build and commits in between 1.4.0 and 1.5.0 release and found that the regression is most likely caused by: #14570

@apeforest could you help to double check?

@TaoLv
Copy link
Member

TaoLv commented Aug 7, 2019

@leezu Please propose if we need fix this issue in the 1.5.1 patch release. Thanks.

@szha
Copy link
Member

szha commented Aug 7, 2019

I thought that feature would not be enabled if the compilation flag isn't turned on. what's the cause for builds without that flag?

@TaoLv
Copy link
Member

TaoLv commented Aug 8, 2019

I'm not sure. I just located the issue to this commit by bisecting. The compilation flag is not set when I built mxnet from source.
@apeforest do you have any idea?

@leezu
Copy link
Contributor Author

leezu commented Aug 14, 2019

@TaoLv the issue is not blocking me, but I find it concerning that #14570 should have had no side-effects when the compilation flag is unset but introduces this issue. It should be fixed on a 1.5.x point release.
Pinging @apeforest

@apeforest
Copy link
Contributor

Sorry, I just saw this. Looking into it now.

@apeforest
Copy link
Contributor

Interestingly, I found that turning on the USE_INT64_TENSOR_SIZE flag (meaning using int64_t instead of int32_t as index_t type) will solve the OOM issue. Still rootcausing it.

@apeforest
Copy link
Contributor

Further narrowed it down to topk operator. There is some implementation of TopKImpl that did not allocate correct amount of GPU memory. Working on a PR now.

@apeforest
Copy link
Contributor

Root cause found:

it is due to this line: https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/ordering_op-inl.h#L434

mshadow::Shape is constructed using index_t, which by default is int32_t in MXNet 1.5. In this case, the workspace size is 3184736511 which exceeds 2^31 and hence causing integer overflow.

Workaround: turn on the USE_INT64_TENSOR_SIZE compiler flag

Possible Fix:

  1. turn on USE_INT64_TENSOR_SIZE flag by default in 1.6
  2. change the constructor of mshadow::Shape to use int64_t always.

Lin

@apeforest
Copy link
Contributor

@leezu Based on the analysis above, this is not a really memory usage regression but a bug due to integer overflow. The memory space required by the topk operator in your script is 2729810175 which exceeds 2^31 (max int32_t). It did not overflow in MXNet 1.4 because int64_t was used by default as the type for index_t. Therefore, this is another case where large integer support is needed in MXNet. Given that we plan to turn on USE_INT64_TENSOR_SIZE flag in MXNet 1.6 by default, would you use the workaround by turning on the compiler flag manually and building mxnet from source? Please let me know if this solution is acceptable before MXNet 1.6 release. Thanks!

@leezu
Copy link
Contributor Author

leezu commented Aug 20, 2019

Thank you for diving deep to find the root cause! I'm not blocked by this fix having to wait for MXNet 1.6, but we may wan't ask @TaoLv as release manager. If the fix has to wait for 1.6 the regression should be documented, for example in the "known issues" section of the release notes? https://github.com/apache/incubator-mxnet/releases

What do you mean with "partially fixes" in #15948? Could that fix in principle be cherry-picked for 1.5.X?

@TaoLv
Copy link
Member

TaoLv commented Aug 20, 2019

@apeforest Thank you for the analysis. What's the blocker to get this issue fixed on the v1.5.x branch?

@apeforest
Copy link
Contributor

apeforest commented Aug 20, 2019

@TaoLv This is not an issue (bug per se) but limitation of int32_t data types we used in MXNet. As I pointed to the line https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/ordering_op-inl.h#L434 the workspace is created using a 1D mshadow::Shape object, whose length is bounded by index_t which is int32_t by default. When the workspace size required is larger than 2^31, there will be overflow and causing OOM.

@leezu #15948 is a partial fix because it only fixed the memory misalignment but not the OOM caused by int overflow. To really fix this issue, we need to support int64_t in mxnet by default.

@leezu leezu changed the title Storage manager / memory usage regression in v1.5 topk regression in v1.5 Aug 21, 2019
@leezu
Copy link
Contributor Author

leezu commented Aug 21, 2019

To really fix this issue, we need to support int64_t in mxnet by default.

@apeforest It seems int32_t was not used for the topk operator prior to v1.5. Thus as v1.5 changed the defaults for topk, it created this regression. This needs to be documented, given that most users will not be aware of the work done on USE_INT64_TENSOR_SIZE flag and why it had to introduce int32_t for topk.

@apeforest
Copy link
Contributor

@leezu int64_t was used by default in v1.4. However, we identified performance regression (#14496, #14790) due to the change of data type and therefore introduced a compiler flag to switch the default data type back to int32_t in v1.5. We are working on performance optimization with int64_t and plan to turn it back on to default in v1.6

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants