Skip to content

Commit 7bdcb76

Browse files
committed
Fix variable-length arrays
1 parent c304bc9 commit 7bdcb76

6 files changed

Lines changed: 83 additions & 66 deletions

File tree

R/paragraph2vec.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ summary.paragraph2vec_trained <- function(object, type = "vocabulary", which = c
284284
#' \item{'doc2doc', 'word2doc', 'word2word', 'sent2doc' can be chosen if \code{type} is set to 'nearest' indicating to extract respectively
285285
#' the closest document to a document (doc2doc), the closest document to a word (word2doc), the closest word to a word (word2word) or the closest document to sentences (sent2doc).}
286286
#' }
287-
#' @param top_n show only the top n nearest neighbours. Defaults to 10. Only used for \code{type} 'nearest'.
287+
#' @param top_n show only the top n nearest neighbours. Defaults to 10, with a maximum value of 100. Only used for \code{type} 'nearest'.
288288
#' @param normalize logical indicating to normalize the embeddings. Defaults to \code{TRUE}. Only used for \code{type} 'embedding'.
289289
#' @param encoding set the encoding of the text elements to the specified encoding. Defaults to 'UTF-8'.
290290
#' @param ... not used
@@ -350,6 +350,7 @@ predict.paragraph2vec <- function(object, newdata,
350350
type <- match.arg(type)
351351
which <- match.arg(which)
352352
top_n <- as.integer(top_n)
353+
stopifnot(top_n <= 100)
353354
if(type == "embedding"){
354355
stopifnot(which %in% c("docs", "words"))
355356
if(is.character(newdata)){

README.md

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,18 @@ x <- subset(x, nwords < 1000 & nchar(text) > 0)
5050
- Build the model
5151

5252

53-
```r
54-
model <- paragraph2vec(x = x, type = "PV-DBOW", dim = 100, iter = 20, min_count = 5,
55-
lr = 0.05, threads = 4)
56-
```
57-
58-
5953
```r
6054
## Low-dimensional model using DM, low number of iterations, for speed and display purposes
61-
model <- paragraph2vec(x = x, type = "PV-DM", dim = 5, iter = 3, min_count = 5,
62-
lr = 0.05, threads = 1)
55+
model <- paragraph2vec(x = x, type = "PV-DM", dim = 5, iter = 3,
56+
min_count = 5, lr = 0.05, threads = 1)
6357
str(model)
6458
```
6559

6660
```
6761
## List of 3
6862
## $ model :<externalptr>
6963
## $ data :List of 4
70-
## ..$ file : chr "C:\\Users\\Jan\\AppData\\Local\\Temp\\RtmpApjuPd\\textspace_1ef05c50176.txt"
64+
## ..$ file : chr "C:\\Users\\Jan\\AppData\\Local\\Temp\\Rtmpk9Npjg\\textspace_1c4458cb6943.txt"
7165
## ..$ n : num 170469
7266
## ..$ n_vocabulary: num 3867
7367
## ..$ n_docs : num 1000
@@ -84,6 +78,13 @@ str(model)
8478
## - attr(*, "class")= chr "paragraph2vec_trained"
8579
```
8680

81+
82+
```r
83+
## More realistic model
84+
model <- paragraph2vec(x = x, type = "PV-DBOW", dim = 100, iter = 20,
85+
min_count = 5, lr = 0.05, threads = 4)
86+
```
87+
8788
- Get the embedding of the documents or words and get the vocabulary
8889

8990

@@ -104,14 +105,22 @@ sentences <- list(
104105
embedding <- predict(model, newdata = sentences, type = "embedding")
105106
embedding <- predict(model, newdata = c("geld", "koning"), type = "embedding", which = "words")
106107
embedding <- predict(model, newdata = c("doc_1", "doc_10", "doc_3"), type = "embedding", which = "docs")
107-
embedding
108+
ncol(embedding)
108109
```
109110

110111
```
111-
## [,1] [,2] [,3] [,4] [,5]
112-
## doc_1 0.09160496 0.5503142 -0.5195833 0.162630379 -0.62637627
113-
## doc_10 0.43539885 0.1009961 -0.8531511 0.266749799 0.03471836
114-
## doc_3 0.59375095 0.3877517 -0.6868675 0.002579026 -0.15910600
112+
## [1] 100
113+
```
114+
115+
```r
116+
embedding[, 1:4]
117+
```
118+
119+
```
120+
## [,1] [,2] [,3] [,4]
121+
## doc_1 0.08172660 -0.03679979 0.05726605 -0.06496991
122+
## doc_10 0.13976580 0.10821507 -0.06986591 -0.05825572
123+
## doc_3 0.09486584 -0.07999156 0.03448128 0.02999697
115124
```
116125

117126
- Get similar documents or words when providing sentences, documents or words
@@ -124,20 +133,20 @@ nn
124133

125134
```
126135
## [[1]]
127-
## term1 term2 similarity rank
128-
## 1 proximus neemt 0.9994797 1
129-
## 2 proximus plaatse 0.9994527 2
130-
## 3 proximus ver 0.9993714 3
131-
## 4 proximus gratis 0.9992922 4
132-
## 5 proximus hiermee 0.9992417 5
136+
## term1 term2 similarity rank
137+
## 1 proximus telefoontoestellen 0.5571629 1
138+
## 2 proximus belfius 0.4994604 2
139+
## 3 proximus toenmalige 0.4873388 3
140+
## 4 proximus internetverbinding 0.4730936 4
141+
## 5 proximus gefactureerd 0.4568973 5
133142
##
134143
## [[2]]
135-
## term1 term2 similarity rank
136-
## 1 koning pleiten 0.9984228 1
137-
## 2 koning ongeacht 0.9983451 2
138-
## 3 koning pensionering 0.9982112 3
139-
## 4 koning profielen 0.9981233 4
140-
## 5 koning beschermd 0.9978001 5
144+
## term1 term2 similarity rank
145+
## 1 koning grondwet 0.5572801 1
146+
## 2 koning verplaatsingen 0.5373006 2
147+
## 3 koning ministerie 0.5140343 3
148+
## 4 koning familie 0.4943074 4
149+
## 5 koning vereiste 0.4715540 5
141150
```
142151

143152
```r
@@ -148,19 +157,19 @@ nn
148157
```
149158
## [[1]]
150159
## term1 term2 similarity rank
151-
## 1 proximus doc_77 0.9989672 1
152-
## 2 proximus doc_263 0.9989251 2
153-
## 3 proximus doc_260 0.9982057 3
154-
## 4 proximus doc_344 0.9980863 4
155-
## 5 proximus doc_408 0.9979483 5
160+
## 1 proximus doc_105 0.6922343 1
161+
## 2 proximus doc_863 0.5826316 2
162+
## 3 proximus doc_186 0.5146015 3
163+
## 4 proximus doc_862 0.5051525 4
164+
## 5 proximus doc_746 0.4467830 5
156165
##
157166
## [[2]]
158167
## term1 term2 similarity rank
159-
## 1 koning doc_553 0.9980003 1
160-
## 2 koning doc_477 0.9964797 2
161-
## 3 koning doc_658 0.9955103 3
162-
## 4 koning doc_99 0.9953933 4
163-
## 5 koning doc_163 0.9953347 5
168+
## 1 koning doc_44 0.6228581 1
169+
## 2 koning doc_583 0.5643232 2
170+
## 3 koning doc_45 0.5535781 3
171+
## 4 koning doc_797 0.4408725 4
172+
## 5 koning doc_943 0.4039679 5
164173
```
165174

166175
```r
@@ -171,19 +180,19 @@ nn
171180
```
172181
## [[1]]
173182
## term1 term2 similarity rank
174-
## 1 doc_198 doc_882 0.9992993 1
175-
## 2 doc_198 doc_709 0.9990637 2
176-
## 3 doc_198 doc_122 0.9989671 3
177-
## 4 doc_198 doc_121 0.9988763 4
178-
## 5 doc_198 doc_569 0.9988336 5
183+
## 1 doc_198 doc_343 0.4893735 1
184+
## 2 doc_198 doc_569 0.4858374 2
185+
## 3 doc_198 doc_358 0.4831750 3
186+
## 4 doc_198 doc_498 0.4766597 4
187+
## 5 doc_198 doc_983 0.4761481 5
179188
##
180189
## [[2]]
181190
## term1 term2 similarity rank
182-
## 1 doc_285 doc_722 0.9988106 1
183-
## 2 doc_285 doc_467 0.9977189 2
184-
## 3 doc_285 doc_250 0.9976925 3
185-
## 4 doc_285 doc_174 0.9975280 4
186-
## 5 doc_285 doc_294 0.9968556 5
191+
## 1 doc_285 doc_319 0.5304061 1
192+
## 2 doc_285 doc_286 0.5205777 2
193+
## 3 doc_285 doc_76 0.5086077 3
194+
## 4 doc_285 doc_74 0.4975725 4
195+
## 5 doc_285 doc_537 0.4802507 5
187196
```
188197

189198
```r
@@ -197,19 +206,19 @@ nn
197206
```
198207
## $sent1
199208
## term1 term2 similarity rank
200-
## 1 sent1 doc_980 0.9784521 1
201-
## 2 sent1 doc_758 0.9678799 2
202-
## 3 sent1 doc_806 0.9547009 3
203-
## 4 sent1 doc_764 0.9544759 4
204-
## 5 sent1 doc_842 0.9529226 5
209+
## 1 sent1 doc_740 0.4637638 1
210+
## 2 sent1 doc_742 0.4621139 2
211+
## 3 sent1 doc_206 0.4315273 3
212+
## 4 sent1 doc_825 0.4221503 4
213+
## 5 sent1 doc_151 0.4183135 5
205214
##
206215
## $sent2
207216
## term1 term2 similarity rank
208-
## 1 sent2 doc_842 0.9873239 1
209-
## 2 sent2 doc_764 0.9832168 2
210-
## 3 sent2 doc_564 0.9739662 3
211-
## 4 sent2 doc_980 0.9675324 4
212-
## 5 sent2 doc_542 0.9622889 5
217+
## 1 sent2 doc_105 0.5789919 1
218+
## 2 sent2 doc_186 0.4938067 2
219+
## 3 sent2 doc_862 0.4848365 3
220+
## 4 sent2 doc_863 0.4685720 4
221+
## 5 sent2 doc_620 0.4497271 5
213222
```
214223

215224
```r

man/predict.paragraph2vec.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/RcppExports.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,28 +62,28 @@ BEGIN_RCPP
6262
END_RCPP
6363
}
6464
// paragraph2vec_nearest
65-
Rcpp::DataFrame paragraph2vec_nearest(SEXP ptr, std::string x, std::size_t top_n, std::string type);
65+
Rcpp::DataFrame paragraph2vec_nearest(SEXP ptr, std::string x, int top_n, std::string type);
6666
RcppExport SEXP _doc2vec_paragraph2vec_nearest(SEXP ptrSEXP, SEXP xSEXP, SEXP top_nSEXP, SEXP typeSEXP) {
6767
BEGIN_RCPP
6868
Rcpp::RObject rcpp_result_gen;
6969
Rcpp::RNGScope rcpp_rngScope_gen;
7070
Rcpp::traits::input_parameter< SEXP >::type ptr(ptrSEXP);
7171
Rcpp::traits::input_parameter< std::string >::type x(xSEXP);
72-
Rcpp::traits::input_parameter< std::size_t >::type top_n(top_nSEXP);
72+
Rcpp::traits::input_parameter< int >::type top_n(top_nSEXP);
7373
Rcpp::traits::input_parameter< std::string >::type type(typeSEXP);
7474
rcpp_result_gen = Rcpp::wrap(paragraph2vec_nearest(ptr, x, top_n, type));
7575
return rcpp_result_gen;
7676
END_RCPP
7777
}
7878
// paragraph2vec_nearest_sentence
79-
Rcpp::List paragraph2vec_nearest_sentence(SEXP ptr, Rcpp::List x, std::size_t top_n);
79+
Rcpp::List paragraph2vec_nearest_sentence(SEXP ptr, Rcpp::List x, int top_n);
8080
RcppExport SEXP _doc2vec_paragraph2vec_nearest_sentence(SEXP ptrSEXP, SEXP xSEXP, SEXP top_nSEXP) {
8181
BEGIN_RCPP
8282
Rcpp::RObject rcpp_result_gen;
8383
Rcpp::RNGScope rcpp_rngScope_gen;
8484
Rcpp::traits::input_parameter< SEXP >::type ptr(ptrSEXP);
8585
Rcpp::traits::input_parameter< Rcpp::List >::type x(xSEXP);
86-
Rcpp::traits::input_parameter< std::size_t >::type top_n(top_nSEXP);
86+
Rcpp::traits::input_parameter< int >::type top_n(top_nSEXP);
8787
rcpp_result_gen = Rcpp::wrap(paragraph2vec_nearest_sentence(ptr, x, top_n));
8888
return rcpp_result_gen;
8989
END_RCPP

src/doc2vec/common_define.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#define MAX_EXP 6
1414
#define MAX_SENTENCE_LENGTH 1000
1515
#define MAX_CODE_LENGTH 40
16+
#define MAX_DOC2VEC_KNN_R 100
1617
#define MAX_DOC2VEC_KNN 2000
1718
const int vocab_hash_size = 30000000;
1819
const int negtive_sample_table_size = 1e8;

src/rcpp_doc2vec.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ std::vector<std::string> paragraph2vec_dictionary(SEXP ptr, std::string type = "
9292

9393

9494
// [[Rcpp::export]]
95-
Rcpp::DataFrame paragraph2vec_nearest(SEXP ptr, std::string x, std::size_t top_n = 10, std::string type = "doc2doc") {
95+
Rcpp::DataFrame paragraph2vec_nearest(SEXP ptr, std::string x, int top_n = 10, std::string type = "doc2doc") {
9696
Rcpp::XPtr<Doc2Vec> model(ptr);
97-
knn_item_t knn_items[top_n];
97+
knn_item_t knn_items[MAX_DOC2VEC_KNN_R];
9898
if(type == "doc2doc"){
9999
model->doc_knn_docs(x.c_str(), knn_items, top_n);
100100
}else if(type == "word2doc"){
@@ -114,6 +114,9 @@ Rcpp::DataFrame paragraph2vec_nearest(SEXP ptr, std::string x, std::size_t top_n
114114
distance.push_back(kv.similarity);
115115
r = r + 1;
116116
rank.push_back(r);
117+
if(r >= top_n || r >= MAX_DOC2VEC_KNN_R) {
118+
break;
119+
}
117120
}
118121
Rcpp::DataFrame out = Rcpp::DataFrame::create(
119122
Rcpp::Named("term1") = x,
@@ -126,7 +129,7 @@ Rcpp::DataFrame paragraph2vec_nearest(SEXP ptr, std::string x, std::size_t top_n
126129
}
127130

128131
// [[Rcpp::export]]
129-
Rcpp::List paragraph2vec_nearest_sentence(SEXP ptr, Rcpp::List x, std::size_t top_n = 10) {
132+
Rcpp::List paragraph2vec_nearest_sentence(SEXP ptr, Rcpp::List x, int top_n = 10) {
130133
Rcpp::XPtr<Doc2Vec> model(ptr);
131134
real * infer_vector = NULL;
132135
//int errnr = posix_memalign((void **)&infer_vector, 128, model->dim() * sizeof(real));
@@ -146,7 +149,7 @@ Rcpp::List paragraph2vec_nearest_sentence(SEXP ptr, Rcpp::List x, std::size_t to
146149
}
147150
model->infer_doc(&doc, infer_vector);
148151
// Get closest docs to sentence
149-
knn_item_t knn_items[top_n];
152+
knn_item_t knn_items[MAX_DOC2VEC_KNN_R];
150153
model->sent_knn_docs(&doc, knn_items, top_n, infer_vector);
151154
// Collect result in data.frame
152155
std::vector<std::string> keys;
@@ -159,6 +162,9 @@ Rcpp::List paragraph2vec_nearest_sentence(SEXP ptr, Rcpp::List x, std::size_t to
159162
distance.push_back(kv.similarity);
160163
r = r + 1;
161164
rank.push_back(r);
165+
if(r >= top_n || r >= MAX_DOC2VEC_KNN_R) {
166+
break;
167+
}
162168
}
163169
Rcpp::DataFrame out = Rcpp::DataFrame::create(
164170
Rcpp::Named("term1") = Rcpp::as<std::string>(rownames_(i)),

0 commit comments

Comments
 (0)