Word Vectors with tidy data principles
Oct 30, 2017
11 minute read

Last week I saw Chris Moody’s post on the Stitch Fix blog about calculating word vectors from a corpus of text using word counts and matrix factorization, and I was so excited! This blog post illustrates how to implement that approach to find word vector representations in R using tidy data principles and sparse matrices.

Word vectors, or word embeddings, are typically calculated using neural networks; that is what word2vec is. (GloVe embeddings are trained a little differently than word2vec.) By contrast, the approach from Chris’s post that I’m implementing here uses only counting and some linear algebra. Deep learning is great, but I am super excited about this approach because it allows practitioners to find word vectors for their own collections of text (no need to rely on pre-trained vectors) using familiar techniques that are not difficult to understand. And it doesn’t take too long computationally!

Getting some data

Let’s download half a million observations from… the Hacker News corpus.

via GIPHY

I know, right? But it’s the dataset that Chris uses in his blog post and it gives me an opportunity to use the bigrquery package for the first time.

library(bigrquery)
library(tidyverse)

project <- "my-first-project-184003"

sql <- "#legacySQL
SELECT
  stories.title AS title,
  stories.text AS text
FROM
  [bigquery-public-data:hacker_news.full] AS stories
WHERE
  stories.deleted IS NULL
LIMIT
  500000"

hacker_news_raw <- query_exec(sql, project = project, max_pages = Inf)

Next, let’s clean this text up to take care of some of the messy ways it has gotten encoded.

library(stringr)

hacker_news_text <- hacker_news_raw %>%
    as_tibble() %>%
    mutate(title = na_if(title, ""),
           text = coalesce(title, text)) %>%
    select(-title) %>%
    mutate(text = str_replace_all(text, "&quot;|&#x2F;", "'"),    ## hex encoding
           text = str_replace_all(text, "&#x2F;", "/"),           ## more hex
           text = str_replace_all(text, "<a(.*?)>", " "),         ## links 
           text = str_replace_all(text, "&gt;|&lt;", " "),        ## html yuck
           text = str_replace_all(text, "<[^>]*>", " "),          ## mmmmm, more html yuck
           postID = row_number())

Unigram probabilities

First, let’s calculate the unigram probabilities, how often we see each word in this corpus. This is straightforward using unnest_tokens() from the tidytext package and then just count() and mutate() from dplyr.

library(tidytext)

unigram_probs <- hacker_news_text %>%
    unnest_tokens(word, text) %>%
    count(word, sort = TRUE) %>%
    mutate(p = n / sum(n))

unigram_probs
## # A tibble: 314,615 x 3
##     word       n          p
##    <chr>   <int>      <dbl>
##  1   the 1101051 0.04059959
##  2    to  761660 0.02808506
##  3     a  669142 0.02467360
##  4    of  541939 0.01998318
##  5   and  532594 0.01963860
##  6     i  464917 0.01714311
##  7   x27  437330 0.01612588
##  8    is  431849 0.01592378
##  9  that  429755 0.01584657
## 10    it  402667 0.01484774
## # ... with 314,605 more rows

Skipgram probabilities

Next, we need to calculate the skipgram probabilities, how often we find each word near each other word. We do this by defining a fixed-size moving window that centers around each word. Do we see word1 and word2 together within this window? I take the approach here of using unnest_tokens() once with token = "ngrams" to find all the windows I need, then using unnest_tokens() again to tidy these n-grams. After that, I can use pairwise_count() from the widyr package to count up cooccuring pairs within each n-gram/sliding window.

I’m not sure what the ideal value for window size is here for the skipgrams. This value determines the sliding window that we move through the text, counting up bigrams that we find within the window. When this window is bigger, the process of counting skipgrams takes longer, obviously. I experimented a bit and windows of 8 words seem to work pretty well. Probably more work needed here! I’d be happy to be pointed to more resources on this topic.

Finding all the skipgrams is a computationally expensive part of this process. Not something that just runs instantly!

library(widyr)

tidy_skipgrams <- hacker_news_text %>%
    unnest_tokens(ngram, text, token = "ngrams", n = 8) %>%
    mutate(ngramID = row_number()) %>% 
    unite(skipgramID, postID, ngramID) %>%
    unnest_tokens(word, ngram)

tidy_skipgrams
## # A tibble: 190,151,488 x 2
##    skipgramID   word
##         <chr>  <chr>
##  1        1_1      i
##  2        1_1    bet
##  3        1_1 taking
##  4        1_1      a
##  5        1_1    few
##  6        1_1 months
##  7        1_1    off
##  8        1_1   from
##  9        1_2    bet
## 10        1_2 taking
## # ... with 190,151,478 more rows
skipgram_probs <- tidy_skipgrams %>%
    pairwise_count(word, skipgramID, diag = TRUE, sort = TRUE) %>%
    mutate(p = n / sum(n))

Normalized skipgram probability

We now know how often words occur on their own, and how often words occur together with other words. We can calculate which words occurred together more often than expected based on how often they occurred on their own. When this number is high (greater than 1), the two words are associated with each other, likely to occur together. When this number is low (less than 1), the two words are not associated with each other, unlikely to occur together.

normalized_prob <- skipgram_probs %>%
    filter(n > 20) %>%
    rename(word1 = item1, word2 = item2) %>%
    left_join(unigram_probs %>%
                  select(word1 = word, p1 = p),
              by = "word1") %>%
    left_join(unigram_probs %>%
                  select(word2 = word, p2 = p),
              by = "word2") %>%
    mutate(p_together = p / p1 / p2)

What are the words most associated with Facebook on Hacker News?

normalized_prob %>% 
    filter(word1 == "facebook") %>%
    arrange(-p_together)
## # A tibble: 1,767 x 7
##       word1            word2     n            p           p1           p2 p_together
##       <chr>            <chr> <dbl>        <dbl>        <dbl>        <dbl>      <dbl>
##  1 facebook         facebook 54505 3.737944e-05 0.0003310502 3.310502e-04  341.07126
##  2 facebook        messenger   364 2.496306e-07 0.0003310502 1.098830e-05   68.62360
##  3 facebook         statuses    40 2.743194e-08 0.0003310502 1.474940e-06   56.18086
##  4 facebook       zuckerburg    23 1.577336e-08 0.0003310502 8.480903e-07   56.18086
##  5 facebook          myspace   327 2.242561e-07 0.0003310502 1.290572e-05   52.48898
##  6 facebook         newsfeed    32 2.194555e-08 0.0003310502 1.327446e-06   49.93854
##  7 facebook           hiphop    29 1.988815e-08 0.0003310502 1.438066e-06   41.77551
##  8 facebook     mashable.com    25 1.714496e-08 0.0003310502 1.253699e-06   41.30946
##  9 facebook            gtalk    22 1.508757e-08 0.0003310502 1.216825e-06   37.45391
## 10 facebook www.facebook.com    47 3.223253e-08 0.0003310502 2.765512e-06   35.20667
## # ... with 1,757 more rows

What about the programming language Scala?

normalized_prob %>% 
    filter(word1 == "scala") %>%
    arrange(-p_together)
## # A tibble: 453 x 7
##    word1        word2     n            p           p1           p2 p_together
##    <chr>        <chr> <dbl>        <dbl>        <dbl>        <dbl>      <dbl>
##  1 scala        scala  9418 6.458850e-06 5.394592e-05 5.394592e-05 2219.41246
##  2 scala      odersky    54 3.703312e-08 5.394592e-05 1.216825e-06  564.16154
##  3 scala          sbt    36 2.468874e-08 5.394592e-05 1.401193e-06  326.61984
##  4 scala         akka    74 5.074908e-08 5.394592e-05 2.913006e-06  322.94479
##  5 scala       groovy    88 6.035026e-08 5.394592e-05 5.494150e-06  203.61983
##  6 scala           mu    23 1.577336e-08 5.394592e-05 1.769928e-06  165.20008
##  7 scala       kotlin    88 6.035026e-08 5.394592e-05 7.448445e-06  150.19482
##  8 scala constructors    23 1.577336e-08 5.394592e-05 2.728638e-06  107.15681
##  9 scala      clojure   475 3.257543e-07 5.394592e-05 5.678518e-05  106.33997
## 10 scala    idiomatic    53 3.634732e-08 5.394592e-05 7.337825e-06   91.82194
## # ... with 443 more rows

Looks good!

Cast to a sparse matrix

We want to do matrix factorization, so we should probably make a matrix. We can use cast_sparse() from the tidytext package to transform our tidy data frame to a matrix.

pmi_matrix <- normalized_prob %>%
    mutate(pmi = log10(p_together)) %>%
    cast_sparse(word1, word2, pmi)

What is the type of this object?

class(pmi_matrix)
## [1] "dgCMatrix"
## attr(,"package")
## [1] "Matrix"

The dgCMatrix class is a class of sparse numeric matrices in R. Text data like this represented in matrix form usually has lots and lots of zeroes, so we want to make use of sparse data structures to save us time and memory and all that.

Reduce the matrix dimensionality

We want to get information out of this giant matrix in a more useful form, so it’s time for singular value decomposition. Since we have a sparse matrix, we don’t want to use base R’s svd function, which casts the input to a plain old matrix (not sparse) first thing. Instead we will use the fast SVD algorithm for sparse matrices in the irlba package.

library(irlba)

pmi_svd <- irlba(pmi_matrix, 256, maxit = 1e3)

The number 256 here means that we are finding 256-dimensional vectors for the words. This is another thing that I am not sure exactly what the best number is, but it will be easy to experiment with. Doing the matrix factorization is another part of this process that is a bit time intensive, but certainly not slow compared to training word2vec on a big corpus. In my experimenting here, it takes less time than counting up the skipgrams.

Once we have the singular value decomposition, we can get out the word vectors! Let’s set some row names, using our input, so we can find out what is what.

word_vectors <- pmi_svd$u
rownames(word_vectors) <- rownames(pmi_matrix)

Now we can search our matrix of word vectors to find synonyms. I want to get back to a tidy data structure at this point, so I’ll write a new little function for tidying.

library(broom)

search_synonyms <- function(word_vectors, selected_vector) {
    
    similarities <- word_vectors %*% selected_vector %>%
        tidy() %>%
        as_tibble() %>%
        rename(token = .rownames,
               similarity = unrowname.x.)
    
    similarities %>%
        arrange(-similarity)    
}

facebook <- search_synonyms(word_vectors, word_vectors["facebook",])
facebook
## # A tibble: 68,664 x 2
##        token similarity
##        <chr>      <dbl>
##  1  facebook 0.07622816
##  2   twitter 0.05477529
##  3    google 0.04833987
##  4    social 0.04367801
##  5        fb 0.03597795
##  6   account 0.03304327
##  7 instagram 0.02955428
##  8     users 0.02581671
##  9    photos 0.02502522
## 10   friends 0.02412458
## # ... with 68,654 more rows
haskell <- search_synonyms(word_vectors, word_vectors["haskell",])
haskell
## # A tibble: 68,664 x 2
##          token similarity
##          <chr>      <dbl>
##  1     haskell 0.04100067
##  2        lisp 0.03796378
##  3   languages 0.03710574
##  4    language 0.03176253
##  5  functional 0.03022514
##  6 programming 0.02743172
##  7       scala 0.02667198
##  8     clojure 0.02652761
##  9      python 0.02462822
## 10      erlang 0.02452040
## # ... with 68,654 more rows

That’s… pretty darn amazing. Let’s visualize the most similar words vector to Facebook and Haskell from this dataset of Hacker News posts.

facebook %>%
    mutate(selected = "facebook") %>%
    bind_rows(haskell %>%
                  mutate(selected = "haskell")) %>%
    group_by(selected) %>%
    top_n(15, similarity) %>%
    ungroup %>%
    mutate(token = reorder(token, similarity)) %>%
    ggplot(aes(token, similarity, fill = selected)) +
    geom_col(show.legend = FALSE) +
    facet_wrap(~selected, scales = "free") +
    coord_flip() +
    theme(strip.text=element_text(hjust=0, family="Roboto-Bold", size=12)) +
    scale_y_continuous(expand = c(0,0)) +
    labs(x = NULL, title = "What word vectors are most similar to Facebook or Haskell?",
         subtitle = "Based on the Hacker News corpus, calculated using counts and matrix factorization")

We can also do the familiar WORD MATH that is so fun with the output of word2vec; you have probably seen examples such as King - Man + Woman = Queen and such. We can just add and subtract our word vectors, and then search the matrix we built!

If the iPhone is an important product associated with Apple, as discussed on Hacker News, what is an important product associated with Microsoft?

mystery_product <- word_vectors["iphone",] - word_vectors["apple",] + word_vectors["microsoft",]
search_synonyms(word_vectors, mystery_product)
## # A tibble: 68,664 x 2
##        token similarity
##        <chr>      <dbl>
##  1   windows 0.03861497
##  2     phone 0.02818951
##  3    iphone 0.02444977
##  4    mobile 0.02427238
##  5   android 0.02321326
##  6 microsoft 0.02233597
##  7       ios 0.02138345
##  8    office 0.02000720
##  9         7 0.01997405
## 10         8 0.01970687
## # ... with 68,654 more rows

We even see some mobile phone and Android terms in this list, below Windows.

What about an important product associated with Google?

mystery_product <- word_vectors["iphone",] - word_vectors["apple",] + word_vectors["google",]
search_synonyms(word_vectors, mystery_product)
## # A tibble: 68,664 x 2
##       token similarity
##       <chr>      <dbl>
##  1   google 0.10194467
##  2   search 0.05937512
##  3    phone 0.03947618
##  4      app 0.03716764
##  5   engine 0.03590006
##  6 facebook 0.03513982
##  7  android 0.03289141
##  8        q 0.02998453
##  9    gmail 0.02980005
## 10    using 0.02944208
## # ... with 68,654 more rows

Google itself is at the top of the list, which is something that often happens to me when I try this word vector arithmetic no matter how I train them (usually one of the positive vectors in the “equation”). Does anyone know what that means? Anyway, “search”, is next on the list.

mystery_product <- word_vectors["iphone",] - word_vectors["apple",] + word_vectors["amazon",]
search_synonyms(word_vectors, mystery_product)
## # A tibble: 68,664 x 2
##      token similarity
##      <chr>      <dbl>
##  1  amazon 0.04757609
##  2     aws 0.04389115
##  3      s3 0.03356846
##  4    book 0.03273032
##  5     ec2 0.03151250
##  6   cloud 0.03083856
##  7   books 0.03008677
##  8  iphone 0.02749843
##  9 storage 0.02549876
## 10  kindle 0.02505824
## # ... with 68,654 more rows

For Amazon, we get AWS, S3, and EC2, as well as book. Nice!

The End

I am so excited about this approach! Like Chris said in his blog post, for all the applications in the kind of work I do (non-academic, industry NLP) these type of word vectors will work great. No need for neural networks! This approach is still not lightning fast (I have to sit and wait for parts of it to run) but I can easily implement it with the tools I am familiar with. I would imagine there are vast swaths of data science practitioners for whom this is also true. I am considering the idea of bundling some of these types of functions up into an R package, and Dave has just built a pairwise_pmi() function in the development version of widyr that simplifies this approach even more. Tidy word vectors, perhaps? Maybe I’ll also look into the higher rank extension of this technique to get at word and document vectors!

Let me know if you have feedback or questions.



comments powered by Disqus