Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow test python model locally #22

Merged
merged 1 commit into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ plugins {

defaultTasks 'build'

def djl_version = System.getenv("DJL_VERSION")
def serving_version = System.getenv("DJL_VERSION")
def stagingRepo = System.getenv("DJL_STAGING")
djl_version = (djl_version == null) ? "0.14.0" : djl_version
serving_version = (serving_version == null) ? djl_version : serving_version
if (!project.hasProperty("staging")) {
djl_version += "-SNAPSHOT"
serving_version += "-SNAPSHOT"
}

allprojects {
group 'ai.djl.serving'
version "${djl_version}"
version "${serving_version}"

repositories {
mavenCentral()
Expand Down
1 change: 1 addition & 0 deletions engines/python/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__pycache__
*.egg-info/
24 changes: 24 additions & 0 deletions engines/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,27 @@ You can pull the Python engine from the central Maven repository by including th
<scope>runtime</scope>
</dependency>
```

## Test your python model

Testing python code within Java environment is challenging. We provide a tool to help you develop
and test your python model locally. You can easily use IDE to debug your model.

1. Install djl_python toolkit:

```
cd engines/python/setup
pip install -U .
```

2. You can use command line tool or python to run djl model testing. The following command is
an example:

```shell
curl -O https://resources.djl.ai/images/kitten.jpg

djl-test-model --model-dir src/test/resources/resnet18 --input kitten.jpg

# or use python
python -m djl_python.test_model --model-dir src/test/resources/resnet18 --input kitten.jpg
```
23 changes: 12 additions & 11 deletions engines/python/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,30 @@ dependencies {
}
}

sourceSets {
main.resources.srcDirs "setup"
}

processResources {
exclude "build", "*.egg-info", "__pycache__", "PyPiDescription.rst", "setup.py"
outputs.dir file("${project.buildDir}/classes/java/main/ai/djl/python")
doLast {
delete "src/main/resources/ai/djl/python/scripts/__pycache__/"
// write properties
def propFile = file("${project.buildDir}/classes/java/main/ai/djl/python/python.properties")
def sb = new StringBuilder()
sb.append("version=${version}\nlibraries=")
def first = true
for (String name : file("src/main/resources/ai/djl/python/scripts").list().sort()) {
if (first) {
first = false
} else {
sb.append(',')
}
sb.append(name)
sb.append("version=${version}\nlibraries=djl_python_engine.py")
for (String name : file("setup/djl_python").list().sort()) {
sb.append(",djl_python/").append(name)
}
propFile.text = sb.toString()
}
}

clean.doFirst {
delete "src/main/resources/ai/djl/python/scripts/__pycache__/"
delete "setup/build/"
delete "setup/djl_python.egg-info/"
delete "setup/__pycache__/"
delete "setup/djl_python/__pycache__/"
delete "src/test/resources/accumulate/__pycache__/"
delete System.getProperty("user.home") + "/.djl.ai/python"
}
Expand Down
5 changes: 5 additions & 0 deletions engines/python/setup/PyPiDescription.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Project Description
===================

DJL Python engine allows you run python model in a JVM based application. However, you still
need to install your python environment and dependencies.
16 changes: 16 additions & 0 deletions engines/python/setup/djl_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python
#
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from .inputs import Input
from .outputs import Output
from .pair_list import PairList
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,32 @@ def python_engine_args():
'If \'sock-type\' is \'tcp\' this is expected to have the host port to bind on'
)
return parser

@staticmethod
def test_model_args():
parser = argparse.ArgumentParser(prog='djl-test-model',
description='Test DJL Python model')
parser.add_argument('--model-dir',
type=str,
dest='model_dir',
help='Model directory')
parser.add_argument('--entry-point',
required=False,
type=str,
dest="entry_point",
default="model.py",
help='The model entry point file')
parser.add_argument('--handler',
type=str,
dest='handler',
required=False,
default="handle",
help='Python function to invoke')
parser.add_argument('--input',
type=str,
dest='input',
required=False,
nargs='+',
default='input.txt',
help='Input file')
return parser
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import io
import struct

from np_util import from_nd_list
from pair_list import PairList
from .np_util import from_nd_list
from .pair_list import PairList


def retrieve_buffer(conn, length):
Expand Down Expand Up @@ -111,6 +111,8 @@ def get_data(self, key=None) -> list:
return self.get_as_numpy(key)
elif content_type == "application/json":
return self.get_as_json(key)
elif content_type is not None and content_type.startswith("text/"):
return self.get_as_string(key)
elif content_type is not None and content_type.startswith("image/"):
return self.get_as_image(key)
else:
Expand All @@ -128,6 +130,9 @@ def get_as_bytes(self, key=None):
return self.content.value_at(0)
return ret

def get_as_string(self, key=None):
return self.get_as_bytes(key=key).decode("utf-8")

def get_as_json(self, key=None) -> list:
return ast.literal_eval(self.get_as_bytes(key=key).decode("utf-8"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import json
import struct

from np_util import to_nd_list
from pair_list import PairList
from .np_util import to_nd_list
from .pair_list import PairList


class Output(object):
Expand All @@ -25,6 +25,20 @@ def __init__(self):
self.properties = dict()
self.content = PairList()

def __str__(self):
d = dict()
for i in range(self.content.size()):
v = "type: " + str(type(self.content.value_at(i)))
d[self.content.key_at(i)] = v
return json.dumps(
{
"code": self.code,
"message": self.message,
"properties": self.properties,
"content": d
},
indent=2)

def set_code(self, code):
self.code = code

Expand Down
98 changes: 98 additions & 0 deletions engines/python/setup/djl_python/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3
#
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import importlib
import logging
import os
import sys

from .arg_parser import ArgParser
from .inputs import Input


def create_request(input_files):
request = Input()
request.properties["device_id"] = "-1"

data_file = None
for file in input_files:
pair = file.split("=")
if len(pair) == 1:
key = None
val = pair[0]
else:
key = pair[0]
val = pair[1]

if data_file is None or key == "data":
data_file = val

if not os.path.exists(val):
raise ValueError("--input file not found {}.".format(val))

with open(val, "rb") as f:
request.content.add(key=key, value=f.read(-1))

if data_file.endswith(".json"):
request.properties["content-type"] = "application/json"
elif data_file.endswith(".txt"):
request.properties["content-type"] = "text/plain"
elif data_file.endswith(".gif"):
request.properties["content-type"] = "images/gif"
elif data_file.endswith(".png"):
request.properties["content-type"] = "images/png"
elif data_file.endswith(".jpeg") or data_file.endswith(".jpg"):
request.properties["content-type"] = "images/jpeg"
elif data_file.endswith(".ndlist") or data_file.endswith(
".npy") or data_file.endswith(".npz"):
request.properties["content-type"] = "tensor/ndlist"

return request


def run():
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
level=logging.INFO)
args = ArgParser.test_model_args().parse_args()

inputs = create_request(args.input)
inputs.function_name = args.handler

os.chdir(args.model_dir)
model_dir = os.getcwd()
sys.path.append(model_dir)

entry_point = args.entry_point
entry_point_file = os.path.join(model_dir, entry_point)
if not os.path.exists(entry_point_file):
logging.error(
"entry-point file not found {}.".format(entry_point_file))
return

entry_point = entry_point[:-3]
module = importlib.import_module(entry_point)
if module is None:
logging.error("Unable to load entry_point {}/{}.py".format(
model_dir, entry_point))
return

logging.info("model loaded: %s/%s.py", model_dir, entry_point)

function_name = inputs.get_function_name()
outputs = getattr(module, function_name)(inputs)
print("output: " + str(outputs))


if __name__ == "__main__":
run()
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import socket
import sys

from arg_parser import ArgParser
from inputs import Input
from outputs import Output
from djl_python.arg_parser import ArgParser
from djl_python.inputs import Input
from djl_python.outputs import Output

SOCKET_ACCEPT_TIMEOUT = 30.0

Expand Down Expand Up @@ -112,7 +112,7 @@ def run_server(self):
entry_point_file = os.path.join(model_dir, entry_point)
if not os.path.exists(entry_point_file):
raise ValueError(
"entry-point file not file {}.".format(entry_point_file))
"entry-point file not found {}.".format(entry_point_file))

entry_point = entry_point[:-3]
module = importlib.import_module(entry_point)
Expand Down
66 changes: 66 additions & 0 deletions engines/python/setup/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python
#
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import setuptools.command.build_py
from setuptools import setup, find_packages

pkgs = find_packages(exclude='src')


def detect_version():
with open("../../../gradle.properties", "r") as f:
for line in f:
if not line.startswith('#'):
prop = line.split('=')
if prop[0] == "djl_version":
return prop[1].strip()

return None


def pypi_description():
with open('PyPiDescription.rst') as df:
return df.read()


class BuildPy(setuptools.command.build_py.build_py):
def run(self):
setuptools.command.build_py.build_py.run(self)


if __name__ == '__main__':
version = detect_version()

requirements = ['psutil', 'packaging', 'wheel']

setup(name='djl_python',
version=version,
description=
'djl_python is a tool to build and test DJL Python model locally',
author='Deep Java Library team',
author_email='[email protected]',
long_description=pypi_description(),
url='https://github.com/deepjavalibrary/djl.git',
keywords='DJL Serving Deep Learning Inference AI',
packages=pkgs,
cmdclass={
'build_py': BuildPy,
},
install_requires=requirements,
entry_points={
'console_scripts': [
'djl-test-model=djl_python.test_model:run',
]
},
include_package_data=True,
license='Apache License Version 2.0')
Loading