Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(paramserver): validates yaml or json data #77

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ urllib3 = ">=1.23"
python-dateutil = ">=2.8.1"
pytz = "*"
jsonschema = "==4.0.0"
pyyaml = ">=5.4.1"
pallabpain marked this conversation as resolved.
Show resolved Hide resolved

[dev-packages]
testtools = "==2.5.0"
Expand Down
722 changes: 412 additions & 310 deletions Pipfile.lock

Large diffs are not rendered by default.

55 changes: 32 additions & 23 deletions rapyuta_io/clients/paramserver.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from __future__ import absolute_import

import enum
import errno
import hashlib
import mimetypes
import os
import tempfile
from concurrent import futures
from os import listdir, makedirs
from os.path import isdir, join
from shutil import rmtree, copyfile

from concurrent import futures
import enum
import tempfile
import os
import hashlib
import mimetypes
import six

from rapyuta_io.utils import RestClient, InvalidParameterException, ConfigNotFoundException
from rapyuta_io.utils.error import InvalidJSONError, InvalidYAMLError
from rapyuta_io.utils.rest_client import HttpMethod
from rapyuta_io.utils.settings import PARAMSERVER_API_TREE_PATH, PARAMSERVER_API_TREEBLOBS_PATH, PARAMSERVER_API_FILENODE_PATH
from rapyuta_io.utils.utils import create_auth_header, prepend_bearer_to_auth_token, get_api_response_data, \
validate_list_of_strings
import six
from rapyuta_io.utils.settings import PARAMSERVER_API_TREE_PATH, PARAMSERVER_API_TREEBLOBS_PATH, \
PARAMSERVER_API_FILENODE_PATH
from rapyuta_io.utils.utils import (
create_auth_header,
prepend_bearer_to_auth_token,
get_api_response_data,
validate_list_of_strings,
parse_json,
parse_yaml
)


class _Node(str, enum.Enum):

Expand All @@ -28,6 +38,7 @@ def __str__(self):
Attribute = 'AttributeNode'
Folder = 'FolderNode'


class _ParamserverClient:
"""
Internal client for paramserver. Not for public use.
Expand All @@ -37,7 +48,6 @@ class _ParamserverClient:
default_binary_content_type = "application/octet-stream"
max_non_binary_size = 128 * 1024


def __init__(self, auth_token, project, core_api_host):
self._auth_token = auth_token
self._headers = create_auth_header(prepend_bearer_to_auth_token(auth_token), project)
Expand Down Expand Up @@ -140,12 +150,10 @@ def process_dir(self, executor, rootdir, tree_path, level, dir_futures, file_fut
if file_stat.st_size > self.max_non_binary_size:
future = executor.submit(self.create_binary_file, new_tree_path, full_path)
if file_name.endswith('.yaml'):
with open(full_path, 'r') as f:
data = f.read()
data = parse_yaml(full_path)
future = executor.submit(self.create_file, new_tree_path, data)
elif file_name.endswith('.json'):
with open(full_path, 'r') as f:
data = f.read()
data = parse_json(full_path)
future = executor.submit(self.create_file, new_tree_path, data, content_type=self.json_content_type)
else:
future = executor.submit(self.create_binary_file, new_tree_path, full_path)
Expand All @@ -165,19 +173,17 @@ def process_folder(self, executor, rootdir, tree_path, level, dir_futures, file_
if file_stat.st_size > self.max_non_binary_size:
future = executor.submit(self.create_binary_file, new_tree_path, full_path)
elif file_name.endswith('.yaml'):
with open(full_path, 'r') as f:
data = f.read()
data = parse_yaml(full_path)
future = executor.submit(self.create_file, new_tree_path, data)
elif file_name.endswith('.json'):
with open(full_path, 'r') as f:
data = f.read()
data = parse_json(full_path)
future = executor.submit(self.create_file, new_tree_path, data, content_type=self.json_content_type)
else:
future = executor.submit(self.create_binary_file, new_tree_path, full_path)
file_futures[future] = new_tree_path
return dir_futures, file_futures

def upload_configurations(self, rootdir, tree_names, delete_existing_trees, as_folder = False):
def upload_configurations(self, rootdir, tree_names, delete_existing_trees, as_folder=False):
self.validate_args(rootdir, tree_names, delete_existing_trees, as_folder)
with futures.ThreadPoolExecutor(max_workers=15) as executor:
dir_futures = self.process_root_dir(executor, rootdir, tree_names, delete_existing_trees)
Expand All @@ -193,7 +199,8 @@ def upload_configurations(self, rootdir, tree_names, delete_existing_trees, as_f
raise exc

processor_func = self.process_dir if not as_folder else self.process_folder
dir_futures, file_futures = processor_func(executor, rootdir, tree_path, level, dir_futures, file_futures)
dir_futures, file_futures = processor_func(executor, rootdir, tree_path, level, dir_futures,
file_futures)
done = futures.wait(dir_futures, return_when=futures.FIRST_COMPLETED).done
future = done.pop() if len(done) else None

Expand Down Expand Up @@ -241,7 +248,8 @@ def download_tree(self, tree_name, rootdir, delete_existing, blob_temp_dir):

def get_blob_data(self, tree_names):
url = self._core_api_host + PARAMSERVER_API_TREEBLOBS_PATH
response = RestClient(url).method(HttpMethod.GET).query_param({'treeNames': tree_names}).headers(self._headers).retry(0).execute()
response = RestClient(url).method(HttpMethod.GET).query_param({'treeNames': tree_names}).headers(
self._headers).retry(0).execute()
blob_data = get_api_response_data(response, parse_full=True).get('data', {})
return blob_data

Expand All @@ -254,7 +262,7 @@ def download_blob_file(blob, blob_temp_dir):
f.write(chunk)

@staticmethod
def validate_args(rootdir, tree_names, delete_existing_trees, as_folder = False):
def validate_args(rootdir, tree_names, delete_existing_trees, as_folder=False):
if not isinstance(rootdir, six.string_types):
raise InvalidParameterException('rootdir must be a string')
if tree_names:
Expand All @@ -263,6 +271,7 @@ def validate_args(rootdir, tree_names, delete_existing_trees, as_folder = False)
raise InvalidParameterException('delete_existing_trees must be a boolean')
if not isinstance(as_folder, bool):
raise InvalidParameterException('as_folder must be a boolean')

def download_configurations(self, rootdir, tree_names, delete_existing_trees):
self.validate_args(rootdir, tree_names, delete_existing_trees)
self._safe_makedirs(rootdir)
Expand Down
19 changes: 19 additions & 0 deletions rapyuta_io/utils/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class DeploymentNotRunningException(Exception):
"""
:ivar deployment_status: Deployment status object retrieved from the last poll
"""

def __init__(self, msg, deployment_status=None):
self.deployment_status = deployment_status
Exception.__init__(self, msg)
Expand Down Expand Up @@ -167,3 +168,21 @@ def __init__(self, msg=None):
class BuildOperationFailed(Exception):
def __init__(self, msg):
Exception.__init__(self, msg)


class InvalidJSONError(Exception):
def __init__(self, file_path=None):
msg = "Invalid JSON"
if file_path:
msg += ": {}".format(file_path)

Exception.__init__(self, msg)


class InvalidYAMLError(Exception):
def __init__(self, file_path=None):
msg = "Invalid YAML"
if file_path:
msg += ": {}".format(file_path)

Exception.__init__(self, msg)
45 changes: 43 additions & 2 deletions rapyuta_io/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import requests
import six
import yaml
from six.moves import range

from rapyuta_io.utils import APIError, ParameterMissingException, InvalidParameterException, \
UnauthorizedError, ResourceNotFoundError, BadRequestError, InternalServerError, ConflictError, \
ForbiddenError
ForbiddenError, InvalidJSONError, InvalidYAMLError
from rapyuta_io.utils.settings import EMPTY, DEFAULT_RANDOM_VALUE_LENGTH
from six.moves import range

BEARER = "Bearer"

Expand Down Expand Up @@ -140,3 +142,42 @@ def is_true(value):
def is_false(value):
return value in [False, 'False', 'false']


def parse_json(filepath):
"""Parses the given file and checks if it is a valid JSON. If not, raises an error."""
try:
with open(filepath, 'r') as f:
data = f.read()
except Exception as e:
raise e

try:
json.loads(data)
except json.decoder.JSONDecodeError:
return InvalidJSONError(filepath)

return data


def parse_yaml(filepath):
"""Parse the given file and checks if it is a valid YAML. If not, raises an error."""
try:
with open(filepath, 'r') as f:
data = f.read()
except Exception as e:
raise e

try:
loaded = yaml.safe_load(data)
except yaml.YAMLError:
raise InvalidYAMLError(filepath)

# For example, consider a file with just the following text.
# The yaml.safe_load() function will still parse this file.
#
# invalid data
#
if not isinstance(loaded, dict):
raise InvalidYAMLError(filepath)

return data
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"urllib3>=1.23",
"python-dateutil>=2.8.2",
"pytz",
"pyyaml>=5.4.1",
"setuptools",
"jsonschema==4.0.0",
],
Expand Down
14 changes: 6 additions & 8 deletions tests/paramserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyfakefs import fake_filesystem_unittest
from requests import Response

import rapyuta_io.utils.error
from rapyuta_io.utils.error import BadRequestError, InternalServerError
from tests.utils.client import get_client, headers
from tests.utils.paramserver import UPLOAD_SUCCESS_TREE_PATHS, UPLOAD_SUCCESS_MOCK_CALLS, UPLOAD_FAILURE_400CASE_TREE_PATHS, \
Expand Down Expand Up @@ -149,20 +150,17 @@ def side_effect(*args, **kwargs):
mock_response = MagicMock(spec=Response)
url = kwargs['url']
url_suffix = url[len(self.URL_PREFIX):]
if url_suffix == '/tree2/robot_type/AMR/motors.yaml':
mock_response.status_code = requests.codes.BAD_REQUEST
mock_response.text = '{"error": "invalid data"}'
else:
if url_suffix != '/tree2/robot_type/AMR/motors.yaml':
mock_response.status_code = requests.codes.OK
mock_response.text = 'null'
return mock_response
mock_request.side_effect = side_effect

with self.assertRaisesRegex(BadRequestError, 'invalid data') as exc:
with self.assertRaisesRegex(rapyuta_io.utils.error.InvalidYAMLError, 'Invalid YAML') as exc:
get_client().upload_configurations(rootdir)
self.assertEqual('tree2/robot_type/AMR/motors.yaml', exc.exception.tree_path)
mock_request.assert_has_calls(expected_mock_calls, any_order=True)
self.assertEqual(len(expected_mock_calls), mock_request.call_count, 'extra request calls were made')
self.assertRegex(str(exc.exception), 'tree2/robot_type/AMR/motors.yaml')
self.assertNotEqual(len(expected_mock_calls), mock_request.call_count,
'expected fewer calls due to client side exception')

@patch('requests.request')
def test_upload_configurations_failure_500case(self, mock_request):
Expand Down
Loading
Loading