Sobes.tech
Back to questions
Middle
11

Что будет, если не вызвать model.eval() на этапе валидации?

Companies where asked
Магнит

Answer from AI

sobes.tech AI

Что хотят услышать интервьюеры:

Если не вызвать model.eval(), модель может вести себя как на обучении, а не как на инференсе. В PyTorch это особенно критично для Dropout и BatchNorm. В результате метрики валидации могут стать нестабильными или завышенными/заниженными, и сравнивать их с реальным качеством будет некорректно.

Определение:

model.eval() переводит модель в режим оценки, в котором слои с разным поведением на train и eval работают как для инференса. Само обучение при этом не прекращается, но меняется поведение некоторых модулей, например Dropout отключается, а BatchNorm использует накопленную статистику вместо статистики текущего батча. Если этот режим не включить, валидация фактически пройдет в train-режиме.

Пример использования:

Типичный сценарий — оценка качества модели на валидационной выборке после каждой эпохи.
Если забыть model.eval(), то Dropout будет случайно занулять активации, а BatchNorm — обновлять свои статистики на валидационных данных, что искажает результат.

import torch

model.train()
for x_train, y_train in train_loader:
    optimizer.zero_grad()
    pred = model(x_train)
    loss = criterion(pred, y_train)
    loss.backward()
    optimizer.step()

model.eval()
with torch.no_grad():
    total_loss = 0.0
    for x_val, y_val in val_loader:
        pred = model(x_val)
        loss = criterion(pred, y_val)
        total_loss += loss.item()

Пояснение кода:

  • model.train() включает режим обучения.
  • На обучении выполняются обычные шаги: forward, вычисление loss, backward(), optimizer.step().
  • model.eval() переключает модель в режим валидации/инференса.
  • with torch.no_grad() отключает вычисление градиентов, чтобы ускорить расчет и снизить потребление памяти.
  • Валидационный проход идет без обновления весов и без изменения статистик модели.

Если model.eval() убрать, то:

  • Dropout продолжит работать как на обучении и добавит случайность в предсказания.
  • BatchNorm может использовать статистику текущего батча и обновлять running mean/var.
  • Метрика на валидации станет менее воспроизводимой и может заметно отличаться от реального качества при деплое.

Ключевые моменты:

  • model.eval() нужен для переключения слоев в режим инференса, а не для остановки обучения.
  • Самый заметный эффект — у Dropout и BatchNorm.
  • Без eval() валидационные метрики могут быть шумными и некорректными.
  • model.eval() обычно используют вместе с torch.no_grad(), но это разные вещи.
  • После валидации, если обучение продолжается, модель нужно вернуть в model.train().