File Coverage

blib/lib/AI/XGBoost/DMatrix.pm
Criterion Covered Total %
statement 10 12 83.3
branch n/a
condition n/a
subroutine 4 4 100.0
pod n/a
total 14 16 87.5


line stmt bran cond sub pod time code
1             package AI::XGBoost::DMatrix;
2              
3 1     1   50373 use strict;
  1         2  
  1         27  
4 1     1   5 use warnings;
  1         1  
  1         28  
5 1     1   4 use utf8;
  1         2  
  1         7  
6              
7             our $VERSION = '0.11'; # VERSION
8              
9             # ABSTRACT: XGBoost class for data
10              
11 1     1   122 use Moose;
  0            
  0            
12             use AI::XGBoost::CAPI qw(:all);
13             use Carp;
14             use namespace::autoclean;
15              
16             has handle => ( is => 'ro', );
17              
18             sub From {
19             my ( $package, %args ) = @_;
20             return __PACKAGE__->FromFile( filename => $args{file}, silent => $args{silent} ) if ( defined $args{file} );
21             return __PACKAGE__->FromMat( map { $_ => $args{$_} if defined $_ } qw(matrix missing label) )
22             if ( defined $args{matrix} );
23             Carp::cluck( "I don't know how to build a " . __PACKAGE__ . " with this data: " . join( ", ", %args ) );
24             }
25              
26             sub FromFile {
27             my ( $package, %args ) = @_;
28             my $handle = XGDMatrixCreateFromFile( @args{qw(filename silent)} );
29             return __PACKAGE__->new( handle => $handle );
30             }
31              
32             sub FromMat {
33             my ( $package, %args ) = @_;
34             my $handle = XGDMatrixCreateFromMat( @args{qw(matrix missing)} );
35             my $matrix = __PACKAGE__->new( handle => $handle );
36             if ( defined $args{label} ) {
37             $matrix->set_label( $args{label} );
38             }
39             return $matrix;
40             }
41              
42             sub set_float_info {
43             my $self = shift();
44             my ( $field, $info ) = @_;
45             XGDMatrixSetFloatInfo( $self->handle, $field, $info );
46             return $self;
47             }
48              
49             sub set_float_info_pdl {
50             my $self = shift();
51             my ( $field, $info ) = @_;
52             XGDMatrixSetFloatInfo( $self->handle, $field, $info->flat()->unpdl() );
53             return $self;
54             }
55              
56             sub get_float_info {
57             my $self = shift();
58             my $field = shift();
59             XGDMatrixGetFloatInfo( $self->handle, $field );
60             }
61              
62             sub set_uint_info {
63             my $self = shift();
64             my ( $field, $info ) = @_;
65             XGDMatrixSetUintInfo( $self->handle, $field, $info );
66             return $self;
67             }
68              
69             sub get_uint_info {
70             my $self = shift();
71             my $field = shift();
72             XGDMatrixGetUintInfo( $self->handle, $field );
73             }
74              
75             sub save_binary {
76             my $self = shift();
77             my ( $filename, $silent ) = @_;
78             $silent //= 1;
79             XGDMatrixSaveBinary( $self->handle, $filename, $silent );
80             return $self;
81             }
82              
83             sub set_label {
84             my $self = shift();
85             my $label = shift();
86             $self->set_float_info( 'label', $label );
87             }
88              
89             sub set_label_pdl {
90             my $self = shift();
91             my $label = shift();
92             $self->set_float_info_pdl( 'label', $label->flat()->unpdl() );
93             }
94              
95             sub get_label {
96             my $self = shift();
97             $self->get_float_info('label');
98             }
99              
100             sub set_weight {
101             my $self = shift();
102             my $weight = shift();
103             $self->set_float_info( 'weight', $weight );
104             return $self;
105             }
106              
107             sub set_weight_pdl {
108             my $self = shift();
109             my $weight = shift();
110             $self->set_float_info( 'weight', $weight->flat()->unpdl() );
111             return $self;
112             }
113              
114             sub get_weight {
115             my $self = shift();
116             $self->get_float_info('weight');
117             }
118              
119             sub set_base_margin {
120             my $self = shift();
121             my $margin = shift();
122             $self->set_float_info( 'base_margin', $margin );
123             return $self;
124             }
125              
126             sub get_base_margin {
127             my $self = shift();
128             $self->get_float_info('base_margin');
129             }
130              
131             sub set_group {
132             my $self = shift();
133             my $group = shift();
134             XGDMatrixSetGroup( $self->handle, $group );
135             return $self;
136             }
137              
138             sub num_row {
139             my $self = shift();
140             XGDMatrixNumRow( $self->handle );
141             }
142              
143             sub num_col {
144             my $self = shift();
145             XGDMatrixNumCol( $self->handle );
146             }
147              
148             sub dims {
149             my $self = shift();
150             return ( $self->num_row(), $self->num_col() );
151             }
152              
153             sub slice {
154             my $self = shift;
155             my ($list_of_indices) = @_;
156             my $handle = XGDMatrixSliceDMatrix( $self->handle(), $list_of_indices );
157             return __PACKAGE__->new( handle => $handle );
158             }
159              
160             sub DEMOLISH {
161             my $self = shift();
162             XGDMatrixFree( $self->handle );
163             }
164              
165             __PACKAGE__->meta->make_immutable();
166              
167             1;
168              
169             __END__