-
Notifications
You must be signed in to change notification settings - Fork 30
/
setup.py
59 lines (51 loc) · 2.19 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from importlib.machinery import SourceFileLoader
import pkg_resources
from distutils.version import LooseVersion
import re
import codecs
from setuptools import setup
from setuptools import find_packages
# Read long description of README markdown, shows in Python Package Index
with codecs.open('README.md', encoding='utf-8') as f:
long_description = f.read()
# Minimal requried dependencies (full dependencies in requirements.txt)
install_requires = ['numpy',
'scipy',
'gym>=0.15.3',
'cloudpickle',
'pyyaml',
'colorama']
tests_require = ['pytest',
'flake8',
'sphinx',
'sphinx_rtd_theme']
setup(name='lagom',
version=SourceFileLoader('version', 'lagom/version.py').load_module().__version__,
# List all lagom packages (folder with __init__.py), useful to distribute a release
packages=find_packages(),
install_requires=install_requires,
tests_require=tests_require,
python_requires='>=3',
author='Xingdong Zuo',
author_email='[email protected]',
description='lagom: A light PyTorch infrastructure to quickly prototype reinforcement learning algorithms.',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/zuoxingdong/lagom',
# tell pip some metadata (e.g. Python version, OS etc.)
classifiers=['Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Natural Language :: English',
'Topic :: Scientific/Engineering :: Artificial Intelligence']
)
# check PyTorch installation
pkg = None
for name in ['torch', 'torch-nightly']:
try:
pkg = pkg_resources.get_distribution(name)
except pkg_resources.DistributionNotFound:
pass
assert pkg is not None, 'PyTorch is not correctly installed.'
version_msg = 'PyTorch of version above 1.2.0 expected'
assert LooseVersion(re.search(r'\d+[.]\d+[.]\d+', pkg.version)[0]) >= LooseVersion('1.2.0'), version_msg