File Coverage

blib/lib/AI/NaiveBayes1.pm
Criterion Covered Total %
statement 265 302 87.7
branch 64 94 68.0
condition 13 24 54.1
subroutine 20 22 90.9
pod 16 16 100.0
total 378 458 82.5


line stmt bran cond sub pod time code
1             # (c) 2003-21 Vlado Keselj https://web.cs.dal.ca/~vlado
2              
3             package AI::NaiveBayes1;
4 8     8   6781 use strict;
  8         20  
  8         330  
5             require Exporter;
6 8     8   41 use vars qw($VERSION @ISA @EXPORT @EXPORT_OK %EXPORT_TAGS);
  8         16  
  8         796  
7             @EXPORT = qw(new);
8 8     8   54 use vars qw($Version);
  8         15  
  8         458  
9             $Version = $VERSION = '2.011';
10              
11 8     8   71 use vars @EXPORT_OK;
  8         16  
  8         223  
12              
13             # non-exported package globals go here
14 8     8   49 use vars qw();
  8         14  
  8         30547  
15              
16             sub new {
17 11     11 1 12919 my $package = shift;
18 11         165 return bless {
19             attributes => [ ],
20             labels => [ ],
21             attvals => {},
22             real_stat => {},
23             numof_instances => 0,
24             stat_labels => {},
25             stat_attributes => {},
26             smoothing => {},
27             attribute_type => {},
28             }, $package;
29             }
30              
31             sub set_real {
32 4     4 1 26 my ($self, @attr) = @_;
33 4         12 foreach my $a (@attr) { $self->{attribute_type}{$a} = 'real' }
  5         21  
34             }
35              
36             sub import_from_YAML {
37 5     5 1 169971 my $package = shift;
38 5         15 my $yaml = shift;
39 5         28 my $self = YAML::Load($yaml);
40 5         282416 return bless $self, $package;
41             }
42              
43             sub import_from_YAML_file {
44 9     9 1 338462 my $package = shift;
45 9         26 my $yamlf = shift;
46 9         40 my $self = YAML::LoadFile($yamlf);
47 9         496049 return bless $self, $package;
48             }
49              
50             # assume that the last header count means counts
51             # after optionally removing counts, the last header is label
52             sub add_table {
53 3     3 1 17 my $self = shift;
54 3         8 my @atts = (); my $lbl=''; my $cnt = '';
  3         8  
  3         6  
55 3         12 while (@_) {
56 3         8 my $table = shift;
57 3 50       24 if ($table =~ /^(.*)\n[ \t]*-+\n/) {
58 3         11 my $a = $1; $table = $';
  3         13  
59 3         11 $a =~ s/^\s+//; $a =~ s/\s+$//;
  3         14  
60 3 50       24 if ($a =~ /\s*\bcount\s*$/) {
61 3         8 $a=$`; $cnt=1; } else { $cnt='' }
  3         76  
  0         0  
62 3         22 @atts = split(/\s+/, $a);
63 3         10 $lbl = pop @atts;
64             }
65 3         27 while ($table ne '') {
66 43 50       148 $table =~ /^(.*)\n?/ or die;
67 43         91 my $r=$1; $table = $';
  43         77  
68 43         95 $r =~ s/^\s+//; $r=~ s/\s+$//;
  43         122  
69 43 100       87 if ($r =~ /^-+$/) { next }
  2         13  
70 41         158 my @v = split(/\s+/, $r);
71 41 50       110 die "values (#=$#v): {@v}\natts (#=$#atts): @atts, lbl=$lbl,\n".
    50          
72             "count: $cnt\n" unless $#v-($cnt?2:1) == $#atts;
73 41         61 my %av=(); my @a = @atts;
  41         90  
74 41         81 while (@a) { $av{shift @a} = shift(@v) }
  144         345  
75 41 50       156 $self->add_instances(attributes=>\%av,
76             label=>"$lbl=$v[0]",
77             cases=>($cnt?$v[1]:1) );
78             }
79             }
80             } # end of add_table
81              
82             # Simplified; not generally compatible.
83             # Assume that the last header is label. The first row contains
84             # attribute names.
85             sub add_csv_file {
86 0     0 1 0 my $self = shift; my $fn = shift; local *F;
  0         0  
  0         0  
87 0 0       0 open(F,$fn) or die "Cannot open CSV file `$fn': $!";
88 0         0 local $_ = ; my @atts = (); my $lbl=''; my $cnt = '';
  0         0  
  0         0  
  0         0  
89 0         0 chomp; @atts = split(/\s*,\s*/, $_); $lbl = pop @atts;
  0         0  
  0         0  
90 0         0 while () {
91 0         0 chomp; my @v = split(/\s*,\s*/, $_);
  0         0  
92 0 0       0 die "values (#=$#v): {@v}\natts (#=$#atts): @atts, lbl=$lbl,\n".
    0          
93             "count: $cnt\n" unless $#v-($cnt?2:1) == $#atts;
94 0         0 my %av=(); my @a = @atts;
  0         0  
95 0         0 while (@a) { $av{shift @a} = shift(@v) }
  0         0  
96 0 0       0 $self->add_instances(attributes=>\%av,
97             label=>"$lbl=$v[0]",
98             cases=>($cnt?$v[1]:1) );
99             }
100 0         0 close(F);
101             } # end of add_csv_file
102              
103             sub drop_attributes {
104 0     0 1 0 my $self = shift;
105 0         0 foreach my $a (@_) {
106 0         0 my @tmp = grep { $a ne $_ } @{ $self->{attributes} };
  0         0  
  0         0  
107 0         0 $self->{attributes} = \@tmp;
108 0         0 delete($self->{attvals}{$a});
109 0         0 delete($self->{stat_attributes}{$a});
110 0         0 delete($self->{attribute_type}{$a});
111 0         0 delete($self->{real_stat}{$a});
112 0         0 delete($self->{smoothing}{$a});
113             }
114             } # end of drop_attributes
115              
116             sub add_instances {
117 147     147 1 522 my ($self, %params) = @_;
118 147         247 for ('attributes', 'label', 'cases') {
119 441 50       839 die "Missing required '$_' parameter" unless exists $params{$_};
120             }
121              
122 147 100       180 if (scalar(keys(%{ $self->{stat_attributes} })) == 0) {
  147         353  
123 11         21 foreach my $a (keys(%{$params{attributes}})) {
  11         40  
124 31         60 $self->{stat_attributes}{$a} = {};
125 31         47 push @{ $self->{attributes} }, $a;
  31         61  
126 31         67 $self->{attvals}{$a} = [ ];
127 31 100       115 $self->{attribute_type}{$a} = 'nominal' unless defined($self->{attribute_type}{$a});
128             }
129             } else {
130 136         181 foreach my $a (keys(%{$self->{stat_attributes}}))
  136         257  
131             { die "attribute not given in instance: $a"
132 421 50       754 unless exists($params{attributes}{$a}) }
133             }
134              
135 147         312 $self->{numof_instances} += $params{cases};
136              
137 22         48 push @{ $self->{labels} }, $params{label} unless
138 147 100       335 exists $self->{stat_labels}->{$params{label}};
139              
140 147         237 $self->{stat_labels}{$params{label}} += $params{cases};
141              
142 147         213 foreach my $a (keys(%{$self->{stat_attributes}})) {
  147         281  
143 452 50       779 if ( not exists($params{attributes}{$a}) )
144 0         0 { die "attribute $a not given" }
145 452         614 my $attval = $params{attributes}{$a};
146 452 100       793 if (not exists($self->{stat_attributes}{$a}{$attval})) {
147 110         143 push @{ $self->{attvals}{$a} }, $attval;
  110         234  
148 110         253 $self->{stat_attributes}{$a}{$attval} = {};
149             }
150 452         1103 $self->{stat_attributes}{$a}{$attval}{$params{label}} += $params{cases};
151             }
152             }
153              
154             sub add_instance {
155 68     68 1 504 my ($self, %params) = @_; $params{cases} = 1;
  68         100  
156 68         153 $self->add_instances(%params);
157             }
158              
159             sub train {
160 11     11 1 58 my $self = shift;
161 11         26 my $m = $self->{model} = {};
162            
163 11         26 $m->{labelprob} = {};
164 11         21 foreach my $label (keys(%{$self->{stat_labels}}))
  11         38  
165             { $m->{labelprob}{$label} = $self->{stat_labels}{$label} /
166 22         76 $self->{numof_instances} }
167              
168 11         28 $m->{condprob} = {};
169 11         34 $m->{condprobe} = {};
170 11         18 foreach my $att (keys(%{$self->{stat_attributes}})) {
  11         39  
171 31 100       89 next if $self->{attribute_type}{$att} eq 'real';
172 26         69 $m->{condprob}{$att} = {};
173 26         66 $m->{condprobe}{$att} = {};
174 26         39 foreach my $label (keys(%{$self->{stat_labels}})) {
  26         58  
175 52         78 my $total = 0; my @attvals = ();
  52         85  
176 52         90 foreach my $attval (keys(%{$self->{stat_attributes}{$att}})) {
  52         142  
177             next unless
178             exists($self->{stat_attributes}{$att}{$attval}{$label}) and
179 128 100 66     451 $self->{stat_attributes}{$att}{$attval}{$label} > 0;
180 121         211 push @attvals, $attval;
181             $m->{condprob}{$att}{$attval} = {} unless
182 121 100       244 exists( $m->{condprob}{$att}{$attval} );
183             $m->{condprob}{$att}{$attval}{$label} =
184 121         250 $self->{stat_attributes}{$att}{$attval}{$label};
185             $m->{condprobe}{$att}{$attval} = {} unless
186 121 50       233 exists( $m->{condprob}{$att}{$attval} );
187             $m->{condprobe}{$att}{$attval}{$label} =
188 121         248 $self->{stat_attributes}{$att}{$attval}{$label};
189 121         217 $total += $m->{condprob}{$att}{$attval}{$label};
190             }
191 52 100 66     184 if (exists($self->{smoothing}{$att}) and
192             $self->{smoothing}{$att} =~ /^unseen count=/) {
193 6 50       14 my $uc = $'; $uc = 0.5 if $uc <= 0;
  6         21  
194 6 100       12 if(! exists($m->{condprob}{$att}{'*'}) ) {
195 3         7 $m->{condprob}{$att}{'*'} = {};
196 3         6 $m->{condprobe}{$att}{'*'} = {};
197             }
198 6         21 $m->{condprob}{$att}{'*'}{$label} = $uc;
199 6         10 $total += $uc;
200 6 50       14 if (grep {$_ eq '*'} @attvals) { die }
  24         57  
  0         0  
201 6         13 push @attvals, '*';
202             }
203 52         94 foreach my $attval (@attvals) {
204 127         393 $m->{condprobe}{$att}{$attval}{$label} =
205             "(= $m->{condprob}{$att}{$attval}{$label} / $total)";
206 127         278 $m->{condprob}{$att}{$attval}{$label} /= $total;
207             }
208             }
209             }
210              
211             # For real-valued attributes, we use Gaussian distribution
212             # let us collect statistics
213 11         22 foreach my $att (keys(%{$self->{stat_attributes}})) {
  11         32  
214 31 100       87 next unless $self->{attribute_type}{$att} eq 'real';
215             print STDERR "Smoothing ignored for real attribute $att!\n" if
216 5 0 33     16 defined($self->{smoothing}{att}) and $self->{smoothing}{att};
217 5         12 $m->{real_stat}->{$att} = {};
218 5         10 foreach my $attval (keys %{$self->{stat_attributes}{$att}}){
  5         21  
219 46         55 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         119  
220             $m->{real_stat}{$att}{$label}{sum}
221 53         135 += $attval * $self->{stat_attributes}{$att}{$attval}{$label};
222              
223             $m->{real_stat}{$att}{$label}{count}
224 53         109 += $self->{stat_attributes}{$att}{$attval}{$label};
225             }
226 46         61 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         108  
227             next if
228             !defined($m->{real_stat}{$att}{$label}{count}) ||
229 53 50 33     178 $m->{real_stat}{$att}{$label}{count} == 0;
230              
231             $m->{real_stat}{$att}{$label}{mean} =
232             $m->{real_stat}{$att}{$label}{sum} /
233 53         126 $m->{real_stat}{$att}{$label}{count};
234             }
235             }
236              
237             # calculate stddev
238 5         12 foreach my $attval (keys %{$self->{stat_attributes}{$att}}) {
  5         29  
239 46         65 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         84  
240             $m->{real_stat}{$att}{$label}{stddev} +=
241             ($attval - $m->{real_stat}{$att}{$label}{mean})**2 *
242 53         175 $self->{stat_attributes}{$att}{$attval}{$label};
243             }
244             }
245 5         12 foreach my $label (keys %{$m->{real_stat}{$att}}) {
  5         26  
246             $m->{real_stat}{$att}{$label}{stddev} =
247             sqrt($m->{real_stat}{$att}{$label}{stddev} /
248 10         38 ($m->{real_stat}{$att}{$label}{count}-1)
249             );
250             }
251             } # foreach real attribute
252             } # end of sub train
253              
254             sub predict {
255 13     13 1 24136 my ($self, %params) = @_;
256 13 50       67 my $newattrs = $params{attributes} or die "Missing 'attributes' parameter for predict()";
257 13         96 my $m = $self->{model}; # For convenience
258            
259 13         25 my %scores;
260 13         25 my @labels = @{ $self->{labels} };
  13         52  
261 13         114 $scores{$_} = $m->{labelprob}{$_} foreach (@labels);
262 13         26 foreach my $att (keys(%{ $newattrs })) {
  13         56  
263 41 50       121 if (!defined($self->{attribute_type}{$att})) { die "Unknown attribute: `$att'" }
  0         0  
264 41 100       107 next if $self->{attribute_type}{$att} eq 'real';
265 36 50       90 die unless exists($self->{stat_attributes}{$att});
266 36         115 my $attval = $newattrs->{$att};
267             die "Unknown value `$attval' for attribute `$att'."
268             unless exists($self->{stat_attributes}{$att}{$attval}) or
269 36 0 33     97 exists($self->{smoothing}{$att});
270 36         64 foreach my $label (@labels) {
271 72 100 66     409 if (exists($m->{condprob}{$att}{$attval}) and
    100 66        
272             exists($m->{condprob}{$att}{$attval}{$label}) and
273             $m->{condprob}{$att}{$attval}{$label} > 0 ) {
274             $scores{$label} *=
275 68         164 $m->{condprob}{$att}{$attval}{$label};
276             } elsif (exists($self->{smoothing}{$att})) {
277             $scores{$label} *=
278 3         9 $m->{condprob}{$att}{'*'}{$label};
279 1         4 } else { $scores{$label} = 0 }
280              
281             }
282             }
283              
284 13         31 foreach my $att (keys %{$newattrs}){
  13         39  
285 41 100       102 next unless $self->{attribute_type}{$att} eq 'real';
286 5         11 my $sum=0; my %nscores;
  5         11  
287 5         13 foreach my $label (@labels) {
288 10 50       37 die unless exists $m->{real_stat}{$att}{$label}{mean};
289             $nscores{$label} =
290             0.398942280401433 / $m->{real_stat}{$att}{$label}{stddev}*
291             exp( -0.5 *
292             ( ( $newattrs->{$att} -
293             $m->{real_stat}{$att}{$label}{mean})
294             / $m->{real_stat}{$att}{$label}{stddev}
295 10         111 ) ** 2
296             );
297 10         24 $sum += $nscores{$label};
298             }
299 5 50       20 if ($sum==0) { print STDERR "Ignoring all Gaussian probabilities: all=0!\n" }
  0         0  
300             else {
301 5         12 foreach my $label (@labels) { $scores{$label} *= $nscores{$label} }
  10         27  
302             }
303             }
304              
305 13         27 my $sumPx = 0.0;
306 13         59 $sumPx += $scores{$_} foreach (keys(%scores));
307 13         52 $scores{$_} /= $sumPx foreach (keys(%scores));
308 13         60 return \%scores;
309             }
310              
311             sub print_model {
312 25     25 1 226 my $self = shift;
313 25         98 my $withcounts = '';
314 25 100 66     126 if ($#_>-1 && $_[0] eq 'with counts')
315 1         2 { shift @_; $withcounts = 1; }
  1         2  
316 25         71 my $m = $self->{model};
317 25         99 my @labels = $self->labels;
318 25         78 my $r;
319              
320             # prepare table category P(category)
321             my @lines;
322 25         64 push @lines, 'category ', '-';
323 25         140 push @lines, "$_ " foreach @labels;
324 25         93 @lines = _append_lines(@lines);
325 25         58 @lines = map { $_.='| ' } @lines;
  100         216  
326 25         85 $lines[1] = substr($lines[1],0,length($lines[1])-2).'+-';
327 25         66 $lines[0] .= "P(category) ";
328 25         109 foreach my $i (2..$#lines) {
329 50         117 my $label = $labels[$i-2];
330 50         235 $lines[$i] .= $m->{labelprob}{$label} .' ';
331 50 100       147 if ($withcounts) {
332 2         8 $lines[$i] .= "(= $self->{stat_labels}{$label} / ".
333             "$self->{numof_instances} ) ";
334             }
335             }
336 25         102 @lines = _append_lines(@lines);
337              
338 25         149 $r .= join("\n", @lines) . "\n". $lines[1]. "\n\n";
339              
340             # prepare conditional tables
341 25         102 my @attributes = sort $self->attributes;
342 25         78 foreach my $att (@attributes) {
343 71         185 @lines = ( "category ", '-' );
344 71         171 my @lines1 = ( "$att ", '-' );
345 71         237 my @lines2 = ( "P( $att | category ) ", '-' );
346 71         151 my @attvals = sort keys(%{ $m->{condprob}{$att} });
  71         339  
347 71         179 foreach my $label (@labels) {
348 142 100       341 if ( $self->{attribute_type}{$att} ne 'real' ) {
349 116         217 foreach my $attval (@attvals) {
350 274 100       644 next unless exists($m->{condprob}{$att}{$attval}{$label});
351 263         482 push @lines, "$label ";
352 263         514 push @lines1, "$attval ";
353              
354 263         511 my $line = $m->{condprob}{$att}{$attval}{$label};
355 263 100       455 if ($withcounts)
356 35         109 { $line.= ' '.$m->{condprobe}{$att}{$attval}{$label} }
357 263         554 $line .= ' ';
358 263         501 push @lines2, $line;
359             }
360             } else {
361 26         60 push @lines, "$label ";
362 26         42 push @lines1, "real ";
363             push @lines2, "Gaussian(mean=".
364             $m->{real_stat}{$att}{$label}{mean}.",stddev=".
365 26         153 $m->{real_stat}{$att}{$label}{stddev}.") ";
366             }
367 142         222 push @lines, '-'; push @lines1, '-'; push @lines2, '-';
  142         213  
  142         227  
368             }
369 71         160 @lines = _append_lines(@lines);
370 71         305 foreach my $i (0 .. $#lines)
371 573 100       1720 { $lines[$i] .= ($lines[$i]=~/-$/?'+-':'| ') . $lines1[$i] }
372 71         207 @lines = _append_lines(@lines);
373 71         179 foreach my $i (0 .. $#lines)
374 573 100       1530 { $lines[$i] .= ($lines[$i]=~/-$/?'+-':'| ') . $lines2[$i] }
375 71         186 @lines = _append_lines(@lines);
376              
377 71         503 $r .= join("\n", @lines). "\n\n";
378             }
379              
380 25         209 return $r;
381             }
382              
383             sub _append_lines {
384 263     263   662 my @l = @_;
385 263         471 my $m = 0;
386 263 100       432 foreach (@l) { $m = length($_) if length($_) > $m }
  1919         3420  
387             @l = map
388 263         449 { while (length($_) < $m) { $_.=substr($_,length($_)-1) }; $_ }
  1919         3190  
  13181         24299  
  1919         3441  
389             @l;
390 263         869 return @l;
391             }
392              
393             sub labels {
394 25     25 1 46 my $self = shift;
395 25         43 return @{ $self->{labels} };
  25         119  
396             }
397              
398             sub attributes {
399 25     25 1 55 my $self = shift;
400 25         57 return keys %{ $self->{stat_attributes} };
  25         261  
401             }
402              
403             sub export_to_YAML {
404 5     5 1 3787 my $self = shift;
405 5         59 require YAML;
406 5         26 return YAML::Dump($self);
407             }
408              
409             sub export_to_YAML_file {
410 9     9 1 58730 my $self = shift;
411 9         21 my $file = shift;
412 9         66 require YAML;
413 9         30 YAML::DumpFile($file, $self);
414             }
415              
416             1;
417             __END__