diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm b/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm index a8da8470f574..4002bbf5d369 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/Base.pm @@ -120,12 +120,17 @@ use constant GRAD_REQ_MAP => { sub zip { - my ($sub, @arrays) = @_; - my $len = @{ $arrays[0] }; - for (my $i = 0; $i < $len; $i++) + if('CODE' eq ref $_[0]) { - $sub->(map { $_->[$i] } @arrays); + # continue supporting the callback style + my $code = shift; + $code->(@$_) for AI::MXNetCAPI::py_zip(map { \@$_ } @_); + return; } + # the map() here may seem like a no-op, but triggers overloading or + # whatever else is needed to make array-ish things actually arrays + # before entering the low level list builder. + return AI::MXNetCAPI::py_zip(map { \@$_ } @_); } =head2 enumerate diff --git a/perl-package/AI-MXNet/t/test_base.t b/perl-package/AI-MXNet/t/test_base.t new file mode 100644 index 000000000000..935f3b2bd03e --- /dev/null +++ b/perl-package/AI-MXNet/t/test_base.t @@ -0,0 +1,108 @@ +use strict; +use warnings; +use Test::More; +use AI::MXNet qw(mx); +use AI::MXNet::TestUtils qw(same reldiff GetMNIST_ubyte GetCifar10); + +sub test_builtin_zip() +{ + is_deeply( + [ AI::MXNet::zip([ 0 .. 9 ], [ 10 .. 19 ]) ], + [ map { [ $_, 10 + $_ ] } 0 .. 9 ]); + is_deeply( + [ AI::MXNet::zip([ 0 .. 9 ], [ 10 .. 19 ], [ 20 .. 29 ]) ], + [ map { [ $_, 10 + $_, 20 + $_ ] } 0 .. 9 ]); + my $over = ListOverload->new(10 .. 19); + is_deeply( + [ AI::MXNet::zip([ 0 .. 9 ], \@$over) ], + [ map { [ $_, 10 + $_ ] } 0 .. 9 ]); + my $tied = ListTied->new(10 .. 19); + is_deeply( + [ AI::MXNet::zip([ 0 .. 9 ], \@$tied) ], + [ map { [ $_, 10 + $_ ] } 0 .. 9 ]); +} + + +test_builtin_zip(); +done_testing(); + +package ListTied { + sub new { + my($class, @list) = @_; + my @tied; + tie @tied, $class, @list; + return \@tied; + } + sub TIEARRAY { + my($class, @list) = @_; + return bless { list => \@list }, $class; + } + sub FETCH { + my($self, $index) = @_; + return $self->{list}[$index]; + } + sub STORE { + my($self, $index, $value) = @_; + return $self->{list}[$index] = $value; + } + sub FETCHSIZE { + my($self) = @_; + return scalar @{$self->{list}}; + } + sub STORESIZE { + my($self, $count) = @_; + return $self->{list}[$count - 1] //= undef; + } + sub EXTEND { + my($self, $count) = @_; + return $self->STORESIZE($count); + } + sub EXISTS { + my($self, $key) = @_; + return exists $self->{list}[$key]; + } + sub DELETE { + my($self, $key) = @_; + return delete $self->{list}[$key]; + } + sub CLEAR { + my($self) = @_; + return @{$self->{list}} = (); + } + sub PUSH { + my($self, @list) = @_; + return push @{$self->{list}}, @list; + } + sub POP { + my($self) = @_; + return pop @{$self->{list}}; + } + sub SHIFT { + my($self) = @_; + return shift @{$self->{list}}; + } + sub UNSHIFT { + my($self, @list) = @_; + return unshift @{$self->{list}}, @list; + } + sub SPLICE { + my($self, $offset, $length, @list) = @_; + return splice @{$self->{list}}, $offset, $length, @list; + } + sub UNTIE { + my($self) = @_; + } + sub DESTROY { + my($self) = @_; + } +} + +package ListOverload { + use overload '@{}' => \&as_list; + sub new { + my($class, @list) = @_; + return bless { list => \@list }, $class; + } + sub as_list { return $_[0]{list} } +} + diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i index e466e98b7842..663a0c285f0b 100644 --- a/perl-package/AI-MXNetCAPI/mxnet.i +++ b/perl-package/AI-MXNetCAPI/mxnet.i @@ -106,7 +106,44 @@ static void ExecutorMonitor_callback(const char* name, NDArrayHandle handle, voi %} +%{ + +/* this is an adaptation of Python/bltinmodule.c's builtin_zip() */ +XS(py_zip) { + dXSARGS; + I32 i; + I32 len = -1; + AV *l[items]; + + for(i = 0; i < items; i++) { + AV *av = (AV *)SvRV(ST(i)); + I32 thislen; + + if(SvTYPE(av) != SVt_PVAV) + croak("zip argument#%d must be an array", i); + thislen = av_len(av) + 1; + if(len < 0 || thislen < len) + len = thislen; + l[i] = av; + } + EXTEND(SP, len); + for(i = 0; i < len; i++) { + I32 j; + SV *next[items]; + + for(j = 0; j < items; j++) { + SV **sv = av_fetch(l[j], i, 0); + next[j] = sv ? *sv : &PL_sv_undef; + } + ST(i) = sv_2mortal(newRV_noinc((SV *)av_make(items, next))); + } + XSRETURN(len); +} + +%} + %init %{ + newXS(SWIG_prefix "py_zip", py_zip, (char *)__FILE__); /* These SWIG_TypeClientData() calls might break in the future, but * %rename should work on these types before that happens. */ SWIG_TypeClientData(SWIGTYPE_p_MXNDArray, (void *)"NDArrayHandle");