Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Fix windows compile tests #2511

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,57 +7,47 @@
import functools
import itertools
import operator

import sys
import warnings
from copy import deepcopy
from dataclasses import asdict, dataclass

from packaging import version as pack_version
import numpy as np
import pytest
import torch
from _utils_internal import ( # noqa
dtype_fixture,
get_available_devices,
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv

from packaging import version, version as pack_version

from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict._C import unravel_keys
from tensordict.nn import (
CompositeDistribution,
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictModule as ProbMod,
ProbabilisticTensorDictSequential,
ProbabilisticTensorDictSequential as ProbSeq,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential,
TensorDictSequential as Seq,
)
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.modules.models import QMixer

_has_functorch = True
try:
import functorch as ft # noqa

make_functional_with_buffers = ft.make_functional_with_buffers
FUNCTORCH_ERR = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERR = str(err)

import numpy as np
import pytest
import torch
from _utils_internal import ( # noqa
dtype_fixture,
get_available_devices,
get_default_devices,
)
from mocking_classes import ContinuousActionConvMockEnv
from packaging import version

# from torchrl.data.postprocs.utils import expand_as_right
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict.nn import NormalParamExtractor, TensorDictModule
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
from torch import autograd, nn
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
from torchrl.data.postprocs.postprocs import MultiStep
from torchrl.envs.model_based.dreamer import DreamerEnv
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.modules import (
DistributionalQValueActor,
OneHotCategorical,
Expand All @@ -66,6 +56,7 @@
WorldModelWrapper,
)
from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
from torchrl.modules.models import QMixer
from torchrl.modules.models.model_based import (
DreamerActor,
ObsDecoder,
Expand Down Expand Up @@ -147,7 +138,18 @@
_split_and_pad_sequence,
)

_has_functorch = True
try:
import functorch as ft # noqa

make_functional_with_buffers = ft.make_functional_with_buffers
FUNCTORCH_ERR = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERR = str(err)

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
IS_WINDOWS = sys.platform == "win32"

# Capture all warnings
pytestmark = [
Expand Down Expand Up @@ -15735,7 +15737,13 @@ def __init__(self):
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
def test_exploration_compile():
try:
torch._dynamo.reset_code_caches()
except Exception:
# older versions of PT don't have that function
pass
m = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["sample"],
Expand Down
Loading