Skip to content

Commit

Permalink
Added a default constructor for torch.finfo.
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#12847

Differential Revision: D10457487

Pulled By: benoitsteiner

fbshipit-source-id: 7d164a71ba52631e5906098f643eecb0630879d1
  • Loading branch information
benoitsteiner authored and facebook-github-bot committed Oct 23, 2018
1 parent 1b07eb7 commit 3fb3a07
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/source/type_info.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ max float The largest representable number.
tiny float The smallest positive representable number.
========= ===== ========================================

.. note::
The constructor of :class:`torch.finfo` can be called without argument, in which case the class is created for the pytorch default dtype (as returned by ``torch.get_default_dtype()``).


.. _iinfo-doc:

Expand Down
6 changes: 5 additions & 1 deletion test/test_type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_iinfo(self):

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_finfo(self):
initial_default_type = torch.get_default_dtype()
for dtype in [torch.float32, torch.float64]:
x = torch.zeros((2, 2), dtype=dtype)
xinfo = torch.finfo(x.dtype)
Expand All @@ -39,7 +40,10 @@ def test_finfo(self):
self.assertEqual(xinfo.max, xninfo.max)
self.assertEqual(xinfo.eps, xninfo.eps)
self.assertEqual(xinfo.tiny, xninfo.tiny)

torch.set_default_dtype(dtype)
self.assertEqual(torch.finfo(dtype), torch.finfo())
# Restore the default type to ensure that the test has no side effect
torch.set_default_dtype(initial_default_type)

if __name__ == '__main__':
run_tests()
45 changes: 36 additions & 9 deletions torch/csrc/TypeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,25 @@ PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
static torch::PythonArgParser parser({
"finfo(ScalarType type)",
"finfo()",
});

torch::ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
AT_CHECK(r.idx == 0, "Not a type");
at::ScalarType scalar_type = r.scalartype(0);
if (!at::isFloatingType(scalar_type)) {
return PyErr_Format(
PyExc_TypeError,
"torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
type->tp_name);
AT_CHECK(r.idx < 2, "Not a type");
at::ScalarType scalar_type;
if (r.idx == 1) {
scalar_type = torch::tensors::get_default_tensor_type().scalarType();
// The default tensor type can only be set to a floating point type/
AT_ASSERT(at::isFloatingType(scalar_type));
} else {
scalar_type = r.scalartype(0);
if (!at::isFloatingType(scalar_type)) {
return PyErr_Format(
PyExc_TypeError,
"torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
type->tp_name);
}
}
return THPFInfo_New(scalar_type);
END_HANDLE_TH_ERRORS
Expand All @@ -85,6 +94,24 @@ PyObject* THPIInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
END_HANDLE_TH_ERRORS
}

PyObject* THPDTypeInfo_compare(THPDTypeInfo* a, THPDTypeInfo* b, int op) {
switch (op) {
case Py_EQ:
if (a->type == b->type) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
case Py_NE:
if (a->type != b->type) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
}

static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
int bits = elementSize(self->type) * 8;
return THPUtils_packInt64(bits);
Expand Down Expand Up @@ -153,7 +180,7 @@ PyTypeObject THPFInfoType = {
nullptr, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Expand Down Expand Up @@ -202,7 +229,7 @@ PyTypeObject THPIInfoType = {
nullptr, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
(richcmpfunc)THPDTypeInfo_compare, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Expand Down

0 comments on commit 3fb3a07

Please sign in to comment.