Skip to content

Commit

Permalink
Update export.py, yolo.py sys.path.append() (#3579)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Jun 10, 2021
1 parent 095197b commit 53ed872
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
12 changes: 7 additions & 5 deletions models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import time
from pathlib import Path

sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories

import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

import models
FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path

from models.common import Conv
from models.yolo import Detect
from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU
from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
Expand Down Expand Up @@ -56,12 +58,12 @@ def export(weights='./yolov5s.pt', # weights path
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv): # assign export-friendly activations
if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.Hardswish):
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, models.yolo.Detect):
elif isinstance(m, Detect):
m.inplace = inplace
m.onnx_dynamic = dynamic
# m.forward = m.forward_export # assign forward (optional)
Expand Down
6 changes: 4 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from copy import deepcopy
from pathlib import Path

sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)
FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path

from models.common import *
from models.experimental import *
Expand All @@ -25,6 +25,8 @@
except ImportError:
thop = None

logger = logging.getLogger(__name__)


class Detect(nn.Module):
stride = None # strides computed during build
Expand Down

0 comments on commit 53ed872

Please sign in to comment.