File Coverage

blib/lib/AI/NaiveBayes1.pm
Criterion Covered Total %
statement 265 302 87.7
branch 64 94 68.0
condition 13 24 54.1
subroutine 20 22 90.9
pod 16 16 100.0
total 378 458 82.5


line stmt bran cond sub pod time code
1             # (c) 2003-21 Vlado Keselj http://web.cs.dal.ca/~vlado
2              
3             package AI::NaiveBayes1;
4 8     8   6320 use strict;
  8         15  
  8         307  
5             require Exporter;
6 8     8   35 use vars qw($VERSION @ISA @EXPORT @EXPORT_OK %EXPORT_TAGS);
  8         14  
  8         692  
7             @EXPORT = qw(new);
8 8     8   41 use vars qw($Version);
  8         14  
  8         414  
9             $Version = $VERSION = '2.010';
10              
11 8     8   41 use vars @EXPORT_OK;
  8         15  
  8         205  
12              
13             # non-exported package globals go here
14 8     8   36 use vars qw();
  8         12  
  8         25674  
15              
16             sub new {
17 11     11 1 11924 my $package = shift;
18 11         180 return bless {
19             attributes => [ ],
20             labels => [ ],
21             attvals => {},
22             real_stat => {},
23             numof_instances => 0,
24             stat_labels => {},
25             stat_attributes => {},
26             smoothing => {},
27             attribute_type => {},
28             }, $package;
29             }
30              
31             sub set_real {
32 4     4 1 21 my ($self, @attr) = @_;
33 4         9 foreach my $a (@attr) { $self->{attribute_type}{$a} = 'real' }
  5         17  
34             }
35              
36             sub import_from_YAML {
37 5     5 1 140529 my $package = shift;
38 5         16 my $yaml = shift;
39 5         33 my $self = YAML::Load($yaml);
40 5         233783 return bless $self, $package;
41             }
42              
43             sub import_from_YAML_file {
44 9     9 1 288936 my $package = shift;
45 9         25 my $yamlf = shift;
46 9         38 my $self = YAML::LoadFile($yamlf);
47 9         418763 return bless $self, $package;
48             }
49              
50             # assume that the last header count means counts
51             # after optionally removing counts, the last header is label
52             sub add_table {
53 3     3 1 17 my $self = shift;
54 3         7 my @atts = (); my $lbl=''; my $cnt = '';
  3         6  
  3         5  
55 3         10 while (@_) {
56 3         6 my $table = shift;
57 3 50       31 if ($table =~ /^(.*)\n[ \t]*-+\n/) {
58 3         10 my $a = $1; $table = $';
  3         9  
59 3         12 $a =~ s/^\s+//; $a =~ s/\s+$//;
  3         13  
60 3 50       23 if ($a =~ /\s*\bcount\s*$/) {
61 3         7 $a=$`; $cnt=1; } else { $cnt='' }
  3         5  
  0         0  
62 3         16 @atts = split(/\s+/, $a);
63 3         8 $lbl = pop @atts;
64             }
65 3         18 while ($table ne '') {
66 43 50       118 $table =~ /^(.*)\n?/ or die;
67 43         75 my $r=$1; $table = $';
  43         63  
68 43         78 $r =~ s/^\s+//; $r=~ s/\s+$//;
  43         103  
69 43 100       69 if ($r =~ /^-+$/) { next }
  2         11  
70 41         142 my @v = split(/\s+/, $r);
71 41 50       93 die "values (#=$#v): {@v}\natts (#=$#atts): @atts, lbl=$lbl,\n".
    50          
72             "count: $cnt\n" unless $#v-($cnt?2:1) == $#atts;
73 41         60 my %av=(); my @a = @atts;
  41         69  
74 41         71 while (@a) { $av{shift @a} = shift(@v) }
  144         295  
75 41 50       130 $self->add_instances(attributes=>\%av,
76             label=>"$lbl=$v[0]",
77             cases=>($cnt?$v[1]:1) );
78             }
79             }
80             } # end of add_table
81              
82             # Simplified; not generally compatible.
83             # Assume that the last header is label. The first row contains
84             # attribute names.
85             sub add_csv_file {
86 0     0 1 0 my $self = shift; my $fn = shift; local *F;
  0         0  
  0         0  
87 0 0       0 open(F,$fn) or die "Cannot open CSV file `$fn': $!";
88 0         0 local $_ = ; my @atts = (); my $lbl=''; my $cnt = '';
  0         0  
  0         0  
  0         0  
89 0         0 chomp; @atts = split(/\s*,\s*/, $_); $lbl = pop @atts;
  0         0  
  0         0  
90 0         0 while () {
91 0         0 chomp; my @v = split(/\s*,\s*/, $_);
  0         0  
92 0 0       0 die "values (#=$#v): {@v}\natts (#=$#atts): @atts, lbl=$lbl,\n".
    0          
93             "count: $cnt\n" unless $#v-($cnt?2:1) == $#atts;
94 0         0 my %av=(); my @a = @atts;
  0         0  
95 0         0 while (@a) { $av{shift @a} = shift(@v) }
  0         0  
96 0 0       0 $self->add_instances(attributes=>\%av,
97             label=>"$lbl=$v[0]",
98             cases=>($cnt?$v[1]:1) );
99             }
100 0         0 close(F);
101             } # end of add_csv_file
102              
103             sub drop_attributes {
104 0     0 1 0 my $self = shift;
105 0         0 foreach my $a (@_) {
106 0         0 my @tmp = grep { $a ne $_ } @{ $self->{attributes} };
  0         0  
  0         0  
107 0         0 $self->{attributes} = \@tmp;
108 0         0 delete($self->{attvals}{$a});
109 0         0 delete($self->{stat_attributes}{$a});
110 0         0 delete($self->{attribute_type}{$a});
111 0         0 delete($self->{real_stat}{$a});
112 0         0 delete($self->{smoothing}{$a});
113             }
114             } # end of drop_attributes
115              
116             sub add_instances {
117 147     147 1 453 my ($self, %params) = @_;
118 147         209 for ('attributes', 'label', 'cases') {
119 441 50       730 die "Missing required '$_' parameter" unless exists $params{$_};
120             }
121              
122 147 100       168 if (scalar(keys(%{ $self->{stat_attributes} })) == 0) {
  147         312  
123 11         22 foreach my $a (keys(%{$params{attributes}})) {
  11         40  
124 31         55 $self->{stat_attributes}{$a} = {};
125 31         39 push @{ $self->{attributes} }, $a;
  31         50  
126 31         62 $self->{attvals}{$a} = [ ];
127 31 100       82 $self->{attribute_type}{$a} = 'nominal' unless defined($self->{attribute_type}{$a});
128             }
129             } else {
130 136         165 foreach my $a (keys(%{$self->{stat_attributes}}))
  136         218  
131             { die "attribute not given in instance: $a"
132 421 50       617 unless exists($params{attributes}{$a}) }
133             }
134              
135 147         224 $self->{numof_instances} += $params{cases};
136              
137 22         43 push @{ $self->{labels} }, $params{label} unless
138 147 100       286 exists $self->{stat_labels}->{$params{label}};
139              
140 147         208 $self->{stat_labels}{$params{label}} += $params{cases};
141              
142 147         179 foreach my $a (keys(%{$self->{stat_attributes}})) {
  147         287  
143 452 50       641 if ( not exists($params{attributes}{$a}) )
144 0         0 { die "attribute $a not given" }
145 452         538 my $attval = $params{attributes}{$a};
146 452 100       669 if (not exists($self->{stat_attributes}{$a}{$attval})) {
147 110         108 push @{ $self->{attvals}{$a} }, $attval;
  110         187  
148 110         215 $self->{stat_attributes}{$a}{$attval} = {};
149             }
150 452         863 $self->{stat_attributes}{$a}{$attval}{$params{label}} += $params{cases};
151             }
152             }
153              
154             sub add_instance {
155 68     68 1 444 my ($self, %params) = @_; $params{cases} = 1;
  68         81  
156 68         129 $self->add_instances(%params);
157             }
158              
159             sub train {
160 11     11 1 60 my $self = shift;
161 11         25 my $m = $self->{model} = {};
162            
163 11         24 $m->{labelprob} = {};
164 11         19 foreach my $label (keys(%{$self->{stat_labels}}))
  11         34  
165             { $m->{labelprob}{$label} = $self->{stat_labels}{$label} /
166 22         67 $self->{numof_instances} }
167              
168 11         24 $m->{condprob} = {};
169 11         38 $m->{condprobe} = {};
170 11         27 foreach my $att (keys(%{$self->{stat_attributes}})) {
  11         40  
171 31 100       82 next if $self->{attribute_type}{$att} eq 'real';
172 26         50 $m->{condprob}{$att} = {};
173 26         55 $m->{condprobe}{$att} = {};
174 26         37 foreach my $label (keys(%{$self->{stat_labels}})) {
  26         53  
175 52         71 my $total = 0; my @attvals = ();
  52         86  
176 52         65 foreach my $attval (keys(%{$self->{stat_attributes}{$att}})) {
  52         116  
177             next unless
178             exists($self->{stat_attributes}{$att}{$attval}{$label}) and
179 128 100 66     386 $self->{stat_attributes}{$att}{$attval}{$label} > 0;
180 121         158 push @attvals, $attval;
181             $m->{condprob}{$att}{$attval} = {} unless
182 121 100       212 exists( $m->{condprob}{$att}{$attval} );
183             $m->{condprob}{$att}{$attval}{$label} =
184 121         204 $self->{stat_attributes}{$att}{$attval}{$label};
185             $m->{condprobe}{$att}{$attval} = {} unless
186 121 50       187 exists( $m->{condprob}{$att}{$attval} );
187             $m->{condprobe}{$att}{$attval}{$label} =
188 121         210 $self->{stat_attributes}{$att}{$attval}{$label};
189 121         179 $total += $m->{condprob}{$att}{$attval}{$label};
190             }
191 52 100 66     154 if (exists($self->{smoothing}{$att}) and
192             $self->{smoothing}{$att} =~ /^unseen count=/) {
193 6 50       11 my $uc = $'; $uc = 0.5 if $uc <= 0;
  6         16  
194 6 100       10 if(! exists($m->{condprob}{$att}{'*'}) ) {
195 3         12 $m->{condprob}{$att}{'*'} = {};
196 3         4 $m->{condprobe}{$att}{'*'} = {};
197             }
198 6         11 $m->{condprob}{$att}{'*'}{$label} = $uc;
199 6         7 $total += $uc;
200 6 50       8 if (grep {$_ eq '*'} @attvals) { die }
  24         41  
  0         0  
201 6         9 push @attvals, '*';
202             }
203 52         85 foreach my $attval (@attvals) {
204 127         328 $m->{condprobe}{$att}{$attval}{$label} =
205             "(= $m->{condprob}{$att}{$attval}{$label} / $total)";
206 127         230 $m->{condprob}{$att}{$attval}{$label} /= $total;
207             }
208             }
209             }
210              
211             # For real-valued attributes, we use Gaussian distribution
212             # let us collect statistics
213 11         20 foreach my $att (keys(%{$self->{stat_attributes}})) {
  11         29  
214 31 100       79 next unless $self->{attribute_type}{$att} eq 'real';
215             print STDERR "Smoothing ignored for real attribute $att!\n" if
216 5 0 33     14 defined($self->{smoothing}{att}) and $self->{smoothing}{att};
217 5         11 $m->{real_stat}->{$att} = {};
218 5         8 foreach my $attval (keys %{$self->{stat_attributes}{$att}}){
  5         18  
219 46         52 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         94  
220             $m->{real_stat}{$att}{$label}{sum}
221 53         111 += $attval * $self->{stat_attributes}{$att}{$attval}{$label};
222              
223             $m->{real_stat}{$att}{$label}{count}
224 53         76 += $self->{stat_attributes}{$att}{$attval}{$label};
225             }
226 46         59 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         85  
227             next if
228             !defined($m->{real_stat}{$att}{$label}{count}) ||
229 53 50 33     160 $m->{real_stat}{$att}{$label}{count} == 0;
230              
231             $m->{real_stat}{$att}{$label}{mean} =
232             $m->{real_stat}{$att}{$label}{sum} /
233 53         117 $m->{real_stat}{$att}{$label}{count};
234             }
235             }
236              
237             # calculate stddev
238 5         10 foreach my $attval (keys %{$self->{stat_attributes}{$att}}) {
  5         16  
239 46         49 foreach my $label (keys %{$self->{stat_attributes}{$att}{$attval}}){
  46         73  
240             $m->{real_stat}{$att}{$label}{stddev} +=
241             ($attval - $m->{real_stat}{$att}{$label}{mean})**2 *
242 53         156 $self->{stat_attributes}{$att}{$attval}{$label};
243             }
244             }
245 5         8 foreach my $label (keys %{$m->{real_stat}{$att}}) {
  5         19  
246             $m->{real_stat}{$att}{$label}{stddev} =
247             sqrt($m->{real_stat}{$att}{$label}{stddev} /
248 10         40 ($m->{real_stat}{$att}{$label}{count}-1)
249             );
250             }
251             } # foreach real attribute
252             } # end of sub train
253              
254             sub predict {
255 13     13 1 19845 my ($self, %params) = @_;
256 13 50       58 my $newattrs = $params{attributes} or die "Missing 'attributes' parameter for predict()";
257 13         88 my $m = $self->{model}; # For convenience
258            
259 13         22 my %scores;
260 13         25 my @labels = @{ $self->{labels} };
  13         60  
261 13         73 $scores{$_} = $m->{labelprob}{$_} foreach (@labels);
262 13         25 foreach my $att (keys(%{ $newattrs })) {
  13         47  
263 41 50       107 if (!defined($self->{attribute_type}{$att})) { die "Unknown attribute: `$att'" }
  0         0  
264 41 100       87 next if $self->{attribute_type}{$att} eq 'real';
265 36 50       77 die unless exists($self->{stat_attributes}{$att});
266 36         102 my $attval = $newattrs->{$att};
267             die "Unknown value `$attval' for attribute `$att'."
268             unless exists($self->{stat_attributes}{$att}{$attval}) or
269 36 0 33     84 exists($self->{smoothing}{$att});
270 36         56 foreach my $label (@labels) {
271 72 100 66     360 if (exists($m->{condprob}{$att}{$attval}) and
    100 66        
272             exists($m->{condprob}{$att}{$attval}{$label}) and
273             $m->{condprob}{$att}{$attval}{$label} > 0 ) {
274             $scores{$label} *=
275 68         133 $m->{condprob}{$att}{$attval}{$label};
276             } elsif (exists($self->{smoothing}{$att})) {
277             $scores{$label} *=
278 3         8 $m->{condprob}{$att}{'*'}{$label};
279 1         3 } else { $scores{$label} = 0 }
280              
281             }
282             }
283              
284 13         28 foreach my $att (keys %{$newattrs}){
  13         35  
285 41 100       93 next unless $self->{attribute_type}{$att} eq 'real';
286 5         9 my $sum=0; my %nscores;
  5         11  
287 5         13 foreach my $label (@labels) {
288 10 50       31 die unless exists $m->{real_stat}{$att}{$label}{mean};
289             $nscores{$label} =
290             0.398942280401433 / $m->{real_stat}{$att}{$label}{stddev}*
291             exp( -0.5 *
292             ( ( $newattrs->{$att} -
293             $m->{real_stat}{$att}{$label}{mean})
294             / $m->{real_stat}{$att}{$label}{stddev}
295 10         108 ) ** 2
296             );
297 10         21 $sum += $nscores{$label};
298             }
299 5 50       16 if ($sum==0) { print STDERR "Ignoring all Gaussian probabilities: all=0!\n" }
  0         0  
300             else {
301 5         14 foreach my $label (@labels) { $scores{$label} *= $nscores{$label} }
  10         21  
302             }
303             }
304              
305 13         28 my $sumPx = 0.0;
306 13         51 $sumPx += $scores{$_} foreach (keys(%scores));
307 13         43 $scores{$_} /= $sumPx foreach (keys(%scores));
308 13         50 return \%scores;
309             }
310              
311             sub print_model {
312 25     25 1 271 my $self = shift;
313 25         60 my $withcounts = '';
314 25 100 66     141 if ($#_>-1 && $_[0] eq 'with counts')
315 1         2 { shift @_; $withcounts = 1; }
  1         1  
316 25         63 my $m = $self->{model};
317 25         92 my @labels = $self->labels;
318 25         106 my $r;
319              
320             # prepare table category P(category)
321             my @lines;
322 25         71 push @lines, 'category ', '-';
323 25         121 push @lines, "$_ " foreach @labels;
324 25         89 @lines = _append_lines(@lines);
325 25         54 @lines = map { $_.='| ' } @lines;
  100         182  
326 25         84 $lines[1] = substr($lines[1],0,length($lines[1])-2).'+-';
327 25         52 $lines[0] .= "P(category) ";
328 25         95 foreach my $i (2..$#lines) {
329 50         93 my $label = $labels[$i-2];
330 50         234 $lines[$i] .= $m->{labelprob}{$label} .' ';
331 50 100       137 if ($withcounts) {
332 2         6 $lines[$i] .= "(= $self->{stat_labels}{$label} / ".
333             "$self->{numof_instances} ) ";
334             }
335             }
336 25         98 @lines = _append_lines(@lines);
337              
338 25         176 $r .= join("\n", @lines) . "\n". $lines[1]. "\n\n";
339              
340             # prepare conditional tables
341 25         100 my @attributes = sort $self->attributes;
342 25         80 foreach my $att (@attributes) {
343 71         174 @lines = ( "category ", '-' );
344 71         145 my @lines1 = ( "$att ", '-' );
345 71         156 my @lines2 = ( "P( $att | category ) ", '-' );
346 71         125 my @attvals = sort keys(%{ $m->{condprob}{$att} });
  71         295  
347 71         145 foreach my $label (@labels) {
348 142 100       298 if ( $self->{attribute_type}{$att} ne 'real' ) {
349 116         181 foreach my $attval (@attvals) {
350 274 100       592 next unless exists($m->{condprob}{$att}{$attval}{$label});
351 263         410 push @lines, "$label ";
352 263         354 push @lines1, "$attval ";
353              
354 263         406 my $line = $m->{condprob}{$att}{$attval}{$label};
355 263 100       387 if ($withcounts)
356 35         92 { $line.= ' '.$m->{condprobe}{$att}{$attval}{$label} }
357 263         505 $line .= ' ';
358 263         404 push @lines2, $line;
359             }
360             } else {
361 26         58 push @lines, "$label ";
362 26         39 push @lines1, "real ";
363             push @lines2, "Gaussian(mean=".
364             $m->{real_stat}{$att}{$label}{mean}.",stddev=".
365 26         128 $m->{real_stat}{$att}{$label}{stddev}.") ";
366             }
367 142         198 push @lines, '-'; push @lines1, '-'; push @lines2, '-';
  142         162  
  142         202  
368             }
369 71         161 @lines = _append_lines(@lines);
370 71         148 foreach my $i (0 .. $#lines)
371 573 100       1381 { $lines[$i] .= ($lines[$i]=~/-$/?'+-':'| ') . $lines1[$i] }
372 71         186 @lines = _append_lines(@lines);
373 71         151 foreach my $i (0 .. $#lines)
374 573 100       1261 { $lines[$i] .= ($lines[$i]=~/-$/?'+-':'| ') . $lines2[$i] }
375 71         165 @lines = _append_lines(@lines);
376              
377 71         440 $r .= join("\n", @lines). "\n\n";
378             }
379              
380 25         189 return $r;
381             }
382              
383             sub _append_lines {
384 263     263   571 my @l = @_;
385 263         307 my $m = 0;
386 263 100       357 foreach (@l) { $m = length($_) if length($_) > $m }
  1919         2860  
387             @l = map
388 263         356 { while (length($_) < $m) { $_.=substr($_,length($_)-1) }; $_ }
  1919         2700  
  13181         20153  
  1919         2800  
389             @l;
390 263         707 return @l;
391             }
392              
393             sub labels {
394 25     25 1 45 my $self = shift;
395 25         83 return @{ $self->{labels} };
  25         106  
396             }
397              
398             sub attributes {
399 25     25 1 50 my $self = shift;
400 25         46 return keys %{ $self->{stat_attributes} };
  25         251  
401             }
402              
403             sub export_to_YAML {
404 5     5 1 3220 my $self = shift;
405 5         51 require YAML;
406 5         24 return YAML::Dump($self);
407             }
408              
409             sub export_to_YAML_file {
410 9     9 1 54521 my $self = shift;
411 9         23 my $file = shift;
412 9         86 require YAML;
413 9         38 YAML::DumpFile($file, $self);
414             }
415              
416             1;
417             __END__