Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I can craft adversarial attacks on my model using PGD, but AutoLiRPA is failing. Why? #75

Open
aknirala opened this issue Jul 10, 2024 · 0 comments

Comments

@aknirala
Copy link

aknirala commented Jul 10, 2024

Update: I found that issue arises even when I simply use YOLO's pytorch model. It is likely because it outputs a nested list/tuples of tensors. I tried wrapping it such that I only get the first tensor as output, but that failed too.

from ultralytics import YOLO
import torch

device = "cuda:0"
yolo_model = YOLO("yolov8n-seg.pt").to(device)
yolo_model.model.eval()
data = torch.randn([2, 3, 640, 640]).to(device)
from auto_LiRPA import BoundedModule
lirpa_model = BoundedModule(yolo_model.model, data)  #I get error here itself.

And the error is:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[8], line 3
      1 data = torch.randn([2, 3, 640, 640]).to(device)
      2 from auto_LiRPA import BoundedModule
----> 3 lirpa_model = BoundedModule(yolo_model.model, data)  #I get error here itself.

File /usr/local/lib/python3.10/dist-packages/auto_LiRPA/bound_general.py:128, in BoundedModule.__init__(self, model, global_input, bound_opts, device, verbose, custom_ops)
    125 object.__setattr__(self, 'ori_state_dict', state_dict_copy)
    126 model.to(self.device)
    127 self.final_shape = model(
--> 128     *unpack_inputs(global_input, device=self.device)).shape
    129 self.bound_opts.update({'final_shape': self.final_shape})
    130 self._convert(model, global_input)

AttributeError: 'tuple' object has no attribute 'shape'

============Previous Post==========
I am trying to estimate robustness of object detectors, and I Thought of using YOLO-v8 for that. I needed to modify the code so that, gradient can propgate properly, and with that I am able to attack YOLOv8 using PGD. However when I use auto_LiRPA it gives error.
I am wondering what is the high level reason for the error. Any poiinters/help is highly appreciated.
The error comes from non_max_suppression (i.e, post-processing) part.

Here is my code: (Also a colab notebook to reproduce the issue: here)

from ultralytics import YOLO
import torch
import torch.nn as nn
from ultralytics.engine.results import Results
from ultralytics.utils import ops

import matplotlib.pyplot as plt
#from torch.utils.data import DataLoader
from torchvision import transforms

import numpy as np
import cv2


"""
All the functions are modified from: https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py
to ensure that inplace update not happens. These are helper functions.
"""
def clip_boxes(boxes, shape):
    """
    Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.

    Args:
        boxes (torch.Tensor): the bounding boxes to clip
        shape (tuple): the shape of the image

    Returns:
        (torch.Tensor | numpy.ndarray): Clipped boxes
    """
    clamped_boxes = boxes.clone()
    if isinstance(boxes, torch.Tensor):  # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
        clamped_boxes[..., 0] = boxes[..., 0].clamp(0, shape[1])  # x1
        clamped_boxes[..., 1] = boxes[..., 1].clamp(0, shape[0])  # y1
        clamped_boxes[..., 2] = boxes[..., 2].clamp(0, shape[1])  # x2
        clamped_boxes[..., 3] = boxes[..., 3].clamp(0, shape[0])  # y2
    else:  # np.array (faster grouped)
        clamped_boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2
        clamped_boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2
    return clamped_boxes


def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
    """
    Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
    specified in (img1_shape) to the shape of a different image (img0_shape).
    AKN: img1_shape: [1280, 1280]; boxes: 1x4
         img0_shape: (1280, 1280, 3)
    Args:
        img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
        boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
        img0_shape (tuple): the shape of the target image, in the format of (height, width).
        ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
            calculated based on the size difference between the two images.
        padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
            rescaling.
        xywh (bool): The box format is xywh or not, default=False.

    Returns:
        boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
    """
    scaled_boxes = boxes.clone()
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (
            round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
            round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
        )  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    if padding:
        scaled_boxes[..., 0] = boxes[..., 0] - pad[0]  # x padding
        scaled_boxes[..., 1] = boxes[..., 1] - pad[1]  # y padding
        if not xywh:
            scaled_boxes[..., 2] = boxes[..., 2] - pad[0]  # x padding
            scaled_boxes[..., 3] = boxes[..., 3] - pad[1]  # y padding
    #boxes[..., :4] /= gain
    return clip_boxes(boxes/gain, img0_shape)


def postprocess(preds, img, orig_imgs, m_names):
    """Applies non-max suppression and processes detections for each image in an input batch."""
    p = ops.non_max_suppression(  #List, p[0].shape: [1, 38]
        preds[0],
        0.25, #self.args.conf,                     #0.25
        0.7,  #self.args.iou,                      #0.7
        agnostic=False, #self.args.agnostic_nms,    #False
        max_det=300,    #self.args.max_det,          #300
        nc=len(m_names),           #len(self.model.names),           #1   model.names = {0: 'runway'}
        classes=None    #self.args.classes,          #None!
    )

    if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
        orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

    results = []
    proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]  # tuple if PyTorch model or array if exported
    
    for i, pred in enumerate(p):
        orig_img = orig_imgs[i]
        img_path = f"image{i}.jpg"#self.batch[0][i]
        pred_clone = pred.clone()
        if not len(pred):  # save empty boxes
            masks = None
        #elif self.args.retina_masks:
        #    pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        #    masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
        else:
            masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
            pred_clone[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        #results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
        results.append(Results(orig_img, path=img_path, names=m_names, boxes=pred_clone[:, :6], masks=masks))
        #Results init takes:
        #self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None, speed=None
    return results


device = "cuda:0"
yolo_model = YOLO("yolov8n-seg.pt").to(device)
yolo_model.model.eval()


class ModelWrapper(nn.Module):
    def __init__(self, model, names):
        super(ModelWrapper, self).__init__()
        self.model = yolo_model.model
        self.names = names

    def forward(self, x):
        # Pass the input through the original model
        outputs = self.model(x)
        results = postprocess(outputs, x, x, self.names)
        opt_box_list = []
        valid_indexes = []
        for idx in range(len(results)):
            if len(results[idx].boxes) > 0:
                #opt_box_list.append(results[idx].boxes.xyxyn[results[idx].boxes.conf.argmax()])
                opt_box_list.append(results[idx].boxes.xyxyn[0])
                valid_indexes.append(idx)
        if opt_box_list:
            opt_box = torch.stack(opt_box_list, dim=0)
        else:
            opt_box = torch.tensor([])
        return opt_box#, valid_indexes

#!wget https://ultralytics.com/images/bus.jpg --no-check-certificate
img_path = "bus.jpg"
data_np = np.array(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)  )

inp = torch.Tensor(data_np).permute([2, 0, 1]).unsqueeze(0)/255.0
transofrm = transforms.Resize((inp.shape[2]//32*32, inp.shape[3]//32*32))
data = transofrm(inp).to(device)


bbox_m = ModelWrapper(yolo_model, yolo_model.names)
bbox = bbox_m(data)
vals = [float(v) for v in bbox[0]]
X = [vals[idx]*data.shape[3] for idx in [0, 2, 2, 0, 0]]
Y = [vals[idx]*data.shape[2] for idx in [1, 1, 3, 3, 1]]
plt.imshow(data[0].detach().cpu().permute([1, 2, 0]))
plt.plot(X, Y)
plt.savefig('detect.png', dpi=300, bbox_inches='tight')   #This works, bounding box is computed!

from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm
lirpa_model = BoundedModule(bbox_m, data)  #I get error here itself.
print("All done")

It gives me error as:

Exception has occurred: IndexError
The shape of the mask [132, 100] at index 0 does not match the shape of the indexed tensor [144, 100, 132] at index 0
  File "/cnvrg/ultralytics/ultralytics/utils/ops.py", line 250, in non_max_suppression
    x = x[xc[xi]]  # confidence
  File "/cnvrg/ultralytics/AutoLiRPA.py", line 84, in postprocess
    p = ops.non_max_suppression(  #List, p[0].shape: [1, 38]
  File "/cnvrg/ultralytics/AutoLiRPA.py", line 135, in forward
    results = postprocess(outputs, x, x, self.names)
  File "/cnvrg/ultralytics/AutoLiRPA.py", line 170, in <module>
    lirpa_model = BoundedModule(bbox_m, data)  #I get error here itself.
IndexError: The shape of the mask [132, 100] at index 0 does not match the shape of the indexed tensor [144, 100, 132] at index 0

Complete stack trace:

/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:2447: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if size_prods == 1:
/cnvrg/ultralytics/ultralytics/utils/ops.py:220: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if prediction.shape[-1] == 6:  # end-to-end model (BNC, i.e. 1,300,6)
/cnvrg/ultralytics/ultralytics/utils/ops.py:246: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for xi, x in enumerate(prediction):  # image index, image inference
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/conf/.code-server/extensions/ms-python.python-2023.2.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/conf/.code-server/extensions/ms-python.python-2023.2.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/conf/.code-server/extensions/ms-python.python-2023.2.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/conf/.code-server/extensions/ms-python.python-2023.2.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/conf/.code-server/extensions/ms-python.python-2023.2.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/conf/.code-server/extensions/ms-python.python-2023.2.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/cnvrg/ultralytics/AutoLiRPA.py", line 170, in <module>
    lirpa_model = BoundedModule(bbox_m, data)  #I get error here itself.
  File "/usr/local/lib/python3.10/dist-packages/auto_LiRPA/bound_general.py", line 130, in __init__
    self._convert(model, global_input)
  File "/usr/local/lib/python3.10/dist-packages/auto_LiRPA/bound_general.py", line 848, in _convert
    nodesOP, nodesIn, nodesOut, template = self._convert_nodes(
  File "/usr/local/lib/python3.10/dist-packages/auto_LiRPA/bound_general.py", line 672, in _convert_nodes
    nodesOP, nodesIn, nodesOut, template = parse_module(
  File "/usr/local/lib/python3.10/dist-packages/auto_LiRPA/parse_graph.py", line 194, in parse_module
    trace, out = torch.jit._get_trace_graph(module, inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1296, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 138, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 129, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/cnvrg/ultralytics/AutoLiRPA.py", line 135, in forward
    results = postprocess(outputs, x, x, self.names)
  File "/cnvrg/ultralytics/AutoLiRPA.py", line 84, in postprocess
    p = ops.non_max_suppression(  #List, p[0].shape: [1, 38]
  File "/cnvrg/ultralytics/ultralytics/utils/ops.py", line 250, in non_max_suppression
    x = x[xc[xi]]  # confidence
IndexError: The shape of the mask [132, 100] at index 0 does not match the shape of the indexed tensor [144, 100, 132] at index 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant