Adversarial Training of Word2Vec for Basket Completion

05/22/2018 ∙ by Ugo Tanielian, et al. ∙ Criteo 0

In recent years, the Word2Vec model trained with the Negative Sampling loss function has shown state-of-the-art results in a number of machine learning tasks, including language modeling tasks, such as word analogy and word similarity, and in recommendation tasks, through Prod2Vec, an extension that applies to modeling user shopping activity and user preferences. Several methods that aim to improve upon the standard Negative Sampling loss have been proposed. In our paper we pursue more sophisticated Negative Sampling, by leveraging ideas from the field of Generative Adversarial Networks (GANs), and propose Adversarial Negative Sampling. We build upon the recent progress made in stabilizing the training objective of GANs in the discrete data setting, and introduce a new GAN-Word2Vec model.We evaluate our model on the task of basket completion, and show significant improvements in performance over Word2Vec trained using standard loss functions, including Noise Contrastive Estimation and Negative Sampling.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1. Introduction

The recommendation task of basket completion is a key part of many online retail applications. Basket completion involves computing predictions for the next item that should be added to a shopping basket, given a collection of items that the user has already added to the basket.

In this context of basket completion, learning item embedding representations can lead to state-of-the-art results, as shown in (Zaheer et al., 2017). Within this class of approaches, Word2Vec (Mikolov et al., 2013a), and its item-based extension Prod2Vec (Grbovic et al., 2015), have become the de-facto standard approach, due to the conceptual simplicity, implementation simplicity, and state-of-the-art performance of these models.

In terms of training and the use of negatives, there have been many extensions of the classical Word2Vec model based on the Negative Sampling (NS) loss function (Mikolov et al., 2013b), such as Swivel (Shazeer et al., 2016). However, these approaches do not have a dynamic way of sampling the most informative negatives. This shortcoming was addressed by (Chen et al., 2018)

, which proposes an active sampling heuristic. In our paper we propose

GAN-Word2Vec, an extension of Word2Vec that uses ideas from Generative Adversarial Networks (GANs) to create an adversarial negative sampling approach that places dynamic negative sampling on firmer theoretical grounds, and shows significant improvements in performance over the classical training approaches. In our current structure the generator is also trained adversarially and benefits from a better training signal coming from the discriminator. In terms of training stability, which becomes an issue in GAN-like settings, our algorithm builds upon recent advances that make GAN training stable for discrete input data. We evaluate the performance of our GAN-Word2Vec model on a basket-completion task and show that it outperforms classical supervised approaches such as Word2Vec with Negative Sampling by a significant margin.

Overall, the main contributions of this paper are the following:

  • We propose a new dynamic negative sampling scheme for Word2Vec based on ideas from GANs. To the best of our knowledge, we are the first to implement adversarial training for Word2Vec.

  • We introduce a stable training algorithm that implements our adversarial sampling scheme.

  • Through an experimental evaluation on two real-world datasets, we show that our GAN-Word2Vec model outperforms classical sampling schemes for Word2Vec.

We briefly discuss related work on sampling schemes for Word2Vec and the recent developments on GAN training in discrete settings in Section 2 of this paper. In Section 3, we formally introduce our GAN-Word2Vec model and describe the training algorithm. We highlight the performance of our method in Section 4, and conclude with main ideas and directions for future work in Section 5.

2. Related Work

2.1. Basket Completion with Embedding Representations

In the recent years, a substantial amount of work has focused on improving the performance of language modeling and text generation. In both tasks, Deep Neural Networks (DNNs) have proven to be extremely effective, and are now considered state-of-the-art. In this paper, we focus on the task of basket completion with learned embedding representations. For the task of basket completion, very little work has focused on applying DNNs, with the notable exception of 

(Zaheer et al., 2017). In this paper, the authors introduced a new family of networks designed to work on sets, where the output is invariant to any permutation in the order of objects in the input set. As an alternative to a set-based interpretation, basket completion can be approached as a sequence generation task, where one seeks to predict the distribution of the next item conditioned on the items already present in the basket. First, one could use the Skip-Gram model proposed by (Mikolov et al., 2013b)

, and use the average of the embeddings for the items within a basket to compute the next-item prediction for the basket. Second, building upon work on models for text generation, it is natural to leverage Recurrent Neural Networks (RNNs), particularly Long Short Term Memory cells, as proposed by 

(Hochreiter and Schmidhuber, 1997), or bi-directional LSTMs (Graves et al., 2005; Schuster and Paliwal, 1997)

. Convolutional neural networks could also be used, for example by using a

Text-CNN like architecture as proposed by (Kim, 2014). The authors of (Kim, 2014) empirically show on different tasks, such as point cloud classification, that their method outperforms state-of-the-art results with good generalization properties.

2.2. Negative sampling schemes for Word2Vec

When dealing with the training of multi-class models with thousands or millions of output classes, candidate sampling algorithms can speed up the training by considering a small randomly-chosen subset of contrastive candidates for each batch of training examples. Ref. (Gutmann and Hyvärinen, 2010)

introduced Noise Contrastive Estimation (NCE) as an unbiased estimator of the softmax loss, and has been proven to be efficient for learning word embeddings 

(Mnih and Teh, 2012). In  (Mikolov et al., 2013b), the authors propose negative sampling and directly sample candidates from a noise distribution.

More recently, in (Chen et al., 2018), the authors provide an insightful analysis of negative sampling. They show that negative samples with high inner product scores with a context word are more informative in terms of gradients on the loss function. Leveraging this analysis, the authors propose a dynamic sampling method based on inner-product rankings. This result can be intuitively interpreted by seeing that negative samples with a higher inner product will lead to a better approximation of the softmax.

In our setup, we simultaneously train two neural networks, and use the output distribution coming from one network to generate the negative samples for the second network. This adversarial negative sampling proves to be a dynamic and efficient way to improve the training of our system. This method echoes the recent work on Generative Adversarial Networks.

2.3. GANs

First proposed by (Goodfellow et al., 2014) in 2014, Generative Adversarial Networks (GANs) have been quite successful at generating realistic images. GANs can be viewed as a framework for training generative models by posing the training procedure as a minimax game between the generative and discriminative models. This adversarial learning framework bypasses many of the difficulties associated with maximum likelihood estimation (MLE), and has demonstrated impressive results in natural image generation tasks.

Theoretical work   (Liu et al., 2017; Zhang et al., 2018; Arjovsky and Bottou, 2017; Nowozin et al., 2016) and practical studies (Salimans et al., 2016; Arjovsky et al., 2017; Dziugaite et al., 2015)

have stabilized GAN training and enabled many improvements in continuous spaces. However, in discrete settings a number of important limitations have prevented major breakthroughs. A major issue involves the complexity of backpropagating the gradients for the generative model. To bypass this differentiation problem, 

(Yu et al., 2016; Che et al., 2017)

proposed a cost function inspired by reinforcement learning. The discriminator is used as a reward function and the generator is trained via policy gradient 

(Sutton et al., 2000)

. Additionally, the authors propose the use of importance sampling and variance reduction techniques to stabilize the training.

3. Our Model: GAN-Word2Vec

We begin this section by formally defining the basket completion task. We denote as all of the potential items, products, or words. is the space of all baskets of size , and is the space of all possible baskets of any size.

The objective of GANs is to generate candidate, or synthetic, samples from a target data distribution , from which only true samples – in our case baskets of products – are available. Therefore, the collection of true baskets is a subset of . We are given each basket , and an element randomly selected from . Knowing , which denotes basket with item removed, we want to be able to predict the missing item . We denote

as the space of probability distributions defined on


As in (Mikolov et al., 2013b), we are working with embeddings. For a given target item , we denote as its context, where . The embeddings of and are and , respectively.

3.1. Notation

Both generators and discriminators as a whole have the form of a parametric family of functions from to , where is the length of the prefix basket. We denote as the family of generator functions, with parameters , and as the family of discriminator functions, with parameters . Each function and is intended to be applied to a -long random input basket. Both the generator and the discriminator output probabilities that are conditioned on an input basket. In our setting, and may therefore be from the same class of models, and may be parametrized by the same class of functions. Given the context , we denote and as the conditional probabilities output by and , respectively. The samples drawn from are denoted as .

3.2. Adversarial Negative Sampling Loss

In the usual GAN setting, is trained with a binary cross-entropy loss function, with positives coming from and negatives generated by . In our setting, we modify the discriminator’s loss using an extension of the standard approach to Negative Sampling, which has proven to be very efficient for language modeling. We define Adversarial Negative Sampling Loss by the following objective:


where is the number of negatives sampled from .

As in the standard GAN setting, D’s task is to learn to discriminate between true samples and synthetic samples coming from . Compared to standard negative sampling, the main difference is that the negatives are now drawn from a dynamic distribution that is by design more informative than the fixed distribution used in standard negative sampling. Ref. (Chen et al., 2018) proposes a dynamic sampling strategy based on self-embedded features, but our approach is a fully adversarial sampling method that is not based on heuristics.

3.3. Training the Generator

In discrete settings, training using the standard GAN architecture is not feasible, due to discontinuities in the space of discrete data that prevent updates of the generator. To address this issue, we sample potential outcomes from , and use as a reward function on these outcomes.

In our case, the training of the generator has been inspired by (Che et al., 2017). In this paper, the authors define two main loss functions for training the generator. The first loss is called the basic MALIGAN and only uses signals from . Adapted to our setting, we have the following formulation for the generator’s loss:



  • . That is, we draw negative samples in the form of “next-item” samples, where only the missing element for each basket is sampled. This missing element is sampled by conditioning on the context, , for .

  • . This term allows us to incorporate reward from the discriminator into updates for the generator. The expression for comes from a property of the optimal for a GAN; a full explanation is provided in (Che et al., 2017).

  • is the number of negatives sampled from used to compute the gradients

Unlike (Che et al., 2017), we do not use any parameter in Eq. 2, as we did not observe the need for further variance reduction provided by this term in our applications. As both models output probability distributions, computing and is straightforward.

The second loss, Mixed MLE-MALIGAN, mixes adversarial loss and the standard maximum likelihood estimate (MLE) loss. In our case, this loss mixes the adversarial training loss and a standard sampled softmax loss (negative sampling loss):


where are negatives uniformly sampled among the potential next items. We empirically find that this mixed loss provides more stable gradients than the loss in Eq. 2, leading to faster convergence during training.

A description of our algorithm can be found in Algorithm 1. We pre-train both and using a standard negative sampling loss before training these components adversarially. We empirically show improvements with this procedure in the following section.

0:  generator policy ; discriminator ; a basket dataset
  Initialize , with random weights .
  Pre-train and using sampled softmax on
     for g-steps do
         Get random true sub-baskets and the targets .
         Generate negatives by sampling from
         Update generator parameters via policy gradient Eq. (3.3)
     end for
     for d-steps do
         Get random true sub-baskets and the targets .
         Generate adversarial samples for from
         Train discriminator by Eq. (1)
     end for
  until GAN-Word2Vec converges
Algorithm 1 GAN-Word2Vec

4. Experiments

All our experiments have been ran on the task basket completion, which is a well-known Recommendation task.

4.1. Datasets

In (Gartrell et al., 2016) and (Gartrell et al., 2017), the authors present state-of-the-art results on basket completion datasets. We performed our experiments on two of the datasets used in this prior work: the Amazon Baby Registries and the Belgian Retail datasets.

  • This public dataset consists of registries of baby products from 15 different categories (such as ’feeding’, ’diapers’, ’toys’, etc.), where the item catalog and registries for each category are disjoint. Each category therefore provides a small dataset, with a maximum of 15,000 purchased baskets per category. We use a random split of 80% of the data for training and 20% for testing.

  • Belgian Retail Supermarket - This is a public dataset composed of shopping baskets purchased over three non-consecutive time periods from a Belgian retail supermarket store. There are 88,163 purchased baskets, with a catalog of 16,470 unique items. We use a random split of 80% of the data for training and 20% for testing.

4.2. Task definition and associated metrics

In the following evaluation, we consider two metrics:

  • Mean Percentile Rank (MPR) - For a basket and one item randomly removed from this basket, we rank all potential items from set of candidates according to their probabilities of completing , which are and . The Percentile Rank (PR) of the missing item is defined by:

    where is the indicator function and is the number of items in the candidate set. The Mean Percentile Rank (MPR) is the average PR of all the instances in the test-set .

    MPR = 100 always places the held-out item for the test instance at the head of the ranked list of predictions, while MPR = 50 is equivalent to random selection.

  • Precision@k - We define this metric as

    where is the predicted rank of the held-out item for test instance . In other words, precision@k is the fraction of instances in the test set for which the predicted rank of the held-out item falls within the top predictions.

4.3. Experimental results

We compare our GAN-Word2Vec model with Word2Vec models training using classical loss funcitions, including Noise Contrastive Estimation Loss (NCE) (Gutmann and Hyvärinen, 2010; Mnih and Teh, 2012) and Negative Sampling Loss (NEG)  (Mikolov et al., 2013b). We observe that we have better results with the Mixed Loss.

We find that pre-training both and with a Negative Sampling Loss leads to better predictive quality for GAN-Word2Vec.

After pre-training, we train and using Eq. 3.3 and Eq. 1, respectively. We observe that the discriminator initially benefits from adversarial sampling, and its performance on both MPR and precision@1 increases. However, after convergence, the generator ultimately provides better performance than the discriminator on both metrics. We conjecture that this may be explained by the fact that basket completion is a generative task.

From Table 1, we see that our GAN-Word2Vec model consistently provides statistically-significant improvements over the Word2Vec baseline models on both the Precision@1 and MPR metrics. As confirmed by the experiments, we expect our method to be more effective on larger datasets.

We also see that Word2Vec trained using Negative Sampling (W2V-NEG) is generally a stronger baseline than Word2Vec trained via NCE.

Method Precision@1 MPR
Amazon dataset
W2V-NCE 14.80 80.15
W2V-NEG 15.40 80.20
W2V-GANs 16.30 80.50
Belgian retail dataset
W2V-NCE 29.50 87.54
W2V-NEG 34.35 88.55
W2V-GANs 35.82 89.45
Table 1. One item basket completion task on the Belgian retail dataset.

5. Conclusions

In this paper, we have proposed a new adversarial negative sampling algorithm suitable for models such as Word2Vec. Based on recent progress made on GANs in discrete data settings, our solution eliminates much of the complexity of implementing a generative adversarial structure for such models. In particular, our adversarial training approach can be easily applied to models that use standard sampled softmax training, where the generator and discriminator can be of the same family of models.

Regarding future work, we plan to investigate the effectiveness of this training procedure on other models. It is possible that models with more capacity than Word2Vec could benefit even more from using softmax with the adversarial negative sampling loss structure that we have proposed. Therefore, we plan to test this procedure on models such as TextCNN, RNNs, and determinantal point processes (DPPs) (Gartrell et al., 2016, 2017), which are known to be effective in modeling discrete set structures.

GANs have proven to be quite effective in conjunction with deep neural networks when applied to image generation. In this work, we have showed that adversarial training can also be applied to simpler models, in discrete settings, and bring statistically significant improvements in predictive quality.