File Coverage

blib/lib/Algorithm/BaumWelch.pm
Criterion Covered Total %
statement 21 348 6.0
branch 0 44 0.0
condition 0 26 0.0
subroutine 7 22 31.8
pod 0 6 0.0
total 28 446 6.2


line stmt bran cond sub pod time code
1             package Algorithm::BaumWelch;
2 1     1   45216 use warnings;
  1         3  
  1         33  
3 1     1   5 use strict;
  1         2  
  1         1437  
4 1     1   8 use Carp;
  1         6  
  1         99  
5 1     1   2481 use Math::Cephes qw/:explog/;
  1         27088  
  1         398  
6 1     1   12 use List::Util qw/sum/;
  1         3  
  1         1826  
7 1     1   4661 use Text::SimpleTable;
  1         5515  
  1         65  
8              
9             # vale a pena em fazer forward/backward com normalisation? e gamma!?!
10              
11             =head1 NAME
12              
13             Algorithm::BaumWelch - Baum-Welch Algorithm for Hidden Markov Chain parameter estimation.
14              
15             =cut
16              
17             =head1 VERSION
18              
19             This document describes Algorithm::BaumWelch version 0.0.2
20              
21             =cut
22              
23             =head1 SYNOPSIS
24              
25             use Algorithm::BaumWelch;
26              
27             # The observation series see http://www.cs.jhu.edu/~jason/.
28             my $obs_series = [qw/ obs2 obs3 obs3 obs2 obs3 obs2 obs3 obs2 obs2
29             obs3 obs1 obs3 obs3 obs1 obs1 obs1 obs2 obs1
30             obs1 obs1 obs3 obs1 obs2 obs1 obs1 obs1 obs2
31             obs3 obs3 obs2 obs3 obs2 obs2
32             /];
33              
34             # The emission matrix - each nested array corresponds to the probabilities of a single observation type.
35             my $emis = {
36             obs1 => [0.3, 0.3],
37             obs2 => [0.3, 0.4],
38             obs3 => [0.4, 0.3],
39             };
40              
41             # The transition matrixi - each row and column correspond to a particular state e.g. P(state1_x|state1_x-1) = 0.9...
42             my $trans = [
43             [0.9, 0.1],
44             [0.1, 0.9],
45             ];
46              
47             # The probabilities of each state at the start of the series.
48             my $start = [0.5, 0.5];
49              
50             # Create an Algorithm::BaumWelch object.
51             my $ba = Algorithm::BaumWelch->new;
52              
53             # Feed in the observation series.
54             $ba->feed_obs($obs_series);
55              
56             # Feed in the transition and emission matrices and the starting probabilities.
57             $ba->feed_values($trans, $emis, $start);
58              
59             # Alternatively you can randomly initialise the values - pass it the number of hidden states -
60             # i.e. to determine the parameters we need to make a first guess).
61             # $ba->random_initialise(2);
62            
63             # Perform the algorithm.
64             $ba->baum_welch;
65              
66             # Use results to pass data.
67             # In VOID-context prints formated results to STDOUT.
68             # In LIST-context returns references to the predicted transition & emission matrices and the starting parameters.
69             $ba->results;
70              
71             =cut
72              
73             =head1 DESCRIPTION
74              
75             The Baum-Welch algorithm is used to compute the parameters (transition and emission probabilities) of an Hidden Markov
76             Model (HMM). The algorithm calculates the forward and backwards probabilities for each HMM state in a series and then re-estimates the parameters of
77             the model.
78              
79             =cut
80              
81 1     1   3830 use version; our $VERSION = qv('0.0.2');
  1         5549  
  1         8  
82              
83             #r/ matrices de BW sao 1xN_states matrices - quer dizer quasi arrays - entao nao usa matrices reais. arrays são bastante
84             sub new {
85 0     0 0   my $class = shift;
86 0           my $self = [undef, undef, []]; bless $self, $class;
  0            
87 0           return $self;
88             }
89              
90             sub feed_obs {
91 0     0 0   my ($self, $series) = @_;
92 0           my %uniq;
93 0           @uniq{@{$series}} = 1;
  0            
94 0           my @obs = (keys %uniq);
95 0           @obs = sort { $a cmp $b } @obs;
  0            
96 0           $self->[0][0] = $series;
97 0           $self->[0][1] = [@obs];
98 0           $self->[0][2] = scalar @obs;
99 0           return;
100             }
101              
102             sub feed_values {
103 0 0   0 0   croak qq{\nThis method expects 3 arguments.} if @_ != 4;
104 0           my ($self, $trans, $emis, $start) = @_;
105 0 0 0       croak qq{\nThis method expects 3 arguments.} if (ref $trans ne q{ARRAY} || ref $emis ne q{HASH} || ref $start ne q{ARRAY});
      0        
106 0           my $obs_tipos = $self->[0][1];
107 0           my $obs_numero = $self->[0][2];
108 0           my $t_length = &_check_trans($trans);
109 0           &_check_emis($emis, $obs_tipos, $obs_numero, $t_length);
110 0           &_check_start($start, $t_length);
111 0           $self->[1][0] = $trans;
112 0           $self->[1][1] = $emis;
113 0           $self->[1][2] = $start;
114 0           my @stop; # 0.1/1 nao faz diferenca e para|comeca (stop|start) sempre iguala = 0
115 0           for (0..$#{$trans}) { push @stop, 1 };
  0            
  0            
116 0           $self->[1][3] = [@stop];
117 0           return;
118             }
119              
120             sub _check_start {
121 0     0     my ($start, $t_length) = @_;
122 0 0         croak qq{\nThere must be an initial probablity for each state in the start ARRAY.} if scalar @{$start} != $t_length;
  0            
123 0 0         for (@{$start}) { croak qq{\nThe start ARRAY values must be numeric.} if !(/^[+-]?\ *(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$/) };
  0            
  0            
124 0           my $sum =0;
125 0           for (@{$start}) { $sum += $_ }
  0            
  0            
126 0 0 0       croak qq{\nThe starting probabilities must sum to 1.} if ($sum <= 0.95 || $sum >= 1.05);
127 0           return;
128             }
129              
130             sub _check_emis {
131 0     0     my ($emis, $obs_tipos, $obs_numero, $t_length) = @_;
132 0           my @emis_keys = (keys %{$emis});
  0            
133 0           @emis_keys = sort {$a cmp $b} @emis_keys;
  0            
134 0 0         croak qq{\nThere must be an entry in the emission matrix for each type of observation in the observation series.} if $obs_numero != scalar @emis_keys;
135 0 0         for (0..$#emis_keys) { croak qq{\nThe observations in the emission matrix do not match those in the observation series.} if $emis_keys[$_] ne $obs_tipos->[$_]; }
  0            
136 0           for (values %{$emis}) {
  0            
137 0 0         croak qq{\nThere must be a probability value for each state in the emission matrix.} if scalar @{$_} != $t_length;
  0            
138 0 0         for my $cell (@{$_}) { croak qq{\nThe emission matrix values must be numeric.} if $cell !~ /^[+-]?\ *(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$/; }
  0            
  0            
139             }
140 0           for my $i (0..$t_length-1) { # só fazendo 2-estado agora
141 0           my $sum = 0;
142 0           for my $o (@{$obs_tipos}) { $sum += $emis->{$o}[$i] }
  0            
  0            
143 0 0 0       croak qq{\nThe emission matrix column must sum to 1.} if ($sum <= 0.95 || $sum >= 1.05);
144             }
145 0           return;
146             }
147              
148             sub _check_trans {
149 0     0     my $trans = shift;
150 0           my $t_length = scalar @{$trans};
  0            
151 0           for (@{$trans}) {
  0            
152 0 0         croak qq{\nThe transition matrix much be square.} if scalar @{$_} != $t_length;
  0            
153 0           my $sum = 0;
154 0           for my $cell (@{$_}) {
  0            
155 0 0         croak qq{\nThe transition matrix values must be numeric.} if $cell !~ /^[+-]?\ *(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$/;
156 0           $sum += $cell
157             }
158 0 0 0       croak qq{\nThe transition matrix row must sum to 1.} if ($sum <= 0.95 || $sum >= 1.05);
159             }
160 0           return $t_length;
161             }
162              
163             sub random_initialise {
164 0     0 0   my ($self, $states) = @_;
165 0           my $obs_names = $self->[0][1];
166 0           my $trans = &_gera_trans($states);
167 0           my $emis = &_gera_emis($states, $obs_names);
168 0           my $start = &_gera_init($states);
169 0           $self->[1][0] = $trans;
170 0           $self->[1][1] = $emis;
171 0           $self->[1][2] = $start;
172 0           my @stop; # 0.1/1 nao faz diferenca e para|comeca (stop|start) sempre iguala = 0
173 0           for (0..$states-1) { push @stop, 1 };
  0            
174 0           $self->[1][3] = [@stop];
175 0           return;
176             }
177              
178             sub _gera_init {
179 0     0     my $length = shift;
180 0           my $sum = 0;
181 0           my $init = [];
182 0           srand;
183 0           $#{$init} = $length-1; # só fazendo 2-estado agora
  0            
184 0           for (@{$init}) { $_ = rand; $sum += $_ }
  0            
  0            
  0            
185             #/ normalise such that sum is equal to 1
186 0           for (@{$init}) { $_ /= $sum }
  0            
  0            
187 0           return $init;
188             }
189              
190             sub _gera_trans {
191 0     0     my $length = shift;
192 0           my $t = [];
193 0           $#{$t} = $length-1; # só fazendo 2-estado agora
  0            
194             #/ gera_init normalises
195 0           for (@{$t}) { $_ = &_gera_init($length); }
  0            
  0            
196 0           return $t;
197             }
198              
199             sub _gera_emis {
200 0     0     my ($length, $obs_names) = @_;
201 0           my $e = {};
202 0           srand;
203 0           for (@{$obs_names}) {
  0            
204 0           my $init = [];
205 0           $#{$init} = $length-1; # só fazendo 2-estado agora
  0            
206 0           for (@{$init}) { $_ = rand; }
  0            
  0            
207 0           $e->{$_} = $init;
208             }
209             # para cada estado a suma deve iguala 1 - normalise such that sum of obs_x|state = 1
210 0           for my $i (0..$length-1) { # só fazendo 2-estado agora
211 0           my $sum = 0;
212 0           for my $o (@{$obs_names}) { $sum += $e->{$o}[$i] }
  0            
  0            
213 0           for my $o (@{$obs_names}) { $e->{$o}[$i] /= $sum }
  0            
  0            
214             }
215             #print qq{\n\nauto-gera emis de numeros aleatorios que sumam 1 para cada estado}; draw($e);
216 0           return $e;
217             }
218              
219             sub _forwardbackward_reestimacao {
220 0     0     my $self = shift;
221 0           my $obs_series = $self->[0][0];
222 0           my $obs_types = $self->[0][1];
223 0           my $trans = $self->[1][0];
224 0           my $emis = $self->[1][1];
225 0           my $start = $self->[1][2];
226 0           my $stop = $self->[1][3];
227 0           my $alpha = [];
228             #y initialise
229 0           for (0..$#{$trans}) { $alpha->[$_][0] = $start->[$_] * $emis->{$obs_series->[0]}[$_]; }
  0            
  0            
230             #y not sure if i´ve extrapolated to higher-state number BW algorithm equations correctly?!?
231 0           for my $n (1..$#{$obs_series}) {
  0            
232 0           for my $s (0..$#{$trans}) {
  0            
233             #push @{$alpha->[$s]}, ( ( ($alpha->[0][$n-1]*$trans->[$s][0]) + ($alpha->[1][$n-1]*$trans->[$s][1]) ) * $emis->{$obs_series->[$n]}[$s] ) ;
234 0           my $sum = 0;
235 0           for my $s_other (0..$#{$trans}) { $sum += $alpha->[$s_other][$n-1]*$trans->[$s][$s_other]; }
  0            
  0            
236 0           push @{$alpha->[$s]}, ( ($sum) * $emis->{$obs_series->[$n]}[$s] ) ;
  0            
237             }
238             }
239              
240 0           my $beta = [];
241             #y initialise
242 0           for (0..$#{$trans}) { $beta->[$_][$#{$obs_series}] = $stop->[$_]; }
  0            
  0            
  0            
243 0           for ( my $n = $#{$obs_series}-1; $n > -1; $n-- ) {
  0            
244 0           for my $s (0..$#{$trans}) {
  0            
245             #$beta->[$s][$i] = ( ($trans->[0][$s]*$beta->[0][$i+1]*$emis->{$obs_series->[$i+1]}[0]) + ($trans->[1][$s]*$beta->[1][$i+1]*$emis->{$obs_series->[$i+1]}[1]) );
246 0           my $sum = 0;
247 0           for my $s_other (0..$#{$trans}) { $sum += ($trans->[$s_other][$s]*$beta->[$s_other][$n+1]*$emis->{$obs_series->[$n+1]}[$s_other]); }
  0            
  0            
248 0           $beta->[$s][$n] = $sum;
249             }
250             }
251              
252             #=fs normalisation?!?
253             #for my $n (0..$#{$obs_series}) { my $sum = 0; for my $s (0..$#{$trans}) { $sum += $alpha->[$s][$n] } for my $s (0..$#{$trans}) { $alpha->[$s][$n] = $alpha->[$s][$n] / $sum; } }
254             #for my $n (0..$#{$obs_series}) { my $sum = 0; for my $s (0..$#{$trans}) { $sum += $beta->[$s][$n] } for my $s (0..$#{$trans}) { $beta->[$s][$n] = $beta->[$s][$n] / $sum; } }
255             #=fe
256              
257             # per state gamma - i.e. gamma é matric de 1 x N_states
258 0           my $gamma = [];
259 0           for my $s (0..$#{$trans}) { @{$gamma->[$s]} = map { $alpha->[$s][$_] * $beta->[$s][$_] } (0..$#{$obs_series}); }
  0            
  0            
  0            
  0            
  0            
260              
261             #=fs normalisation?!?
262             #for my $n (0..$#{$obs_series}) { my $sum = 0; for my $s (0..$#{$trans}) { $sum += $gamma->[$s][$n] } for my $s (0..$#{$trans}) { $gamma->[$s][$n] = $gamma->[$s][$n] / $sum; } }
263             #=fe
264              
265             #y gamma_sum = probadilidade total - entao nos nao normalisar dados como normal - faz differenca?!?
266 0           my $gamma_sum = []; # should be same for all elements or...
267             #@{$gamma_sum} = map { $gamma->[0][$_] + $gamma->[1][$_] } (0..$#{$obs_series});
268              
269             # map so devolve o último statement / map only returns the last statement
270 0           @{$gamma_sum} = map { my $sum = 0; for my $s (0..$#{$trans}) { $sum += $gamma->[$s][$_] }; $sum } (0..$#{$obs_series});
  0            
  0            
  0            
  0            
  0            
  0            
  0            
271             #push @{$perp}, 2**(-log2($gamma_sum->[0])/(scalar @{$obs_series} + 1));
272 0           push @{$self->[2]}, 2**(-log2($gamma_sum->[0])/(scalar @{$obs_series} + 1));
  0            
  0            
273              
274 0           my $p_too_state_trans = [];
275 0           for my $s (0..$#{$trans}) { @{$p_too_state_trans->[$s]} = map { $gamma->[$s][$_] / $gamma_sum->[$_] } (0..$#{$obs_series}); }
  0            
  0            
  0            
  0            
  0            
276              
277 0           my $p_too_state_trans_with_obs = []; # estado será primeira índice e obs será a segunda - é uma matric real mas facil
278 0           for my $s (0..$#{$trans}) {
  0            
279 0           for my $o (0..$#{$obs_types}) {
  0            
280 0 0         @{$p_too_state_trans_with_obs->[$s][$o]} = map { $obs_series->[$_] eq $obs_types->[$o] ? $p_too_state_trans->[$s][$_] : 0; } (0..$#{$obs_series});
  0            
  0            
  0            
281             }
282             }
283              
284 0           my $p_state_too_state_trans = [];
285 0           for my $s_1st (0..$#{$trans}) {
  0            
286 0           for my $s_2nd (0..$#{$trans}) {
  0            
287             #/ this is pretty inefficient - but its fun
288 0 0         @{$p_state_too_state_trans->[$s_1st][$s_2nd]} = map { $_ != 0 ? ( $alpha->[$s_1st][$_-1] * $trans->[$s_2nd][$s_1st]
  0            
  0            
289             * $beta->[$s_2nd][$_] * $emis->{$obs_series->[$_]}[$s_2nd] )
290 0           / $gamma_sum->[$_] : 0 } (0..$#{$obs_series});
291             }
292             }
293              
294 0           my $emis_new = {};
295 0           for my $s (0..$#{$trans}) {
  0            
296 0           for my $o (0..$#{$obs_types}) {
  0            
297 0           $emis_new->{$obs_types->[$o]}[$s] = (sum @{$p_too_state_trans_with_obs->[$s][$o]} ) / (sum @{$p_too_state_trans->[$s]} );
  0            
  0            
298             }
299             }
300              
301 0           my $trans_new = [];
302 0           for my $s_1st (0..$#{$trans}) {
  0            
303 0           for my $s_2nd (0..$#{$trans}) {
  0            
304 0           $trans_new->[$s_2nd][$s_1st] = (sum @{$p_state_too_state_trans->[$s_1st][$s_2nd]} ) / (sum @{$p_too_state_trans->[$s_1st]} );
  0            
  0            
305             }
306             }
307              
308 0           my $stop_new = [];
309 0           for my $s (0..$#{$trans}) { $stop_new->[$s] = ( $p_too_state_trans->[$s][$#{$obs_series}] ) / (sum @{$p_too_state_trans->[$s]} ); }
  0            
  0            
  0            
  0            
310 0           my $start_new = [];
311 0           for my $s (0..$#{$trans}) { $start_new->[$s] = $p_too_state_trans->[$s][0]; }
  0            
  0            
312              
313 0           $self->[1][0] = $trans_new;
314 0           $self->[1][1] = $emis_new;
315 0           $self->[1][2] = $start_new;
316 0           $self->[1][3] = $stop_new;
317              
318 0           return;
319             }
320              
321             sub baum_welch {
322             #/ i´m being lazy this is an acceptable cut-off mechanism for now
323 0     0 0   my ($self, $max) = @_;
324 0   0       $max ||= 100;
325 0           my $val;
326 0           my $count = 1;
327 0           while (1) {
328 0           $self->_forwardbackward_reestimacao;
329 0 0 0       last if defined $val && $val < ${$self->[2]}[-1];
  0            
330 0 0         $val = ${$self->[2]}[-1] - ( ${$self->[2]}[-1]/1000000000) if $count > 3;
  0            
  0            
331 0           $count++;
332 0 0         last if $count > 100;
333             }
334 0           return;
335             }
336              
337             sub _baum_welch_10 {
338 0     0     my $self = shift;
339 0           for (0..10) { $self->_forwardbackward_reestimacao; }
  0            
340 0           return;
341             }
342              
343             sub _baum_welch_length {
344 0     0     my $self = shift;
345 0           for (0..$#{$self->[0][0]}) { $self->_forwardbackward_reestimacao; }
  0            
  0            
346 0           return;
347             }
348              
349             sub results {
350 0     0 0   my $self = shift;
351 0           my $trans = $self->[1][0];
352 0           my $emis = $self->[1][1];
353 0           my $start = $self->[1][2];
354 0 0         if (wantarray) {
355 0           return ($trans, $emis, $start);
356             }
357             else {
358 0           my $keys = $self->[0][1];
359 0           my @config = ( [15, q{}] );
360 0           push @config, (map { [ 15, q{P(...|State_}.$_.q{)} ] } (1..$#{$trans->[0]}+1));
  0            
  0            
361 0           my $tbl = Text::SimpleTable->new(@config);
362 0           for my $row (0..$#{$trans}) {
  0            
363 0           my @data;
364             # quem liga qual serie
365 0           for my $col (0..$#{$trans->[0]}) { push @data, sprintf(q{%.8e},$trans->[$row][$col]); }
  0            
  0            
366 0           my $row_num = $row+1;
367 0           $tbl->row( qq{P(State_${row_num}|...)}, @data );
368 0 0         $tbl->hr if $row != $#{$trans};
  0            
369             }
370 0           print qq{\nTransition matrix.\n};
371 0           print $tbl->draw;
372              
373 0           undef @config;
374 0           @config = ( [15, q{}] );
375 0           push @config, (map { [ 15, q{P(...|State_}.$_.q{)} ] } (1..$#{$trans->[0]}+1));
  0            
  0            
376 0           my $tbl1 = Text::SimpleTable->new(@config);
377 0           my $count = 0;
378 0           for my $row (@{$keys}) {
  0            
379             #$tbl1->row( $row, ( map { my $v = $emis->{$row}[$_]; if ($v > 1e-4 || $v < 1e4 ) { $v = sprintf(q{%.12f},$start->[$_]) } else { $v = sprintf(q{%.8e},$start->[$_]) }; $v } (0..$#{$trans->[0]}) ) );
380 0           my @data;
381 0           for my $col (0..$#{$trans->[0]}) { push @data, sprintf(q{%.8e},$emis->{$row}[$col]); }
  0            
  0            
382 0           $tbl1->row( qq{P($row|...)}, @data );
383 0 0         $tbl1->hr if $count != $#{$keys};
  0            
384 0           $count++;
385             }
386 0           print qq{\nEmission matrix.\n};
387 0           print $tbl1->draw;
388              
389 0           undef @config;
390 0           push @config, (map { [ 15, q{State_}.$_ ] } (1..$#{$start}+1));
  0            
  0            
391 0           my $tbl2 = Text::SimpleTable->new(@config);
392             #my @data;
393             #for my $i (0..$#{$trans->[0]}) { push @data, sprintf(q{%.8e},$start->[$i]); }
394             #$tbl2->row(@data);
395 0 0 0       $tbl2->row( ( map { my $v = $start->[$_]; if ($v > 1e-4 && $v < 1e4 || $v == 0 ) {
  0   0        
  0            
396 0           $v = sprintf(q{%.12f},$start->[$_])
397             }
398             else {
399 0           $v = sprintf(q{%.8e},$start->[$_]) }; $v
  0            
400 0           } (0..$#{$trans->[0]}) ) );
401 0           print qq{\nStart probabilities.\n};
402 0           print $tbl2->draw;
403             }
404 0           return;
405             }
406              
407             1; # Magic true value required at end of module
408              
409             __END__