File Coverage

blib/lib/AI/MXNet/Base.pm
Criterion Covered Total %
statement 13 15 86.6
branch n/a
condition n/a
subroutine 5 5 100.0
pod n/a
total 18 20 90.0


line stmt bran cond sub pod time code
1             # Licensed to the Apache Software Foundation (ASF) under one
2             # or more contributor license agreements. See the NOTICE file
3             # distributed with this work for additional information
4             # regarding copyright ownership. The ASF licenses this file
5             # to you under the Apache License, Version 2.0 (the
6             # "License"); you may not use this file except in compliance
7             # with the License. You may obtain a copy of the License at
8             #
9             # http://www.apache.org/licenses/LICENSE-2.0
10             #
11             # Unless required by applicable law or agreed to in writing,
12             # software distributed under the License is distributed on an
13             # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14             # KIND, either express or implied. See the License for the
15             # specific language governing permissions and limitations
16             # under the License.
17              
18             package AI::MXNet::Base;
19 26     26   203 use strict;
  26         73  
  26         898  
20 26     26   178 use warnings;
  26         70  
  26         1102  
21 26     26   19406 use PDL;
  26         502  
  26         498  
22 26     26   7445243 use PDL::Types qw();
  26         92  
  26         796  
23 26     26   28468 use AI::MXNetCAPI 1.1;
  0            
  0            
24             use AI::NNVMCAPI 1.1;
25             use AI::MXNet::Types;
26             use Time::HiRes;
27             use Carp;
28             use Exporter;
29             use base qw(Exporter);
30             use List::Util qw(shuffle);
31              
32             @AI::MXNet::Base::EXPORT = qw(product enumerate assert zip check_call build_param_doc
33             pdl cat dog svd bisect_left pdl_shuffle as_array
34             DTYPE_STR_TO_MX DTYPE_MX_TO_STR DTYPE_MX_TO_PDL
35             DTYPE_PDL_TO_MX DTYPE_MX_TO_PERL GRAD_REQ_MAP);
36             @AI::MXNet::Base::EXPORT_OK = qw(pzeros pceil);
37              
38             use constant DTYPE_STR_TO_MX => {
39             float32 => 0,
40             float64 => 1,
41             float16 => 2,
42             uint8 => 3,
43             int32 => 4,
44             int8 => 5,
45             int64 => 6
46             };
47             use constant DTYPE_MX_TO_STR => {
48             0 => 'float32',
49             1 => 'float64',
50             2 => 'float16',
51             3 => 'uint8',
52             4 => 'int32',
53             5 => 'int8',
54             6 => 'int64'
55             };
56             use constant DTYPE_MX_TO_PDL => {
57             0 => 6,
58             1 => 7,
59             2 => 6,
60             3 => 0,
61             4 => 3,
62             5 => 0,
63             6 => 5,
64             float32 => 6,
65             float64 => 7,
66             float16 => 6,
67             uint8 => 0,
68             int32 => 3,
69             int8 => 0,
70             int64 => 5
71             };
72             use constant DTYPE_PDL_TO_MX => {
73             6 => 0,
74             7 => 1,
75             0 => 3,
76             3 => 4,
77             5 => 6
78             };
79             use constant DTYPE_MX_TO_PERL => {
80             0 => 'f',
81             1 => 'd',
82             2 => 'S',
83             3 => 'C',
84             4 => 'l',
85             5 => 'c',
86             6 => 'q',
87             float32 => 'f',
88             float64 => 'd',
89             float16 => 'S',
90             uint8 => 'C',
91             int32 => 'l',
92             int8 => 'c',
93             int64 => 'q'
94             };
95             use constant GRAD_REQ_MAP => {
96             null => 0,
97             write => 1,
98             add => 3
99             };
100              
101             =head1 NAME
102              
103             AI::MXNet::Base - Helper functions
104              
105             =head1 DEFINITION
106              
107             Helper functions
108              
109             =head2 zip
110              
111             Perl version of for x,y,z in zip (arr_x, arr_y, arr_z)
112              
113             Parameters
114             ----------
115             $sub_ref, called with @_ filled with $arr_x->[$i], $arr_y->[$i], $arr_z->[$i]
116             for each loop iteration.
117              
118             @array_refs
119             =cut
120              
121             sub zip
122             {
123             my ($sub, @arrays) = @_;
124             my $len = @{ $arrays[0] };
125             for (my $i = 0; $i < $len; $i++)
126             {
127             $sub->(map { $_->[$i] } @arrays);
128             }
129             }
130              
131             =head2 enumerate
132              
133             Same as zip, but the argument list in the anonymous sub is prepended
134             by the iteration count.
135             =cut
136              
137             sub enumerate
138             {
139             my ($sub, @arrays) = @_;
140             my $len = @{ $arrays[0] };
141             zip($sub, [0..$len-1], @arrays);
142             }
143              
144             =head2 product
145              
146             Calculates the product of the input agruments.
147             =cut
148              
149             sub product
150             {
151             my $p = 1;
152             map { $p = $p * $_ } @_;
153             return $p;
154             }
155              
156             =head2 bisect_left
157              
158             https://hg.python.org/cpython/file/2.7/Lib/bisect.py
159             =cut
160              
161             sub bisect_left
162             {
163             my ($a, $x, $lo, $hi) = @_;
164             $lo //= 0;
165             $hi //= @{ $a };
166             if($lo < 0)
167             {
168             Carp::confess('lo must be non-negative');
169             }
170             while($lo < $hi)
171             {
172             my $mid = int(($lo+$hi)/2);
173             if($a->[$mid] < $x)
174             {
175             $lo = $mid+1;
176             }
177             else
178             {
179             $hi = $mid;
180             }
181             }
182             return $lo;
183             }
184              
185             =head2 pdl_shuffle
186              
187             Shuffle the pdl by the last dimension
188              
189             Parameters
190             -----------
191             PDL $pdl
192             $preshuffle Maybe[ArrayRef[Index]], if defined the array elements are used
193             as shuffled last dimension's indexes
194             =cut
195              
196              
197             sub pdl_shuffle
198             {
199             my ($pdl, $preshuffle) = @_;
200             my $c = $pdl->copy;
201             my @shuffle = $preshuffle ? @{ $preshuffle } : shuffle(0..$pdl->dim(-1)-1);
202             my $rem = $pdl->ndims-1;
203             for my $i (0..$pdl->dim(-1)-1)
204             {
205             $c->slice(('X')x$rem, $i) .= $pdl->slice(('X')x$rem, $shuffle[$i])
206             }
207             $c;
208             }
209              
210             =head2 assert
211              
212             Parameters
213             -----------
214             Bool $input
215             Str $error_str
216             Calls Carp::confess with $error_str//"AssertionError" if the $input is false
217             =cut
218              
219             sub assert
220             {
221             my ($input, $error_str) = @_;
222             local($Carp::CarpLevel) = 1;
223             Carp::confess($error_str//'AssertionError')
224             unless $input;
225             }
226              
227             =head2 check_call
228              
229             Checks the return value of C API call
230              
231             This function will raise an exception when error occurs.
232             Every API call is wrapped with this function.
233              
234             Returns the C API call return values stripped of first return value,
235             checks for return context and returns first element in
236             the values list when called in scalar context.
237             =cut
238              
239             sub check_call
240             {
241             Carp::confess(AI::MXNetCAPI::GetLastError()) if shift;
242             return wantarray ? @_ : $_[0];
243             }
244              
245             =head2 build_param_doc
246              
247             Builds argument docs in python style.
248              
249             arg_names : array ref of str
250             Argument names.
251              
252             arg_types : array ref of str
253             Argument type information.
254              
255             arg_descs : array ref of str
256             Argument description information.
257              
258             remove_dup : boolean, optional
259             Whether to remove duplication or not.
260              
261             Returns
262             -------
263             docstr : str
264             Python docstring of parameter sections.
265             =cut
266              
267             sub build_param_doc
268             {
269             my ($arg_names, $arg_types, $arg_descs, $remove_dup) = @_;
270             $remove_dup //= 1;
271             my %param_keys;
272             my @param_str;
273             zip(sub {
274             my ($key, $type_info, $desc) = @_;
275             return if exists $param_keys{$key} and $remove_dup;
276             $param_keys{$key} = 1;
277             my $ret = sprintf("%s : %s", $key, $type_info);
278             $ret .= "\n ".$desc if length($desc);
279             push @param_str, $ret;
280             },
281             $arg_names, $arg_types, $arg_descs
282             );
283             return sprintf("Parameters\n----------\n%s\n", join("\n", @param_str));
284             }
285              
286             =head2 _notify_shutdown
287              
288             Notify MXNet about shutdown.
289             =cut
290              
291             sub _notify_shutdown
292             {
293             check_call(AI::MXNetCAPI::NotifyShutdown());
294             }
295              
296             sub _indent
297             {
298             my ($s_, $numSpaces) = @_;
299             my @s = split(/\n/, $s_);
300             if (@s == 1)
301             {
302             return $s_;
303             }
304             my $first = shift(@s);
305             @s = ($first, map { (' 'x$numSpaces) . $_ } @s);
306             return join("\n", @s);
307             }
308              
309             sub as_array
310             {
311             return ref $_[0] eq 'ARRAY' ? $_[0] : [$_[0]];
312             }
313              
314             my %internal_arguments = (prefix => 1, params => 1, shared => 1);
315             my %attributes_per_class;
316             sub process_arguments
317             {
318             my $orig = shift;
319             my $class = shift;
320             if($class->can('python_constructor_arguments'))
321             {
322             if(not exists $attributes_per_class{$class})
323             {
324             %{ $attributes_per_class{$class} } = map { $_->name => 1 } $class->meta->get_all_attributes;
325             }
326             my %kwargs;
327             while(@_ >= 2 and not ref $_[-2] and (exists $attributes_per_class{$class}{ $_[-2] } or exists $internal_arguments{ $_[-2] }))
328             {
329             my $v = pop(@_);
330             my $k = pop(@_);
331             $kwargs{ $k } = $v;
332             }
333             if(@_)
334             {
335             @kwargs{ @{ $class->python_constructor_arguments }[0..@_-1] } = @_;
336             }
337             return $class->$orig(%kwargs);
338             }
339             return $class->$orig(@_);
340             }
341              
342             END {
343             _notify_shutdown();
344             Time::HiRes::sleep(0.01);
345             }
346              
347             *pzeros = \&zeros;
348             *pceil = \&ceil;
349             ## making sure that we can stringify arbitrarily large piddles
350             $PDL::toolongtoprint = 1000_000_000;
351             ## convenience subs
352             {
353             my $orig_at = PDL->can('at');
354             no warnings 'redefine';
355             *PDL::at = sub {
356             my ($self, @args) = @_;
357             return $orig_at->($self, @args) if @args != 1;
358             return $orig_at->($self, @args) if $self->ndims == 1;
359             return $self->slice(('X')x($self->ndims-1), $args[0])->squeeze;
360             };
361             *PDL::len = sub { shift->dim(-1) };
362             *PDL::dtype = sub { DTYPE_MX_TO_STR->{ DTYPE_PDL_TO_MX->{ shift->type->numval } } };
363             }
364              
365             1;