File Coverage

blib/lib/AI/Embedding.pm
Criterion Covered Total %
statement 17 83 20.4
branch 0 24 0.0
condition n/a
subroutine 6 17 35.2
pod 7 9 77.7
total 30 133 22.5


line stmt bran cond sub pod time code
1             package AI::Embedding;
2            
3 1     1   70395 use strict;
  1         2  
  1         29  
4 1     1   4 use warnings;
  1         2  
  1         23  
5            
6 1     1   895 use HTTP::Tiny;
  1         50995  
  1         32  
7 1     1   732 use JSON::PP;
  1         15663  
  1         64  
8 1     1   489 use Data::CosineSimilarity;
  1         15402  
  1         1007  
9            
10             our $VERSION = '0.1_2';
11             $VERSION = eval $VERSION;
12            
13             my $http = HTTP::Tiny->new;
14            
15             # Create Embedding object
16             sub new {
17 0     0 1 0 my $class = shift;
18 0         0 my %attr = @_;
19            
20 0         0 $attr{'error'} = '';
21            
22 0 0       0 $attr{'api'} = 'OpenAI' unless $attr{'api'};
23 0 0       0 $attr{'error'} = 'Invalid API' unless $attr{'api'} eq 'OpenAI';
24 0 0       0 $attr{'error'} = 'API Key missing' unless $attr{'key'};
25            
26 0 0       0 $attr{'model'} = 'text-embedding-ada-002' unless $attr{'model'};
27            
28 0 0       0 $attr{'comparator'} = $class->_make_vector($attr{'comparator'}) if exists $attr{'comparator'};
29            
30 0         0 return bless \%attr, $class;
31             }
32            
33             # Define endpoints for APIs
34             my %url = (
35             'OpenAI' => 'https://api.openai.com/v1/embeddings',
36             );
37            
38             # Define HTTP Headers for APIs
39             my %header = (
40             'OpenAI' => &_get_header_openai,
41             );
42            
43             # Returns true if last operation was success
44             sub success {
45 0     0 1 0 my $self = shift;
46 0         0 return !$self->{'error'};
47             }
48            
49             # Returns error if last operation failed
50             sub error {
51 0     0 1 0 my $self = shift;
52 0         0 return $self->{'error'};
53             }
54            
55             # Header for calling OpenAI
56             sub _get_header_openai {
57 1     1   2 my $self = shift;
58             return {
59 1         65 'Authorization' => 'Bearer ' . $self->{'key'},
60             'Content-type' => 'application/json'
61             };
62             }
63            
64             # Fetch Embedding response
65             sub _get_embedding {
66 0     0     my ($self, $text) = @_;
67            
68             return $http->post($url{$self->{'api'}}, {
69             'headers' => {
70             'Authorization' => 'Bearer ' . $self->{'key'},
71             'Content-type' => 'application/json'
72             },
73             content => encode_json {
74             input => $text,
75 0           model => $self->{'model'},
76             }
77             });
78             }
79            
80             # TODO:
81             # Make 'headers' use $header{$self->{'api'}}
82             # Currently hard coded to OpenAI
83            
84             # Added purely for testing - IGNORE!
85             sub test {
86 0     0 0   my $self = shift;
87             # return $self->{'api'};
88 0           return $header{$self->{'api'}};
89             }
90            
91             # Return Embedding as a CSV string
92             sub embedding {
93 0     0 1   my ($self, $text) = @_;
94            
95 0           my $response = $self->_get_embedding($text);
96 0 0         if ($response->{'success'}) {
97 0           my $embedding = decode_json($response->{'content'});
98 0           return join (',', @{$embedding->{'data'}[0]->{'embedding'}});
  0            
99             }
100 0           $self->{'error'} = 'HTTP Error - ' . $response->{'reason'};
101 0           return $response;
102             }
103            
104             # Return Embedding as an array
105             sub raw_embedding {
106 0     0 1   my ($self, $text) = @_;
107            
108 0           my $response = $self->_get_embedding($text);
109 0 0         if ($response->{'success'}) {
110 0           my $embedding = decode_json($response->{'content'});
111 0           return @{$embedding->{'data'}[0]->{'embedding'}};
  0            
112             }
113 0           $self->{'error'} = 'HTTP Error - ' . $response->{'reason'};
114 0           return $response;
115             }
116            
117             # Return Test Embedding
118             sub test_embedding {
119 0     0 0   my ($self, $text, $dimension) = @_;
120            
121 0 0         $dimension = 1536 unless defined $dimension;
122            
123 0 0         if ($text) {
124 0           srand scalar split /\s+/, $text;
125             }
126            
127 0           my @vector;
128 0           for (1...$dimension) {
129 0           push @vector, rand(2) - 1;
130             }
131 0           return join ',', @vector;
132             }
133            
134             # Convert a CSV Embedding into a hashref
135             sub _make_vector {
136 0     0     my ($self, $embed_string) = @_;
137            
138 0           my %vector;
139 0           my @embed = split /,/, $embed_string;
140 0           for (my $i = 0; $i < @embed; $i++) {
141 0           $vector{'feature' . $i} = $embed[$i];
142             }
143 0           return \%vector;
144             }
145            
146             # Set a vector to compare
147             sub comparator {
148 0     0 1   my ($self, $embed) = @_;
149            
150 0           $self->{'comparator'} = $self->_make_vector($embed);
151 0           return;
152             }
153            
154             # Compare 2 Embeddings
155             sub compare {
156 0     0 1   my ($self, $embed1, $embed2) = @_;
157            
158 0           my $vector1 = $self->_make_vector($embed1);
159 0           my $vector2;
160 0 0         if (defined $embed2) {
161 0           $vector2 = $self->_make_vector($embed2);
162             } else {
163 0           $vector2 = $self->{'comparator'};
164             }
165            
166 0 0         if (!defined $vector2) {
167 0           $self->{'error'} = 'Nothing to compare!';
168 0           return;
169             }
170            
171 0 0         if (scalar keys %$vector1 != scalar keys %$vector2) {
172 0           $self->{'error'} = 'Embeds are unequal length';
173 0           return;
174             }
175            
176 0           my $cs = Data::CosineSimilarity->new;
177 0           $cs->add( label1 => $vector1 );
178 0           $cs->add( label2 => $vector2 );
179 0           return $cs->similarity('label1', 'label2')->cosine;
180             }
181            
182             1;
183            
184             __END__