Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing mypy errors #3335

Merged
merged 2 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions torchvision/datasets/semeion.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def __init__(
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.data = []
self.labels = []
fp = os.path.join(self.root, self.filename)
data = np.loadtxt(fp)
# convert value to 8 bit unsigned integer
Expand Down
6 changes: 4 additions & 2 deletions torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
'You can use download=True to download it')

# now load the picked numpy arrays
self.labels: np.ndarray
self.labels: Optional[np.ndarray]
if self.split == 'train':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
Expand Down Expand Up @@ -182,4 +182,6 @@ def __load_folds(self, folds: Optional[int]) -> None:
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
self.data = self.data[list_idx, :, :, :]
if self.labels is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can self.labels actually be None in this code path?

If it can, then mypy properly caught a potential bug and we should probably add a test for it.

If it can't, mypy is unfortunately forcing us to obfuscate the code (as it often does...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it can be None. See:

def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
labels = None

@pmeier: Why are there no tests for the entire STL10 class? I think this is something that could be addressed on a separate PR, this one focuses on fixing master failures.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox There are no tests for a lot of datasets, see #963 (comment). I'm on it.

Copy link
Member

@NicolasHug NicolasHug Feb 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation of the function doesn't tell us much unfortunately. It tells us that self.labels can be None in general, but it doesn't tell whether it can be None in that particular code path.

As far as I can tell, __load_folds is only called right after __loadfile is called with labels_file, and so self.labels can never be None within __load_folds.
Unfortunately, this isn't something that mypy can figure out, but I feel like the fix hurts readability of the code.

Perhaps we should tell mypy to ignore this line instead?

Copy link
Contributor Author

@datumbox datumbox Feb 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that based on the current code the self.labels seems to always has values. It's not clear whether the intention is to support also None values because other parts of the class seem to check on whether self.labels is None:

if self.labels is not None:
img, target = self.data[index], int(self.labels[index])
else:
img, target = self.data[index], None

At any case, there is no point assessing the removal of this check because without it the type_check continues to fail.

Perhaps on the future it's worth refactoring the class to simplify the logic and increase the readability. Would you like to create an issue with your points?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other parts of the class seem to also check for on whether self.labels is None

It makes sense to check in these parts because self.labels can indeed be None there (although I haven't double-checked).

But checking for None in __load_folds is misleading IMHO:

  • it suggests that labels can be None even though it can't
  • it doesn't say what to do if labels is None, i.e. there's no else clause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug Sounds good. The scope of this PR is not to do massive rewrites on the dataset class bur rather fix the failing master and facilitate FBcode syncing. Please create an issue with your proposals describing which parts you think should be rewritten/simplified so that we can discuss it in more detail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not suggesting to refactor the code, I'm suggesting to ignore mypy

self.labels = self.labels[list_idx]
4 changes: 2 additions & 2 deletions torchvision/datasets/usps.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def __init__(
import bz2
with bz2.open(full_path) as fp:
raw_data = [line.decode().split() for line in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is typically an instance of perfectly fine code that mypy flags as unsafe and it leads to obfuscated code.

Maybe we can just avoid the tmp variable and directly declare

imgs = np.asarray(
	[...],
	dtype=...
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy marks the code because imgs is originally defined as a list of lists and then it becomes a numpy.

Not sure if inlining the expression increases readability or makes the code less obfuscated. Line 61 does already casting and reshaping, should we add splitting and iterating?

imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]

Expand Down