Как файнтюнить BERT под задачу классификации? Что брать из выхода?
Answer from AI
sobes.tech AI
Что хотят услышать интервьюеры:
Для классификации BERT обычно дообучают end-to-end, добавляя поверх выходов модели небольшой классификационный head. Для задачи из одного текста чаще всего берут pooled representation специального токена [CLS] и подают его в линейный слой. Если нужен более качественный результат, важно правильно подготовить данные, выбрать loss и контролировать переобучение.
Определение:
Fine-tuning BERT для классификации — это дообучение предобученной языковой модели на размеченной выборке под конкретные классы. Обычно к BERT добавляют head: dropout + linear layer, а обучение идет по целевой функции классификации, например cross-entropy.
Что брать из выхода BERT:
- для single-sentence / sentence-pair classification чаще всего используют embedding токена
[CLS]как представление всего входа; - иногда вместо
[CLS]усредняют hidden states всех токенов, если это дает лучший результат на конкретной задаче; - для token classification берут не
[CLS], а выходы по каждому токену.
Пример использования:
Допустим, нужно классифицировать отзывы как позитивные или негативные. Берется предобученный BERT, на выходе для [CLS] ставится линейный слой на 2 класса, и вся модель дообучается на размеченных отзывах.
import torch
import torch.nn as nn
from transformers import BertModel
class BertForClassification(nn.Module):
def __init__(self, model_name="bert-base-uncased", num_classes=2, dropout=0.1):
super().__init__()
self.bert = BertModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embedding = outputs.last_hidden_state[:, 0, :] # [CLS]
x = self.dropout(cls_embedding)
logits = self.classifier(x)
return logits
В этом примере outputs.last_hidden_state[:, 0, :] — это вектор первого токена [CLS], который используют как агрегированное представление всей последовательности.
Пояснение кода:
- Загружается предобученный
BertModel. - Сверху добавляется dropout для регуляризации и linear layer для предсказания класса.
- На вход подаются
input_idsиattention_mask. - Из выхода BERT берется первый токен последовательности, то есть
[CLS]. - Этот вектор проходит через dropout и классификатор.
- Во время обучения обычно считаются
logits, а затем применяетсяCrossEntropyLoss.
Ключевые моменты:
- Для классификации BERT чаще всего дообучают целиком, а не только head.
- Самый распространенный вариант входа в классификатор — вектор
[CLS]. - Альтернатива
[CLS]— pooling по токенам, если на задаче это дает лучший quality. - Для нескольких классов обычно используют
CrossEntropyLoss, для бинарной классификации — тоже часто через 2 logits или через 1 logit с sigmoid. - Важно использовать
attention_mask, чтобы padding не влиял на представление. - При fine-tuning полезны маленький learning rate, dropout и ранняя остановка, чтобы не переобучиться.