Sobes.tech
Назад к вопросам
Middle — Senior
10

Как файнтюнить BERT под задачу классификации? Что брать из выхода?

Компании, где спрашивали
ПАО Газпром нефть

Ответ от нейросети

sobes.tech AI

Что хотят услышать интервьюеры:

Для классификации BERT обычно дообучают end-to-end, добавляя поверх выходов модели небольшой классификационный head. Для задачи из одного текста чаще всего берут pooled representation специального токена [CLS] и подают его в линейный слой. Если нужен более качественный результат, важно правильно подготовить данные, выбрать loss и контролировать переобучение.

Определение:

Fine-tuning BERT для классификации — это дообучение предобученной языковой модели на размеченной выборке под конкретные классы. Обычно к BERT добавляют head: dropout + linear layer, а обучение идет по целевой функции классификации, например cross-entropy.

Что брать из выхода BERT:

  • для single-sentence / sentence-pair classification чаще всего используют embedding токена [CLS] как представление всего входа;
  • иногда вместо [CLS] усредняют hidden states всех токенов, если это дает лучший результат на конкретной задаче;
  • для token classification берут не [CLS], а выходы по каждому токену.

Пример использования:

Допустим, нужно классифицировать отзывы как позитивные или негативные. Берется предобученный BERT, на выходе для [CLS] ставится линейный слой на 2 класса, и вся модель дообучается на размеченных отзывах.

import torch
import torch.nn as nn
from transformers import BertModel

class BertForClassification(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_classes=2, dropout=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # [CLS]
        x = self.dropout(cls_embedding)
        logits = self.classifier(x)
        return logits

В этом примере outputs.last_hidden_state[:, 0, :] — это вектор первого токена [CLS], который используют как агрегированное представление всей последовательности.

Пояснение кода:

  1. Загружается предобученный BertModel.
  2. Сверху добавляется dropout для регуляризации и linear layer для предсказания класса.
  3. На вход подаются input_ids и attention_mask.
  4. Из выхода BERT берется первый токен последовательности, то есть [CLS].
  5. Этот вектор проходит через dropout и классификатор.
  6. Во время обучения обычно считаются logits, а затем применяется CrossEntropyLoss.

Ключевые моменты:

  • Для классификации BERT чаще всего дообучают целиком, а не только head.
  • Самый распространенный вариант входа в классификатор — вектор [CLS].
  • Альтернатива [CLS] — pooling по токенам, если на задаче это дает лучший quality.
  • Для нескольких классов обычно используют CrossEntropyLoss, для бинарной классификации — тоже часто через 2 logits или через 1 logit с sigmoid.
  • Важно использовать attention_mask, чтобы padding не влиял на представление.
  • При fine-tuning полезны маленький learning rate, dropout и ранняя остановка, чтобы не переобучиться.