File Coverage

blib/lib/AI/DecisionTree.pm
Criterion Covered Total %
statement 206 233 88.4
branch 65 92 70.6
condition 16 21 76.1
subroutine 31 35 88.5
pod 15 21 71.4
total 333 402 82.8


line stmt bran cond sub pod time code
1 2     2   15832 use strict;
  2         5  
  2         214  
2              
3             package AI::DecisionTree;
4             {
5             $AI::DecisionTree::VERSION = '0.11';
6             }
7              
8 2     2   906 use AI::DecisionTree::Instance;
  2         4  
  2         56  
9 2     2   12 use Carp;
  2         3  
  2         138  
10 2     2   9 use vars qw(@ISA);
  2         4  
  2         5693  
11              
12              
13             sub new {
14 7     7 1 1274 my $package = shift;
15 7         68 return bless {
16             noise_mode => 'fatal',
17             prune => 1,
18             purge => 1,
19             verbose => 0,
20             max_depth => 0,
21             @_,
22             nodes => 0,
23             instances => [],
24             name_gen => 0,
25             }, $package;
26             }
27              
28 1     1 1 979 sub nodes { $_[0]->{nodes} }
29 0     0 1 0 sub noise_mode { $_[0]->{noise_mode} }
30 1     1 1 200 sub depth { $_[0]->{depth} }
31              
32             sub add_instance {
33 626     626 1 37736 my ($self, %args) = @_;
34 626 50       1284 croak "Missing 'attributes' parameter" unless $args{attributes};
35 626 50       1168 croak "Missing 'result' parameter" unless defined $args{result};
36 626 100       1955 $args{name} = $self->{name_gen}++ unless exists $args{name};
37            
38 626         570 my @attributes;
39 626         643 while (my ($k, $v) = each %{$args{attributes}}) {
  21726         53193  
40 21100         35006 $attributes[ _hlookup($self->{attributes}, $k) ] = _hlookup($self->{attribute_values}{$k}, $v);
41             }
42 626   100     18514 $_ ||= 0 foreach @attributes;
43            
44 626         778 push @{$self->{instances}}, AI::DecisionTree::Instance->new(\@attributes, _hlookup($self->{results}, $args{result}), $args{name});
  626         1767  
45             }
46              
47             sub _hlookup {
48 42826   100 42826   66874 $_[0] ||= {}; # Autovivify as a hash
49 42826         50930 my ($hash, $key) = @_;
50 42826 100       72243 unless (exists $hash->{$key}) {
51 288         568 $hash->{$key} = 1 + keys %$hash;
52             }
53 42826         105021 return $hash->{$key};
54             }
55              
56             sub _create_lookup_hashes {
57 8     8   11 my $self = shift;
58 8         15 my $h = $self->{results};
59 8         43 $self->{results_reverse} = [ undef, sort {$h->{$a} <=> $h->{$b}} keys %$h ];
  65         121  
60            
61 8         15 foreach my $attr (keys %{$self->{attribute_values}}) {
  8         33  
62 82         109 my $h = $self->{attribute_values}{$attr};
63 82         219 $self->{attribute_values_reverse}{$attr} = [ undef, sort {$h->{$a} <=> $h->{$b}} keys %$h ];
  178         329  
64             }
65             }
66              
67             sub train {
68 7     7 1 80 my ($self, %args) = @_;
69 7 50       9 if (not @{ $self->{instances} }) {
  7         29  
70 0 0       0 croak "Training data has been purged, can't re-train" if $self->{tree};
71 0         0 croak "Must add training instances before calling train()";
72             }
73            
74 7         22 $self->_create_lookup_hashes;
75 7         31 local $self->{curr_depth} = 0;
76 7 100       25 local $self->{max_depth} = $args{max_depth} if exists $args{max_depth};
77 7         13 $self->{depth} = 0;
78 7         26 $self->{tree} = $self->_expand_node( instances => $self->{instances} );
79 6         17 $self->{total_instances} = @{$self->{instances}};
  6         16  
80            
81 6 50       29 $self->prune_tree if $self->{prune};
82 6 100       21 $self->do_purge if $self->purge;
83 6         30 return 1;
84             }
85              
86             sub do_purge {
87 3     3 1 6 my $self = shift;
88 3         5 delete @{$self}{qw(instances attribute_values attribute_values_reverse results results_reverse)};
  3         510  
89             }
90              
91             sub copy_instances {
92 1     1 1 4 my ($self, %opt) = @_;
93 1 50       4 croak "Missing 'from' parameter to copy_instances()" unless exists $opt{from};
94 1         2 my $other = $opt{from};
95 1 50       6 croak "'from' parameter is not a decision tree" unless UNIVERSAL::isa($other, __PACKAGE__);
96              
97 1         2 foreach (qw(instances attributes attribute_values results)) {
98 4         9 $self->{$_} = $other->{$_};
99             }
100 1         4 $self->_create_lookup_hashes;
101             }
102              
103             sub set_results {
104 1     1 1 3 my ($self, $hashref) = @_;
105 1         2 foreach my $instance (@{$self->{instances}}) {
  1         3  
106 2         7 my $name = $instance->name;
107 2 50       6 croak "No result given for instance '$name'" unless exists $hashref->{$name};
108 2         10 $instance->set_result( $self->{results}{ $hashref->{$name} } );
109             }
110             }
111              
112 10     10 1 218 sub instances { $_[0]->{instances} }
113              
114             sub purge {
115 6     6 1 12 my $self = shift;
116 6 50       14 $self->{purge} = shift if @_;
117 6         23 return $self->{purge};
118             }
119              
120             # Each node contains:
121             # { split_on => $attr_name,
122             # children => { $attr_value1 => $node1,
123             # $attr_value2 => $node2, ... }
124             # }
125             # or
126             # { result => $result }
127              
128             sub _expand_node {
129 169     169   303 my ($self, %args) = @_;
130 169         201 my $instances = $args{instances};
131 169 50       338 print STDERR '.' if $self->{verbose};
132            
133 169 100       931 $self->{depth} = $self->{curr_depth} if $self->{curr_depth} > $self->{depth};
134 169         343 local $self->{curr_depth} = $self->{curr_depth} + 1;
135 169         198 $self->{nodes}++;
136              
137 169         154 my %results;
138 169         434 $results{$self->_result($_)}++ foreach @$instances;
139 169         416 my @results = map {$_,$results{$_}} sort {$results{$b} <=> $results{$a}} keys %results;
  292         659  
  205         320  
140 169         518 my %node = ( distribution => \@results, instances => scalar @$instances );
141              
142 169         295 foreach (keys %results) {
143 292         573 $self->{prior_freqs}{$_} += $results{$_};
144             }
145              
146 169 100       367 if (keys(%results) == 1) {
147             # All these instances have the same result - make this node a leaf
148 106         221 $node{result} = $self->_result($instances->[0]);
149 106         548 return \%node;
150             }
151            
152             # Multiple values are present - find the best predictor attribute and split on it
153 63         124 my $best_attr = $self->best_attr($instances);
154              
155 63 100 100     398 croak "Inconsistent data, can't build tree with noise_mode='fatal'"
156             if $self->{noise_mode} eq 'fatal' and !defined $best_attr;
157              
158 62 100 100     241 if ( !defined($best_attr)
      66        
159             or $self->{max_depth} && $self->{curr_depth} > $self->{max_depth} ) {
160             # Pick the most frequent result for this leaf
161 3         10 $node{result} = (sort {$results{$b} <=> $results{$a}} keys %results)[0];
  3         12  
162 3         17 return \%node;
163             }
164            
165 59         90 $node{split_on} = $best_attr;
166            
167 59         61 my %split;
168 59         84 foreach my $i (@$instances) {
169 2254         3753 my $v = $self->_value($i, $best_attr);
170 2254 100       2190 push @{$split{ defined($v) ? $v : '' }}, $i;
  2254         5155  
171             }
172 59 50       144 die ("Something's wrong: attribute '$best_attr' didn't split ",
173 0         0 scalar @$instances, " instances into multiple buckets (@{[ keys %split ]})")
174             unless keys %split > 1;
175              
176 59         113 foreach my $value (keys %split) {
177 162         431 $node{children}{$value} = $self->_expand_node( instances => $split{$value} );
178             }
179            
180 59         411 return \%node;
181             }
182              
183             sub best_attr {
184 63     63 0 70 my ($self, $instances) = @_;
185              
186             # 0 is a perfect score, entropy(#instances) is the worst possible score
187            
188 63         1419 my ($best_score, $best_attr) = (@$instances * $self->entropy( map $_->result_int, @$instances ), undef);
189 63         240 my $all_attr = $self->{attributes};
190 63         291 foreach my $attr (keys %$all_attr) {
191              
192             # %tallies is correlation between each attr value and result
193             # %total is number of instances with each attr value
194 1917         1695 my (%totals, %tallies);
195 1917         30972 my $num_undef = AI::DecisionTree::Instance::->tally($instances, \%tallies, \%totals, $all_attr->{$attr});
196 1917 50       4021 next unless keys %totals; # Make sure at least one instance defines this attribute
197            
198 1917         1773 my $score = 0;
199 1917         4118 while (my ($opt, $vals) = each %tallies) {
200 3165         6253 $score += $totals{$opt} * $self->entropy2( $vals, $totals{$opt} );
201             }
202              
203 1917 100       9473 ($best_attr, $best_score) = ($attr, $score) if $score < $best_score;
204             }
205            
206 63         212 return $best_attr;
207             }
208              
209             sub entropy2 {
210 3165     3165 0 2909 shift;
211 3165         3330 my ($counts, $total) = @_;
212              
213             # Entropy is defined with log base 2 - we just divide by log(2) at the end to adjust.
214 3165         2834 my $sum = 0;
215 3165         12711 $sum += $_ * log($_) foreach values %$counts;
216 3165         13709 return +(log($total) - $sum/$total)/log(2);
217             }
218              
219             sub entropy {
220 63     63 0 63 shift;
221              
222 63         75 my %count;
223 63         1100 $count{$_}++ foreach @_;
224              
225             # Entropy is defined with log base 2 - we just divide by log(2) at the end to adjust.
226 63         79 my $sum = 0;
227 63         318 $sum += $_ * log($_) foreach values %count;
228 63         269 return +(log(@_) - $sum/@_)/log(2);
229             }
230              
231             sub prune_tree {
232 6     6 0 10 my $self = shift;
233              
234             # We use a minimum-description-length approach. We calculate the
235             # score of each node:
236             # n = number of nodes below
237             # r = number of results (categories) in the entire tree
238             # i = number of instances in the entire tree
239             # e = number of errors below this node
240              
241             # Hypothesis description length (MML):
242             # describe tree: number of nodes + number of edges
243             # describe exceptions: num_exceptions * log2(total_num_instances) * log2(total_num_results)
244            
245 6         8 my $r = keys %{ $self->{results} };
  6         14  
246 6         13 my $i = $self->{tree}{instances};
247 6         15 my $exception_cost = log($r) * log($i) / log(2)**2;
248              
249             # Pruning can turn a branch into a leaf
250             my $maybe_prune = sub {
251 161     161   156 my ($self, $node) = @_;
252 161 100       327 return unless $node->{children}; # Can't prune leaves
253              
254 58         104 my $nodes_below = $self->nodes_below($node);
255 58         86 my $tree_cost = 2 * $nodes_below - 1; # $edges_below == $nodes_below - 1
256            
257 58         98 my $exceptions = $self->exceptions( $node );
258 58         88 my $simple_rule_exceptions = $node->{instances} - $node->{distribution}[1];
259              
260 58         108 my $score = -$nodes_below - ($exceptions - $simple_rule_exceptions) * $exception_cost;
261             #warn "Score = $score = -$nodes_below - ($exceptions - $simple_rule_exceptions) * $exception_cost\n";
262 58 100       133 if ($score < 0) {
263 2         5 delete @{$node}{'children', 'split_on', 'exceptions', 'nodes_below'};
  2         16  
264 2         6 $node->{result} = $node->{distribution}[0];
265             # XXX I'm not cleaning up 'exceptions' or 'nodes_below' keys up the tree
266             }
267 6         40 };
268              
269 6         21 $self->_traverse($maybe_prune);
270             }
271              
272             sub exceptions {
273 770     770 0 744 my ($self, $node) = @_;
274 770 50       1319 return $node->{exceptions} if exists $node->{exeptions};
275            
276 770         708 my $count = 0;
277 770 100       1126 if ( exists $node->{result} ) {
278 501         739 $count = $node->{instances} - $node->{distribution}[1];
279             } else {
280 269         214 foreach my $child ( values %{$node->{children}} ) {
  269         492  
281 712         1089 $count += $self->exceptions($child);
282             }
283             }
284            
285 770         1462 return $node->{exceptions} = $count;
286             }
287              
288             sub nodes_below {
289 58     58 0 58 my ($self, $node) = @_;
290 58 50       117 return $node->{nodes_below} if exists $node->{nodes_below};
291              
292 58         57 my $count = 0;
293 58     770   211 $self->_traverse( sub {$count++}, $node );
  770         797  
294              
295 58         220 return $node->{nodes_below} = $count - 1;
296             }
297              
298             # This is *not* for external use, I may change it.
299             sub _traverse {
300 931     931   1130 my ($self, $callback, $node, $parent, $node_name) = @_;
301 931   66     1503 $node ||= $self->{tree};
302            
303 931 50       1915 ref($callback) ? $callback->($self, $node, $parent, $node_name) : $self->$callback($node, $parent, $node_name);
304            
305 931 100       2495 return unless $node->{children};
306 325         307 foreach my $child ( keys %{$node->{children}} ) {
  325         694  
307 867         2171 $self->_traverse($callback, $node->{children}{$child}, $node, $child);
308             }
309             }
310              
311             sub get_result {
312 90     90 1 7050 my ($self, %args) = @_;
313 90 50 66     218 croak "Missing 'attributes' or 'callback' parameter" unless $args{attributes} or $args{callback};
314              
315 90 50       170 $self->train unless $self->{tree};
316 90         94 my $tree = $self->{tree};
317            
318 90         79 while (1) {
319 435 100       739 if (exists $tree->{result}) {
320 88         103 my $r = $tree->{result};
321 88 100       230 return $r unless wantarray;
322              
323 81         70 my %dist = @{$tree->{distribution}};
  81         180  
324 81         151 my $confidence = $tree->{distribution}[1] / $tree->{instances};
325              
326             # my $confidence = P(H|D) = [P(D|H)P(H)]/[P(D|H)P(H)+P(D|H')P(H')]
327             # = [P(D|H)P(H)]/P(D);
328             # my $confidence =
329             # $confidence *= $self->{prior_freqs}{$r} / $self->{total_instances};
330            
331 81         311 return ($r, $confidence, \%dist);
332             }
333            
334 347 100       924 my $instance_val = (exists $args{callback} ? $args{callback}->($tree->{split_on}) :
    100          
335             exists $args{attributes}{$tree->{split_on}} ? $args{attributes}{$tree->{split_on}} :
336             '');
337             ## no critic (ProhibitExplicitReturnUndef)
338 347 100       774 $tree = $tree->{children}{ $instance_val }
339             or return undef;
340             }
341             }
342              
343             sub as_graphviz {
344 0     0 1 0 my ($self, %args) = @_;
345 0   0     0 my $colors = delete $args{leaf_colors} || {};
346 0         0 require GraphViz;
347 0         0 my $g = GraphViz->new(%args);
348              
349 0         0 my $id = 1;
350             my $add_edge = sub {
351 0     0   0 my ($self, $node, $parent, $node_name) = @_;
352             # We use stringified reference names for node names, as a convenient hack.
353              
354 0 0       0 if ($node->{split_on}) {
355 0         0 $g->add_node( "$node",
356             label => $node->{split_on},
357             shape => 'ellipse',
358             );
359             } else {
360 0         0 my $i = 0;
361 0         0 my $distr = join ',', grep {$i++ % 2} @{$node->{distribution}};
  0         0  
  0         0  
362 0 0       0 my %fill = (exists $colors->{$node->{result}} ?
363             (fillcolor => $colors->{$node->{result}},
364             style => 'filled') :
365             ()
366             );
367 0         0 $g->add_node( "$node",
368             label => "$node->{result} ($distr)",
369             shape => 'box',
370             %fill,
371             );
372             }
373 0 0       0 $g->add_edge( "$parent" => "$node",
374             label => $node_name,
375             ) if $parent;
376 0         0 };
377              
378 0         0 $self->_traverse( $add_edge );
379 0         0 return $g;
380             }
381              
382             sub rule_tree {
383 16     16 1 319 my $self = shift;
384 16 100       35 my ($tree) = @_ ? @_ : $self->{tree};
385            
386             # build tree:
387             # [ question, { results => [ question, { ... } ] } ]
388            
389 16 100       67 return $tree->{result} if exists $tree->{result};
390            
391             return [
392 14         29 $tree->{split_on}, {
393 6         7 map { $_ => $self->rule_tree($tree->{children}{$_}) } keys %{$tree->{children}},
  6         14  
394             }
395             ];
396             }
397              
398             sub rule_statements {
399 12     12 1 206 my $self = shift;
400 12 100       26 my ($stmt, $tree) = @_ ? @_ : ('', $self->{tree});
401 12 100       39 return("$stmt -> '$tree->{result}'") if exists $tree->{result};
402            
403 4         4 my @out;
404 4 100       10 my $prefix = $stmt ? "$stmt and" : "if";
405 4         5 foreach my $val (keys %{$tree->{children}}) {
  4         8  
406 10         37 push @out, $self->rule_statements("$prefix $tree->{split_on}='$val'", $tree->{children}{$val});
407             }
408 4         15 return @out;
409             }
410              
411             ### Some instance accessor stuff:
412              
413             sub _result {
414 3006     3006   3185 my ($self, $instance) = @_;
415 3006         4465 my $int = $instance->result_int;
416 3006         7604 return $self->{results_reverse}[$int];
417             }
418              
419             sub _delete_value {
420 0     0   0 my ($self, $instance, $attr) = @_;
421 0         0 my $val = $self->_value($instance, $attr);
422 0 0       0 return unless defined $val;
423            
424 0         0 $instance->set_value($self->{attributes}{$attr}, 0);
425 0         0 return $val;
426             }
427              
428             sub _value {
429 2254     2254   2408 my ($self, $instance, $attr) = @_;
430 2254 50       4035 return unless exists $self->{attributes}{$attr};
431 2254         4204 my $val_int = $instance->value_int($self->{attributes}{$attr});
432 2254         3915 return $self->{attribute_values_reverse}{$attr}[$val_int];
433             }
434              
435              
436              
437             1;
438             __END__