diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py
index ba1290b91..fdbe1411b 100644
--- a/maskrcnn_benchmark/structures/segmentation_mask.py
+++ b/maskrcnn_benchmark/structures/segmentation_mask.py
@@ -1,6 +1,7 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
 import torch
-
+import numpy as np
+from torch.nn.functional import interpolate
 import pycocotools.mask as mask_utils
 
 # transpose
@@ -15,8 +16,36 @@ class Mask(object):
     a 2d tensor
     """
 
-    def __init__(self, masks, size, mode):
-        self.masks = masks
+    def __init__(self, segm, size, mode):
+        width, height = size
+        if isinstance(segm, Mask):
+            mask = segm.mask
+        else:
+            if type(segm) == list:
+                # polygons
+                mask = (
+                    Polygons(segm, size, "polygon")
+                    .convert("mask")
+                    .to(dtype=torch.float32)
+                )
+            elif type(segm) == dict and "counts" in segm:
+                if type(segm["counts"]) == list:
+                    # uncompressed RLE
+                    h, w = segm["size"]
+                    rle = mask_utils.frPyObjects(segm, h, w)
+                    mask = mask_utils.decode(rle)
+                    mask = torch.from_numpy(mask).to(dtype=torch.float32)
+                else:
+                    # compressed RLE
+                    mask = mask_utils.decode(segm)
+                    mask = torch.from_numpy(mask).to(dtype=torch.float32)
+            else:
+                # binary mask
+                if type(segm) == np.ndarray:
+                    mask = torch.from_numpy(segm).to(dtype=torch.float32)
+                else:  # torch.Tensor
+                    mask = segm.to(dtype=torch.float32)
+        self.mask = mask
         self.size = size
         self.mode = mode
 
@@ -28,24 +57,45 @@ def transpose(self, method):
 
         width, height = self.size
         if method == FLIP_LEFT_RIGHT:
-            dim = width
-            idx = 2
+            max_idx = width
+            dim = 1
         elif method == FLIP_TOP_BOTTOM:
-            dim = height
-            idx = 1
+            max_idx = height
+            dim = 0
 
-        flip_idx = list(range(dim)[::-1])
-        flipped_masks = self.masks.index_select(dim, flip_idx)
-        return Mask(flipped_masks, self.size, self.mode)
+        flip_idx = torch.tensor(list(range(max_idx)[::-1]))
+        flipped_mask = self.mask.index_select(dim, flip_idx)
+        return Mask(flipped_mask, self.size, self.mode)
 
     def crop(self, box):
-        w, h = box[2] - box[0], box[3] - box[1]
-
-        cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]]
-        return Mask(cropped_masks, size=(w, h), mode=self.mode)
+        box = [round(float(b)) for b in box]
+        w, h = box[2] - box[0] + 1, box[3] - box[1] + 1
+        w = max(w, 1)
+        h = max(h, 1)
+        cropped_mask = self.mask[box[1] : box[3], box[0] : box[2]]
+        return Mask(cropped_mask, size=(w, h), mode=self.mode)
 
     def resize(self, size, *args, **kwargs):
-        pass
+        width, height = size
+        scaled_mask = interpolate(
+            self.mask[None, None, :, :], (height, width), mode="bilinear"
+        )[0, 0]
+        return Mask(scaled_mask, size=size, mode=self.mode)
+
+    def convert(self, mode):
+        mask = self.mask.to(dtype=torch.uint8)
+        return mask
+
+    def __iter__(self):
+        return iter(self.mask)
+
+    def __repr__(self):
+        s = self.__class__.__name__ + "("
+        # s += "num_mask={}, ".format(len(self.mask))
+        s += "image_width={}, ".format(self.size[0])
+        s += "image_height={}, ".format(self.size[1])
+        s += "mode={})".format(self.mode)
+        return s
 
 
 class Polygons(object):
@@ -148,17 +198,26 @@ class SegmentationMask(object):
     This class stores the segmentations for all objects in the image
     """
 
-    def __init__(self, polygons, size, mode=None):
+    def __init__(self, segms, size, mode=None):
         """
         Arguments:
-            polygons: a list of list of lists of numbers. The first
+            segms: three types
+                (1) polygons: a list of list of lists of numbers. The first
                 level of the list correspond to individual instances,
                 the second level to all the polygons that compose the
                 object, and the third level to the polygon coordinates.
+                (2) rles: COCO's run length encoding format, uncompressed or compressed
+                (3) binary masks
+            size: (width, height)
+            mode: 'polygon', 'mask'. if mode is 'mask', convert mask of any format to binary mask
         """
-        assert isinstance(polygons, list)
-
-        self.polygons = [Polygons(p, size, mode) for p in polygons]
+        assert isinstance(segms, list)
+        if not isinstance(segms[0], (list, Polygons)):
+            mode = "mask"
+        if mode == "mask":
+            self.masks = [Mask(m, size, mode) for m in segms]
+        else:  # polygons
+            self.masks = [Polygons(p, size, mode) for p in segms]
         self.size = size
         self.mode = mode
 
@@ -169,21 +228,21 @@ def transpose(self, method):
             )
 
         flipped = []
-        for polygon in self.polygons:
-            flipped.append(polygon.transpose(method))
+        for mask in self.masks:
+            flipped.append(mask.transpose(method))
         return SegmentationMask(flipped, size=self.size, mode=self.mode)
 
     def crop(self, box):
         w, h = box[2] - box[0], box[3] - box[1]
         cropped = []
-        for polygon in self.polygons:
-            cropped.append(polygon.crop(box))
+        for mask in self.masks:
+            cropped.append(mask.crop(box))
         return SegmentationMask(cropped, size=(w, h), mode=self.mode)
 
     def resize(self, size, *args, **kwargs):
         scaled = []
-        for polygon in self.polygons:
-            scaled.append(polygon.resize(size, *args, **kwargs))
+        for mask in self.masks:
+            scaled.append(mask.resize(size, *args, **kwargs))
         return SegmentationMask(scaled, size=size, mode=self.mode)
 
     def to(self, *args, **kwargs):
@@ -191,24 +250,24 @@ def to(self, *args, **kwargs):
 
     def __getitem__(self, item):
         if isinstance(item, (int, slice)):
-            selected_polygons = [self.polygons[item]]
+            selected_masks = [self.masks[item]]
         else:
             # advanced indexing on a single dimension
-            selected_polygons = []
+            selected_masks = []
             if isinstance(item, torch.Tensor) and item.dtype == torch.uint8:
                 item = item.nonzero()
                 item = item.squeeze(1) if item.numel() > 0 else item
                 item = item.tolist()
             for i in item:
-                selected_polygons.append(self.polygons[i])
-        return SegmentationMask(selected_polygons, size=self.size, mode=self.mode)
+                selected_masks.append(self.masks[i])
+        return SegmentationMask(selected_masks, size=self.size, mode=self.mode)
 
     def __iter__(self):
-        return iter(self.polygons)
+        return iter(self.masks)
 
     def __repr__(self):
         s = self.__class__.__name__ + "("
-        s += "num_instances={}, ".format(len(self.polygons))
+        s += "num_instances={}, ".format(len(self.masks))
         s += "image_width={}, ".format(self.size[0])
         s += "image_height={})".format(self.size[1])
         return s
diff --git a/tests/test_segmentation_mask.py b/tests/test_segmentation_mask.py
new file mode 100644
index 000000000..0c0a810a3
--- /dev/null
+++ b/tests/test_segmentation_mask.py
@@ -0,0 +1,54 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import torch
+import numpy as np
+import unittest
+from maskrcnn_benchmark.structures.segmentation_mask import Mask, Polygons, SegmentationMask
+
+
+class TestSegmentationMask(unittest.TestCase):
+    def __init__(self, method_name='runTest'):
+        super(TestSegmentationMask, self).__init__(method_name)
+        self.poly = [[423.0, 306.5, 406.5, 277.0, 400.0, 271.5, 389.5, 277.0, 387.5, 292.0,
+                    384.5, 295.0, 374.5, 220.0, 378.5, 210.0, 391.0, 200.5, 404.0, 199.5,
+                    414.0, 203.5, 425.5, 221.0, 438.5, 297.0, 423.0, 306.5],
+                   [385.5, 240.0, 404.0, 234.5, 419.5, 234.0, 416.5, 219.0, 409.0, 209.5,
+                    394.0, 207.5, 385.5, 213.0, 382.5, 221.0, 385.5, 240.0]]
+        self.width = 640
+        self.height = 480
+        self.size = (self.width, self.height)
+        self.box = [35, 55, 540, 400] # xyxy
+
+        self.polygon = Polygons(self.poly, self.size, 'polygon')
+        self.mask = Mask(self.poly, self.size, 'mask')
+
+    def test_crop(self):
+        poly_crop = self.polygon.crop(self.box)
+        mask_from_poly_crop = poly_crop.convert('mask')
+        mask_crop = self.mask.crop(self.box).convert('mask')
+
+        self.assertTrue(torch.equal(mask_from_poly_crop, mask_crop))
+    
+    def test_convert(self):
+        mask_from_poly_convert = self.polygon.convert('mask')
+        mask = self.mask.convert('mask')
+        self.assertTrue(torch.equal(mask_from_poly_convert, mask))
+
+    def test_transpose(self):
+        FLIP_LEFT_RIGHT = 0
+        FLIP_TOP_BOTTOM = 1
+        methods = (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM)
+        for method in methods:
+            mask_from_poly_flip = self.polygon.transpose(method).convert('mask')
+            mask_flip = self.mask.transpose(method).convert('mask')
+            print(method, torch.abs(mask_flip.float() - mask_from_poly_flip.float()).sum())
+            self.assertTrue(torch.equal(mask_flip, mask_from_poly_flip))
+    
+    def test_resize(self):
+        new_size = (600, 500)
+        mask_from_poly_resize = self.polygon.resize(new_size).convert('mask')
+        mask_resize = self.mask.resize(new_size).convert('mask')
+        print('diff resize: ', torch.abs(mask_from_poly_resize.float() - mask_resize.float()).sum())
+        self.assertTrue(torch.equal(mask_from_poly_resize, mask_resize))
+
+if __name__ == "__main__":
+    unittest.main()