File Coverage

blib/lib/AI/Embedding.pm
Criterion Covered Total %
statement 68 98 69.3
branch 15 32 46.8
condition n/a
subroutine 15 19 78.9
pod 8 8 100.0
total 106 157 67.5


line stmt bran cond sub pod time code
1             package AI::Embedding;
2            
3 4     4   669377 use strict;
  4         8  
  4         183  
4 4     4   34 use warnings;
  4         8  
  4         228  
5            
6 4     4   4408 use HTTP::Tiny;
  4         255087  
  4         280  
7 4     4   3251 use JSON::PP;
  4         88940  
  4         390  
8 4     4   2394 use Data::CosineSimilarity;
  4         80013  
  4         6546  
9            
10             our $VERSION = '1.11';
11             $VERSION = eval $VERSION;
12            
13             my $http = HTTP::Tiny->new;
14            
15             # Create Embedding object
16             sub new {
17 4     4 1 458950 my $class = shift;
18 4         18 my %attr = @_;
19            
20 4         9 $attr{'error'} = '';
21            
22 4 100       17 $attr{'api'} = 'OpenAI' unless $attr{'api'};
23 4 100       15 $attr{'error'} = 'Invalid API' unless $attr{'api'} eq 'OpenAI';
24 4 100       12 $attr{'error'} = 'API Key missing' unless $attr{'key'};
25            
26 4 50       14 $attr{'model'} = 'text-embedding-ada-002' unless $attr{'model'};
27            
28 4         13 return bless \%attr, $class;
29             }
30            
31             # Define endpoints for APIs
32             my %url = (
33             'OpenAI' => 'https://api.openai.com/v1/embeddings',
34             );
35            
36             # Define HTTP Headers for APIs
37             my %header = (
38             'OpenAI' => &_get_header_openai,
39             );
40            
41             # Returns true if last operation was success
42             sub success {
43 6     6 1 2343 my $self = shift;
44 6         33 return !$self->{'error'};
45             }
46            
47             # Returns error if last operation failed
48             sub error {
49 1     1 1 2 my $self = shift;
50 1         4 return $self->{'error'};
51             }
52            
53             # Header for calling OpenAI
54             sub _get_header_openai {
55 4     4   12 my $self = shift;
56 4 50       35 $self->{'key'} = '' unless defined $self->{'key'};
57             return {
58 4         27 'Authorization' => 'Bearer ' . $self->{'key'},
59             'Content-type' => 'application/json'
60             };
61             }
62            
63             # Fetch Embedding response
64             sub _get_embedding {
65 0     0   0 my ($self, $text) = @_;
66            
67             my $response = $http->post($url{$self->{'api'}}, {
68             'headers' => {
69             'Authorization' => 'Bearer ' . $self->{'key'},
70             'Content-type' => 'application/json'
71             },
72             content => encode_json {
73             input => $text,
74 0         0 model => $self->{'model'},
75             }
76             });
77 0 0       0 if ($response->{'content'} =~ 'invalid_api_key') {
78 0         0 die 'Incorrect API Key - check your API Key is correct';
79             }
80 0         0 return $response;
81             }
82            
83             # TODO:
84             # Make 'headers' use $header{$self->{'api'}}
85             # Currently hard coded to OpenAI
86            
87             # Added purely for testing - IGNORE!
88             sub _test {
89 0     0   0 my $self = shift;
90             # return $self->{'api'};
91 0         0 return $header{$self->{'api'}};
92             }
93            
94             # Return Embedding as a CSV string
95             sub embedding {
96 0     0 1 0 my ($self, $text, $verbose) = @_;
97            
98 0         0 my $response = $self->_get_embedding($text);
99 0 0       0 if ($response->{'success'}) {
100 0         0 my $embedding = decode_json($response->{'content'});
101 0         0 return join (',', @{$embedding->{'data'}[0]->{'embedding'}});
  0         0  
102             }
103 0         0 $self->{'error'} = 'HTTP Error - ' . $response->{'reason'};
104 0 0       0 return $response if defined $verbose;
105 0         0 return undef;
106             }
107            
108             # Return Embedding as an array
109             sub raw_embedding {
110 0     0 1 0 my ($self, $text, $verbose) = @_;
111            
112 0         0 my $response = $self->_get_embedding($text);
113 0 0       0 if ($response->{'success'}) {
114 0         0 my $embedding = decode_json($response->{'content'});
115 0         0 return @{$embedding->{'data'}[0]->{'embedding'}};
  0         0  
116             }
117 0         0 $self->{'error'} = 'HTTP Error - ' . $response->{'reason'};
118 0 0       0 return $response if defined $verbose;
119 0         0 return undef;
120             }
121            
122             # Return Test Embedding
123             sub test_embedding {
124 3     3 1 2267 my ($self, $text, $dimension) = @_;
125 3         11 $self->{'error'} = '';
126            
127 3 50       14 $dimension = 1536 unless defined $dimension;
128            
129 3 50       10 if ($text) {
130 3         30 srand scalar split /\s+/, $text;
131             }
132            
133 3         7 my @vector;
134 3         18 for (1...$dimension) {
135 4608         10330 push @vector, rand(2) - 1;
136             }
137 3         7846 return join ',', @vector;
138             }
139            
140             # Convert a CSV Embedding into a hashref
141             sub _make_vector {
142 6     6   7 my ($self, $embed_string) = @_;
143            
144 6 50       12 if (!defined $embed_string) {
145 0         0 $self->{'error'} = 'Nothing to compare!';
146 0         0 return;
147             }
148            
149 6         7 my %vector;
150 6         32 my @embed = split /,/, $embed_string;
151 6         22 for (my $i = 0; $i < @embed; $i++) {
152 55         100 $vector{'feature' . $i} = $embed[$i];
153             }
154 6         13 return \%vector;
155             }
156            
157             # Return a comparator to compare to a set vector
158             sub comparator {
159 1     1 1 735 my($self, $embed) = @_;
160 1         3 $self->{'error'} = '';
161            
162 1         3 my $vector1 = $self->_make_vector($embed);
163             return sub {
164 1     1   3 my($embed2) = @_;
165 1         3 my $vector2 = $self->_make_vector($embed2);
166 1         3 return $self->_compare_vector($vector1, $vector2);
167 1         5 };
168             }
169            
170             # Compare 2 Embeddings
171             sub compare {
172 2     2 1 5 my ($self, $embed1, $embed2) = @_;
173            
174 2         7 my $vector1 = $self->_make_vector($embed1);
175 2         3 my $vector2;
176 2 50       5 if (defined $embed2) {
177 2         3 $vector2 = $self->_make_vector($embed2);
178             } else {
179 0         0 $vector2 = $self->{'comparator'};
180             }
181            
182 2 50       4 if (!defined $vector2) {
183 0         0 $self->{'error'} = 'Nothing to compare!';
184 0         0 return;
185             }
186            
187 2 100       6 if (scalar keys %$vector1 != scalar keys %$vector2) {
188 1         10 $self->{'error'} = 'Embeds are unequal length';
189 1         5 return;
190             }
191            
192 1         4 return $self->_compare_vector($vector1, $vector2);
193             }
194            
195             # Compare 2 Vectors
196             sub _compare_vector {
197 2     2   3 my ($self, $vector1, $vector2) = @_;
198 2         16 my $cs = Data::CosineSimilarity->new;
199 2         23 $cs->add( label1 => $vector1 );
200 2         85 $cs->add( label2 => $vector2 );
201 2         44 return $cs->similarity('label1', 'label2')->cosine;
202             }
203            
204             1;
205            
206             __END__