From c8063305d71184e6ae41e4987dac074e39bb39de Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Tue, 9 Apr 2019 17:27:15 +0300 Subject: [PATCH 1/2] Simplify code in Tensorflow frontend --- python/tvm/relay/frontend/tensorflow.py | 190 +++++++++--------- .../frontend/tensorflow/test_forward.py | 43 ++-- topi/python/topi/util.py | 12 +- 3 files changed, 105 insertions(+), 140 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 11026b9e5ad8..705762d465e9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -63,7 +63,7 @@ def _get_relay_op(op_name): return op class AttrCvt(object): - """Common attribute conveter. An AttrConverter instance is a callable: + """Common attribute converter. An AttrConverter instance is a callable: ``` attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) new_op_name, new_attr = attr_converter(attrs) @@ -222,17 +222,37 @@ def _dim_check(attrs): return False return _dim_check, "Only 2d kernel supported." -def _infer_channels(inputs, params, transpose=False): - """A hack for getting 'channles' or 'units' since tensorflow don't provide +def _infer_channels(node, params, transpose=False): + """A hack for getting 'channels' or 'units' since tensorflow don't provide these attributes. We check the shape of weights provided to get the number. """ - out_type = ir_pass.infer_type(inputs) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - channels = out_shapes[0][0] if not transpose else out_shapes[0][1] + out_shape = _infer_shape(node, params) + channels = out_shape[0] if not transpose else out_shape[1] return channels +def _infer_out_shapes(inputs, params): + """A method to get the output shape of intermediate nodes in the relay graph.""" + return [_infer_shape(inputs, params)] + +def _infer_shape(node, params=None): + """A method to get the output shape of an intermediate node in the relay graph.""" + out_type = ir_pass.infer_type(node) + return get_const_tuple(out_type.checked_type.shape) + +def _get_param(params, input_node): + return params.pop(input_node.name_hint).asnumpy() + +def _get_num_param(params, input_node): + return _get_param(params, input_node)[0] + +def _get_list_param(params, input_node): + return _get_param(params, input_node).tolist() + +def _get_tuple_param(params, input_node): + return tuple(_get_param(params, input_node)) + def _rsqrt(): - def _impl(inputs, attr, *args): + def _impl(inputs, attr, params): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) return AttrCvt(op_name="power")(inputs, attr) return _impl @@ -243,16 +263,15 @@ def _impl(inputs, attr, params): try: # In Tensorflow, `axis` argument is a Tensor, not attribute. We # support the case where it inputs from a scalar constant. - axis_input_name = inputs[1].name_hint - axis_input_vlaue = [params[axis_input_name].asnumpy()[0]] + axis_input_value = [_get_num_param(params, inputs[1])] except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) - return func(inputs[0], axis=axis_input_vlaue, keepdims=False) + return func(inputs[0], axis=axis_input_value, keepdims=False) return _impl def _elemwise(name): - def _impl(inputs, attr, *args): + def _impl(inputs, attr, params): assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) return _get_relay_op(name)(*inputs) return _impl @@ -472,7 +491,7 @@ def _impl(inputs, attr, params): def _expand_dims(): def _impl(inputs, attr, params): dim_input = inputs.pop(1) - axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0] + axis = _get_num_param(params, dim_input) return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr) return _impl @@ -527,21 +546,19 @@ def _impl(inputs, attr, params): def _concatV2(): def _impl(inputs, attr, params): pop_node = inputs.pop(len(inputs)-1) - axis = params[pop_node.name_hint] - params.pop(pop_node.name_hint) + axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['T', 'N', 'Tidx'], - extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + extras={'axis': axis})([inputs], attr) return _impl def _concat(): def _impl(inputs, attr, params): pop_node = inputs.pop(0) - axis = params[pop_node.name_hint] - params.pop(pop_node.name_hint) + axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['N'], - extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + extras={'axis': axis})([inputs], attr) return _impl def _pack(): @@ -565,8 +582,8 @@ def _impl(inputs, attr, params): def _slice(): def _impl(inputs, attr, params): - begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist() - size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist() + begin = _get_list_param(params, inputs[1]) + size = _get_list_param(params, inputs[2]) data_shape = attr['_input_shapes'][inputs[0]] data_dim = len(data_shape) end = size @@ -581,24 +598,18 @@ def _impl(inputs, attr, params): def _reshape(): def _impl(inputs, attr, params): + pop_node = inputs.pop(1) try: - pop_node = inputs[1] - shape_arg = params.pop(pop_node.name_hint) - inputs.pop(1) - - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(shape_arg.asnumpy())}, - ignores=['Tshape'])(inputs, attr) + shape_arg = _get_tuple_param(params, pop_node) except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - params_new = _infer_value(inputs[1], params) - inputs.pop(1) - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())}, - ignores=['Tshape'])(inputs, attr) + params_new = _infer_value(pop_node, params) + shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) + return AttrCvt( + op_name="reshape", + extras={'newshape': shape_arg}, + ignores=['Tshape'])(inputs, attr) return _impl @@ -737,9 +748,10 @@ def _impl(inputs, attr, params): if -1 in output_shape: output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() - fill_arg = params.pop(inputs.pop(1).name_hint) - return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name), - output_shape, attr['T'].name) + fill_arg = _get_num_param(params, inputs.pop(1)) + dtype = attr['T'].name + return _op.full(tvm.relay.const(fill_arg, dtype), + output_shape, dtype) return _impl def _lrn(): @@ -757,9 +769,7 @@ def _impl(inputs, attr, params): def _sum(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint).asnumpy() - # convert to tuple for preventing invalid parameter format error - axis = tuple(axis) + axis = _get_tuple_param(params, inputs[1]) return AttrCvt( op_name='sum', extras={'axis': axis}, @@ -775,25 +785,17 @@ def _impl(inputs, attr, params): def _gather(): "GatherV2, Gather" def _impl(inputs, attr, params): - - axis = 0 if len(inputs) > 2: - axis = params[inputs.pop(2).name_hint].asnumpy()[0] - new_input = [] - new_input.append(inputs.pop(0)) - new_input.append(inputs.pop(0)) + axis = _get_num_param(params, inputs.pop(2)) + else: + axis = 0 + new_input = inputs[0:2] return AttrCvt(op_name="take", extras={'axis': tvm.const(axis, 'int32')}, - ignores=['Tindices', 'Tparams', 'validate_indices', \ + ignores=['Tindices', 'Tparams', 'validate_indices', 'Taxis', '_class'])(new_input, attr) return _impl -def _infer_out_shapes(inputs, params): - """A method to get the output shape of an intermediate node in the relay graph.""" - out_type = ir_pass.infer_type(inputs) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - return out_shapes - def _stridedSlice(): def _impl(inputs, attr, params): """Strided Slice. @@ -801,9 +803,9 @@ def _impl(inputs, attr, params): Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ tensorflow/core/util/strided_slice_op.cc#L147-L368 """ - begin = params.pop(inputs[1].name_hint).asnumpy().tolist() - end = params.pop(inputs[2].name_hint).asnumpy().tolist() - stride = params.pop(inputs[3].name_hint).asnumpy().tolist() + begin = _get_list_param(params, inputs[1]) + end = _get_list_param(params, inputs[2]) + stride = _get_list_param(params, inputs[3]) begin_mask = int(attr.get('begin_mask', 0)) end_mask = int(attr.get('end_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0)) @@ -878,7 +880,7 @@ def _transform_mask(stride_dim, ellipsis_mask): if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_out_shapes(out, params)[0] + out_shape = _infer_shape(out, params) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -899,19 +901,14 @@ def _transform_mask(stride_dim, ellipsis_mask): def _pad(name): def _impl(inputs, attr, params): - padlist_key = inputs[1].name_hint - if padlist_key in params: - padlist = params.pop(padlist_key).asnumpy() - else: - raise tvm.error.OpAttributeRequired( - 'Attribute {} not found in operator Pad.'.format(padlist_key)) - paddings = tuple([tuple(l) for l in padlist]) + padlist = _get_param(params, inputs[1]) + paddings = tuple(tuple(l) for l in padlist) attr['pad_width'] = paddings attr['pad_value'] = 0 new_inputs = [inputs[0]] if name == 'PadV2': - constant_values = params.pop(inputs[2].name_hint).asnumpy() - attr['pad_value'] = constant_values[0] + constant_values = _get_num_param(params, inputs[2]) + attr['pad_value'] = constant_values return AttrCvt( op_name='pad', ignores=['Tpaddings'],)(new_inputs, attr) @@ -921,10 +918,9 @@ def _transpose(): def _impl(inputs, attr, params): # If perm is not specified, axes is left empty, # otherwise its value is get from params - param_name = _get_name_hint(inputs[1]) - if param_name in params: - axes = tuple(params.get(param_name).asnumpy()) - else: + try: + axes = _get_list_param(params, inputs[1]) + except (IndexError, KeyError): axes = None return _op.transpose(inputs[0], axes=axes) return _impl @@ -936,7 +932,7 @@ def _impl(inputs, attr, params): def _reverse_v2(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint).asnumpy()[0] + axis = _get_num_param(params, inputs[1]) return AttrCvt( op_name="reverse", ignores=['Tidx'], @@ -957,9 +953,9 @@ def _impl(inputs, attr, params): def _range(): def _impl(inputs, attr, params): - start = params.pop(inputs[0].name_hint).asnumpy()[0] - limit = params.pop(inputs[1].name_hint).asnumpy()[0] - delta = params.pop(inputs[2].name_hint).asnumpy()[0] + start = _get_num_param(params, inputs[0]) + limit = _get_num_param(params, inputs[1]) + delta = _get_num_param(params, inputs[2]) name = attr["_node_name"] params[name] = tvm.nd.array([start, limit, delta]) @@ -970,25 +966,27 @@ def _impl(inputs, attr, params): def _elu(): def _impl(inputs, attr, params): - alpha = tvm.relay.const(-1.0, attr['T'].name) - return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + dtype = attr['T'].name + alpha = tvm.relay.const(-1.0, dtype) + return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) return _impl def _selu(): def _impl(inputs, attr, params): - alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name) - gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name) - return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + dtype = attr['T'].name + alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) + gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) + return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) return _impl def _mean(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint) + axis = _get_tuple_param(params, inputs[1]) return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], transforms={'keep_dims': 'keepdims'}, - extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr) + extras={'axis': axis})([inputs[0]], attr) return _impl def _broadcast(name): @@ -1014,8 +1012,7 @@ def _impl(inputs, attr, params): if has_size_vector: input_node_index = 0 input_axis_index = 2 - size_splits_input_name = _get_name_hint(inputs[1]) - size_splits = params[size_splits_input_name].asnumpy() + size_splits = _get_param(params, inputs[1]) section_beginnings = np.cumsum(size_splits)[:-1] indices_or_sections = tuple(section_beginnings) else: @@ -1023,8 +1020,7 @@ def _impl(inputs, attr, params): input_axis_index = 0 indices_or_sections = attr['num_split'] input_node = inputs[input_node_index] - axis_input_name = _get_name_hint(inputs[input_axis_index]) - axis_input_value = params[axis_input_name].asnumpy()[0] + axis_input_value = _get_num_param(params, inputs[input_axis_index]) except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for split: `axis` and `num_or_size_splits` " \ @@ -1080,8 +1076,8 @@ def _space_to_batch_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] - block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() - paddings = params.pop(inputs[2].name_hint).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1]) + paddings = _get_list_param(params, inputs[2]) N = len(input_shape) M = len(block_shape) batch = input_shape[0] @@ -1102,7 +1098,7 @@ def _impl(inputs, attr, params): axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0] + permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params) # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # producing an output tensor of shape: # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., @@ -1119,8 +1115,8 @@ def _batch_to_space_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] - block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() - crops = params.pop(inputs[2].name_hint).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1]) + crops = _get_list_param(params, inputs[2]) M = len(block_shape) batch = input_shape[0] # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: @@ -1145,7 +1141,7 @@ def _impl(inputs, attr, params): # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # input_shape[M+1], ..., input_shape[N-1]] - reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0] + reshaped_permuted_shape = _infer_shape(reshaped_permuted, params) cropped = reshaped_permuted for axis in range(1, M+1): crop = crops[axis - 1] @@ -1944,23 +1940,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Infer shapes even without specifying "add_shapes=True" if output_shapes == [None]: - out_shapes = [] - for node_item in self._nodes[node.name]: - out_type = ir_pass.infer_type(node_item) - out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]] self._output_shapes[node.name] = out_shapes if self._output_shapes[node.name] and shape and node.name in shape: assert self._output_shapes[node.name] == list(shape[node.name]) - # Infer shapes if passed explicitely + # Infer shapes if passed explicitly node_output = self._nodes[node.name] if shape and (not self._output_shapes[node.name][0] or -1 in self._output_shapes[node.name][0]): - out_shapes = [] - for node_item in node_output: - out_type = ir_pass.infer_type(node_item) - out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + out_shapes = [_infer_shape(node_item) for node_item in node_output] self._output_shapes[node.name] = out_shapes out = [] diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index e4626e0d60ff..449795faf02b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, layout = None if target == "cuda": layout = "NCHW" - target_host = 'llvm' - - if isinstance(input_data, list): - shape_dict = {} - dtype_dict = {} - for i, e in enumerate(input_node): - shape_dict[e] = input_data[i].shape - dtype_dict[e] = input_data[i].dtype - else: - shape_dict = {input_node: input_data.shape} - dtype_dict = {input_node: input_data.dtype} + target_host = None + + shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} sym, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=out_names) with relay.build_config(opt_level=opt_level): - graph, lib, params = relay.build(sym, target, params=params) + graph, lib, params = relay.build(sym, target, target_host, params) ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs - for i, e in enumerate(input_node): - m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + for e, i in zip(input_node, input_data): + m.set_input(e, tvm.nd.array(i)) m.set_input(**params) # execute @@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, # get outputs assert out_names is None or num_output == len(out_names), ( "out_names: {} num_output: {}".format(out_names, num_output)) - tvm_output_list = [] - for i in range(0, num_output): - tvm_output = m.get_output(i) - tvm_output_list.append(tvm_output.asnumpy()) + tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)] return tvm_output_list def run_tf_graph(sess, input_data, input_node, output_node): @@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node): input_node = convert_to_list(input_node) output_node = convert_to_list(output_node) - tensor = [0] * len(output_node) - for i in range(len(output_node)): - tensor[i] = sess.graph.get_tensor_by_name(output_node[i]) + tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node] - input_dict = {} - for i, e in enumerate(input_node): - input_dict[e] = input_data[i] + input_dict = {e: input_data[i] for i, e in enumerate(input_node)} output_data = sess.run(tensor, input_dict) return output_data @@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node): def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False, opt_level=3): """Generic function to generate and compare tensorflow and TVM output""" + def name_without_num(name): + return name.split(':')[0] if ":" in name else name out_name = convert_to_list(out_name) - out_node = [0]*len(out_name) - for i in range(len(out_name)): - out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i] + out_node = [name_without_num(name) for name in out_name] in_data = convert_to_list(in_data) in_name = convert_to_list(in_name) - in_node = [0]*len(in_name) - for i in range(len(in_name)): - in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i] + in_node = [name_without_num(name) for name in in_name] with tf.Session() as sess: if init_global_variables: sess.run(variables.global_variables_initializer()) diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index f648245c6bb7..623c81a07da8 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -151,11 +151,7 @@ def get_const_tuple(in_tuple): out_tuple : tuple of int The output. """ - out_tuple = () - for elem in in_tuple: - value = get_const_int(elem) - out_tuple = out_tuple + (value, ) - return out_tuple + return tuple(get_const_int(elem) for elem in in_tuple) def get_float_tuple(in_tuple): @@ -171,11 +167,7 @@ def get_float_tuple(in_tuple): out_tuple : tuple of float The output. """ - out_tuple = () - for elem in in_tuple: - value = get_const_float(elem) - out_tuple = out_tuple + (value, ) - return out_tuple + return tuple(get_const_float(elem) for elem in in_tuple) def simplify(expr): From da3d09a88287d4f6c9654ffd011bf9b15721b2f7 Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Tue, 9 Apr 2019 17:42:23 +0300 Subject: [PATCH 2/2] Add tests for TF matmul op --- .../frontend/tensorflow/test_forward.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 449795faf02b..c74b3b7bcbf6 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -560,6 +560,38 @@ def test_forward_variable(): _test_variable(np.random.uniform(size=(32, 100)).astype('float32')) +####################################################################### +# MatMul +# ------ + +def _test_matmul(i, j, k, dtype, outer=None): + """ One iteration of matmul """ + + A_shape_init = [i, j] + B_shape_init = [j, k] + + for transpose_a in [False, True]: + for transpose_b in [False, True]: + outer = outer or [] + A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init) + B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init) + + with tf.Graph().as_default(): + A = tf.placeholder(shape=A_shape, dtype=dtype, name='A') + B = tf.placeholder(shape=B_shape, dtype=dtype, name='B') + result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b) + + A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name) + +def test_forward_matmul(): + """ Matmul op test""" + _test_matmul(1, 3, 6, 'int32') + _test_matmul(5, 3, 1, 'float64') + # TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support + + ####################################################################### # StridedSlice # ------------ @@ -1737,3 +1769,6 @@ def test_placeholder(): test_forward_rel_ops() test_forward_logical() test_where() + + test_forward_matmul() + # TODO missing tests: rank, range \ No newline at end of file