File Coverage

blib/lib/AI/XGBoost/Booster.pm
Criterion Covered Total %
statement 10 12 83.3
branch n/a
condition n/a
subroutine 4 4 100.0
pod n/a
total 14 16 87.5


line stmt bran cond sub pod time code
1             package AI::XGBoost::Booster;
2              
3 1     1   8 use strict;
  1         2  
  1         22  
4 1     1   5 use warnings;
  1         1  
  1         18  
5 1     1   5 use utf8;
  1         2  
  1         4  
6              
7             our $VERSION = '0.11'; # VERSION
8              
9             # ABSTRACT: XGBoost main class for training, prediction and evaluation
10              
11 1     1   202 use Moose;
  0            
  0            
12             use AI::XGBoost::CAPI qw(:all);
13             use namespace::autoclean;
14              
15             has _handle => ( is => 'rw',
16             init_arg => undef, );
17              
18             sub update {
19             my $self = shift;
20             my %args = @_;
21             my ( $iteration, $dtrain ) = @args{qw(iteration dtrain)};
22             XGBoosterUpdateOneIter( $self->_handle, $iteration, $dtrain->handle );
23             return $self;
24             }
25              
26             sub boost {
27             my $self = shift;
28             my %args = @_;
29             my ( $dtrain, $grad, $hess ) = @args{qw(dtrain grad hess)};
30             XGBoosterBoostOneIter( $self->_handle, $dtrain, $grad, $hess );
31             return $self;
32             }
33              
34             sub predict {
35             my $self = shift;
36             my %args = @_;
37             my $data = $args{'data'};
38             my $result = XGBoosterPredict( $self->_handle, $data->handle );
39             my $result_size = scalar @$result;
40             my $matrix_rows = $data->num_row;
41             if ( $result_size != $matrix_rows && $result_size % $matrix_rows == 0 ) {
42             my $col_size = $result_size / $matrix_rows;
43             return [ map { [ @$result[ $_ * $col_size .. $_ * $col_size + $col_size - 1 ] ] } 0 .. $matrix_rows - 1 ];
44             }
45             return $result;
46             }
47              
48             sub set_param {
49             my $self = shift;
50             my ( $name, $value ) = @_;
51             XGBoosterSetParam( $self->_handle, $name, $value );
52             return $self;
53             }
54              
55             sub set_attr {
56             my $self = shift;
57             my ( $name, $value ) = @_;
58             XGBoosterSetAttr( $self->_handle, $name, $value );
59             return $self;
60             }
61              
62             sub get_attr {
63             my $self = shift;
64             my ($name) = @_;
65             XGBoosterGetAttr( $self->_handle, $name );
66             }
67              
68             sub get_score {
69             my $self = shift;
70             my %args = @_;
71             my ( $fmap, $importance_type ) = @args{qw(fmap importance_type)};
72              
73             if ( $importance_type eq "weight" ) {
74             my @trees = $self->get_dump;
75             } else {
76              
77             }
78              
79             }
80              
81             sub get_dump {
82             my $self = shift;
83             return XGBoosterDumpModelEx( $self->_handle, "", 1, "text" );
84             }
85              
86             sub attributes {
87             my $self = shift;
88             return { map { $_ => $self->get_attr($_) } @{ XGBoosterGetAttrNames( $self->_handle ) } };
89             }
90              
91             sub TO_JSON {
92             my $self = shift;
93             my $trees = XGBoosterDumpModelEx( $self->_handle, "", 1, "json" );
94             return "[" . join( ',', @$trees ) . "]";
95             }
96              
97             sub BUILD {
98             my $self = shift;
99             my $args = shift;
100             $self->_handle( XGBoosterCreate( [ map { $_->handle } @{ $args->{'cache'} } ] ) );
101             }
102              
103             sub DEMOLISH {
104             my $self = shift();
105             XGBoosterFree( $self->_handle );
106             }
107              
108             __PACKAGE__->meta->make_immutable();
109              
110             1;
111              
112             __END__