Skip to content

Commit

Permalink
Fix print precision and match numpy behavior (pytorch#12746)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#12578 pytorch#9395.

* Fix and simplify print logic

* Follow numpy print rule https://github.com/numpy/numpy/blob/eb2bd11870731ea19a0eee72e616c7deb00f6c54/numpy/core/arrayprint.py#L859
> scientific notation is used when absolute value of the smallest number is < 1e-4 or maximum > 1e8 or the ratio of the maximum absolute value to the minimum is > 1e3

I hope I didn't break anything since there seems to be a lot of edge cases here... Here are some easy sanity checks.
```
In [5]: torch.tensor(1)
Out[5]: tensor(1)
Out[2]: array(1) # numpy

In [6]: torch.tensor(10)
Out[6]: tensor(10)
Out[3]: array(10) # numpy

In [8]: torch.tensor(99000000)
Out[8]: tensor(99000000)
Out[5]: array(99000000) # numpy

In [9]: torch.tensor(100000000)
Out[9]: tensor(100000000)
Out[6]: array(100000000) # numpy

In [10]: torch.tensor(100000001)
Out[10]: tensor(100000001)
Out[7]: array(100000001) # numpy

In [11]: torch.tensor(1000000000)
Out[11]: tensor(1000000000)
Out[8]: array(1000000000) # numpy

In [12]: torch.tensor([1, 1000])
Out[12]: tensor([   1, 1000])
Out[9]: array([   1, 1000]) # numpy

In [13]: torch.tensor([1, 1010])
Out[13]: tensor([   1, 1010])
Out[10]: array([   1, 1010]) # numpy
```
For floating points, we use scientific when `max/min > 1000 || max > 1e8 || min < 1e-4`
Lines with "old" are old behaviors that either has precision issue, or not aligned with numpy
```
In [14]: torch.tensor(0.01)
Out[14]: tensor(0.0100)
Out[11]: array(0.01) # numpy

In [15]: torch.tensor(0.1)
Out[15]: tensor(0.1000)
Out[12]: array(0.1) # numpy

In [16]: torch.tensor(0.0001)
Out[16]: tensor(0.0001)
Out[14]: array(0.0001) # numpy

In [17]: torch.tensor(0.00002)
Out[17]: tensor(2.0000e-05)
Out[15]: array(2e-05) # numpy
Out[5]: tensor(0.0000) # old

In [18]: torch.tensor(1e8)
Out[18]: tensor(100000000.)
Out[16]: array(100000000.0) # numpy

In [19]: torch.tensor(1.1e8)
Out[19]: tensor(1.1000e+08)
Out[17]: array(1.1e8) # numpy 1.14.5, In <= 1.13 this was not using scientific print
Out[10]: tensor(110000000.) # old

In [20]: torch.tensor([0.01, 10.])
Out[20]: tensor([ 0.0100, 10.0000])
Out[18]: array([  0.01,  10.  ]) # numpy

In [21]: torch.tensor([0.01, 11.])
Out[21]: tensor([1.0000e-02, 1.1000e+01])
Out[19]: array([  1.00000000e-02,   1.10000000e+01]) # numpy
Out[7]: tensor([ 0.0100, 11.0000]) # old
```
When print floating number in int mode, we still need to respect rules to use scientific mode first
```
In [22]: torch.tensor([1., 1000.])
Out[22]: tensor([   1., 1000.])
Out[20]: array([    1.,  1000.]) # numpy

In [23]: torch.tensor([1., 1010.])
Out[23]: tensor([1.0000e+00, 1.0100e+03])
Out[21]: array([  1.00000000e+00,   1.01000000e+03]) # numpy
Out[9]: tensor([   1., 1010.]) # old
```
Pull Request resolved: pytorch#12746

Differential Revision: D10443800

Pulled By: ailzhang

fbshipit-source-id: f5e4e3fe9bf0b44af2c64c93a9ed42b73fa613f5
  • Loading branch information
Ailing Zhang authored and facebook-github-bot committed Oct 25, 2018
1 parent 3761adc commit 478886b
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 74 deletions.
1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-bigint.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-default_device.expect

This file was deleted.

2 changes: 0 additions & 2 deletions test/expect/TestTorch.test_print-default_dtype.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-device.expect

This file was deleted.

2 changes: 0 additions & 2 deletions test/expect/TestTorch.test_print-dtype.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-negint.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-nonfinite.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-posint.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-requires_grad.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-scimode.expect

This file was deleted.

1 change: 0 additions & 1 deletion test/expect/TestTorch.test_print-summary.expect

This file was deleted.

94 changes: 82 additions & 12 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8105,67 +8105,137 @@ def test_print(self):
# test big integer
x = torch.tensor(2341234123412341)
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='bigint')
self.assertExpectedInline(str(x), '''tensor(2341234123412341)''')

# test scientific notation
x = torch.tensor([1e28, 1e-28])
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='scimode')
self.assertExpectedInline(str(x), '''tensor([1.0000e+28, 1.0000e-28])''')

# test no leading space if all elements positive
x = torch.tensor([1, 2])
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='posint')
self.assertExpectedInline(str(x), '''tensor([1, 2])''')

# test for leading space if there are negative elements
x = torch.tensor([1, -2])
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='negint')
self.assertExpectedInline(str(x), '''tensor([ 1, -2])''')

# test inf and nan
x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1])
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='nonfinite')
self.assertExpectedInline(str(x), '''tensor([4.0000, inf, 1.5000, -inf, 0.0000, nan, 1.0000])''')

# test dtype
torch.set_default_dtype(torch.float)
x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='dtype')
expected_str = '''\
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
inf], dtype=torch.float64)'''
self.assertExpectedInline(str(x), expected_str)

# test changing default dtype
torch.set_default_dtype(torch.float64)
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='default_dtype')
expected_str = '''\
tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308,
inf])'''
self.assertExpectedInline(str(x), expected_str)

# test summary
x = torch.zeros(10000)
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='summary')
self.assertExpectedInline(str(x), '''tensor([0., 0., 0., ..., 0., 0., 0.])''')

# test device
if torch.cuda.is_available():
x = torch.tensor([123], device='cuda:0')
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='device')
self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''')

# test changing default to cuda
torch.set_default_tensor_type(torch.cuda.FloatTensor)
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='default_device')
self.assertExpectedInline(str(x), '''tensor([123])''')
torch.set_default_tensor_type(default_type)

# test integral floats and requires_grad
x = torch.tensor([123.], requires_grad=True)
self.assertEqual(x.__repr__(), str(x))
self.assertExpected(str(x), subname='requires_grad')
self.assertExpectedInline(str(x), '''tensor([123.], requires_grad=True)''')

# test non-contiguous print
# sliced tensor should have > PRINT_OPTS.threshold elements
x = torch.ones(100, 2, 2, 10)
y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
self.assertEqual(str(y), y.__repr__())
self.assertExpected(str(y), subname='non_contiguous')
expected_str = '''\
tensor([[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
...,
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]],
[[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]]])\
'''

self.assertExpectedInline(str(y), expected_str)

# test print 0-dim tensor: there's no 0-dim in Numpy, we match arrayprint style
x = torch.tensor(0.00002)
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor(2.0000e-05)''')

# [Numpy] test print float in sci_mode when min < 0.0001.
x = torch.tensor([0.00002])
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([2.0000e-05])''')

# [Numpy] test print float in sci_mode when max > 1e8.
# TODO: Pytorch uses fixed precision to print, while Numpy uses dragon4_scientific
# to do automatic trimming and padding.
x = torch.tensor([123456789.])
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([1.2346e+08])''')

# [Numpy] test print float in sci_mode when max / min > 1000.
x = torch.tensor([0.01, 11])
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([1.0000e-02, 1.1000e+01])''')

# [Numpy] test print int max / min > 1000, no sci_mode
x = torch.tensor([1, 1010])
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([ 1, 1010])''')

# [Numpy] test print int > 1e8, no sci_mode
x = torch.tensor([1000000000]) # 1e9
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([1000000000])''')

# [Numpy] test printing float in int_mode
x = torch.tensor([1., 1000.])
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([ 1., 1000.])''')

# [Numpy] test printing float in int_mode in sci format when max / min > 1000.
x = torch.tensor([1., 1010.])
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''')

def test_sizeof(self):
sizeof_empty = torch.randn(0).storage().__sizeof__()
Expand Down
87 changes: 38 additions & 49 deletions torch/_tensor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def set_printoptions(
Args:
precision: Number of digits of precision for floating point output
(default = 8).
(default = 4).
threshold: Total number of array elements which trigger summarization
rather than full `repr` (default = 1000).
edgeitems: Number of array items in summary at beginning and end of
Expand Down Expand Up @@ -72,65 +72,54 @@ def __init__(self, tensor):
self.sci_mode = False
self.max_width = 1

with torch.no_grad():
tensor_view = tensor.reshape(-1)

if not self.floating_dtype:
copy = torch.empty(tensor.size(), dtype=torch.long).copy_(tensor).view(tensor.nelement())
for value in copy.tolist():
for value in tensor_view:
value_str = '{}'.format(value)
self.max_width = max(self.max_width, len(value_str))

else:
copy = torch.empty(tensor.size(), dtype=torch.float64).copy_(tensor).view(tensor.nelement())
copy_list = copy.tolist()
try:
for value in copy_list:
if value != math.ceil(value):
self.int_mode = False
break
# nonfinites will throw errors
except (ValueError, OverflowError):
self.int_mode = False
nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))

if nonzero_finite_vals.numel() == 0:
# no valid number, do nothing
return

# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
nonzero_finite_abs = nonzero_finite_vals.abs().double()
nonzero_finite_min = nonzero_finite_abs.min().double()
nonzero_finite_max = nonzero_finite_abs.max().double()

for value in nonzero_finite_vals:
if value != torch.ceil(value):
self.int_mode = False
break

if self.int_mode:
for value in copy_list:
value_str = '{:.0f}'.format(value)
if math.isnan(value) or math.isinf(value):
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8:
self.sci_mode = True
for value in nonzero_finite_vals:
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
self.max_width = max(self.max_width, len(value_str))
else:
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
else:
for value in nonzero_finite_vals:
value_str = ('{:.0f}').format(value)
self.max_width = max(self.max_width, len(value_str) + 1)

else:
copy_abs = copy.abs()
pos_inf_mask = copy_abs.eq(inf)
neg_inf_mask = copy_abs.eq(-inf)
nan_mask = copy_abs.ne(copy)
invalid_value_mask = pos_inf_mask + neg_inf_mask + nan_mask
if invalid_value_mask.all():
example_value = 0
else:
example_value = copy_abs[invalid_value_mask.eq(0)][0]
copy_abs[invalid_value_mask] = example_value

exp_min = copy_abs.min()
if exp_min != 0:
exp_min = math.floor(math.log10(exp_min)) + 1
else:
exp_min = 1
exp_max = copy_abs.max()
if exp_max != 0:
exp_max = math.floor(math.log10(exp_max)) + 1
else:
exp_max = 1

# these conditions for using scientific notation are based on numpy
if exp_max - exp_min > PRINT_OPTS.precision or exp_max > 8 or exp_min < -4:
# Check if scientific representation should be used.
if nonzero_finite_max / nonzero_finite_min > 1000.\
or nonzero_finite_max > 1.e8\
or nonzero_finite_min < 1.e-4:
self.sci_mode = True
for value in copy_list:
for value in nonzero_finite_vals:
value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
self.max_width = max(self.max_width, len(value_str))
else:
for value in copy_list:
for value in nonzero_finite_vals:
value_str = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
self.max_width = max(self.max_width, len(value_str))

Expand All @@ -139,12 +128,12 @@ def width(self):

def format(self, value):
if self.floating_dtype:
if self.int_mode:
if self.sci_mode:
ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value)
elif self.int_mode:
ret = '{:.0f}'.format(value)
if not (math.isinf(value) or math.isnan(value)):
ret += '.'
elif self.sci_mode:
ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value)
else:
ret = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
else:
Expand Down

0 comments on commit 478886b

Please sign in to comment.