File Coverage

blib/lib/AI/NaiveBayes1.pm
Criterion Covered Total %
statement 265 302 87.7
branch 64 94 68.0
condition 13 24 54.1
subroutine 20 22 90.9
pod 16 16 100.0
total 378 458 82.5


line stmt bran cond sub pod time code
1             # (c) 2003-21 Vlado Keselj https://web.cs.dal.ca/~vlado
2              
3             package AI::NaiveBayes1;
4 8     8   7025 use strict;
  8         19  
  8         327  
5             require Exporter;
6 8     8   50 use vars qw($VERSION @ISA @EXPORT @EXPORT_OK %EXPORT_TAGS);
  8         16  
  8         746  
7             @EXPORT = qw(new);
8 8     8   51 use vars qw($Version);
  8         15  
  8         449  
9             $Version = $VERSION = '2.012';
10              
11 8     8   61 use vars @EXPORT_OK;
  8         16  
  8         239  
12              
13             # non-exported package globals go here
14 8     8   44 use vars qw();
  8         13  
  8         30819  
15              
16             sub new {
17 11     11 1 13805 my $package = shift;
18 11         182 return bless {
19             attributes => [ ],
20             labels => [ ],
21             attvals => {},
22             real_stat => {},
23             numof_instances => 0,
24             stat_labels => {},
25             stat_attributes => {},
26             smoothing => {},
27             attribute_type => {},
28             }, $package;
29             }
30              
31             sub set_real {
32 4     4 1 28 my ($self, @attr) = @_;
33 4         12 foreach my $a (@attr) { $self->{attribute_type}{$a} = 'real' }
  5         23  
34             }
35              
36             sub import_from_YAML {
37 5     5 1 170184 my $package = shift;
38 5         16 my $yaml = shift;
39 5         29 my $self = YAML::Load($yaml);
40 5         278619 return bless $self, $package;
41             }
42              
43             sub import_from_YAML_file {
44 9     9 1 341656 my $package = shift;
45 9         33 my $yamlf = shift;
46 9         44 my $self = YAML::LoadFile($yamlf);
47 9         497223 return bless $self, $package;
48             }
49              
50             # assume that the last header count means counts
51             # after optionally removing counts, the last header is label
52             sub add_table {
53 3     3 1 22 my $self = shift;
54 3         7 my @atts = (); my $lbl=''; my $cnt = '';
  3         7  
  3         7  
55 3         13 while (@_) {
56 3         8 my $table = shift;
57 3 50       26 if ($table =~ /^(.*)\n[ \t]*-+\n/) {
58 3         11 my $a = $1; $table = $';
  3         11  
59 3         11 $a =~ s/^\s+//; $a =~ s/\s+$//;
  3         16  
60 3 50       22 if ($a =~ /\s*\bcount\s*$/) {
61 3         8 $a=$`; $cnt=1; } else { $cnt='' }
  3         7  
  0         0  
62 3         18 @atts = split(/\s+/, $a);
63 3         8 $lbl = pop @atts;
64             }
65 3         22 while ($table ne '') {
66 43 50       145 $table =~ /^(.*)\n?/ or die;
67 43         86 my $r=$1; $table = $';
  43         78  
68 43         93 $r =~ s/^\s+//; $r=~ s/\s+$//;
  43         133  
69 43 100       91 if ($r =~ /^-+$/) { next }
  2         13  
70 41         158 my @v = split(/\s+/, $r);
71 41 50       107 die "values (#=$#v): {@v}\natts (#=$#atts): @atts, lbl=$lbl,\n".
    50          
72             "count: $cnt\n" unless $#v-($cnt?2:1) == $#atts;
73 41         70 my %av=(); my @a = @atts;
  41         90  
74 41         77 while (@a) { $av{shift @a} = shift(@v) }
  144         304  
75 41 50       170 $self->add_instances(attributes=>\%av,
76             label=>"$lbl=$v[0]",
77             cases=>($cnt?$v[1]:1) );
78             }
79             }
80             } # end of add_table
81              
82             # Simplified; not generally compatible.
83             # Assume that the last header is label. The first row contains
84             # attribute names.
85             sub add_csv_file {
86 0     0 1 0 my $self = shift; my $fn = shift; local *F;
  0         0  
  0         0  
87 0 0       0 open(F,$fn) or die "Cannot open CSV file `$fn': $!";
88 0         0 local $_ = ; my @atts = (); my $lbl=''; my $cnt = '';
  0         0  
  0         0  
  0         0  
89 0         0 chomp; @atts = split(/\s*,\s*/, $_); $lbl = pop @atts;
  0         0  
  0         0  
90 0         0 while () {
91 0         0 chomp; my @v = split(/\s*,\s*/, $_);
  0         0  
92 0 0       0 die "values (#=$#v): {@v}\natts (#=$#atts): @atts, lbl=$lbl,\n".
    0          
93             "count: $cnt\n" unless $#v-($cnt?2:1) == $#atts;
94 0         0 my %av=(); my @a = @atts;
  0         0  
95 0         0 while (@a) { $av{shift @a} = shift(@v) }
  0         0  
96 0 0       0 $self->add_instances(attributes=>\%av,
97             label=>"$lbl=$v[0]",
98             cases=>($cnt?$v[1]:1) );
99             }
100 0         0 close(F);
101             } # end of add_csv_file
102              
103             sub drop_attributes {
104 0     0 1 0 my $self = shift;
105 0         0 foreach my $a (@_) {
106 0         0 my @tmp = grep { $a ne $_ } @{ $self->{attributes} };
  0         0  
  0         0  
107 0         0 $self->{attributes} = \@tmp;
108 0         0 delete($self->{attvals}{$a});
109 0         0 delete($self->{stat_attributes}{$a});
110 0         0 delete($self->{attribute_type}{$a});
111 0         0 delete($self->{real_stat}{$a});
112 0         0 delete($self->{smoothing}{$a});
113             }
114             } # end of drop_attributes
115              
116             sub add_instances {
117 147     147 1 583 my ($self, %params) = @_;
118 147         245 for ('attributes', 'label', 'cases') {
119 441 50       937 die "Missing required '$_' parameter" unless exists $params{$_};
120             }
121              
122 147 100       183 if (scalar(keys(%{ $self->{stat_attributes} })) == 0) {
  147         418  
123 11         23 foreach my $a (keys(%{$params{attributes}})) {
  11         41  
124 31         67 $self->{stat_attributes}{$a} = {};
125 31         45 push @{ $self->{attributes} }, $a;
  31         64  
126 31         70 $self->{attvals}{$a} = [ ];
127 31 100       100 $self->{attribute_type}{$a} = 'nominal' unless defined($self->{attribute_type}{$a});
128             }
129             } else {
130 136         178 foreach my $a (keys(%{$self->{stat_attributes}}))
  136         303  
131             { die "attribute not given in instance: $a"
132 421 50       749 unless exists($params{attributes}{$a}) }
133             }
134              
135 147         267 $self->{numof_instances} += $params{cases};
136              
137 22         50 push @{ $self->{labels} }, $params{label} unless
138 147 100       342 exists $self->{stat_labels}->{$params{label}};
139              
140 147         239 $self->{stat_labels}{$params{label}} += $params{cases};
141              
142 147         204 foreach my $a (keys(%{$self->{stat_attributes}})) {
  147         282  
143 452 50       802 if ( not exists($params{attributes}{$a}) )
144 0         0 { die "attribute $a not given" }
145 452         622 my $attval = $params{attributes}{$a};
146 452 100       793 if (not exists($self->{stat_attributes}{$a}{$attval})) {
147 110         143 push @{ $self->{attvals}{$a} }, $attval;
  110         225  
148 110         235 $self->{stat_attributes}{$a}{$attval} = {};
149             }
150 452         1101 $self->{stat_attributes}{$a}{$attval}{$params{label}} += $params{cases};
151             }
152             }
153              
154             sub add_instance {
155 68     68 1 489 my ($self, %params) = @_; $params{cases} = 1;
  68         103  
156 68         157 $self->add_instances(%params);
157             }
158              
159             sub train {
160 11     11 1 62 my $self = shift;
161 11         29 my $m = $self->{model} = {};
162            
163 11         25 $m->{labelprob} = {};
164 11         20 foreach my $label (keys(%{$self->{stat_labels}}))
  11         38  
165             { $m->{labelprob}{$label} = $self->{stat_labels}{$label} /
166 22         73 $self->{numof_instances} }
167              
168 11         33 $m->{condprob} = {};
169 11         35 $m->{condprobe} = {};
170 11         33 foreach my $att (keys(%{$self->{stat_attributes}})) {
  11         42  
171 31 100       93 next if $self->{attribute_type}{$att} eq 'real';
172 26         57 $m->{condprob}{$att} = {};
173 26         62 $m->{condprobe}{$att} = {};
174 26         43 foreach my $label (keys(%{$self->{stat_labels}})) {
  26         66  
175 52         70 my $total = 0; my @attvals = ();
  52         84  
176 52         81 foreach my $attval (keys(%{$self->{stat_attributes}{$att}})) {
  52         139  
177             next unless
178             exists($self->{stat_attributes}{$att}{$attval}{$label}) and
179 128 100 66     446 $self->{stat_attributes}{$att}{$attval}{$label} > 0;
180 121         216 push @attvals, $attval;
181             $m->{condprob}{$att}{$attval} = {} unless
182 121 100       268 exists( $m->{condprob}{$att}{$attval} );
183             $m->{condprob}{$att}{$attval}{$label} =
184 121         240 $self->{stat_attributes}{$att}{$attval}{$label};
185             $m->{condprobe}{$att}{$attval} = {} unless
186 121 50       237 exists( $m->{condprob}{$att}{$attval} );
187             $m->{condprobe}{$att}{$attval}{$label} =
188 121         279 $self->{stat_attributes}{$att}{$attval}{$label};
189 121         220 $total += $m->{condprob}{$att}{$attval}{$label};
190             }
191 52 100 66     170 if (exists($self->{smoothing}{$att}) and
192             $self->{smoothing}{$att} =~ /^unseen count=/) {
193 6 50       21 my $uc = $'; $uc = 0.5 if $uc <= 0;
  6         18  
194 6 100       14 if(! exists($m->{condprob}{$att}{'*'}) ) {
195 3         5 $m->{condprob}{$att}{'*'} = {};
196 3         5 $m->{condprobe}{$att}{'*'} = {};
197             }
198 6         12 $m->{condprob}{$att}{'*'}{$label} = $uc;
199 6         9 $total += $uc;
200 6 50       11 if (grep {$_ eq '*'} @attvals) { die }
  24         56  
  0         0  
201 6         9 push @attvals, '*';
202             }
203 52         89 foreach my $attval (@attvals) {
204 127         380 $m->{condprobe}{$att}{$attval}{$label} =
205             "(= $m->{condprob}{$att}{$attval}{$label} / $total)";
206 127         280 $m->{condprob}{$att}{$attval}{$label} /= $total;
207             }
208             }
209             }
210              
211             # For real-valued attributes, we use Gaussian distribution
212             # let us collect statistics
213 11         21 foreach my $att (keys(%{$self->{stat_attributes}})) {
  11         33  
214 31 100       85 next unless $self->{attribute_type}{$att} eq 'real';
215             print STDERR "Smoothing ignored for real attribute $att!\n" if
216 5 0 33     17 defined($self->{smoothing}{att}) and $self->{smoothing}{att};
217 5         15 $m->{real_stat}->{$att} = {};
218 5         8 foreach my $attval (keys %{$self->{stat_attributes}{$att}}){
  5         22  
219 46         61 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         122  
220             $m->{real_stat}{$att}{$label}{sum}
221 53         128 += $attval * $self->{stat_attributes}{$att}{$attval}{$label};
222              
223             $m->{real_stat}{$att}{$label}{count}
224 53         106 += $self->{stat_attributes}{$att}{$attval}{$label};
225             }
226 46         72 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         89  
227             next if
228             !defined($m->{real_stat}{$att}{$label}{count}) ||
229 53 50 33     199 $m->{real_stat}{$att}{$label}{count} == 0;
230              
231             $m->{real_stat}{$att}{$label}{mean} =
232             $m->{real_stat}{$att}{$label}{sum} /
233 53         130 $m->{real_stat}{$att}{$label}{count};
234             }
235             }
236              
237             # calculate stddev
238 5         13 foreach my $attval (keys %{$self->{stat_attributes}{$att}}) {
  5         22  
239 46         60 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         88  
240             $m->{real_stat}{$att}{$label}{stddev} +=
241             ($attval - $m->{real_stat}{$att}{$label}{mean})**2 *
242 53         181 $self->{stat_attributes}{$att}{$attval}{$label};
243             }
244             }
245 5         11 foreach my $label (keys %{$m->{real_stat}{$att}}) {
  5         25  
246             $m->{real_stat}{$att}{$label}{stddev} =
247             sqrt($m->{real_stat}{$att}{$label}{stddev} /
248 10         53 ($m->{real_stat}{$att}{$label}{count}-1)
249             );
250             }
251             } # foreach real attribute
252             } # end of sub train
253              
254             sub predict {
255 13     13 1 23245 my ($self, %params) = @_;
256 13 50       79 my $newattrs = $params{attributes} or die "Missing 'attributes' parameter for predict()";
257 13         93 my $m = $self->{model}; # For convenience
258            
259 13         27 my %scores;
260 13         27 my @labels = @{ $self->{labels} };
  13         51  
261 13         81 $scores{$_} = $m->{labelprob}{$_} foreach (@labels);
262 13         30 foreach my $att (keys(%{ $newattrs })) {
  13         54  
263 41 50       125 if (!defined($self->{attribute_type}{$att})) { die "Unknown attribute: `$att'" }
  0         0  
264 41 100       123 next if $self->{attribute_type}{$att} eq 'real';
265 36 50       90 die unless exists($self->{stat_attributes}{$att});
266 36         114 my $attval = $newattrs->{$att};
267             die "Unknown value `$attval' for attribute `$att'."
268             unless exists($self->{stat_attributes}{$att}{$attval}) or
269 36 0 33     103 exists($self->{smoothing}{$att});
270 36         66 foreach my $label (@labels) {
271 72 100 66     409 if (exists($m->{condprob}{$att}{$attval}) and
    100 66        
272             exists($m->{condprob}{$att}{$attval}{$label}) and
273             $m->{condprob}{$att}{$attval}{$label} > 0 ) {
274             $scores{$label} *=
275 68         155 $m->{condprob}{$att}{$attval}{$label};
276             } elsif (exists($self->{smoothing}{$att})) {
277             $scores{$label} *=
278 3         9 $m->{condprob}{$att}{'*'}{$label};
279 1         3 } else { $scores{$label} = 0 }
280              
281             }
282             }
283              
284 13         33 foreach my $att (keys %{$newattrs}){
  13         35  
285 41 100       103 next unless $self->{attribute_type}{$att} eq 'real';
286 5         12 my $sum=0; my %nscores;
  5         11  
287 5         15 foreach my $label (@labels) {
288 10 50       35 die unless exists $m->{real_stat}{$att}{$label}{mean};
289             $nscores{$label} =
290             0.398942280401433 / $m->{real_stat}{$att}{$label}{stddev}*
291             exp( -0.5 *
292             ( ( $newattrs->{$att} -
293             $m->{real_stat}{$att}{$label}{mean})
294             / $m->{real_stat}{$att}{$label}{stddev}
295 10         122 ) ** 2
296             );
297 10         26 $sum += $nscores{$label};
298             }
299 5 50       21 if ($sum==0) { print STDERR "Ignoring all Gaussian probabilities: all=0!\n" }
  0         0  
300             else {
301 5         13 foreach my $label (@labels) { $scores{$label} *= $nscores{$label} }
  10         24  
302             }
303             }
304              
305 13         32 my $sumPx = 0.0;
306 13         61 $sumPx += $scores{$_} foreach (keys(%scores));
307 13         50 $scores{$_} /= $sumPx foreach (keys(%scores));
308 13         61 return \%scores;
309             }
310              
311             sub print_model {
312 25     25 1 221 my $self = shift;
313 25         62 my $withcounts = '';
314 25 100 66     147 if ($#_>-1 && $_[0] eq 'with counts')
315 1         1 { shift @_; $withcounts = 1; }
  1         3  
316 25         71 my $m = $self->{model};
317 25         85 my @labels = $self->labels;
318 25         76 my $r;
319              
320             # prepare table category P(category)
321             my @lines;
322 25         67 push @lines, 'category ', '-';
323 25         151 push @lines, "$_ " foreach @labels;
324 25         92 @lines = _append_lines(@lines);
325 25         58 @lines = map { $_.='| ' } @lines;
  100         225  
326 25         101 $lines[1] = substr($lines[1],0,length($lines[1])-2).'+-';
327 25         60 $lines[0] .= "P(category) ";
328 25         98 foreach my $i (2..$#lines) {
329 50         99 my $label = $labels[$i-2];
330 50         254 $lines[$i] .= $m->{labelprob}{$label} .' ';
331 50 100       138 if ($withcounts) {
332 2         7 $lines[$i] .= "(= $self->{stat_labels}{$label} / ".
333             "$self->{numof_instances} ) ";
334             }
335             }
336 25         110 @lines = _append_lines(@lines);
337              
338 25         155 $r .= join("\n", @lines) . "\n". $lines[1]. "\n\n";
339              
340             # prepare conditional tables
341 25         98 my @attributes = sort $self->attributes;
342 25         85 foreach my $att (@attributes) {
343 71         185 @lines = ( "category ", '-' );
344 71         185 my @lines1 = ( "$att ", '-' );
345 71         188 my @lines2 = ( "P( $att | category ) ", '-' );
346 71         149 my @attvals = sort keys(%{ $m->{condprob}{$att} });
  71         341  
347 71         180 foreach my $label (@labels) {
348 142 100       347 if ( $self->{attribute_type}{$att} ne 'real' ) {
349 116         253 foreach my $attval (@attvals) {
350 274 100       683 next unless exists($m->{condprob}{$att}{$attval}{$label});
351 263         483 push @lines, "$label ";
352 263         471 push @lines1, "$attval ";
353              
354 263         466 my $line = $m->{condprob}{$att}{$attval}{$label};
355 263 100       469 if ($withcounts)
356 35         114 { $line.= ' '.$m->{condprobe}{$att}{$attval}{$label} }
357 263         546 $line .= ' ';
358 263         498 push @lines2, $line;
359             }
360             } else {
361 26         74 push @lines, "$label ";
362 26         43 push @lines1, "real ";
363             push @lines2, "Gaussian(mean=".
364             $m->{real_stat}{$att}{$label}{mean}.",stddev=".
365 26         162 $m->{real_stat}{$att}{$label}{stddev}.") ";
366             }
367 142         224 push @lines, '-'; push @lines1, '-'; push @lines2, '-';
  142         198  
  142         260  
368             }
369 71         200 @lines = _append_lines(@lines);
370 71         192 foreach my $i (0 .. $#lines)
371 573 100       1694 { $lines[$i] .= ($lines[$i]=~/-$/?'+-':'| ') . $lines1[$i] }
372 71         196 @lines = _append_lines(@lines);
373 71         231 foreach my $i (0 .. $#lines)
374 573 100       1537 { $lines[$i] .= ($lines[$i]=~/-$/?'+-':'| ') . $lines2[$i] }
375 71         197 @lines = _append_lines(@lines);
376              
377 71         518 $r .= join("\n", @lines). "\n\n";
378             }
379              
380 25         215 return $r;
381             }
382              
383             sub _append_lines {
384 263     263   701 my @l = @_;
385 263         371 my $m = 0;
386 263 100       417 foreach (@l) { $m = length($_) if length($_) > $m }
  1919         3376  
387             @l = map
388 263         432 { while (length($_) < $m) { $_.=substr($_,length($_)-1) }; $_ }
  1919         3321  
  13181         24858  
  1919         3458  
389             @l;
390 263         884 return @l;
391             }
392              
393             sub labels {
394 25     25 1 54 my $self = shift;
395 25         42 return @{ $self->{labels} };
  25         103  
396             }
397              
398             sub attributes {
399 25     25 1 49 my $self = shift;
400 25         58 return keys %{ $self->{stat_attributes} };
  25         245  
401             }
402              
403             sub export_to_YAML {
404 5     5 1 3628 my $self = shift;
405 5         54 require YAML;
406 5         28 return YAML::Dump($self);
407             }
408              
409             sub export_to_YAML_file {
410 9     9 1 62352 my $self = shift;
411 9         21 my $file = shift;
412 9         78 require YAML;
413 9         39 YAML::DumpFile($file, $self);
414             }
415              
416             1;
417             __END__