Skip to content

Commit

Permalink
polish(nyz): simplify requirements (#672)
Browse files Browse the repository at this point in the history
* polish(nyz): simplify requirements

* fix(nyz): correct flake8 style
  • Loading branch information
PaParaZz1 authored May 31, 2023
1 parent caf8b4c commit 2ab7c44
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 52 deletions.
17 changes: 14 additions & 3 deletions ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union
import os
from ditk import logging
from easydict import EasyDict
from matplotlib import pyplot as plt
from matplotlib import animation
from matplotlib import ticker as mtick
from torch.nn import functional as F
from sklearn.manifold import TSNE
import os
import numpy as np
import torch
import wandb
import h5py
import pickle
import treetensor.numpy as tnp
from ding.framework import task
Expand Down Expand Up @@ -346,6 +345,18 @@ def wandb_offline_logger(
)

def _vis_dataset(datasetpath: str):
try:
from sklearn.manifold import TSNE
except ImportError:
import sys
logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.")
sys.exit(1)
try:
import h5py
except ImportError:
import sys
logging.warning("Please install h5py first, such as `pip3 install h5py`.")
sys.exit(1)
assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5']
if os.path.splitext(datasetpath)[-1] == '.pkl':
with open(datasetpath, 'rb') as f:
Expand Down
15 changes: 12 additions & 3 deletions ding/reward_model/pdeil_irl_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Dict
from ditk import logging
import numpy as np
import torch
import pickle
from typing import List, Dict
import scipy.stats as stats
try:
from sklearn.svm import SVC
except ImportError:
Expand Down Expand Up @@ -71,6 +71,13 @@ def __init__(self, cfg: dict, device, tb_logger: 'SummaryWriter') -> None: # no
- tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
"""
super(PdeilRewardModel, self).__init__()
try:
import scipy.stats as stats
self.stats = stats
except ImportError:
import sys
logging.warning("Please install scipy first, such as `pip3 install scipy`.")
sys.exit(1)
self.cfg: dict = cfg
self.e_u_s = None
self.e_sigma_s = None
Expand Down Expand Up @@ -145,7 +152,9 @@ def _batch_mn_pdf(self, x: np.ndarray, mean: np.ndarray, cov: np.ndarray) -> np.
Overview:
Get multivariate normal pdf of given np array.
"""
return np.asarray(stats.multivariate_normal.pdf(x, mean=mean, cov=cov, allow_singular=False), dtype=np.float32)
return np.asarray(
self.stats.multivariate_normal.pdf(x, mean=mean, cov=cov, allow_singular=False), dtype=np.float32
)

def estimate(self, data: list) -> List[Dict]:
"""
Expand Down
15 changes: 14 additions & 1 deletion ding/utils/compression_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pickle
import cloudpickle
import zlib
import lz4.block
import numpy as np


Expand Down Expand Up @@ -52,6 +51,13 @@ def lz4_data_compressor(data):
>>> lz4.block.compress(pickle.dumps("Hello"))
b'\x14\x00\x00\x00R\x80\x04\x95\t\x00\x01\x00\x90\x8c\x05Hello\x94.'
"""
try:
import lz4.block
except ImportError:
from ditk import logging
import sys
logging.warning("Please install lz4 first, such as `pip3 install lz4`")
sys.exit(1)
return lz4.block.compress(pickle.dumps(data))


Expand Down Expand Up @@ -112,6 +118,13 @@ def lz4_data_decompressor(compressed_data):
Overview:
Return the decompressed original data (lz4 compressor).
"""
try:
import lz4.block
except ImportError:
from ditk import logging
import sys
logging.warning("Please install lz4 first, such as `pip3 install lz4`")
sys.exit(1)
return pickle.loads(lz4.block.decompress(compressed_data))


Expand Down
22 changes: 13 additions & 9 deletions ding/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import List, Dict, Tuple
import pickle

import easydict
import torch
import numpy as np
from ditk import logging
from copy import deepcopy
from easydict import EasyDict
from torch.utils.data import Dataset
from dataclasses import dataclass

from easydict import EasyDict
from ding.utils.bfs_helper import get_vi_sequence
import pickle
import easydict
import torch
import numpy as np

from ding.utils.bfs_helper import get_vi_sequence
from ding.utils import DATASET_REGISTRY, import_module
from ding.rl_utils import discount_cumsum

Expand Down Expand Up @@ -50,7 +49,9 @@ def __init__(self, cfg: dict) -> None:
try:
import d4rl # register d4rl enviroments with open ai gym
except ImportError:
import sys
logging.warning("not found d4rl env, please install it, refer to https://github.com/rail-berkeley/d4rl")
sys.exit(1)

# Init parameters
data_path = cfg.policy.collect.get('data_path', None)
Expand Down Expand Up @@ -127,7 +128,9 @@ def __init__(self, cfg: dict) -> None:
try:
import h5py
except ImportError:
logging.warning("not found h5py package, please install it trough 'pip install h5py' ")
import sys
logging.warning("not found h5py package, please install it trough `pip install h5py ")
sys.exit(1)
data_path = cfg.policy.collect.get('data_path', None)
data = h5py.File(data_path, 'r')
self._load_data(data)
Expand Down Expand Up @@ -526,8 +529,9 @@ def hdf5_save(exp_data, expert_data_path):
try:
import h5py
except ImportError:
import sys
logging.warning("not found h5py package, please install it trough 'pip install h5py' ")
import numpy as np
sys.exit(1)
dataset = dataset = h5py.File('%s_demos.hdf5' % expert_data_path.replace('.pkl', ''), 'w')
dataset.create_dataset('obs', data=np.array([d['obs'].numpy() for d in exp_data]), compression='gzip')
dataset.create_dataset('action', data=np.array([d['action'].numpy() for d in exp_data]), compression='gzip')
Expand Down
14 changes: 10 additions & 4 deletions ding/world_model/ddppo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import partial
from ditk import logging
import itertools
import copy
import numpy as np
import multiprocessing
import copy
import torch
from torch import nn
import torch.nn as nn

from scipy.spatial import KDTree
from functools import partial
from ding.utils import WORLD_MODEL_REGISTRY
from ding.utils.data import default_collate
from ding.torch_utils import unsqueeze_repeat
Expand All @@ -27,6 +27,12 @@ def get_neighbor_index(data, k, serial=False):
ret: [B, k]
"""
try:
from scipy.spatial import KDTree
except ImportError:
import sys
logging.warning("Please install scipy first, such as `pip3 install scipy`.")
sys.exit(1)
data = data.cpu().numpy()
tree = KDTree(data)

Expand Down
59 changes: 27 additions & 32 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,47 +51,35 @@
python_requires=">=3.7",
install_requires=[
'setuptools<=66.1.1',
'yapf==0.29.0',
'gym==0.25.1', # pypy incompatible; some environmrnt only support gym==0.22.0
'gymnasium',
'torch>=1.1.0',
'numpy>=1.18.0',
'pandas',
'DI-treetensor>=0.4.0',
'DI-toolkit>=0.1.0',
'trueskill',
'tensorboardX>=2.2',
'requests>=2.25.1',
'pyyaml',
'wandb',
'matplotlib',
'easydict==1.9',
'protobuf',
'flask~=1.1.2',
'tqdm',
'lz4',
'scipy',
'pyyaml',
'enum_tools',
'cloudpickle',
'hickle',
'tabulate',
'click>=7.0.0',
'URLObject>=2.4.0',
'urllib3>=1.26.5',
'responses~=0.12.1',
'enum_tools',
'trueskill',
'h5py',
'mpire>=2.3.5',
'pynng',
'redis',
'pettingzoo<=1.22.3',
'DI-treetensor>=0.3.0',
'DI-toolkit>=0.0.2',
'hbutils>=0.5.0',
'wandb',
'matplotlib',
'MarkupSafe==2.0.1', # compatibility
'h5py',
'scikit-learn',
'hickle',
'gymnasium',
'requests>=2.25.1', # interaction
'flask~=1.1.2', # interaction
'responses~=0.12.1', # interaction
'URLObject>=2.4.0', # interaction
'MarkupSafe==2.0.1', # interaction, compatibility
'pynng', # parallel
'redis', # parallel
'mpire>=2.3.5', # parallel
],
extras_require={
'test': [
'gym[box2d]>=0.25.0',
'opencv-python', # pypy incompatible
'coverage>=5,<=7.0.1',
'mock>=4.0.3',
'pytest~=7.0.1', # required by gym>=0.25.0
Expand All @@ -101,6 +89,14 @@
'pytest-rerunfailures~=10.2',
'pytest-timeout~=2.0.2',
'readerwriterlock',
'pandas',
'lz4',
'h5py',
'scipy',
'scikit-learn',
'gym[box2d]==0.25.1',
'pettingzoo<=1.22.3',
'opencv-python', # pypy incompatible
],
'style': [
'yapf==0.29.0',
Expand All @@ -112,13 +108,12 @@
'numba>=0.53.0',
],
'dist': [
'redis==3.5.3',
'redis-py-cluster==2.1.0',
],
'common_env': [
'ale-py', # >=0.7.5', # atari
'autorom',
'gym[all]>=0.25.0',
'gym[all]==0.25.1',
'cmake>=3.18.4',
'opencv-python', # pypy incompatible
],
Expand Down

0 comments on commit 2ab7c44

Please sign in to comment.