Skip to content

Commit

Permalink
Support warnings.warn (pytorch#12964)
Browse files Browse the repository at this point in the history
Summary:
`warnings.warn` is used commonly thoughout `nn.functional`, so this adds
support for it by forwarding its arguments to `print`
Pull Request resolved: pytorch#12964

Differential Revision: D10559427

Pulled By: driazati

fbshipit-source-id: 5b591f6f446c906418f9fc7730c17e301f263d9b
  • Loading branch information
David Riazati authored and facebook-github-bot committed Oct 24, 2018
1 parent b790fca commit 6727133
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 0 deletions.
15 changes: 15 additions & 0 deletions test/expect/TestJit.test_warnings.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
graph(%x : Dynamic) {
%1 : string = prim::Constant[value="x is less than 2"]()
%2 : int = prim::Constant[value=2]()
%3 : Dynamic = aten::lt(%x, %2)
%4 : bool = prim::TensorToBool(%3)
= prim::If(%4)
block0() {
= prim::Print(%1)
-> ()
}
block1() {
-> ()
}
return (%x);
}
11 changes: 11 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,6 +2033,17 @@ def forward(self, input, other=four):
t = Test()
self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)

def test_warnings(self):
import warnings

@torch.jit.script
def fn(x):
if bool(x < 2):
warnings.warn("x is less than 2")
return x

self.assertExpectedGraph(fn.graph)


class TestBatched(TestCase):
# generate random examples and create an batchtensor with them
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/script/builtin_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def div(a : ${Scalar}, b : Tensor) -> Tensor:
return torch.reciprocal(b) * a
)SCRIPT");

auto python_builtins_source = R"SCRIPT(
def warn(string: str):
print(string)
)SCRIPT";

struct BuiltinFunctionRegistry {

const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
Expand Down Expand Up @@ -68,6 +73,7 @@ struct BuiltinFunctionRegistry {
env.s("Scalar", scalar);
loadSource(scalar_operators_source.format(env));
}
loadSource(python_builtins_source);
}
enum {UNINITIALIZED, INTIIALIZING, INITIALIZED} state = UNINITIALIZED;
std::recursive_mutex mutex;
Expand Down
1 change: 1 addition & 0 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ def register_all(mod):
_builtin_table[id(v)] = "aten::" + name
for mod in _modules_containing_builtins:
register_all(mod)
_builtin_table[id(warnings.warn)] = "aten::warn"

return _builtin_table

Expand Down

0 comments on commit 6727133

Please sign in to comment.