Что будет, если не вызвать model.eval() на этапе валидации?
Answer from AI
sobes.tech AI
Что хотят услышать интервьюеры:
Если не вызвать model.eval(), модель может вести себя как на обучении, а не как на инференсе. В PyTorch это особенно критично для Dropout и BatchNorm. В результате метрики валидации могут стать нестабильными или завышенными/заниженными, и сравнивать их с реальным качеством будет некорректно.
Определение:
model.eval() переводит модель в режим оценки, в котором слои с разным поведением на train и eval работают как для инференса. Само обучение при этом не прекращается, но меняется поведение некоторых модулей, например Dropout отключается, а BatchNorm использует накопленную статистику вместо статистики текущего батча. Если этот режим не включить, валидация фактически пройдет в train-режиме.
Пример использования:
Типичный сценарий — оценка качества модели на валидационной выборке после каждой эпохи.
Если забыть model.eval(), то Dropout будет случайно занулять активации, а BatchNorm — обновлять свои статистики на валидационных данных, что искажает результат.
import torch
model.train()
for x_train, y_train in train_loader:
optimizer.zero_grad()
pred = model(x_train)
loss = criterion(pred, y_train)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
total_loss = 0.0
for x_val, y_val in val_loader:
pred = model(x_val)
loss = criterion(pred, y_val)
total_loss += loss.item()
Пояснение кода:
model.train()включает режим обучения.- На обучении выполняются обычные шаги: forward, вычисление loss,
backward(),optimizer.step(). model.eval()переключает модель в режим валидации/инференса.with torch.no_grad()отключает вычисление градиентов, чтобы ускорить расчет и снизить потребление памяти.- Валидационный проход идет без обновления весов и без изменения статистик модели.
Если model.eval() убрать, то:
Dropoutпродолжит работать как на обучении и добавит случайность в предсказания.BatchNormможет использовать статистику текущего батча и обновлять running mean/var.- Метрика на валидации станет менее воспроизводимой и может заметно отличаться от реального качества при деплое.
Ключевые моменты:
model.eval()нужен для переключения слоев в режим инференса, а не для остановки обучения.- Самый заметный эффект — у
DropoutиBatchNorm. - Без
eval()валидационные метрики могут быть шумными и некорректными. model.eval()обычно используют вместе сtorch.no_grad(), но это разные вещи.- После валидации, если обучение продолжается, модель нужно вернуть в
model.train().