Skip to content

Commit

Permalink
Merge pull request #87 from jakartaresearch/dry
Browse files Browse the repository at this point in the history
refactor return exists
  • Loading branch information
adhiiisetiawan authored Oct 17, 2021
2 parents 3a467c1 + f695e54 commit 31f33e8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 70 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ __pycache__
settings.json
.DS_Store
train
env-earthvision
40 changes: 16 additions & 24 deletions earthvision/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ class EuroSat():

mirrors = "http://madm.dfki.de/files/sentinel"
resources = "EuroSAT.zip"
classes = {"AnnualCrop": 0, \
"Forest": 1, \
"HerbaceousVegetation": 2, \
"Highway": 3, \
"Industrial": 4, \
"Pasture": 5, \
"PermanentCrop": 6, \
"Residential": 7, \
"River": 8, \
"SeaLake": 9}


def __init__(self,
root: str,
Expand All @@ -31,7 +42,7 @@ def __init__(self,
self.data_mode = data_mode
self.transform = transform
self.target_transform = target_transform

if not self._check_exists():
self.download()
self.extract_file()
Expand Down Expand Up @@ -62,18 +73,9 @@ def __iter__(self):
def _check_exists(self) -> None:
self.data_path = os.path.join(
self.root, self.data_mode)

return os.path.exists(os.path.join(self.data_path, "AnnualCrop")) and \
os.path.exists(os.path.join(self.data_path, "Forest")) and \
os.path.exists(os.path.join(self.data_path, "HerbaceousVegetation")) and \
os.path.exists(os.path.join(self.data_path, "Highway")) and \
os.path.exists(os.path.join(self.data_path, "Industrial")) and \
os.path.exists(os.path.join(self.data_path, "Pasture")) and \
os.path.exists(os.path.join(self.data_path, "PermanentCrop")) and \
os.path.exists(os.path.join(self.data_path, "Residential")) and \
os.path.exists(os.path.join(self.data_path, "River")) and \
os.path.exists(os.path.join(self.data_path, "SeaLake"))

self.dir_classes = list(self.classes.keys())

return all([os.path.exists(os.path.join(self.data_path, i)) for i in self.dir_classes])

def download(self):
"""Download file"""
Expand All @@ -87,19 +89,9 @@ def extract_file(self):

def get_path_and_label(self):
"""Return dataframe type consist of image path and corresponding label."""
classes = {"AnnualCrop": 0, \
"Forest": 1, \
"HerbaceousVegetation": 2, \
"Highway": 3, \
"Industrial": 4, \
"Pasture": 5, \
"PermanentCrop": 6, \
"Residential": 7, \
"River": 8, \
"SeaLake": 9}
image_path = []
label = []
for cat, enc in classes.items():
for cat, enc in self.classes.items():
cat_path = os.path.join(
self.root, self.data_mode, cat)
cat_image = [os.path.join(cat_path, path)
Expand Down
70 changes: 24 additions & 46 deletions earthvision/datasets/ucmercedland.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,27 @@ class UCMercedLand(Dataset):

mirrors = "http://weegee.vision.ucmerced.edu/datasets/"
resources = "UCMerced_LandUse.zip"
classes = {'agricultural': 0,
'airplane': 1,
'baseballdiamond': 2,
'beach': 3,
'buildings': 4,
'chaparral': 5,
'denseresidential': 6,
'forest': 7,
'freeway': 8,
'golfcourse': 9,
'harbor': 10,
'intersection': 11,
'mediumresidential': 12,
'mobilehomepark': 13,
'overpass': 14,
'parkinglot': 15,
'river': 16,
'runway': 17,
'sparseresidential': 18,
'storagetanks': 19,
'tenniscourt': 20}

def __init__(self,
root: str,
Expand Down Expand Up @@ -57,30 +78,9 @@ def __iter__(self):

def get_path_and_label(self):
"""Return dataframe type consist of image path and corresponding label."""
classes = {'agricultural': 0,
'airplane': 1,
'baseballdiamond': 2,
'beach': 3,
'buildings': 4,
'chaparral': 5,
'denseresidential': 6,
'forest': 7,
'freeway': 8,
'golfcourse': 9,
'harbor': 10,
'intersection': 11,
'mediumresidential': 12,
'mobilehomepark': 13,
'overpass': 14,
'parkinglot': 15,
'river': 16,
'runway': 17,
'sparseresidential': 18,
'storagetanks': 19,
'tenniscourt': 20}
image_path = []
label = []
for cat, enc in classes.items():
for cat, enc in self.classes.items():
cat_path = os.path.join(
self.root, 'UCMerced_LandUse', self.data_mode, cat)
cat_image = [os.path.join(cat_path, path)
Expand All @@ -95,28 +95,9 @@ def get_path_and_label(self):
def _check_exists(self):
self.data_path = os.path.join(
self.root, "UCMerced_LandUse", "Images")
self.dir_classes = list(self.classes.keys())

return os.path.exists(os.path.join(self.data_path, "agricultural")) and \
os.path.exists(os.path.join(self.data_path, "airplane")) and \
os.path.exists(os.path.join(self.data_path, "baseballdiamond")) and \
os.path.exists(os.path.join(self.data_path, "beach")) and \
os.path.exists(os.path.join(self.data_path, "buildings")) and \
os.path.exists(os.path.join(self.data_path, "chaparral")) and \
os.path.exists(os.path.join(self.data_path, "denseresidential")) and \
os.path.exists(os.path.join(self.data_path, "forest")) and \
os.path.exists(os.path.join(self.data_path, "freeway")) and \
os.path.exists(os.path.join(self.data_path, "golfcourse")) and \
os.path.exists(os.path.join(self.data_path, "harbor")) and \
os.path.exists(os.path.join(self.data_path, "intersection")) and \
os.path.exists(os.path.join(self.data_path, "mediumresidential")) and \
os.path.exists(os.path.join(self.data_path, "mobilehomepark")) and \
os.path.exists(os.path.join(self.data_path, "overpass")) and \
os.path.exists(os.path.join(self.data_path, "parkinglot")) and \
os.path.exists(os.path.join(self.data_path, "river")) and \
os.path.exists(os.path.join(self.data_path, "runway")) and \
os.path.exists(os.path.join(self.data_path, "sparseresidential")) and \
os.path.exists(os.path.join(self.data_path, "storagetanks")) and \
os.path.exists(os.path.join(self.data_path, "tenniscourt"))
return all([os.path.exists(os.path.join(self.data_path, i)) for i in self.dir_classes])

def download(self):
"""download and extract file."""
Expand All @@ -125,9 +106,6 @@ def download(self):

def extract_file(self):
"""Extract file from compressed."""
# path_destination = os.path.join(
# self.root, self.resources.replace(".zip", ""))
# os.makedirs(path_destination, exist_ok=True)
shutil.unpack_archive(os.path.join(
self.root, self.resources), self.root)
os.remove(os.path.join(self.root, self.resources))

0 comments on commit 31f33e8

Please sign in to comment.