Skip to content

Commit

Permalink
[Perl] emulate Python zip() for Perl (apache#8192)
Browse files Browse the repository at this point in the history
* [Perl] emulate Python zip() for Perl

* [Perl] retool zip() uses away from the callback form
  • Loading branch information
tlby authored and crazy-cat committed Oct 26, 2017
1 parent a3b4c5c commit d76d050
Show file tree
Hide file tree
Showing 23 changed files with 318 additions and 177 deletions.
8 changes: 4 additions & 4 deletions perl-package/AI-MXNet/lib/AI/MXNet/AutoGrad.pm
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,10 @@ method grad(
);

my @ret;
zip(sub {
my ($handle, $stype) = @_;
for(zip($grad_vars, $grad_stypes)) {
my ($handle, $stype) = @$_;
push @ret, AI::MXNet::NDArray->new(handle => $handle, stype => $stype);
}, $grad_vars, $grad_stypes);
}
if(blessed $variables)
{
return $ret[0];
Expand Down Expand Up @@ -474,4 +474,4 @@ func _parse_head($heads, $head_grads)
return (\@head_handles, \@hgrad_handles);
}

1;
1;
23 changes: 13 additions & 10 deletions perl-package/AI-MXNet/lib/AI/MXNet/Base.pm
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,17 @@ use constant GRAD_REQ_MAP => {

sub zip
{
my ($sub, @arrays) = @_;
my $len = @{ $arrays[0] };
for (my $i = 0; $i < $len; $i++)
if('CODE' eq ref $_[0])
{
$sub->(map { $_->[$i] } @arrays);
# continue supporting the callback style
my $code = shift;
$code->(@$_) for AI::MXNetCAPI::py_zip(map { \@$_ } @_);
return;
}
# the map() here may seem like a no-op, but triggers overloading or
# whatever else is needed to make array-ish things actually arrays
# before entering the low level list builder.
return AI::MXNetCAPI::py_zip(map { \@$_ } @_);
}

=head2 enumerate
Expand Down Expand Up @@ -270,16 +275,14 @@ sub build_param_doc
$remove_dup //= 1;
my %param_keys;
my @param_str;
zip(sub {
my ($key, $type_info, $desc) = @_;
return if exists $param_keys{$key} and $remove_dup;
for(zip($arg_names, $arg_types, $arg_descs)) {
my ($key, $type_info, $desc) = @$_;
next if exists $param_keys{$key} and $remove_dup;
$param_keys{$key} = 1;
my $ret = sprintf("%s : %s", $key, $type_info);
$ret .= "\n ".$desc if length($desc);
push @param_str, $ret;
},
$arg_names, $arg_types, $arg_descs
);
}
return sprintf("Parameters\n----------\n%s\n", join("\n", @param_str));
}

Expand Down
74 changes: 37 additions & 37 deletions perl-package/AI-MXNet/lib/AI/MXNet/Executor/Group.pm
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ func _split_input_slice($batch_size, $work_load_list)
# Load a array ref of arrays into a array ref of arrays specified by slices
func _load_general($data, $targets, $major_axis)
{
zip(sub {
my ($d_src, $d_targets, $axis) = @_;
for(zip($data, $targets, $major_axis)) {
my ($d_src, $d_targets, $axis) = @$_;
if(blessed($d_targets) and $d_targets->isa('AI::MXNet::NDarray'))
{
$d_src->copyto($d_targets);
}
elsif(ref $d_targets eq 'ARRAY' and blessed $d_targets->[0])
{
zip(sub {
my ($src, $dst) = @_;
for(zip($d_src, $d_targets)) {
my ($src, $dst) = @$_;
$src->copyto($dst);
}, $d_src, $d_targets);
}
}
else
{
Expand Down Expand Up @@ -124,7 +124,7 @@ func _load_general($data, $targets, $major_axis)
}
}
}
}, $data, $targets, $major_axis);
}
}

# Load data into sliced arrays
Expand All @@ -144,8 +144,8 @@ func _load_label($batch, $targets, $major_axis)
func _merge_multi_context($outputs, $major_axis)
{
my @rets;
zip(sub {
my ($tensors, $axis) = @_;
for(zip($outputs, $major_axis)) {
my ($tensors, $axis) = @$_;
if($axis >= 0)
{
if(@$tensors == 1)
Expand All @@ -165,7 +165,7 @@ func _merge_multi_context($outputs, $major_axis)
# first one, without checking they are actually the same
push @rets, $tensors->[0];
}
}, $outputs, $major_axis);
}
return \@rets;
}

Expand Down Expand Up @@ -353,9 +353,9 @@ method decide_slices(ArrayRef[AI::MXNet::DataDesc] $data_shapes)
{
confess("empty data_shapes array") unless @{ $data_shapes } > 0;
my $major_axis = [map { AI::MXNet::DataDesc->get_batch_axis($_->layout) } @{ $data_shapes }];
zip(sub {
my ($desc, $axis) = @_;
return if($axis == -1);
for(zip($data_shapes, $major_axis)) {
my ($desc, $axis) = @$_;
next if($axis == -1);
my $batch_size = $desc->shape->[$axis];
if(defined $self->_p->batch_size)
{
Expand All @@ -370,7 +370,7 @@ method decide_slices(ArrayRef[AI::MXNet::DataDesc] $data_shapes)
$self->_p->batch_size($batch_size);
$self->_p->slices(AI::MXNet::Executor::Group::_split_input_slice($self->_p->batch_size, $self->workload));
}
}, $data_shapes, $major_axis);
}
return $major_axis;
}

Expand Down Expand Up @@ -590,16 +590,16 @@ method set_params(HashRef[AI::MXNet::NDArray] $arg_params, HashRef[AI::MXNet::ND
method get_params(HashRef[AI::MXNet::NDArray] $arg_params, HashRef[AI::MXNet::NDArray] $aux_params)
{
my $weight = 0;
zip(sub {
my ($name, $block) = @_;
for(zip($self->param_names, $self->_p->param_arrays)) {
my ($name, $block) = @$_;
my $weight = sum(map { $_->copyto(AI::MXNet::Context->cpu) } @{ $block }) / @{ $block };
$weight->astype($arg_params->{$name}->dtype)->copyto($arg_params->{$name});
}, $self->param_names, $self->_p->param_arrays);
zip(sub {
my ($name, $block) = @_;
}
for(zip($self->_p->aux_names, $self->_p->aux_arrays)) {
my ($name, $block) = @$_;
my $weight = sum(map { $_->copyto(AI::MXNet::Context->cpu) } @{ $block }) / @{ $block };
$weight->astype($aux_params->{$name}->dtype)->copyto($aux_params->{$name});
}, $self->_p->aux_names, $self->_p->aux_arrays);
}
}


Expand Down Expand Up @@ -668,15 +668,15 @@ method get_output_shapes()
{
my @shapes = map { $_->shape } @{ $self->execs->[0]->outputs };
my @concat_shapes;
zip(sub {
my ($key, $shape, $axis) = @_;
for(zip($self->symbol->list_outputs, \@shapes, $self->_p->output_layouts)) {
my ($key, $shape, $axis) = @$_;
my @the_shape = @{ $shape };
if($axis >= 0)
{
$the_shape[$axis] = $self->_p->batch_size;
}
push @concat_shapes, AI::MXNet::DataDesc->new(name => $key, shape => \@the_shape);
}, $self->symbol->list_outputs, \@shapes, $self->_p->output_layouts);
}
return \@concat_shapes;
}

Expand Down Expand Up @@ -765,11 +765,11 @@ method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out_grad
{
confess('re-bind with for_training=1 to run backward') unless $self->for_training;
$out_grads //= [];
zip(sub {
my ($i, $exec, $islice) = @_;
for(zip([0..@{ $self->_p->execs }-1], $self->_p->execs, $self->_p->slices)) {
my ($i, $exec, $islice) = @$_;
my @out_grads_slice;
zip(sub{
my ($grad, $axis) = @_;
for(zip($out_grads, $self->_p->output_layouts)) {
my ($grad, $axis) = @$_;
if($axis >= 0)
{
my $og_my_slice = $grad->slice_axis({
Expand All @@ -783,9 +783,9 @@ method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out_grad
{
push @out_grads_slice, $grad->copyto($self->contexts->[$i]);
}
}, $out_grads, $self->_p->output_layouts);
}
$exec->backward(\@out_grads_slice);
}, [0..@{ $self->_p->execs }-1], $self->_p->execs, $self->_p->slices);
}
}

=head2 update_metric
Expand All @@ -802,11 +802,11 @@ method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out_grad

method update_metric(AI::MXNet::EvalMetric $eval_metric, ArrayRef[AI::MXNet::NDArray] $labels)
{
zip(sub {
my ($texec, $islice) = @_;
for(zip($self->_p->execs, $self->_p->slices)) {
my ($texec, $islice) = @$_;
my @labels_slice;
zip(sub {
my ($label, $axis) = @_;
for(zip($labels, $self->_p->label_layouts)) {
my ($label, $axis) = @$_;
if($axis == 0)
{
# slicing NDArray along axis 0 can avoid copying
Expand All @@ -825,9 +825,9 @@ method update_metric(AI::MXNet::EvalMetric $eval_metric, ArrayRef[AI::MXNet::NDA
{
push @labels_slice, $label;
}
}, $labels, $self->_p->label_layouts);
}
$eval_metric->update(\@labels_slice, $texec->outputs);
}, $self->_p->execs, $self->_p->slices);
}
}

method _bind_ith_exec(
Expand Down Expand Up @@ -874,8 +874,8 @@ method _bind_ith_exec(
method _sliced_shape(ArrayRef[AI::MXNet::DataDesc] $shapes, Int $i, ArrayRef[Int] $major_axis)
{
my @sliced_shapes;
zip(sub {
my ($desc, $axis) = @_;
for(zip($shapes, $major_axis)) {
my ($desc, $axis) = @$_;
my @shape = @{ $desc->shape };
if($axis >= 0)
{
Expand All @@ -887,7 +887,7 @@ method _sliced_shape(ArrayRef[AI::MXNet::DataDesc] $shapes, Int $i, ArrayRef[Int
dtype => $desc->dtype,
layout => $desc->layout
);
}, $shapes, $major_axis);
}
return \@sliced_shapes;
}

Expand Down
24 changes: 12 additions & 12 deletions perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Block.pm
Original file line number Diff line number Diff line change
Expand Up @@ -565,21 +565,21 @@ method infer_shape(@args)
my $args = \@args;
($args) = __PACKAGE__->_flatten($args);
my %in;
zip(sub {
my ($i, $j) = @_;
for(zip($inputs, $args)) {
my ($i, $j) = @$_;
$in{ $i->name } = $j->shape;
}, $inputs, $args);
}
my ($arg_shapes, undef, $aux_shapes) = $out->infer_shape(%in);
my %sdict;
zip(sub {
my ($i, $j) = @_;
for(zip($out->list_arguments(), $arg_shapes)) {
my ($i, $j) = @$_;
$sdict{ $i } = $j;
}, $out->list_arguments(), $arg_shapes);
}
my %aux;
zip(sub {
my ($i, $j) = @_;
for(zip($out->list_auxiliary_states(), $aux_shapes)) {
my ($i, $j) = @$_;
$aux{ $i } = $j;
}, $out->list_auxiliary_states(), $aux_shapes);
}
%sdict = (%sdict, %aux);
for my $i ($self->collect_params->values)
{
Expand Down Expand Up @@ -878,10 +878,10 @@ method forward($x, @args)
assert((Data::Dumper::Dumper($in_fmt) eq Data::Dumper::Dumper($self->_in_format)), "Invalid input format");
my $ret = $self->_cached_graph->[1]->deepcopy;
my %in;
zip(sub {
my ($k, $v) = @_;
for(zip($self->_cached_graph->[0], $args)) {
my ($k, $v) = @$_;
$in{$k->name} = $v;
}, $self->_cached_graph->[0], $args);
}
$ret->_compose(%in);
$ret = (__PACKAGE__->_regroup($ret, $self->_out_format))[0];
if(ref($ret) eq 'ARRAY' and wantarray)
Expand Down
8 changes: 4 additions & 4 deletions perl-package/AI-MXNet/lib/AI/MXNet/Gluon/Parameter.pm
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ method _load_init($data, $ctx)
{
if($self->shape)
{
zip(sub {
my ($i, $j) = @_;
for(zip($self->shape, $data->shape)) {
my ($i, $j) = @$_;
assert(
($i == 0 or $i == $j),
sprintf(
Expand All @@ -204,7 +204,7 @@ method _load_init($data, $ctx)
$self->name, "@{$self->shape}", "@{$data->shape}"
)
);
}, $self->shape, $data->shape);
}
}
if($self->dtype)
{
Expand Down Expand Up @@ -923,4 +923,4 @@ method load(
}
}

1;
1;
14 changes: 7 additions & 7 deletions perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN/Cell.pm
Original file line number Diff line number Diff line change
Expand Up @@ -1047,10 +1047,10 @@ method hybrid_forward(GluonClass $F, GluonInput $inputs, GluonInput $states)
if($p_states != 0)
{
my @tmp;
zip(sub {
my ($new_s, $old_s) = @_;
for(zip($next_states, $states)) {
my ($new_s, $old_s) = @$_;
push @tmp, $F->where($mask->($p_states, $new_s), $new_s, $old_s);
}, $next_states, $states);
}
$states = \@tmp;
}
else
Expand Down Expand Up @@ -1109,10 +1109,10 @@ method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=,
else
{
my @tmp;
zip(sub {
my ($i, $j) = @_;
for(zip($outputs, $inputs)) {
my ($i, $j) = @$_;
push @tmp, $F->elemwise_add($i, $j);
}, $outputs, $inputs);
}
$outputs = \@tmp;
}
return ($outputs, $states);
Expand Down Expand Up @@ -1222,4 +1222,4 @@ method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=,

__PACKAGE__->register('AI::MXNet::Gluon::RNN');

1;
1;
6 changes: 3 additions & 3 deletions perl-package/AI-MXNet/lib/AI/MXNet/Gluon/RNN/Layer.pm
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ method forward(GluonInput $inputs, Maybe[GluonInput] $states=)
{
$states = [$states];
}
zip(sub {
my ($state, $info) = @_;
for(zip($states, $self->state_info($batch_size))) {
my ($state, $info) = @$_;
if(Dumper($state->shape) ne Dumper($info->{shape}))
{
my @state_shape = @{ $state->shape };
confess("Invalid recurrent state shape. Expecting @{$info->{shape}}, got @state_shape.");
}
}, $states, $self->state_info($batch_size));
}
if($self->input_size == 0)
{
for my $i (0..$self->dir-1)
Expand Down
Loading

0 comments on commit d76d050

Please sign in to comment.