Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion R/as.textmodel2.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@

#' @rdname as.textmodel_lss
#' @export
#' @param spatial \[experimental\] if `FALSE`, return a probabilistic model.
#' See the details.
#' @details
#' If `spatial = TRUE`, it return a spatial model; otherwise a probabilistic model.
#' While the polarity scores of words are their cosine similarity to seed words in
#' spatial models, they are predicted probability that the seed words to occur in
#' their contexts.
#'
#' @method as.textmodel_lss textmodel_word2vec
as.textmodel_lss.textmodel_word2vec <- function(x, seeds,
terms = NULL,
nested_weight = TRUE,
spatial = FALSE,
verbose = FALSE,
spatial = TRUE,
...) {

#args <- list(terms = terms, seeds = seeds)
Expand Down
7 changes: 2 additions & 5 deletions R/textmodel_lss.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,8 @@
#' in `seed_weighted` in the object.
#'
#' When `x` is a tokens or tokens_xptr object, [wordvector::textmodel_word2vec]
#' is called internally with `type = "skip-gram"` and other arguments passed via `...`.
#' If `spatial = TRUE`, it return a spatial model; otherwise a probabilistic model.
#' While the polarity scores of words are their cosine similarity to seed words in
#' spatial models, they are predicted probability that the seed words to occur in
#' their contexts. The probabilistic models are still experimental, so use them with caution.
#' is called internally with `type = "skip-gram"` and other arguments passed via
#' `...`.
#'
#' `nested_weight = TRUE` to limit the impact of glob patterns used in seed words.
#' When it is `FALSE`, the weights of the seed words are all equal being the inverse of
Expand Down
6 changes: 2 additions & 4 deletions R/textmodel_lss2.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#' @rdname textmodel_lss
#' @param spatial \[experimental\] if `FALSE`, return a probabilistic model. See the details.
#' @export
#' @inheritParams wordvector::textmodel_word2vec
#' @importFrom quanteda dfm dfm_group
Expand All @@ -10,7 +9,6 @@ textmodel_lss.tokens <- function(x, seeds, terms = NULL, k = 200,
nested_weight = TRUE,
include_data = FALSE,
group_data = FALSE,
spatial = TRUE,
verbose = FALSE, ...) {


Expand All @@ -25,8 +23,8 @@ textmodel_lss.tokens <- function(x, seeds, terms = NULL, k = 200,
w2v <- wordvector::textmodel_word2vec(x, dim = k, min_count = min_count,
type = "skip-gram", tolower = tolower,
normalize = FALSE, verbose = verbose, ...)
result <- as.textmodel_lss(w2v, seeds = seeds, terms = terms, spatial = spatial,
nested_weight = nested_weight, verbose = FALSE)
result <- as.textmodel_lss(w2v, seeds = seeds, terms = terms,
nested_weight = nested_weight, verbose = FALSE, ...)
result$type <- "word2vec"
result$call <- try(match.call(sys.function(-1), call = sys.call(-1)), silent = TRUE)

Expand Down
10 changes: 8 additions & 2 deletions man/as.textmodel_lss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/as.textmodel_lss.textmodel_wordvector.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 2 additions & 8 deletions man/textmodel_lss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-as.textmodel2.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ test_that("as.textmodel_lss works with textmodel_wordvector", {

# spatial
wdv <- readRDS("../data/word2vec.RDS")
lss <- as.textmodel_lss(wdv, seed)
lss <- as.textmodel_lss(wdv, seed, spatial = TRUE)

expect_equal(lss$beta_type, "similarity")
expect_equal(lss$embedding, t(wdv$values))
Expand Down
14 changes: 8 additions & 6 deletions tests/testthat/test-textmodel_lss2.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ test_that("textmodel_lss works when spatial = TRUE", {
skip_on_cran()

# without data
lss1 <- textmodel_lss(toks_test, seed_test, k = 10)
lss1 <- textmodel_lss(toks_test, seed_test, k = 10, spatial = TRUE)

expect_s3_class(lss1, "textmodel_lss")
expect_equal(lss1$k, 10)
Expand All @@ -30,7 +30,8 @@ test_that("textmodel_lss works when spatial = TRUE", {
)

# with data
lss2 <- textmodel_lss(toks_test, seed_test, k = 10, include_data = TRUE)
lss2 <- textmodel_lss(toks_test, seed_test, k = 10, include_data = TRUE,
spatial = TRUE)

expect_s3_class(lss2, "textmodel_lss")
expect_equal(lss2$concatenator, concatenator(toks_test))
Expand All @@ -42,7 +43,7 @@ test_that("textmodel_lss works when spatial = TRUE", {

# with terms
lss3 <- textmodel_lss(toks_test, seed_test, k = 10, terms = feat_test,
include_data = TRUE, group_data = TRUE)
include_data = TRUE, group_data = TRUE, spatial = TRUE)

expect_s3_class(lss3, "textmodel_lss")
expect_true(all(names(lss3$beta) %in% feat_test))
Expand All @@ -53,15 +54,15 @@ test_that("textmodel_lss works when spatial = TRUE", {

# with tokens_xptr
lss4 <- textmodel_lss(as.tokens_xptr(toks_test), seed_test, k = 10,
include_data = TRUE)
include_data = TRUE, spatial = TRUE)

expect_s3_class(lss4, "textmodel_lss")
expect_equal(docnames(lss4$data), docnames(toks_test))

# warning
expect_warning(
textmodel_lss(toks_test, seed_test, k = 10,
include_data = FALSE, group_data = TRUE),
include_data = FALSE, group_data = TRUE, spatial = TRUE),
"group_data is ignored when include_data = FALSE"
)

Expand Down Expand Up @@ -99,7 +100,8 @@ test_that("textmodel_lss works when spatial = FALSE", {
)

# with data
lss2 <- textmodel_lss(toks_test, seed_test, k = 10, include_data = TRUE, spatial = FALSE)
lss2 <- textmodel_lss(toks_test, seed_test, k = 10, include_data = TRUE,
spatial = FALSE)

expect_s3_class(lss2, "textmodel_lss")
expect_equal(lss2$concatenator, concatenator(toks_test))
Expand Down
Loading