如何为Tensorflow中的未知单词添加新的嵌入(训练和预设测试)
nlp
python
tensorflow
5
0

我很好奇,当遇到预训练词汇表中未知的单词时,如何添加正态随机化的300维向量(元素类型= tf.float32)。我正在使用经过预先训练的GloVe词嵌入,但是在某些情况下,我意识到遇到了未知词,并且我想为这个新发现的未知词创建一个正规随机的词向量。

问题是在当前设置下,我使用tf.contrib.lookup.index_table_from_tensor根据已知的词汇将单词转换为整数。此函数可以创建新令牌,并为某些预定义数量的词汇表单词散列它们,但是我的embed将不包含此新的未知散列值的嵌入。我不确定是否可以简单地将随机嵌入添加到embed列表的末尾。

我也想以一种有效的方式做到这一点,因此预构建的tensorflow函数或涉及tensorflow函数的方法可能是最有效的。我定义了一些众所周知的特殊标记,例如句子结尾标记和默认未知标记(如空字符串)(位于索引0),但这在学习各种不同未知单词的能力上受到了限制。我目前使用tf.nn .embedding_lookup()作为最后的嵌入步骤。

我希望能够为训练数据中的每个未知单词添加新的随机300d向量,并且我还想为测试期间可能遇到的训练中未发现的任何未知标记添加预制的随机单词向量。最有效的方法是什么?

def embed_tensor(string_tensor, trainable=True):
    """    
    Convert List of strings into list of indicies then into 300d vectors
    """
    # ordered lists of vocab and corresponding (by index) 300d vector
    vocab, embed = load_pretrained_glove()

    # Set up tensorflow look up from string word to unique integer
    vocab_lookup = tf.contrib.lookup.index_table_from_tensor(
        mapping=tf.constant(vocab),
        default_value = 0)
    string_tensor = vocab_lookup.lookup(string_tensor)

    # define the word embedding 
    embedding_init = tf.Variable(tf.constant(np.asarray(embed),
                                 dtype=tf.float32),
                                 trainable=trainable,
                                 name="embed_init")

    # return the word embedded version of the sentence (300d vectors/word)
    return tf.nn.embedding_lookup(embedding_init, string_tensor)
参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

我从没有尝试过,但是我可以尝试提供一种使用您代码的相同机器的可能方法,但是以后会再考虑。

index_table_from_tensor方法接受一个num_oov_buckets参数,该参数将您所有的oov单词混洗到预定义数量的存储桶中。

如果将此参数设置为某个“足够大”的值,您将看到数据在这些存储区之间传播(每个存储区的ID>最后一个语音中单词的ID)。

所以,

  • 如果(在每次查找时)将embedding_init变量的最后一行(与存储桶相对应)设置(即assign )为随机值
  • 如果将num_oov_buckets得足够大,以使冲突最小化

您可以以非常有效的方式获得与您所要求的行为(近似)的行为。

可以通过类似于哈希表的理论来证明随机行为的合理性:如果存储桶的数量足够大,则字符串的哈希方法会以很高的概率将每个oov词分配给不同的存储桶(即,将相同的碰撞减到最少)桶)。由于您要为每个不同的存储桶分配不同的随机数,因此您可以获得(几乎)每个oov单词的映射。

收藏
评论

下面的代码示例适应您的embed_tensor函数,以便将单词嵌入如下:

  • 对于具有预训练的嵌入的单词,使用预训练的嵌入来初始化嵌入。如果trainableFalse则可以在训练期间将嵌入保持固定。
  • 对于训练数据中没有预训练嵌入的单词,将随机初始化嵌入。如果trainableFalse则可以在训练期间将嵌入保持固定。
  • 对于测试数据中未出现且未进行预训练嵌入的单词,将使用单个随机初始化的嵌入向量。此向量无法训练。
import tensorflow as tf
import numpy as np

EMB_DIM = 300
def load_pretrained_glove():
    return ["a", "cat", "sat", "on", "the", "mat"], np.random.rand(6, EMB_DIM)

def get_train_vocab():
    return ["a", "dog", "sat", "on", "the", "mat"]

def embed_tensor(string_tensor, trainable=True):
  """
  Convert List of strings into list of indices then into 300d vectors
  """
  # ordered lists of vocab and corresponding (by index) 300d vector
  pretrained_vocab, pretrained_embs = load_pretrained_glove()
  train_vocab = get_train_vocab()
  only_in_train = list(set(train_vocab) - set(pretrained_vocab))
  vocab = pretrained_vocab + only_in_train

  # Set up tensorflow look up from string word to unique integer
  vocab_lookup = tf.contrib.lookup.index_table_from_tensor(
    mapping=tf.constant(vocab),
    default_value=len(vocab))
  string_tensor = vocab_lookup.lookup(string_tensor)

  # define the word embedding
  pretrained_embs = tf.get_variable(
      name="embs_pretrained",
      initializer=tf.constant_initializer(np.asarray(pretrained_embs), dtype=tf.float32),
      shape=pretrained_embs.shape,
      trainable=trainable)
  train_embeddings = tf.get_variable(
      name="embs_only_in_train",
      shape=[len(only_in_train), EMB_DIM],
      initializer=tf.random_uniform_initializer(-0.04, 0.04),
      trainable=trainable)
  unk_embedding = tf.get_variable(
      name="unk_embedding",
      shape=[1, EMB_DIM],
      initializer=tf.random_uniform_initializer(-0.04, 0.04),
      trainable=False)

  embeddings = tf.concat([pretrained_embs, train_embeddings, unk_embedding], axis=0)

  return tf.nn.embedding_lookup(embeddings, string_tensor)

仅供参考,为了对不在训练数据中且没有预先训练的单词提供明智的,非随机的表示,您可以考虑将训练数据中频率较低的单词映射到无效令牌 (不在您的词汇范围之内),并使unk_embedding可训练。这样,您将为训练数据中看不到的单词学习原型。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号