diff --git a/scripts/datasets/general_nlp_benchmark/prepare_glue.py b/scripts/datasets/general_nlp_benchmark/prepare_glue.py index 85d19e7a6c..e109e6bf5b 100644 --- a/scripts/datasets/general_nlp_benchmark/prepare_glue.py +++ b/scripts/datasets/general_nlp_benchmark/prepare_glue.py @@ -614,8 +614,9 @@ def main(args): if args.data_dir is None: args.data_dir = args.benchmark args.cache_path = os.path.join(args.cache_path, args.benchmark) - print('Downloading {} to {}. Selected tasks = {}'.format(args.benchmark, - args.data_dir, args.tasks)) + print('Downloading {} to "{}". Selected tasks = {}'.format(args.benchmark, + args.data_dir, + args.tasks)) os.makedirs(args.cache_path, exist_ok=True) os.makedirs(args.data_dir, exist_ok=True) tasks = get_tasks(args.benchmark, args.tasks) diff --git a/src/gluonnlp/layers.py b/src/gluonnlp/layers.py index a2eff2bf4f..479137ce03 100644 --- a/src/gluonnlp/layers.py +++ b/src/gluonnlp/layers.py @@ -33,36 +33,42 @@ @use_np -def get_layer_norm(normalization: str = 'layer_norm', +def get_norm_layer(normalization: str = 'layer_norm', axis: int = -1, epsilon: float = 1e-5, in_channels: int = 0, **kwargs): """ - Get the layer normalization based on the type + Get the normalization layer based on the type Parameters ---------- - normalization: str, default: 'layer_norm' - The type of the layer normalization from ['layer_norm', 'no_norm'] + normalization + The type of the layer normalization from ['layer_norm', 'no_norm', 'batch_norm'] axis The axis to normalize the epsilon + The epsilon of the normalization layer in_channels + Input channel Returns ------- - ln + norm_layer The layer normalization layer """ if isinstance(normalization, str): if normalization == 'layer_norm': - ln = nn.LayerNorm(axis=axis, epsilon=epsilon, in_channels=in_channels, - **kwargs) + norm_layer = nn.LayerNorm(axis=axis, epsilon=epsilon, in_channels=in_channels, + **kwargs) elif normalization == 'no_norm': - ln = NoNorm(in_channels=in_channels, **kwargs) + norm_layer = NoNorm(in_channels=in_channels, **kwargs) + elif normalization == 'identity': + norm_layer = IdentityActivation() + elif normalization == 'batch_norm': + norm_layer = nn.BatchNorm(axis=axis, epsilon=epsilon, in_channels=in_channels, **kwargs) else: raise NotImplementedError('normalization={} is not supported'.format(normalization)) - return ln + return norm_layer else: raise NotImplementedError('The type of normalization must be str') @@ -629,7 +635,7 @@ def __init__(self, bias_initializer=bias_initializer, dtype=dtype) # TODO(sxjscience) We may need to set the dtype flag in LayerNorm, need to double check - self.layer_norm = get_layer_norm(normalization=normalization, + self.layer_norm = get_norm_layer(normalization=normalization, in_channels=units, epsilon=layer_norm_eps) diff --git a/src/gluonnlp/models/__init__.py b/src/gluonnlp/models/__init__.py index 67a33c5f6d..3d8ba68237 100644 --- a/src/gluonnlp/models/__init__.py +++ b/src/gluonnlp/models/__init__.py @@ -4,6 +4,7 @@ from . import albert from . import bert from . import electra +from . import gpt2 from . import mobilebert from . import roberta from . import transformer diff --git a/src/gluonnlp/models/albert.py b/src/gluonnlp/models/albert.py index 13bbc2458f..69b6e29a9e 100644 --- a/src/gluonnlp/models/albert.py +++ b/src/gluonnlp/models/albert.py @@ -336,10 +336,12 @@ def __init__(self, dtype=dtype) if embed_size != units: self.embed_factorized_proj = nn.Dense(units=units, + in_units=embed_size, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer) - self.embed_layer_norm = nn.LayerNorm(epsilon=self.layer_norm_eps) + self.embed_layer_norm = nn.LayerNorm(epsilon=self.layer_norm_eps, + in_channels=embed_size) self.embed_dropout = nn.Dropout(hidden_dropout_prob) # Construct token type embedding self.token_type_embed = nn.Embedding(input_dim=num_token_types, @@ -561,15 +563,18 @@ def __init__(self, backbone_cfg, self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.embed_size, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) - self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) + self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps, + in_channels=self.backbone_model.embed_size)) # only load the dense weights with a re-initialized bias # parameters are stored in 'word_embed_bias' which is # not used in original embedding self.mlm_decoder.add(nn.Dense(units=self.backbone_model.vocab_size, + in_units=self.backbone_model.embed_size, flatten=False, bias_initializer=bias_initializer)) self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight @@ -649,19 +654,23 @@ def __init__(self, backbone_cfg, bias_initializer = self.backbone_model.bias_initializer # Construct sop_classifier for sentence order prediction self.sop_classifier = nn.Dense(units=2, + in_units=self.backbone_model.units, weight_initializer=weight_initializer) self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.embed_size, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) - self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) + self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps, + in_channels=self.backbone_model.embed_size)) # only load the dense weights with a re-initialized bias # parameters are stored in 'word_embed_bias' which is # not used in original embedding self.mlm_decoder.add(nn.Dense(units=self.backbone_model.vocab_size, + in_units=self.backbone_model.embed_size, flatten=False, bias_initializer=bias_initializer)) self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index 3d6a3329e8..522b8536a5 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -171,13 +171,16 @@ def __init__(self, if not extract_feature: if self.tie_weights: self.tgt_final_layer = \ - nn.Dense(self._tgt_vocab_size, flatten=False, + nn.Dense(units=self._tgt_vocab_size, + in_units=self.dec_units, + flatten=False, use_bias=False, dtype=self._dtype) self.tgt_final_layer.weight = self.tgt_embed_layer.weight else: self.tgt_final_layer = \ - nn.Dense(self._tgt_vocab_size, + nn.Dense(units=self._tgt_vocab_size, + in_units=self.dec_units, flatten=False, weight_initializer=self.weight_initializer, use_bias=False, diff --git a/src/gluonnlp/models/bert.py b/src/gluonnlp/models/bert.py index 2bc57a7124..3600ea4c88 100644 --- a/src/gluonnlp/models/bert.py +++ b/src/gluonnlp/models/bert.py @@ -370,7 +370,8 @@ def __init__(self, output_dim=units, weight_initializer=embed_initializer, dtype=dtype) - self.embed_layer_norm = nn.LayerNorm(epsilon=self.layer_norm_eps) + self.embed_layer_norm = nn.LayerNorm(epsilon=self.layer_norm_eps, + in_channels=units) self.embed_dropout = nn.Dropout(hidden_dropout_prob) # Construct token type embedding self.token_type_embed = nn.Embedding(input_dim=num_token_types, @@ -585,15 +586,18 @@ def __init__(self, backbone_cfg, self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.units, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) - self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) + self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps, + in_channels=self.backbone_model.units)) # only load the dense weights with a re-initialized bias # parameters are stored in 'word_embed_bias' which is # not used in original embedding self.mlm_decoder.add(nn.Dense(units=self.backbone_model.vocab_size, + in_units=self.backbone_model.units, flatten=False, bias_initializer=bias_initializer)) self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight @@ -674,19 +678,23 @@ def __init__(self, backbone_cfg, bias_initializer = self.backbone_model.bias_initializer # Construct nsp_classifier for next sentence prediction self.nsp_classifier = nn.Dense(units=2, + in_units=self.backbone_model.units, weight_initializer=weight_initializer) self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.units, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) - self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) + self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps, + in_channels=self.backbone_model.units)) # only load the dense weights with a re-initialized bias # parameters are stored in 'word_embed_bias' which is # not used in original embedding self.mlm_decoder.add(nn.Dense(units=self.backbone_model.vocab_size, + in_units=self.backbone_model.units, flatten=False, bias_initializer=bias_initializer)) self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight diff --git a/src/gluonnlp/models/electra.py b/src/gluonnlp/models/electra.py index bb26f37d15..cb7dfc61f0 100644 --- a/src/gluonnlp/models/electra.py +++ b/src/gluonnlp/models/electra.py @@ -29,7 +29,7 @@ 'ElectraForPretrain', 'list_pretrained_electra', 'get_pretrained_electra'] import os -from typing import Tuple, Optional +from typing import Tuple, Optional, List import mxnet as mx import numpy as np @@ -388,11 +388,13 @@ def __init__(self, max_length=max_length, dtype=self._dtype, method=pos_embed_type) - self.embed_layer_norm = nn.LayerNorm(epsilon=self.layer_norm_eps) + self.embed_layer_norm = nn.LayerNorm(epsilon=self.layer_norm_eps, + in_channels=embed_size) self.embed_dropout = nn.Dropout(hidden_dropout_prob) if embed_size != units: self.embed_factorized_proj = nn.Dense(units=units, + in_units=embed_size, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer) @@ -509,7 +511,9 @@ def get_initial_embedding(self, F, inputs, token_types=None): embedding = self.embed_dropout(embedding) return embedding - def apply_layerwise_decay(self, layerwise_decay, not_included=None): + def apply_layerwise_decay(self, layerwise_decay: int, + not_included: Optional[List[str]] = None, + num_additional_layers: int = 2): """Apply the layer-wise gradient decay .. math:: @@ -517,17 +521,20 @@ def apply_layerwise_decay(self, layerwise_decay, not_included=None): Parameters: ---------- - layerwise_decay: int - layer-wise decay power - not_included: list of str + layerwise_decay + Power rate of the layer-wise decay + not_included A list or parameter names that not included in the layer-wise decay + num_additional_layers + The number of layers after the current backbone. This helps determine the max depth """ - # consider the task specific finetuning layer as the last layer, following with pooler - # In addition, the embedding parameters have the smaller learning rate based on this setting. - max_depth = self.num_layers + 2 + # Consider the task specific finetuning layer as the last layer, following with pooler + # In addition, the embedding parameters have the smaller learning rate based on this + # setting. + max_depth = self.num_layers + num_additional_layers for _, value in self.collect_params('.*embed*').items(): - value.lr_mult = layerwise_decay**(max_depth) + value.lr_mult = layerwise_decay ** max_depth for (layer_depth, layer) in enumerate(self.encoder.all_encoder_layers): layer_params = layer.collect_params() @@ -630,11 +637,13 @@ def __init__(self, backbone_cfg, self.rtd_encoder = nn.HybridSequential() # Extra non-linear layer self.rtd_encoder.add(nn.Dense(units=self.backbone_model.units, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) self.rtd_encoder.add(get_activation(self.backbone_model.activation)) self.rtd_encoder.add(nn.Dense(units=1, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) @@ -711,17 +720,20 @@ def __init__(self, backbone_cfg, self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.embed_size, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) - self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) + self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps, + in_channels=self.backbone_model.embed_size)) # only load the dense weights with a re-initialized bias # parameters are stored in 'word_embed_bias' which is # not used in original embedding self.mlm_decoder.add( nn.Dense( units=self.backbone_model.vocab_size, + in_units=self.backbone_model.embed_size, flatten=False, bias_initializer=bias_initializer)) self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight diff --git a/src/gluonnlp/models/mobilebert.py b/src/gluonnlp/models/mobilebert.py index 96ada137f3..55aa081cab 100644 --- a/src/gluonnlp/models/mobilebert.py +++ b/src/gluonnlp/models/mobilebert.py @@ -37,7 +37,7 @@ from ..op import select_vectors_by_position from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir -from ..layers import InitializerType, PositionwiseFFN, PositionalEmbedding, get_layer_norm, get_activation +from ..layers import InitializerType, PositionwiseFFN, PositionalEmbedding, get_norm_layer, get_activation from ..initializer import TruncNorm from ..utils.config import CfgNode as CN from ..utils.misc import load_checksum_stats, download @@ -177,7 +177,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self._dtype) - self.in_bottleneck_ln = get_layer_norm(normalization=normalization, + self.in_bottleneck_ln = get_norm_layer(normalization=normalization, in_channels=real_units, epsilon=layer_norm_eps) self.out_bottleneck_proj = nn.Dense(units=units, @@ -186,7 +186,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self._dtype) - self.out_bottleneck_ln = get_layer_norm(normalization=normalization, + self.out_bottleneck_ln = get_norm_layer(normalization=normalization, in_channels=units, epsilon=layer_norm_eps) @@ -197,7 +197,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self._dtype) - self.shared_qk_ln = get_layer_norm(normalization=normalization, + self.shared_qk_ln = get_norm_layer(normalization=normalization, in_channels=real_units, epsilon=layer_norm_eps) self.attention_proj = nn.Dense(units=real_units, @@ -258,7 +258,7 @@ def __init__(self, dtype=self._dtype, layout=attention_layout ) - self.layer_norm = get_layer_norm(normalization=normalization, + self.layer_norm = get_norm_layer(normalization=normalization, in_channels=real_units, epsilon=layer_norm_eps) @@ -577,7 +577,7 @@ def __init__(self, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer) - self.embed_layer_norm = get_layer_norm(normalization=normalization, + self.embed_layer_norm = get_norm_layer(normalization=normalization, in_channels=units, epsilon=self.layer_norm_eps) @@ -815,13 +815,15 @@ def __init__(self, backbone_cfg, self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.units, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self.backbone_model.dtype)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) # use basic layer normalization for pretaining - self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) + self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps, + in_channels=self.backbone_model.units)) self.mlm_decoder.hybridize() # only load the dense weights with a re-initialized bias # parameters are stored in 'word_embed_bias' which is @@ -918,18 +920,21 @@ def __init__(self, backbone_cfg, bias_initializer = self.backbone_model.bias_initializer # Construct nsp_classifier for next sentence prediction self.nsp_classifier = nn.Dense(units=2, + in_units=self.backbone_model.units, weight_initializer=weight_initializer, dtype=self.backbone_model.dtype) self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.units, + in_units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self.backbone_model.dtype)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) # use basic layer normalization for pretaining - self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) + self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps, + in_channels=self.backbone_model.units)) self.mlm_decoder.hybridize() # only load the dense weights with a re-initialized bias # parameters are stored in 'word_embed_bias' which is @@ -1019,8 +1024,8 @@ def list_pretrained_mobilebert(): def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert', root: str = get_model_zoo_home_dir(), - load_backbone: str = True, - load_mlm: str = False)\ + load_backbone: str = True, + load_mlm: str = False)\ -> Tuple[CN, HuggingFaceWordPieceTokenizer, str, str]: """Get the pretrained mobile bert weights diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index bbb7605f18..7bbdae06af 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -1112,7 +1112,9 @@ def __init__(self, src_vocab_size: int, layout=layout) if tie_weights: self.tgt_final_layer = \ - nn.Dense(tgt_vocab_size, flatten=False, + nn.Dense(units=tgt_vocab_size, + flatten=False, + in_units=self.dec_units, bias_initializer=bias_initializer, use_bias=False, dtype=self._dtype) @@ -1121,6 +1123,7 @@ def __init__(self, src_vocab_size: int, self.tgt_final_layer = \ nn.Dense(tgt_vocab_size, flatten=False, + in_units=self.dec_units, weight_initializer=weight_initializer, bias_initializer=bias_initializer, use_bias=False, diff --git a/src/gluonnlp/utils/misc.py b/src/gluonnlp/utils/misc.py index 38d1fa6258..11bccbb259 100644 --- a/src/gluonnlp/utils/misc.py +++ b/src/gluonnlp/utils/misc.py @@ -273,7 +273,7 @@ def count_parameters(params) -> Tuple[int, int]: Parameters ---------- params - + The input parameter dict Returns ------- @@ -651,3 +651,24 @@ def init_comm(backend, gpus): logging.info('GPU communication supported by KVStore') return store, num_workers, rank, local_rank, is_master_node, ctx_l + + +def get_mxnet_visible_ctx(): + """Get the visible contexts in MXNet. + + - If GPU is available, it will return all the visible GPUs, which can be controlled via + "CUDA_VISIBLE_DEVICES". + - If no GPU is available, it will return the cpu device. + + Returns + ------- + ctx_l + The recommended contexts to use for MXNet + """ + import mxnet as mx + num_gpus = mx.context.num_gpus() + if num_gpus == 0: + ctx_l = [mx.cpu()] + else: + ctx_l = [mx.gpu(i) for i in range(num_gpus)] + return ctx_l diff --git a/src/gluonnlp/utils/parameter.py b/src/gluonnlp/utils/parameter.py index 2898933c93..7bff88ebe2 100644 --- a/src/gluonnlp/utils/parameter.py +++ b/src/gluonnlp/utils/parameter.py @@ -24,6 +24,7 @@ import mxnet as mx from collections import defaultdict from mxnet.gluon import Parameter +from mxnet.util import use_np from typing import Iterable, Optional, Tuple @@ -152,3 +153,27 @@ def clip_grad_global_norm(parameters: Iterable[Parameter], for arr in p.list_grad(): arr *= scale return total_norm, ratio, is_finite + + +@use_np +def move_to_ctx(arr, ctx): + """Move a nested structure of array to the given context + + Parameters + ---------- + arr + The input array + ctx + The MXNet context + + Returns + ------- + new_arr + The array that has been moved to context + """ + if isinstance(arr, tuple): + return tuple(move_to_ctx(ele, ctx) for ele in arr) + elif isinstance(arr, list): + return [move_to_ctx(ele, ctx) for ele in arr] + else: + return None if arr is None else arr.as_in_ctx(ctx) diff --git a/src/gluonnlp/utils/preprocessing.py b/src/gluonnlp/utils/preprocessing.py index 25fd057298..4c194d010c 100644 --- a/src/gluonnlp/utils/preprocessing.py +++ b/src/gluonnlp/utils/preprocessing.py @@ -29,7 +29,7 @@ def get_trimmed_lengths(lengths: List[int], Returns ------- trimmed_lengths - The trimmed lengths of the + The trimmed lengths of the sequences. """ lengths = np.array(lengths) if do_merge: diff --git a/tests/README.md b/tests/README.md index a231fc928f..f17980cf48 100644 --- a/tests/README.md +++ b/tests/README.md @@ -24,6 +24,12 @@ To test both for cpu and gpu device, use the following command python3 -m pytest --device="cpu" --device="gpu" test_models_transformer.py ``` +In addition, to run all the tests, you should add the `--runslow` flag + +```bash +python3 -m pytest --device="gpu" --runslow test_models.py +``` + Refer to the [official guide of pytest](https://docs.pytest.org/en/latest/) for more details. # Naming Convention diff --git a/tests/test_layers.py b/tests/test_layers.py index e88f2c2167..77032a4b52 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -9,7 +9,8 @@ BucketPositionalEmbedding, \ AdaptiveEmbedding, \ ProjectedAdaptiveLogSoftmaxWithLoss, \ - get_activation + get_activation, \ + get_norm_layer from gluonnlp.op import relative_position_bucket mx.npx.set_np() @@ -247,3 +248,19 @@ def test_bucket_positional_embedding(units, num_buckets, bidirectional, max_dist out_of_bound_cnt = buckets[relative_positions > max_distance].sum() if out_of_bound_cnt.asnumpy() > 0: assert buckets[relative_positions > max_distance].std().asnumpy() == 0 + + +@pytest.mark.parametrize('normalization', ['layer_norm', 'no_norm', 'identity', 'batch_norm']) +def test_get_norm_layer(normalization, ctx): + with ctx: + norm_layer = get_norm_layer(normalization=normalization, + in_channels=16) + net = mx.gluon.nn.HybridSequential() + net.add(mx.gluon.nn.Dense(16, in_units=16)) + net.add(norm_layer) + net.add(mx.gluon.nn.Dense(16, in_units=16)) + net.hybridize() + net.initialize() + data_in = mx.np.random.normal(0, 1, (8, 16)) + out = net(data_in) + out_np = out.asnumpy() diff --git a/tests/test_models.py b/tests/test_models.py index a68e53f6c0..5df3701a5e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -36,7 +36,7 @@ def test_get_backbone(name, ctx): out = net(inputs, valid_length, inputs, valid_length) elif 'gpt2' in name: # Temporarily skip GPT-2 test - pass + return else: out = net(inputs, token_types, valid_length) mx.npx.waitall() diff --git a/tests/test_utils_misc.py b/tests/test_utils_misc.py index 83593cb3f6..7e3b692bcf 100644 --- a/tests/test_utils_misc.py +++ b/tests/test_utils_misc.py @@ -5,11 +5,13 @@ import mxnet as mx import multiprocessing import functools +from mxnet.util import use_np from mxnet.gluon import nn from pathlib import Path import numpy as np from numpy.testing import assert_allclose -from gluonnlp.utils.misc import AverageSGDTracker, download, sha1sum, logging_config +from gluonnlp.utils.misc import AverageSGDTracker, download, sha1sum, logging_config,\ + get_mxnet_visible_ctx mx.npx.set_np() @@ -151,3 +153,11 @@ def test_logging_config(): assert file_size_test3 == file_size_test2 assert file_size_foo2 == file_size_foo1 assert file_size_zoo1 > 0 + + +@use_np +def test_get_mxnet_visible_ctx(ctx): + ctx_l = get_mxnet_visible_ctx() + for ele_ctx in ctx_l: + arr = mx.np.array(1.0, ctx=ele_ctx) + arr_np = arr.asnumpy()