Sobes.tech
Back to questions
Middle — Senior
7

Что такое weight tying между embedding и output projection?

Companies where asked
SberDevices

Answer from AI

sobes.tech AI

Что хотят услышать интервьюеры:

Weight tying — это когда матрица эмбеддингов входных токенов и матрица выходной проекции в словарь используют одни и те же веса. Это уменьшает число параметров и часто улучшает качество, потому что модель учится в одном общем пространстве представлений. В трансформерах и языковых моделях это распространённый приём, особенно при большом словаре.

Определение:

Weight tying между embedding и output projection — это связывание весов так, что один и тот же параметр используется и для преобразования токена в вектор, и для вычисления логитов по словарю на выходе.

Обычно:

  • Embedding переводит индекс токена в dense-вектор.
  • Output projection преобразует hidden state модели в размер словаря.
  • При weight tying матрица выходного слоя берётся как транспонированная или эквивалентно связанная версия матрицы embedding.

Идея в том, что входное и выходное представление токенов должны жить в согласованном пространстве.

Пример использования:

В языковой модели для предсказания следующего слова:

  1. Токен cat берётся из словаря и превращается в embedding.
  2. Этот embedding проходит через Transformer.
  3. Последний hidden state умножается на ту же матрицу, которая использовалась для embedding, чтобы получить logits по всему словарю.
import torch
import torch.nn as nn

class TinyLM(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.TransformerEncoderLayer(d_model=d_model, nhead=4)
        self.output = nn.Linear(d_model, vocab_size, bias=False)

        # weight tying
        self.output.weight = self.embedding.weight

    def forward(self, x):
        # x: [batch, seq]
        emb = self.embedding(x)              # [batch, seq, d_model]
        h = self.transformer(emb.transpose(0, 1)).transpose(0, 1)
        logits = self.output(h)              # [batch, seq, vocab_size]
        return logits

Пояснение кода:

В этом примере:

  • self.embedding.weight хранит матрицу размера [vocab_size, d_model].
  • self.output обычно тоже имеет вес размера [vocab_size, d_model].
  • Строка self.output.weight = self.embedding.weight делает веса общими.

Что происходит по шагам:

  1. Входной токен получает embedding из общей матрицы.
  2. Модель строит скрытое представление токена.
  3. На выходе это представление сравнивается со всеми токенами словаря через ту же матрицу весов.
  4. В результате модель не учит две независимые матрицы, а оптимизирует одну общую.

Важно: на практике иногда дополнительно используют отдельный bias в выходном слое, но сами веса проекции остаются связанными.

Ключевые моменты:

  • Weight tying снижает число параметров модели.
  • Часто улучшает обобщение и стабильность обучения.
  • Особенно полезен в NLP-моделях с большим словарём.
  • Входное embedding и выходная проекция должны иметь совместимые размерности.
  • Это не просто оптимизация памяти, а ещё и способ навязать полезную индуктивную связь между входом и выходом.
  • В простейшем виде выходной слой использует ту же матрицу, что и embedding, иногда с транспонированием.