Skip to content

Commit

Permalink
[FRONTEND] make torch optional (#1604)
Browse files Browse the repository at this point in the history
make torch optional to fix circular dependency issue
  • Loading branch information
pommedeterresautee authored May 3, 2023
1 parent 33174cc commit d196302
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 0 additions & 4 deletions python/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
# ---------------------------------------
# Note: import order is significant here.

# TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch # noqa: F401

# submodules
from .runtime import (
autotune,
Expand Down
6 changes: 4 additions & 2 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from pathlib import Path
from typing import Any, Tuple

import torch

import triton
import triton._C.libtriton.triton as _triton
from ..runtime import driver
Expand Down Expand Up @@ -324,6 +322,10 @@ def _is_cuda(arch):


def get_architecture_descriptor(capability):
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if capability is None:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
Expand Down
6 changes: 4 additions & 2 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar

import torch

import triton
from . import core as tl
from triton._C.libtriton.triton import ir
Expand Down Expand Up @@ -1183,6 +1181,10 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
Expand Down

0 comments on commit d196302

Please sign in to comment.