From 6727133f3d2695263450e1408908874ccd5a7555 Mon Sep 17 00:00:00 2001 From: David Riazati Date: Wed, 24 Oct 2018 16:36:56 -0700 Subject: [PATCH] Support warnings.warn (#12964) Summary: `warnings.warn` is used commonly thoughout `nn.functional`, so this adds support for it by forwarding its arguments to `print` Pull Request resolved: https://github.com/pytorch/pytorch/pull/12964 Differential Revision: D10559427 Pulled By: driazati fbshipit-source-id: 5b591f6f446c906418f9fc7730c17e301f263d9b --- test/expect/TestJit.test_warnings.expect | 15 +++++++++++++++ test/test_jit.py | 11 +++++++++++ torch/csrc/jit/script/builtin_functions.cpp | 6 ++++++ torch/jit/__init__.py | 1 + 4 files changed, 33 insertions(+) create mode 100644 test/expect/TestJit.test_warnings.expect diff --git a/test/expect/TestJit.test_warnings.expect b/test/expect/TestJit.test_warnings.expect new file mode 100644 index 0000000000000..4828cbdae6b8f --- /dev/null +++ b/test/expect/TestJit.test_warnings.expect @@ -0,0 +1,15 @@ +graph(%x : Dynamic) { + %1 : string = prim::Constant[value="x is less than 2"]() + %2 : int = prim::Constant[value=2]() + %3 : Dynamic = aten::lt(%x, %2) + %4 : bool = prim::TensorToBool(%3) + = prim::If(%4) + block0() { + = prim::Print(%1) + -> () + } + block1() { + -> () + } + return (%x); +} diff --git a/test/test_jit.py b/test/test_jit.py index 5641b5aa00695..453f907c0b94b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2033,6 +2033,17 @@ def forward(self, input, other=four): t = Test() self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4) + def test_warnings(self): + import warnings + + @torch.jit.script + def fn(x): + if bool(x < 2): + warnings.warn("x is less than 2") + return x + + self.assertExpectedGraph(fn.graph) + class TestBatched(TestCase): # generate random examples and create an batchtensor with them diff --git a/torch/csrc/jit/script/builtin_functions.cpp b/torch/csrc/jit/script/builtin_functions.cpp index ea82d06879d7c..129e9fca2a213 100644 --- a/torch/csrc/jit/script/builtin_functions.cpp +++ b/torch/csrc/jit/script/builtin_functions.cpp @@ -28,6 +28,11 @@ def div(a : ${Scalar}, b : Tensor) -> Tensor: return torch.reciprocal(b) * a )SCRIPT"); +auto python_builtins_source = R"SCRIPT( +def warn(string: str): + print(string) +)SCRIPT"; + struct BuiltinFunctionRegistry { const std::vector& getAllBuiltinFunctionsFor(Symbol name) { @@ -68,6 +73,7 @@ struct BuiltinFunctionRegistry { env.s("Scalar", scalar); loadSource(scalar_operators_source.format(env)); } + loadSource(python_builtins_source); } enum {UNINITIALIZED, INTIIALIZING, INITIALIZED} state = UNINITIALIZED; std::recursive_mutex mutex; diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 6387139d4421d..6bb32192e79b9 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1321,6 +1321,7 @@ def register_all(mod): _builtin_table[id(v)] = "aten::" + name for mod in _modules_containing_builtins: register_all(mod) + _builtin_table[id(warnings.warn)] = "aten::warn" return _builtin_table