From ef403c2e5b13a6b3cf34e47194598e34a2f672b8 Mon Sep 17 00:00:00 2001 From: "Timothy M. Shead" Date: Thu, 1 Aug 2013 22:41:46 -0600 Subject: [PATCH] Additional sanity checking for redimension() arguments. --- .../slycat/analysis/coordinator/__init__.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/packages/slycat/analysis/coordinator/__init__.py b/packages/slycat/analysis/coordinator/__init__.py index 36e97a67c..9a1d38784 100644 --- a/packages/slycat/analysis/coordinator/__init__.py +++ b/packages/slycat/analysis/coordinator/__init__.py @@ -84,6 +84,10 @@ def require_attributes(self, attributes): else: attributes = [self.require_attribute(attribute) for attribute in attributes] return attributes + def require_attribute_names(self, names): + if isinstance(names, basestring): + return [self.require_attribute_name(names)] + return [self.require_attribute_name(name) for attribute_name in names] def require_chunk_size(self, chunk_size): if not isinstance(chunk_size, int): raise InvalidArgument("Chunk size must be an integer.") @@ -99,15 +103,14 @@ def require_chunk_sizes(self, shape, chunk_sizes): if len(shape) != len(chunk_sizes): raise InvalidArgument("Array shape and chunk sizes must contain the same number of dimensions.") return chunk_sizes - def require_dimension(self, dimension): - if isinstance(dimension, basestring): - dimension = {"name":dimension, "type":"int64"} - return dimension - def require_dimensions(self, dimensions): - dimensions = [self.require_dimension(dimension) for dimension in dimensions] - if not len(dimensions): - raise InvalidArgument("Array must have at least one dimension.") - return dimensions + def require_dimension_name(self, name): + if not isinstance(name, basestring): + raise InvalidArgument("Dimension name must be a string.") + return name + def require_dimension_names(self, names): + if isinstance(names, basestring): + return [self.require_dimension_name(names)] + return [self.require_dimension_name(name) for name in names] def require_expression(self, expression): if isinstance(expression, basestring): expression = ast.parse(expression) @@ -236,6 +239,8 @@ def random(self, shape, chunk_sizes, seed, attributes): return self.pyro_register(array(array_workers, [])) def redimension(self, source, dimensions, attributes): source = self.require_object(source) + dimensions = self.require_dimension_names(dimensions) + attributes = self.require_attribute_names(attributes) array_workers = [] for worker_index, (source_proxy, worker) in enumerate(zip(source.workers, self.workers())): array_workers.append(worker.redimension(worker_index, source_proxy._pyroUri, dimensions, attributes))