File Coverage

blib/lib/AI/MXNet/Base.pm
Criterion Covered Total %
statement 13 15 86.6
branch n/a
condition n/a
subroutine 5 5 100.0
pod n/a
total 18 20 90.0


line stmt bran cond sub pod time code
1             package AI::MXNet::Base;
2 20     20   128 use strict;
  20         46  
  20         504  
3 20     20   101 use warnings;
  20         44  
  20         435  
4 20     20   12833 use PDL;
  20         294  
  20         98  
5 20     20   4151169 use PDL::Types qw();
  20         62  
  20         449  
6 20     20   18513 use AI::MXNetCAPI 1.0101;
  0            
  0            
7             use AI::NNVMCAPI 1.01;
8             use AI::MXNet::Types;
9             use Time::HiRes;
10             use Carp;
11             use Exporter;
12             use base qw(Exporter);
13             use List::Util qw(shuffle);
14              
15             @AI::MXNet::Base::EXPORT = qw(product enumerate assert zip check_call build_param_doc
16             pdl cat dog svd bisect_left pdl_shuffle
17             DTYPE_STR_TO_MX DTYPE_MX_TO_STR DTYPE_MX_TO_PDL
18             DTYPE_PDL_TO_MX DTYPE_MX_TO_PERL GRAD_REQ_MAP);
19             @AI::MXNet::Base::EXPORT_OK = qw(pzeros pceil);
20             use constant DTYPE_STR_TO_MX => {
21             float32 => 0,
22             float64 => 1,
23             float16 => 2,
24             uint8 => 3,
25             int32 => 4
26             };
27             use constant DTYPE_MX_TO_STR => {
28             0 => 'float32',
29             1 => 'float64',
30             2 => 'float16',
31             3 => 'uint8',
32             4 => 'int32'
33             };
34             use constant DTYPE_MX_TO_PDL => {
35             0 => 6,
36             1 => 7,
37             2 => 6,
38             3 => 0,
39             4 => 3,
40             float32 => 6,
41             float64 => 7,
42             float16 => 6,
43             uint8 => 0,
44             int32 => 3
45             };
46             use constant DTYPE_PDL_TO_MX => {
47             6 => 0,
48             7 => 1,
49             0 => 3,
50             3 => 4,
51             };
52             use constant DTYPE_MX_TO_PERL => {
53             0 => 'f',
54             1 => 'd',
55             2 => 'S',
56             3 => 'C',
57             4 => 'l',
58             float32 => 'f',
59             float64 => 'd',
60             float16 => 'S',
61             uint8 => 'C',
62             int32 => 'l'
63             };
64             use constant GRAD_REQ_MAP => {
65             null => 0,
66             write => 1,
67             add => 3
68             };
69              
70             =head1 NAME
71              
72             AI::MXNet::Base - Helper functions
73              
74             =head1 DEFINITION
75              
76             Helper functions
77              
78             =head2 zip
79              
80             Perl version of for x,y,z in zip (arr_x, arr_y, arr_z)
81              
82             Parameters
83             ----------
84             $sub_ref, called with @_ filled with $arr_x->[$i], $arr_y->[$i], $arr_z->[$i]
85             for each loop iteration.
86              
87             @array_refs
88             =cut
89              
90             sub zip
91             {
92             my ($sub, @arrays) = @_;
93             my $len = @{ $arrays[0] };
94             for (my $i = 0; $i < $len; $i++)
95             {
96             $sub->(map { $_->[$i] } @arrays);
97             }
98             }
99              
100             =head2 enumerate
101              
102             Same as zip, but the argument list in the anonymous sub is prepended
103             by the iteration count.
104             =cut
105              
106             sub enumerate
107             {
108             my ($sub, @arrays) = @_;
109             my $len = @{ $arrays[0] };
110             zip($sub, [0..$len-1], @arrays);
111             }
112              
113             =head2 product
114              
115             Calculates the product of the input agruments.
116             =cut
117              
118             sub product
119             {
120             my $p = 1;
121             map { $p = $p * $_ } @_;
122             return $p;
123             }
124              
125             =head2 bisect_left
126              
127             https://hg.python.org/cpython/file/2.7/Lib/bisect.py
128             =cut
129              
130             sub bisect_left
131             {
132             my ($a, $x, $lo, $hi) = @_;
133             $lo //= 0;
134             $hi //= @{ $a };
135             if($lo < 0)
136             {
137             Carp::confess('lo must be non-negative');
138             }
139             while($lo < $hi)
140             {
141             my $mid = int(($lo+$hi)/2);
142             if($a->[$mid] < $x)
143             {
144             $lo = $mid+1;
145             }
146             else
147             {
148             $hi = $mid;
149             }
150             }
151             return $lo;
152             }
153              
154             =head2 pdl_shuffle
155              
156             Shuffle the pdl by the last dimension
157              
158             Parameters
159             -----------
160             PDL $pdl
161             $preshuffle Maybe[ArrayRef[Index]], if defined the array elements are used
162             as shuffled last dimension's indexes
163             =cut
164              
165              
166             sub pdl_shuffle
167             {
168             my ($pdl, $preshuffle) = @_;
169             my $c = $pdl->copy;
170             my @shuffle = $preshuffle ? @{ $preshuffle } : shuffle(0..$pdl->dim(-1)-1);
171             my $rem = $pdl->ndims-1;
172             for my $i (0..$pdl->dim(-1)-1)
173             {
174             $c->slice(('X')x$rem, $i) .= $pdl->slice(('X')x$rem, $shuffle[$i])
175             }
176             $c;
177             }
178              
179             =head2 assert
180              
181             Parameters
182             -----------
183             Bool $input
184             Str $error_str
185             Calls Carp::confess with $error_str//"AssertionError" if the $input is false
186             =cut
187              
188             sub assert
189             {
190             my ($input, $error_str) = @_;
191             local($Carp::CarpLevel) = 1;
192             Carp::confess($error_str//'AssertionError')
193             unless $input;
194             }
195              
196             =head2 check_call
197              
198             Checks the return value of C API call
199              
200             This function will raise an exception when error occurs.
201             Every API call is wrapped with this function.
202              
203             Returns the C API call return values stripped of first return value,
204             checks for return context and returns first element in
205             the values list when called in scalar context.
206             =cut
207              
208             sub check_call
209             {
210             Carp::confess(AI::MXNetCAPI::GetLastError()) if shift;
211             return wantarray ? @_ : $_[0];
212             }
213              
214             =head2 build_param_doc
215              
216             Builds argument docs in python style.
217              
218             arg_names : array ref of str
219             Argument names.
220              
221             arg_types : array ref of str
222             Argument type information.
223              
224             arg_descs : array ref of str
225             Argument description information.
226              
227             remove_dup : boolean, optional
228             Whether to remove duplication or not.
229              
230             Returns
231             -------
232             docstr : str
233             Python docstring of parameter sections.
234             =cut
235              
236             sub build_param_doc
237             {
238             my ($arg_names, $arg_types, $arg_descs, $remove_dup) = @_;
239             $remove_dup //= 1;
240             my %param_keys;
241             my @param_str;
242             zip(sub {
243             my ($key, $type_info, $desc) = @_;
244             return if exists $param_keys{$key} and $remove_dup;
245             $param_keys{$key} = 1;
246             my $ret = sprintf("%s : %s", $key, $type_info);
247             $ret .= "\n ".$desc if length($desc);
248             push @param_str, $ret;
249             },
250             $arg_names, $arg_types, $arg_descs
251             );
252             return sprintf("Parameters\n----------\n%s\n", join("\n", @param_str));
253             }
254              
255             =head2 _notify_shutdown
256              
257             Notify MXNet about shutdown.
258             =cut
259              
260             sub _notify_shutdown
261             {
262             check_call(AI::MXNetCAPI::NotifyShutdown());
263             }
264              
265             END {
266             _notify_shutdown();
267             Time::HiRes::sleep(0.01);
268             }
269              
270             *pzeros = \&zeros;
271             *pceil = \&ceil;
272             ## making sure that we can stringify arbitrarily large piddles
273             $PDL::toolongtoprint = 1000_000_000;
274              
275             1;