File Coverage

blib/lib/AI/Embedding.pm
Criterion Covered Total %
statement 67 92 72.8
branch 14 28 50.0
condition n/a
subroutine 15 19 78.9
pod 8 8 100.0
total 104 147 70.7


line stmt bran cond sub pod time code
1             package AI::Embedding;
2            
3 3     3   204985 use strict;
  3         27  
  3         93  
4 3     3   17 use warnings;
  3         5  
  3         75  
5            
6 3     3   2117 use HTTP::Tiny;
  3         156773  
  3         126  
7 3     3   2387 use JSON::PP;
  3         52623  
  3         274  
8 3     3   1570 use Data::CosineSimilarity;
  3         52019  
  3         3828  
9            
10             our $VERSION = '1.01';
11             $VERSION = eval $VERSION;
12            
13             my $http = HTTP::Tiny->new;
14            
15             # Create Embedding object
16             sub new {
17 4     4 1 218 my $class = shift;
18 4         19 my %attr = @_;
19            
20 4         11 $attr{'error'} = '';
21            
22 4 100       15 $attr{'api'} = 'OpenAI' unless $attr{'api'};
23 4 100       15 $attr{'error'} = 'Invalid API' unless $attr{'api'} eq 'OpenAI';
24 4 100       13 $attr{'error'} = 'API Key missing' unless $attr{'key'};
25            
26 4 50       14 $attr{'model'} = 'text-embedding-ada-002' unless $attr{'model'};
27            
28 4         15 return bless \%attr, $class;
29             }
30            
31             # Define endpoints for APIs
32             my %url = (
33             'OpenAI' => 'https://api.openai.com/v1/embeddings',
34             );
35            
36             # Define HTTP Headers for APIs
37             my %header = (
38             'OpenAI' => &_get_header_openai,
39             );
40            
41             # Returns true if last operation was success
42             sub success {
43 6     6 1 2370 my $self = shift;
44 6         38 return !$self->{'error'};
45             }
46            
47             # Returns error if last operation failed
48             sub error {
49 1     1 1 2 my $self = shift;
50 1         5 return $self->{'error'};
51             }
52            
53             # Header for calling OpenAI
54             sub _get_header_openai {
55 3     3   7 my $self = shift;
56 3 50       19 $self->{'key'} = '' unless defined $self->{'key'};
57             return {
58 3         20 'Authorization' => 'Bearer ' . $self->{'key'},
59             'Content-type' => 'application/json'
60             };
61             }
62            
63             # Fetch Embedding response
64             sub _get_embedding {
65 0     0   0 my ($self, $text) = @_;
66            
67             return $http->post($url{$self->{'api'}}, {
68             'headers' => {
69             'Authorization' => 'Bearer ' . $self->{'key'},
70             'Content-type' => 'application/json'
71             },
72             content => encode_json {
73             input => $text,
74 0         0 model => $self->{'model'},
75             }
76             });
77             }
78            
79             # TODO:
80             # Make 'headers' use $header{$self->{'api'}}
81             # Currently hard coded to OpenAI
82            
83             # Added purely for testing - IGNORE!
84             sub _test {
85 0     0   0 my $self = shift;
86             # return $self->{'api'};
87 0         0 return $header{$self->{'api'}};
88             }
89            
90             # Return Embedding as a CSV string
91             sub embedding {
92 0     0 1 0 my ($self, $text, $verbose) = @_;
93            
94 0         0 my $response = $self->_get_embedding($text);
95 0 0       0 if ($response->{'success'}) {
96 0         0 my $embedding = decode_json($response->{'content'});
97 0         0 return join (',', @{$embedding->{'data'}[0]->{'embedding'}});
  0         0  
98             }
99 0         0 $self->{'error'} = 'HTTP Error - ' . $response->{'reason'};
100 0 0       0 return $response if defined $verbose;
101 0         0 return undef;
102             }
103            
104             # Return Embedding as an array
105             sub raw_embedding {
106 0     0 1 0 my ($self, $text, $verbose) = @_;
107            
108 0         0 my $response = $self->_get_embedding($text);
109 0 0       0 if ($response->{'success'}) {
110 0         0 my $embedding = decode_json($response->{'content'});
111 0         0 return @{$embedding->{'data'}[0]->{'embedding'}};
  0         0  
112             }
113 0         0 $self->{'error'} = 'HTTP Error - ' . $response->{'reason'};
114 0 0       0 return $response if defined $verbose;
115 0         0 return undef;
116             }
117            
118             # Return Test Embedding
119             sub test_embedding {
120 3     3 1 1353 my ($self, $text, $dimension) = @_;
121 3         8 $self->{'error'} = '';
122            
123 3 50       9 $dimension = 1536 unless defined $dimension;
124            
125 3 50       9 if ($text) {
126 3         21 srand scalar split /\s+/, $text;
127             }
128            
129 3         6 my @vector;
130 3         10 for (1...$dimension) {
131 4608         6373 push @vector, rand(2) - 1;
132             }
133 3         5605 return join ',', @vector;
134             }
135            
136             # Convert a CSV Embedding into a hashref
137             sub _make_vector {
138 6     6   50 my ($self, $embed_string) = @_;
139            
140 6         14 my %vector;
141 6         26 my @embed = split /,/, $embed_string;
142 6         20 for (my $i = 0; $i < @embed; $i++) {
143 55         150 $vector{'feature' . $i} = $embed[$i];
144             }
145 6         20 return \%vector;
146             }
147            
148             # Return a comparator to compare to a set vector
149             sub comparator {
150 1     1 1 828 my($self, $embed) = @_;
151 1         3 $self->{'error'} = '';
152            
153 1         5 my $vector1 = $self->_make_vector($embed);
154             return sub {
155 1     1   3 my($embed2) = @_;
156 1         4 my $vector2 = $self->_make_vector($embed2);
157 1         3 return $self->_compare_vector($vector1, $vector2);
158 1         6 };
159             }
160            
161             # Compare 2 Embeddings
162             sub compare {
163 2     2 1 8 my ($self, $embed1, $embed2) = @_;
164            
165 2         8 my $vector1 = $self->_make_vector($embed1);
166 2         3 my $vector2;
167 2 50       6 if (defined $embed2) {
168 2         14 $vector2 = $self->_make_vector($embed2);
169             } else {
170 0         0 $vector2 = $self->{'comparator'};
171             }
172            
173 2 50       7 if (!defined $vector2) {
174 0         0 $self->{'error'} = 'Nothing to compare!';
175 0         0 return;
176             }
177            
178 2 100       8 if (scalar keys %$vector1 != scalar keys %$vector2) {
179 1         4 $self->{'error'} = 'Embeds are unequal length';
180 1         6 return;
181             }
182            
183 1         4 return $self->_compare_vector($vector1, $vector2);
184             }
185            
186             # Compare 2 Vectors
187             sub _compare_vector {
188 2     2   5 my ($self, $vector1, $vector2) = @_;
189 2         15 my $cs = Data::CosineSimilarity->new;
190 2         26 $cs->add( label1 => $vector1 );
191 2         102 $cs->add( label2 => $vector2 );
192 2         74 return $cs->similarity('label1', 'label2')->cosine;
193             }
194            
195             1;
196            
197             __END__