Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor cachedop to specialize (#6735)
Browse files Browse the repository at this point in the history
revert cachedop for perl

fix
t st
imt qprove error message

ci

fix

fix

fix
  • Loading branch information
piiswrong authored Jun 21, 2017
1 parent 0df68e8 commit 3ceb6d2
Show file tree
Hide file tree
Showing 25 changed files with 292 additions and 499 deletions.
31 changes: 7 additions & 24 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,24 +585,20 @@ MXNET_DLL int MXAutogradBackward(mx_uint num_output,
/*!
* \brief create cached operator
*/
MXNET_DLL int MXCachedCreateOp(AtomicSymbolCreator creator,
int num_inputs,
int num_params,
const char **param_keys,
const char **param_vals,
MXNET_DLL int MXCreateCachedOp(SymbolHandle handle,
CachedOpHandle *out);
/*!
* \brief free cached operator
*/
MXNET_DLL int MXCachedFree(CachedOpHandle handle);
MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle);
/*!
* \brief invoke cached operator
*/
MXNET_DLL int MXCachedInvoke(CachedOpHandle handle,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs);
MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle,
int num_inputs,
NDArrayHandle *inputs,
int *num_outputs,
NDArrayHandle **outputs);
//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down Expand Up @@ -670,19 +666,6 @@ MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
const char **keys,
const char **vals,
SymbolHandle *out);
/*!
* \brief Create an AtomicSymbol from cached op.
* \param handle cached node attribute.
* \param name name of new symbol.
* \param num_args the number of symbol arguments
* \param args symbol arguments
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXCachedCreateSymbol(CachedOpHandle handle,
const char* name,
mx_uint num_args,
SymbolHandle* args,
SymbolHandle* out);
/*!
* \brief Create a Variable Symbol.
* \param name name of the variable
Expand Down
1 change: 0 additions & 1 deletion perl-package/AI-MXNet/lib/AI/MXNet.pm
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use v5.14.0;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::CachedOp;
use AI::MXNet::Callback;
use AI::MXNet::NDArray;
use AI::MXNet::Symbol;
Expand Down
41 changes: 0 additions & 41 deletions perl-package/AI-MXNet/lib/AI/MXNet/CachedOp.pm

This file was deleted.

1 change: 0 additions & 1 deletion perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,6 @@ method backward(Maybe[AI::MXNet::NDArray] $out_grad=, Bool $retain_graph=0)
)
}

method CachedOp(@args) { AI::MXNet::CachedOp->new(@args) }

my $lvalue_methods = join "\n", map {"use attributes 'AI::MXNet::NDArray', \\&AI::MXNet::NDArray::$_, 'lvalue';"}
qw/at slice aspdl asmpdl reshape copy sever T astype as_in_context copyto empty zero ones full
Expand Down
38 changes: 0 additions & 38 deletions perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Base.pm
Original file line number Diff line number Diff line change
Expand Up @@ -140,44 +140,6 @@ method _init_ndarray_module()
}
}

method invoke(
AI::MXNet::CachedOp $cached_op,
ArrayRef[AI::MXNet::NDArray] $args,
Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out=,
Maybe[Str] $name=
)
{
my $original_output;
if(defined $out)
{
$original_output = $out;
if(not ref($out) eq 'ARRAY')
{
$out = [$out];
}
}
else
{
$out = [];
}
my $output = check_call(
AI::MXNetCAPI::CachedInvoke(
$cached_op->handle,
scalar(@$args),
[map { $_->handle } @$args],
[map { $_->handle } @$out]
)
);
return $original_output if defined $original_output;
if(@$output == 1)
{
return $self->new(handle => $output->[0]);
}
else
{
return [map { $self->new(handle => $_) } @$output];
}
}

__PACKAGE__->_init_ndarray_module;

Expand Down
1 change: 0 additions & 1 deletion perl-package/AI-MXNet/lib/AI/MXNet/Symbol.pm
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,6 @@ method arange(Index :$start=0, Index :$stop=, Num :$step=1.0, Index :$repeat=1,
});
}

method CachedOp(@args) { AI::MXNet::CachedOp->new(@args) }

sub _parse_arguments
{
Expand Down
14 changes: 0 additions & 14 deletions perl-package/AI-MXNet/lib/AI/MXNet/Symbol/Base.pm
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,6 @@ method _init_symbol_module()
}
}

method invoke(AI::MXNet::CachedOp $cached_op, ArrayRef[AI::MXNet::Symbol] $args, Maybe[Str] $name=)
{
my $hint = lc($cached_op->op);
$name = AI::MXNet::Symbol::NameManager->current->get($name, $hint);
my $handle = check_call(
AI::MXNetCAPI::CachedCreateSymbol(
$cached_op->handle,
$name,
scalar(@$args),
[map { $_->handle } @$args]
)
);
return $self->new(handle => $handle);
}

__PACKAGE__->_init_symbol_module;

Expand Down
16 changes: 2 additions & 14 deletions perl-package/AI-MXNet/t/test_ndarray.t
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(almost_equal);
use Test::More tests => 9;
use Test::More tests => 8;

sub test_ndarray_reshape
{
Expand Down Expand Up @@ -36,17 +36,6 @@ sub test_moveaxis
is_deeply($X->moveaxis(2, 0)->shape, [3, 2, 2]);
}

sub test_cached
{
my $op = mx->nd->CachedOp('Convolution', 3, kernel=>[3, 3], num_filter=>10);
my $data = mx->nd->ones([3, 4, 10, 10]);
my $weight = mx->nd->ones([10, 4, 3, 3]);
my $bias = mx->nd->ones([10]);
my $o1 = mx->nd->invoke($op, [$data, $weight, $bias]);
$bias .= 2;
my $o2 = mx->nd->invoke($op, [$data, $weight, $bias]);
ok(almost_equal($o2->aspdl, $o1->aspdl + 1));
}

sub test_output
{
Expand All @@ -64,5 +53,4 @@ sub test_output

test_ndarray_reshape();
test_moveaxis();
test_cached();
test_output();
test_output();
22 changes: 2 additions & 20 deletions perl-package/AI-MXNet/t/test_symbol.t
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use strict;
use warnings;
use Test::More tests => 102;
use Test::More tests => 98;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(mlp2 conv check_consistency zip assert enumerate);
use Storable qw(freeze thaw);
Expand Down Expand Up @@ -221,24 +221,6 @@ sub test_load_000800

test_load_000800();

sub test_cached
{
my $op = mx->sym->CachedOp('Convolution', 3, kernel=>[3, 3], num_filter=>10);
my $data = mx->sym->var('data');
my $weight = mx->sym->var('weight');
my $bias = mx->sym->var('bias');
my $out = mx->sym->invoke($op, [$data, $weight, $bias], 'conv');
is_deeply($out->list_arguments, ['data', 'weight', 'bias']);
is_deeply($out->list_outputs, ['conv_output']);
{
local($mx::NameManager) = mx->name->Prefix('test_');
is(mx->sym->invoke($op, [$data, $weight, $bias])->name,'test_convolution0');
is(mx->sym->invoke($op, [$data, $weight, $bias])->name, 'test_convolution1');
}
}

test_cached();

__DATA__
{
"nodes": [
Expand Down Expand Up @@ -427,4 +409,4 @@ __DATA__
],
"arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12, 13, 15],
"heads": [[16, 0]]
}
}
41 changes: 0 additions & 41 deletions perl-package/AI-MXNetCAPI/mxnet.i
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ static void ExecutorMonitor_callback(const char* name, NDArrayHandle handle, voi
SWIG_TypeClientData(SWIGTYPE_p_MXKVStore, (void *)"KVStoreHandle");
SWIG_TypeClientData(SWIGTYPE_p_MXRecordIO, (void *)"RecordIOHandle");
SWIG_TypeClientData(SWIGTYPE_p_MXRtc, (void *)"RtcHandle");
SWIG_TypeClientData(SWIGTYPE_p_MXCachedOp, (void *)"CachedOpHandle");
%}

/*! \brief manually define unsigned int */
Expand Down Expand Up @@ -151,8 +150,6 @@ typedef MXKVStore *KVStoreHandle;
typedef MXRecordIO *RecordIOHandle;
/*! \brief handle to MXRtc*/
typedef MXRtc *RtcHandle;
/*! \brief handle to cached operator */
typedef MXCachedOp *CachedOpHandle;

typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
Expand Down Expand Up @@ -628,30 +625,6 @@ int MXAutogradBackward(mx_uint num_output,
NDArrayHandle* in,
int retain_graph);

/*!
* \brief create cached operator
*/
int MXCachedCreateOp(AtomicSymbolCreator in,
int num_inputs,
int num_params,
const char **keys,
const char **vals,
CachedOpHandle *out);

/*!
* \brief free cached operator
*/
int MXCachedFree(CachedOpHandle handle);

/*!
* \brief invoke cached operator
*/
int MXCachedInvoke(CachedOpHandle handle,
int num_inputs,
NDArrayHandle *in,
int *out_size,
NDArrayHandle** out_array);

//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down Expand Up @@ -719,20 +692,6 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator in,
const char **keys,
const char **vals,
SymbolHandle *out);
/*!
* \brief Create an AtomicSymbol from cached op.
* \param handle cached node attribute.
* \param name name of new symbol.
* \param num_args the number of symbol arguments
* \param args symbol arguments
* \return 0 when success, -1 when failure happens
*/
int MXCachedCreateSymbol(CachedOpHandle handle,
const char* name,
mx_uint num_args,
SymbolHandle* in,
SymbolHandle* out);

/*!
* \brief Create a Variable Symbol.
* \param name name of the variable
Expand Down
5 changes: 2 additions & 3 deletions perl-package/AI-MXNetCAPI/mxnet_typemaps.i
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,12 @@
(DataIterHandle *out) (ExecutorHandle temp),
(KVStoreHandle *out) (KVStoreHandle temp),
(RecordIOHandle *out) (RecordIOHandle temp),
(RtcHandle *out) (RtcHandle temp),
(CachedOpHandle *out) (CachedOpHandle temp)
(RtcHandle *out) (RtcHandle temp)
{
$1 = &temp;
}
%typemap(argout) (NDArrayHandle *out), (FunctionHandle* out), (SymbolHandle *out), (ExecutorHandle *out), (DataIterHandle *out),
(KVStoreHandle *out), (RecordIOHandle *out), (RtcHandle *out) (RtcHandle temp), (CachedOpHandle *out) (CachedOpHandle temp)
(KVStoreHandle *out), (RecordIOHandle *out), (RtcHandle *out) (RtcHandle temp)
{
if(!result)
{
Expand Down
30 changes: 0 additions & 30 deletions python/mxnet/_ctypes/common.py

This file was deleted.

Loading

0 comments on commit 3ceb6d2

Please sign in to comment.