Обучение на реальных данных
В реальных задачах, особенно если вы работаете над новой проблемой, вы столкнетесь с широким спектром трудностей. Приведем часть из них и сопроводим примером на основе задач нахождения клеток крови на фотографии мазка крови.
нехватка данных — фотографий мазков крови может быть недостаточно для построения сложной модели с нуля
недостаток размеченных данных — возможно, существует достаточно большое количество фотографий мазков крови (например, в историях болезни), но очень малая часть из них размечена
некачественная разметка — мазок крови могли доверить анализировать студенту-практиканту. Размечать его мог вообще человек не из профессии — например, хотевший таким образом увеличить обучающую выборку для модели на конкурс Kaggle. Даже в широко известных MNIST, CIFAR-10 и ImageNet есть ошибки в разметке (примеры)
Source: Label Errors in ML Test Sets
Серповидная клеточная анемия приводит к аномальным эритроцитам
Source: Wikipedia
А так выглядит мазок крови при сонной болезни
Source: Wikipedia
Модель учится сопоставлять целевые значения признакам. В такой ситуации модель не в состоянии делать адекватные предсказания на тесте, так как во время обучения она не видела области пространства, в которой расположены тестовые объекты. Источники ошибок, приводящих к ковариантному сдвигу, обсуждались ранее в лекции №7).
Практический совет: для быстрого обнаружения ковариантного сдвига можно обучить модель, которая будет предсказывать, относится ли объект к train или test выборке. Если модель легко делит данные, то имеет смысл визуализировать значения признаков, по которым она это делает.
полные дубликаты — в данных могут быть полные дубликаты. Кто-то до вас агрегировал фотографии из разных источников, и либо вы не обратили на это внимание, либо он забыл об этом сказать. Такие данные надо сразу помечать и использовать только после предварительного размышления, т.к. они могут мешать вам и на этапе обучения модели, и на итоговой валидации ее качества (если один и тот же объект попадет и в обучение, и в валидацию).
неполные дубликаты — в данных могут быть данные от одного и того же пациента. Кажется, что если это разные мазки крови, то всё нормально. На самом деле и это уменьшает количество информации, которую может извлечь нейросеть из данных. С такими данными также нужно аккуратно работать и не допускать попадания одного пациента и в обучение, и в тест.
малое число источников данных — проблема, родственная предыдущей. В вашем датасете могут быть данные только от одного микроскопа или одной модели микроскопа. Могут быть данные, снятые только одним специалистом, или в одной больнице, или только у взрослых (фотографий мазков детей нет). Это также может влиять на способность вашего алгоритма обобщать полученное решение и требует пристального внимания.
Все это приводит к целому спектру проблем, из которых самой типичной будет переобучение модели — какую бы простую модель вы не взяли, она все равно будет выучивать искажения вашего датасета.
"Если у вас мало данных, попробуйте найти еще данные для вашей задачи" — совет приводится во многих инструкциях по борьбе с малым количеством данных и может быть воспринят с юмором. Однако часто для вашей задачи действительно существуют данные, собранные другими людьми. Также часто можно найти данные, которые очень похожи на ваши, и их можно использовать в обучении, но, например, учитывать с меньшим весом. Даже 20 дополнительных примеров могут сильно облегчить ситуацию.
Вы также можете использовать данные, которые не совсем похожи на ваши, в качестве внешней валидации. Тем самым вы можете разбивать свой изначальный датасет только для кросс-валидации и не выделять отдельную часть для теста. Тестом послужат как раз "не совсем" похожие данные.
Также можно использовать две техники, которые мы сегодня рассмотрим подробнее: Аугментация и Transfer Learning.
Прежде всего надо убедиться, что датасет сбалансирован:
import torch
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
def show_class_balance(y, classes):
_, counts = torch.unique(torch.tensor(y), return_counts=True)
plt.bar(classes, counts)
plt.ylabel("n_samples")
plt.ylim([0, 75])
plt.show()
wine = load_wine()
classes = wine.target_names
show_class_balance(wine.target, classes)
Разница в 10–20% будет незначительна, поэтому для наглядности мы искусственно разбалансируем наш датасет при помощи метода make_imbalance
from imblearn.datasets import make_imbalance
x, y = make_imbalance(
wine.data, wine.target, sampling_strategy={0: 10, 1: 70, 2: 40}, random_state=42
)
show_class_balance(y, classes)
Если в данных недостаток именно конкретного класса, то можно бороться с этим при помощи разных способов сэмплирования.
Важно понимать, что в большинстве случаев данные, полученные таким способом, должны использоваться в качестве обучающего набора, но ни в коем случае не в качестве валидации или теста.
Мы можем увеличить число объектов меньшего класса за счет дублирования.
В этом случае наша модель будет "вынуждена" обращать внимание на минорный класс.
Такой Resampling
может быть выполнен с помощью пакета imbalanced-learn, как показано ниже:
from imblearn.over_sampling import RandomOverSampler
ros = RandomOverSampler(random_state=0)
x_ros, y_ros = ros.fit_resample(x, y)
show_class_balance(y_ros, classes)
Аналогично, можно взять для обучения не всех представителей большего класса.
Это также вынуждает модель обращать внимание на оба класса. Минус подхода очевиден: мы можем выбросить важных представителей большего класса, ответственных за существенное улучшение генерализации, и из-за этого качество модели существенно ухудшится. Можно пытаться выбрасывать объекты большего класса как-то по-умному. Например, кластеризовать объекты большего класса и брать по заданному количеству объектов из каждого класса.
from imblearn.under_sampling import RandomUnderSampler
rus = RandomUnderSampler(random_state=42)
x_res, y_res = rus.fit_resample(x, y)
show_class_balance(y_res, classes)
Можно использовать ансамбли вместе с undersampling. В этом случае мы можем, к примеру, делать сэмплирование только большего класса, а объекты минорного класса оставлять как есть.
Или просто сэмплировать объекты и того, и другого класса в равном количестве.
В случае нейросетей можно балансировать встречаемость каждого класса не на уровне датасета, а на уровне батча. Например, собирать каждый батч таким образом, чтобы в нем было поровну всех классов.
Это может улучшать сходимость даже в случае небольшого дисбаланса или его отсутствия, т.к. мы будем избегать шаги обучения нейросети, в которых она просто не увидела какого-то класса в силу случайных причин.
В PyTorch эту функциональность можно получить, используя класс WeightedRandomSampler . Для его инициализации требуется рассчитать вес каждого класса. Сумма весов не обязана быть равной единице.
# https://pytorch.org/docs/stable/generated/torch.unique.html
_, counts = torch.unique(torch.tensor(y), return_counts=True)
weights = counts.max() / counts
print("Classes: ", classes)
print("Weights: ", weights)
Classes: ['class_0' 'class_1' 'class_2'] Weights: tensor([7.0000, 1.0000, 1.7500])
Теперь создаем объект WeightedRandomSampler, в конструктор подаем два аргумента:
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
tensor_x = torch.Tensor(x) # transform to torch tensor
tensor_y = torch.Tensor(y)
dataset = TensorDataset(tensor_x, tensor_y)
batch_size = 8
weight_for_sample = [] # Every sample must have a weight
for l in y:
weight_for_sample.append(weights[l].item())
sampler = WeightedRandomSampler(torch.tensor(weight_for_sample), len(dataset))
loader = DataLoader(dataset, batch_size=32, drop_last=True, sampler=sampler)
Посмотрим на распределение элементов разных классов по батчам.
batch_labels = []
for data, labels in loader:
print(
"Labels:",
labels.int().tolist(),
"Classes in batch:",
torch.unique(labels, return_counts=True)[1].tolist(),
)
batch_labels.append(labels.tolist())
show_class_balance(batch_labels, classes)
Labels: [1, 0, 2, 2, 1, 2, 0, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 0, 1, 2, 0, 2, 1, 2, 2, 1, 0, 0, 0, 1, 2] Classes in batch: [7, 12, 13] Labels: [1, 0, 1, 0, 0, 0, 1, 0, 2, 0, 0, 1, 1, 1, 0, 1, 1, 0, 2, 0, 1, 0, 0, 2, 2, 2, 0, 1, 1, 1, 1, 1] Classes in batch: [13, 14, 5] Labels: [0, 1, 1, 2, 1, 0, 2, 1, 0, 2, 1, 1, 1, 2, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 1, 1, 2, 1, 0, 0, 2] Classes in batch: [7, 14, 11]
Результаты будут отличаться от запуска к запуску. Но видно, что в батчах объекты каждого класса встречаются почти равномерно.
Стоит отметить, что нужно быть осторожным со взвешиванием объектов в батчах и контролировать состав батчей. Дело в том, что при существенном дисбалансе веса при объектах минорного класса могут оказываться на несколько порядков больше, чем при объектах мажорного класса. Данные веса преобразуются в вероятности для сэмплирования, и может случиться так, что вероятности при объектах мажорного класса станут численно неотличимы от нуля. Тем самым можно получить обратный эффект: батчи будут состоять исключительно из объектов минорного класса. В таком случае нужно намеренно ограничивать веса.
Другой подход к решению этой проблемы — создание синтетических данных. Делать это можно по-разному.
Synthetic Minority Over-sampling Technique (SMOTE) позволяет генерировать синтетические данные за счет реальных объектов из минорного класса.
Алгоритм работает следующим образом:
Число соседей, как и число раз, которое мы запускаем описанную выше процедуру, можно регулировать.
from imblearn.over_sampling import SMOTE
oversample = SMOTE()
x_smote, y_smote = oversample.fit_resample(x, y)
show_class_balance(y_smote, classes)
Количество объектов каждого класса, которое должно получиться после генерации, можно задать явно:
over = SMOTE(sampling_strategy={0: 20, 1: 70, 2: 70})
x_smote, y_smote = over.fit_resample(x, y)
show_class_balance(y_smote, classes)
Подробнее про использование пакета можно прочесть в статье: SMOTE for Imbalanced Classification with Python
Модели в машинном обучении “ленивы”. При работе с несбалансированными классами модель будет чаще сталкиваться с доминирующим классом и вместо того, чтобы разбираться в признаках объектов, может начать ориентироваться на статистическое распределение классов.
Пример: датасет, в котором 95% объектов относятся к классу 1 и 5% к классу 0. Модель может выучиться всегда относить объекты к классу 1, и в 95% случаях она будет права.
Чтобы это поправить, можно изменить функцию потерь так, чтобы больше штрафовать модель за ошибки в минорных классах. Для этого можно:
weight
, который имеет по умолчанию значение None
. В него можно передать тензор весов, соответствующий размеру вектора целевых значений, и получить взвешенную функцию ошибок.Посмотрим, как это работает. Допустим, мы получили от нейросети неверные предсказания:
второй объект должен относиться к классу 1, а не 0
scores = torch.tensor([[30.0, 2.0], [30.0, 2.0]]) # Scores for batch of two samples
target = torch.tensor([0, 1]) # Second sample belongs to class 1
# but logit for class 0 greater: 30 > 2. So it was misclassified
Подсчитаем Cross-Entropy Loss без весов:
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(scores, target)
print(f"Loss = {loss.item():.2f}")
Loss = 14.00
Если у нас есть два класса с соотношением 4:1, можно задать веса weight = [0.2, 0.8]
. И, так как сеть ошиблась на классе с большим весом, ошибка вырастет:
weights = torch.tensor([0.2, 0.8], dtype=torch.float32)
criterion = torch.nn.CrossEntropyLoss(weight=weights)
loss = criterion(scores, target)
print(f"Loss = {loss.item():.2f}")
Loss = 22.40
Сумма весов может быть не равна единице:
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([2.0, 8.0]))
loss = criterion(scores, target)
print(f"Loss = {loss.item():.2f}")
Loss = 22.40
Иногда качество модели можно улучшить, взяв квадратные корни от полученных таким образом весов (немного снижает штрафы за ошибки на редких классах).
Несмотря на интуитивно понятную логику работы способа, он не всегда дает значительный эффект. Тем не менее, на практике стоит пробовать экспериментировать с этим способом наряду с прочими техниками борьбы с дисбалансом.
Можно менять форму функции потерь. В 2017 году для работы с несбалансированными классами был предложен Focal Loss. Это — модификация Cross-Entropy Loss, которая модифицирует ее форму для различных классов. С тех пор появились различные модификации функции ошибок, с которыми можно и нужно экспериментировать.
Focal Loss — это функция потерь, используемая в нейронных сетях для решения проблемы классификации сложных объектов (hard examples).
Приведем пример: пусть модель должна классифицировать фрукты на два класса: яблоки и груши. В наборе данных есть много явных представителей того и иного класса: зеленые яблоки и желтые груши. Для модели они будут простыми: она на них не ошибается, и предсказываемая вероятность истинного класса для них велика и равна $0.9$.
В то же время в наборе данных есть малое количество зеленых груш: они по форме похожи на груши, а по цвету — на яблоки. Говоря более общими словами, зеленые груши ближе к границе классов в пространстве признаков. Для модели такие примеры могут оказаться сложными, и она будет на них ошибаться. Пусть для зеленой груши вероятность истинного класса равна $0.2$, то есть модель ошиблась и предсказала класс "яблоко" с вероятностью $0.8$.
Проблема состоит в том, что сумма большого количества малых ошибок на простых объектах может перевешивать сумму малого количества ошибок потерь на сложных объектах. Поэтому модель будет плохо учиться верно классифицировать сложные объекты: ей будет "лень" исправлять незначительные ошибки.
Для решения этой проблемы была предложена специальная функция потерь — Focal Loss. Она немного модифицирует кросс-энтропию для придания большей значимости ошибкам на сложных объектах.
Focal Loss была предложена в статье Focal Loss for Dense Object Detection (Lin et al., 2017) изначально для задачи детектирования объектов на изображениях. Она определяется так:
$$\large\text{FL}(p_t) = -(1 - p_t)^\gamma\text{log}(p_t)$$Здесь $p_t$ — предсказанная вероятность истинного класса, а $\gamma$ — настраиваемый гиперпараметр.
Focal Loss уменьшает потери на уверенно классифицируемых примерах (где $p_t>0.5$) и больше фокусируется на сложных примерах, которые классифицированы неправильно. Параметр $\gamma$ управляет относительной важностью неправильно классифицируемых примеров. Более высокое значение $\gamma$ увеличивает важность неправильно классифицированных примеров. В экспериментах авторы показали, что параметр $\gamma=2$ показывал себя наилучшим образом в их задаче.
При $\gamma=0$ Focal Loss становится равной Cross-Entropy Loss, которая выражается как обратный логарифм вероятности истинного класса:
$$\large\text{CE}(p_t)=-\text{log}(p_t)$$Разберем на нашем примере с яблоками и грушами.
Мы имеем 20 простых объектов с вероятностью истинного класса $0.9$: 10 яблок и 10 груш, и один сложный объект с вероятностью истинного класса $0.2$.
$\large{\text{CE} = \overbrace{\sum^{20}-\text{log}(0.9)}^{\large\color{#3C8031}{\text{loss(easy apples and pears)=2.11}}} + \overbrace{(-\text{log}(0.2))}^{\large\color{#F26035}{\text{loss(hard pear)=1.61}}} \approx 3.72}$
$\large{\text{FL}(\gamma=2) = \overbrace{\sum^{20}-\color{#AF3235}{\underbrace{(1-0.9)^2}_{0.01}}\text{log}(0.9)}^{\large\color{#3C8031}{\text{loss(easy apples and pears)=0.02}}} + \overbrace{(-\color{#AF3235}{\underbrace{(1-0.2)^2}_{0.64}}\text{log}(0.2))}^{\large\color{#F26035}{\text{loss(hard pear)=1.03}}} \approx 1.05}$
Фактически, потери для уверенно классифицированных объектов дополнительно занижаются.
Этот эффект достигается путем домножения на коэффициент: $ \large(1-p_{t})^\gamma$
Пока модель ошибается, $p_{t}$ — мала, и значение выражения в скобках соответственно близко к 1.
Когда модель обучилась, значение $p_{t}$ становится близким к 1, а разность в скобках становится маленьким числом, которое возводится в степень $ \gamma \ge 0 $. Таким образом, домножение на это небольшое число нивелирует вклад верно классифицированных объектов.
Это позволяет модели сосредоточиться (сфокусироваться, отсюда и название) на изучении сложных объектов (hard examples).
В примере коэффициент $(1-p_t)^\gamma$ в Focal Loss в 100 раз занизил потери при уверенной классификации простых яблок и груш, и потери при неверной классификации сложной груши стали преобладать.
Давайте посчитаем для различных значений $γ$, сколько понадобится примеров с небольшой ошибкой (высокой вероятностью истинного класса, равной $0.9$), чтобы получить суммарный Focal Loss примерно такой же, как у одного примера с большой ошибкой (низкой вероятностью истинного класса, равной $0.2$).
import numpy as np
def cross_entropy(prob_true):
return -np.log(prob_true)
def focal_loss(prob_true, gamma=2):
return (1 - prob_true) ** gamma * cross_entropy(prob_true)
p1 = 0.9 # probability of easy examples predictions
p2 = 0.2 # probability of hard examples predictions
gammas = [0, 0.5, 1, 2, 5, 10, 15]
print(
f"For probability of easy examples predictions {p1} and probability of hard examples predictions {p2}\n"
)
for gamma in gammas:
fl1 = focal_loss(p1, gamma)
fl2 = focal_loss(p2, gamma)
print(
f"gamma = {gamma},".ljust(15),
f"for an equal loss with a problematic prediction, almost correct ones are required {int(fl2 / fl1)}",
)
For probability of easy examples predictions 0.9 and probability of hard examples predictions 0.2 gamma = 0, for an equal loss with a problematic prediction, almost correct ones are required 15 gamma = 0.5, for an equal loss with a problematic prediction, almost correct ones are required 43 gamma = 1, for an equal loss with a problematic prediction, almost correct ones are required 122 gamma = 2, for an equal loss with a problematic prediction, almost correct ones are required 977 gamma = 5, for an equal loss with a problematic prediction, almost correct ones are required 500548 gamma = 10, for an equal loss with a problematic prediction, almost correct ones are required 16401977428 gamma = 15, for an equal loss with a problematic prediction, almost correct ones are required 537459996388583
Как видно, при увеличении значения $\gamma$ можно достичь значительного роста "важности" примеров с высокой ошибкой, что по сути позволяет модели обращать внимание на "hard examples".
Этот пример также показывает опасность Focal Loss: если мы имеем ошибки в разметке, то при большом $\gamma$ можно начать очень сильно наказывать модель за ошибки на неверно размеченных примерах, что может привести к переобучению под ошибки в разметке.
Focal Loss может применяться также и в задачах с дисбалансом классов. В этом смысле объекты преобладающего класса могут считаться простыми, а объекты минорного класса — сложными.
Однако для работы с дисбалансом в Focal Loss могут быть добавлены веса для классов. Тогда формула будет выглядеть так:
$$\large\text{FL}(p_t) = -\alpha_t(1 - p_t)^\gamma\text{log}(p_t)$$Здесь $\alpha_t$ — вес для истинного класса, имеющий такой же смысл, как параметр weight
в Cross-Entropy Loss.
Focal Loss не реализована в PyTorch нативно, но существуют сторонние совместимые реализации. Посмотрим, как воспользоваться одной из них.
import random
def set_random_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
set_random_seed(42)
#!wget https://raw.githubusercontent.com/AdeelH/pytorch-multi-class-focal-loss/master/focal_loss.py
!wget -q https://edunet.kea.su/repo/EduNet-web_dependencies/L11/focal_loss.py
import torch
from torch import nn
from focal_loss import FocalLoss
criterion = FocalLoss(alpha=None, gamma=2.0)
model_output = torch.rand(3, 3) # model output is logits, as in CELoss
print(f"model_output:\n {model_output}")
target = torch.empty(3, dtype=torch.long).random_(3)
print(f"target: {target}")
loss_fl = criterion(model_output, target)
print(f"loss_fl: {loss_fl}")
model_output: tensor([[0.8823, 0.9150, 0.3829], [0.9593, 0.3904, 0.6009], [0.2566, 0.7936, 0.9408]]) target: tensor([2, 1, 1]) loss_fl: 0.6864498257637024
Убедимся, что сторонняя реализация вычисляет то, что нужно, и вычислим значение вручную. В первую очередь нужно перевести model_output
из логитов в вероятности с помощью softmax.
probs = torch.nn.functional.softmax(model_output, dim=1)
print(f"probabilities after softmax:\n {probs}")
probabilities after softmax: tensor([[0.3788, 0.3914, 0.2299], [0.4415, 0.2500, 0.3085], [0.2131, 0.3646, 0.4224]])
Теперь вручную рассчитаем значение функции потерь.
def focal_loss(prob_true, gamma=2):
return -((1 - prob_true) ** gamma) * np.log(prob_true)
hand_calculated_loss = 0
for i in range(3):
hand_calculated_loss += focal_loss(probs[i, target[i]])
hand_calculated_loss /= 3 # average by number of samples
print(f"hand-calculated focal loss: {hand_calculated_loss.item()}")
print(f"library-calculated focal loss: {loss_fl}")
print(
f"Are results almost equal? {torch.isclose(loss_fl, hand_calculated_loss).item()}"
)
hand-calculated focal loss: 0.6864497661590576 library-calculated focal loss: 0.6864498257637024 Are results almost equal? True
Действительно, при расчете вручную получили то же значение, что и при расчете с помощью сторонней реализации.
В случае сильно несбалансированных наборов данных стоит задуматься, могут ли такие примеры рассматриваться как аномалия (выброс) или нет. Если такое событие и впрямь может считаться аномальным, мы можем использовать такие модели, как OneClassSVM
, методы кластеризации или методы обнаружения гауссовских аномалий.
Эти методы требуют изменения взгляда на задачу: мы будем рассматривать аномалии как отдельный класс выбросов. Это может помочь нам найти новые способы разделения и классификации.
Если продолжать пример с фруктами, то задача обнаружения аномалий возникла бы, если бы мы предполагали, что среди яблок и груш может вдруг возникнуть мандарин, или любой другой фрукт, и нам бы нужно было не отнести его к одному из известных классов, а пометить как отдельный, отличающийся класс.
Проблемой при работе с аномалиями является то, что аномальных значений может быть очень мало или вообще не быть. В таком случае алгоритм учит паттерны нормального поведения и реагирует на отличия от паттернов.
Разберем примеры обнаружения аномалий с помощью трех алгоритмов из библиотеки Scikit-Learn (там можно найти еще много различных алгоритмов).
Создадим датасет из двух кластеров и случайных значений.
rng = np.random.RandomState(42)
# Train
x = 0.3 * rng.randn(100, 2) # 100 2D points
x_train = np.r_[x + 2, x - 2] # split into two clusters
# Test norlmal
x = 0.3 * rng.randn(20, 2) # 20 2D points
x_test_norlmal = np.r_[x + 2, x - 2] # split into two clusters
# Test outliers
x_test_outliers = rng.uniform(low=-4, high=4, size=(20, 2))
Напишем функцию визуализации, которая будет изображать созданный датасет на рисунке слева, а результат поиска аномалий — на рисунке справа.
def plot_outliers(x_train, x_test_norlmal, x_test_outliers, model=None):
fig, (plt_data, plt_model) = plt.subplots(1, 2, figsize=(12, 6))
plt_data.set_title("Created Dataset (real labels)")
plot_train = plt_data.scatter(
x_train[:, 0], x_train[:, 1], c="white", s=40, edgecolor="k"
)
plot_test_normal = plt_data.scatter(
x_test_norlmal[:, 0], x_test_norlmal[:, 1], c="green", s=40, edgecolor="k"
)
plot_test_outliers = plt_data.scatter(
x_test_outliers[:, 0], x_test_outliers[:, 1], c="red", s=40, edgecolor="k"
)
plt_data.set_xlim((-5, 5))
plt_data.set_ylim((-5, 5))
plt_data.legend(
[plot_train, plot_test_normal, plot_test_outliers],
["train", "test normal", "test outliers"],
loc="lower right",
)
if model:
plt_model.set_title("Model Results")
# plot decision function
xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))
Z = model.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt_model.contourf(xx, yy, Z, cmap=plt.cm.Blues_r)
# plot prediction
full_data = np.concatenate((x_train, x_test_norlmal, x_test_outliers), axis=0)
predicted = model.predict(full_data)
anom_index = np.where(predicted == -1)
anom_values = full_data[anom_index]
plot_all_data = plt_model.scatter(
full_data[:, 0], full_data[:, 1], c="white", s=40, edgecolor="k"
)
plot_anom_data = plt_model.scatter(
anom_values[:, 0], anom_values[:, 1], c="red", s=40, marker="x"
)
plt_model.legend(
[plot_all_data, plot_anom_data],
["normal", "outliers"],
loc="lower right",
)
plt.show()
Посмотрим, как работает на этих данных алгоритм OneClassSVM.
Идея алгоритма состоит в поиске функции, которая положительна для областей с высокой плотностью и отрицательна для областей с малой плотностью. Подробнее об алгоритме можно прочитать в оригинальной статье.
from sklearn.svm import OneClassSVM
gamma = 2.0
contamination = 0.05
model = OneClassSVM(gamma=gamma, kernel="rbf", nu=contamination)
model.fit(x_train)
plot_outliers(x_train, x_test_norlmal, x_test_outliers, model)
Посмотрим, как на этих же данных работает алгоритм IsolationForest.
IsolationForest состоит из деревьев, которые «изолируют» (пытаются отделить от остальной выборки) наблюдения, случайным образом выбирая признак и случайное значение порога для этого признака (между max и min значениями признака). Такой алгоритм чаще и проще отделяет значения аномалии. Если построить по такому принципу множество деревьев, то значения, которые чаще других отделяются раньше, будут аномалиями.
from sklearn.ensemble import IsolationForest
n_estimators = 200
contamination = 0.05
model = IsolationForest(
n_estimators=n_estimators, contamination=contamination, random_state=rng
)
model.fit(x_train)
plot_outliers(x_train, x_test_norlmal, x_test_outliers, model)
Последним алгоритмом, на который мы посмотрим, будет LocalOutlierFactor.
В нем используется метод k-NN. Расстояние до ближайших соседей используется для оценки расположения точек. Если соседи далеко, то точка с большой вероятностью является аномалией.
from sklearn.neighbors import LocalOutlierFactor
n_neighbors = 10
contamination = 0.05
model = LocalOutlierFactor(
n_neighbors=n_neighbors, novelty=True, contamination=contamination
)
model.fit(x_train)
plot_outliers(x_train, x_test_norlmal, x_test_outliers, model)
Обращайте внимание на то, какие метрики вы используете. При решении задачи классификации часто используется accuracy (точность), равная доле правильно классифицированных объектов. Эта метрика позволяет адекватно оценить результат классификации в случае сбалансированных классов. В случае дисбаланса классов данная метрика может выдать обманчиво хороший результат.
Пример: датасет, в котором 95% объектов относятся к классу 0 и 5% к классу 1.
from sklearn.datasets import make_classification
from collections import Counter
x, y = make_classification(
n_samples=1000,
n_features=2,
n_redundant=0,
n_clusters_per_class=1,
weights=[0.95],
flip_y=0,
random_state=42,
)
counter = Counter(y)
print("Class distribution ", Counter(y))
for label, _ in counter.items():
row_ix = np.where(y == label)[0]
plt.scatter(x[row_ix, 0], x[row_ix, 1], label=str(label))
plt.legend()
plt.show()
Class distribution Counter({0: 950, 1: 50})
И модель, которая для всех данных выдает класс 0,
class DummyModel:
def predict(self, x):
return np.zeros(x.shape[0]) # always predict class 0
Такая модель будет иметь $accuracy = 0.95$, хотя не выдает никакой полезной информации:
from sklearn.metrics import accuracy_score
dummy_model = DummyModel()
y_pred = dummy_model.predict(x)
accuracy = accuracy_score(y, y_pred)
print("Accuracy", accuracy)
Accuracy 0.95
Для несбалансированных данных лучше выбирать F1 Score, MCC (Matthews correlation coefficient, коэффициент корреляции Мэтьюса) или balanced accuracy (среднее между recall разных классов).
from sklearn.metrics import f1_score, matthews_corrcoef, balanced_accuracy_score
print("F1", f1_score(y, y_pred))
print("MCC", matthews_corrcoef(y, y_pred))
print("Balanced accuracy", balanced_accuracy_score(y, y_pred))
F1 0.0 MCC 0.0 Balanced accuracy 0.5
Все эти метрики оказываются равны нулю, что отражает отсутствие связи предсказаний с данными на входе модели.
Другой способ побороть маленькое количество данных для обучения — аугментация.
Аугмента́ция (от лат. augmentatio — увеличение, расширение) — увеличение выборки обучающих данных через модификацию существующих данных.
Модели глубокого обучения обычно требуют большого количества данных для обучения. В целом, чем больше данных, тем лучше для обучения модели. В то же время получение огромных объемов данных сопряжено со своими проблемами (например, с нехваткой размеченных данных или с трудозатратами, сопряженными с разметкой).
Вместо того, чтобы тратить дни на сбор данных вручную, мы можем использовать методы аугментации для автоматической генерации новых примеров из уже имеющихся.
Помимо увеличения размеченных датасетов, многие методы self-supervised learning построены на использовании разных аугментаций одного и того же сэмпла.
Важный момент: при обучении модели мы используем разбиение данных на train-val-test
. Аугментации стоит применять только на train
. Почему так? Конечная цель обучения нейросети — это применение на реальных данных, которые сеть не видела. Поэтому для адекватной оценки качества модели валидационные и тестовые данные изменять не нужно.
В любом случае, test
должен быть отделен от данных еще до того, как они попали в DataLoader
или нейросеть.
Другое дело, что аугментации на тесте можно использовать как метод ансамблирования в случае классификации. Можно взять sample → создать несколько его копий → по-разному их аугментировать → предсказать класс на каждой из этих аугментированных копий → а потом выбрать наиболее вероятный класс голосованием (такой функционал реализован, например, в YOLOv5, о которой речь пойдет в следующих лекциях).
Загрузим и отобразим пример картинки. Картинку отмасштабируем, чтобы она не занимала весь экран.
# setting random seed for reproducible illustrations
set_random_seed(42)
URL = "https://edunet.kea.su/repo/EduNet-web_dependencies/L11/capybara_image.jpg"
!wget -q $URL -O test.jpg
from IPython.display import display
from PIL import Image
from torchvision import transforms
input_img = Image.open("/content/test.jpg")
input_img = transforms.Resize(size=300)(input_img)
display(input_img)
Рассмотрим несколько примеров аугментаций картинок. С полным списком можно ознакомиться на сайте [doc] документации torchvision.
Трансформация transforms.Random Rotation
принимает параметр degrees
— диапазон углов, из которого выбирается случайный угол для поворота изображения.
Создадим переменную transform
, в которую добавим нашу аугментацию, и применим ее к исходному изображению. Затем запустим следующую ячейку несколько раз подряд
import matplotlib.pyplot as plt
def plot_augmented_img(transform, input_img):
fig, ax = plt.subplots(1, 2, figsize=(15, 15))
augmented_img = transform(input_img)
ax[0].imshow(input_img)
ax[0].set_title("Original img")
ax[0].axis("off")
ax[1].imshow(augmented_img)
ax[1].set_title("Augmented img")
ax[1].axis("off")
plt.show()
transform = transforms.RandomRotation(degrees=(0, 180))
plot_augmented_img(transform, input_img)
transforms.GaussianBlur
размывает изображение с помощью фильтра Гаусса.
transform = transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
plot_augmented_img(transform, input_img)
transforms.RandomErasing
стирает на изображении произвольный прямоугольник. Она имеет параметр p
— вероятность, с которой данная трансформация вообще применится к изображению.
Данная трансформация работает только с torch.Tensor
, поэтому предварительно нужно применить трансформацию ToTensor
, а затем ToPILImage
, чтобы воспользоваться нашей функцией для отображения.
transform = transforms.Compose(
[transforms.ToTensor(), transforms.RandomErasing(p=1), transforms.ToPILImage()]
)
plot_augmented_img(transform, input_img)
Не лишним будет заметить, что некоторые трансформации могут существенно исказить изображение. Например, здесь, RandomErasing
практически полностью стерла основной объект на снимке — капибару. Такая грубая аугментация может только навредить процессу обучения, и на практике нужно быть осторожным.
RandomErasing
также имеет параметр scale
— диапазон соотношения стираемой области к входному изображению. Попробуем уменьшить этот диапазон относительно значения по умолчанию, чтобы избежать нежелательного эффекта стирания капибары.
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.RandomErasing(p=1, scale=(0.02, 0.1)),
transforms.ToPILImage(),
]
)
plot_augmented_img(transform, input_img)
transforms.ColorJitter
случайным образом меняет яркость, контрастность, насыщенность и оттенок изображения.
transform = transforms.ColorJitter(brightness=0.5, hue=0.3)
plot_augmented_img(transform, input_img)
Для этого будем использовать метод transforms.Compose
. Нам нужно будет создать list
со всеми аугментациями, которые будут применены последовательно.
transform = transforms.Compose(
[
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
transforms.RandomPerspective(distortion_scale=0.5, p=1.0),
transforms.ColorJitter(brightness=0.5, hue=0.3),
]
)
plot_augmented_img(transform, input_img)
Для того, чтобы применять аугментации случайным образом, можно воспользоваться методом transforms.RandomApply
, который на вход принимает список аугментаций и вероятность p
, с которой каждая аугментация будет применена.
transform = transforms.RandomApply(
transforms=[
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
transforms.RandomPerspective(distortion_scale=0.5),
transforms.ColorJitter(brightness=0.5, hue=0.3),
],
p=0.9,
)
plot_augmented_img(transform, input_img)
В других случаях может быть полезен метод transforms.RandomChoice
, который на вход принимает список аугментаций transforms
, выбирает из него одну случайную аугментацию и применяет ее к изображению. Необязательным параметром является список вероятностей p
, который указывает, с какой вероятностью каждая из аугментаций может быть выбрана из списка (по умолчанию каждая может быть выбрана равновероятно).
transform = transforms.RandomChoice(
transforms=[
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
transforms.RandomPerspective(distortion_scale=0.5, p=1.0),
transforms.ColorJitter(brightness=0.5, hue=0.3),
],
p=[0.2, 0.4, 0.6],
)
plot_augmented_img(transform, input_img)
plot_augmented_img(transform, input_img)
plot_augmented_img(transform, input_img)
Иногда может оказаться, что среди широкого спектра реализованных аугментаций нет такой, какую вы хотели бы применить к своим данным. В таком случае ее можно описать в виде класса и использовать наравне с реализованными в библиотеке.
Главное, что необходимо описать при создании класса — метод __call__
. Он должен принимать изображение (оно может быть представлено в формате PIL.Image
, np.array
или torch.Tensor
), делать с ним интересующие нас видоизменения и возвращать измененное изображение.
Рассмотрим пример добавления на изображение шума "соль и перец". Наш метод аугментации будет и принимать на вход, и возвращать PIL.Image
.
from PIL import Image
import numpy as np
class SaltAndPepperNoise:
"""
Add a "salt and pepper" noise to the PIL image
__call__ method returns PIL Image with noise
"""
def __init__(self, p=0.01):
self.p = p # noise level
def __call__(self, pil_image):
np_image = np.array(pil_image)
# create random mask for "salt" and "pepper" pixels
salt_ind = np.random.choice(
a=[True, False], size=np_image.shape[:2], p=[self.p, 1 - self.p]
)
pepper_ind = np.random.choice(
a=[True, False], size=np_image.shape[:2], p=[self.p, 1 - self.p]
)
# add "salt" and "pepper"
np_image[salt_ind] = 255
np_image[pepper_ind] = 0
return Image.fromarray(np_image)
transform = SaltAndPepperNoise(p=0.03)
plot_augmented_img(transform, input_img)
Dataset
¶Возьмем папку с картинками.
import os
from zipfile import ZipFile
os.chdir("/content")
# download files
!wget -q --no-check-certificate 'https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/for_transforms.Compose.zip' -O data.zip
with ZipFile(
"data.zip", "r"
) as folder: # Create a ZipFile Object and load sample.zip in it
folder.extractall() # Extract all the contents of zip file in current directory
os.chdir("/content/for_transforms.Compose")
img_list = os.listdir()
print(img_list)
['horse4.jpg', 'horse2.jpg', 'bicornis5.jpg', 'bicornis1.jpg', 'bicornis3.jpg', 'bicornis2.jpg', 'horse1.jpg', 'horse5.jpg']
Напишем класс Dataset
from torch.utils.data import Dataset
class AugmentationDataset(Dataset):
def __init__(self, img_list, transforms=None):
self.img_list = img_list
self.transforms = transforms
def __len__(self):
return len(self.img_list)
def __getitem__(self, i):
img = plt.imread(self.img_list[i])
img = Image.fromarray(img).convert("RGB")
img = np.array(img).astype(np.uint8)
if self.transforms is not None:
img = self.transforms(img)
return img
Напишем вспомогательную функцию для отображения картинок. Напомним, что в PyTorch размерность каналов идет в первом, а не в последнем измерении тензора, описывающего картинку: Channels x Height x Width
. Для отображения при помощи Matplotlib необходимо перевести массив в формат Height x Width x Channels
.
def show_img(img):
plt.figure(figsize=(40, 38))
img_np = img.numpy()
plt.imshow(np.transpose(img_np, (1, 2, 0))) # [CxHxW] -> [HxWxC] for imshow
plt.show()
Создадим list
с аугментациями, которые мы хотим применить. Чтобы загрузить аугментации в PyTorch, нам необходимо эти картинки преобразовать в тензоры. Для этого воспользуемся стандартным преобразованием transforms.ToTensor()
tensor_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((164, 164)),
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
transforms.RandomPerspective(distortion_scale=0.5),
transforms.ToTensor(),
]
)
Теперь обернем все в DataLoader
и отобразим
from torch.utils.data import DataLoader
import torchvision
Augmentation_dataloader = DataLoader(
AugmentationDataset(img_list, tensor_transform), batch_size=8, shuffle=True
)
data = iter(Augmentation_dataloader)
show_img(torchvision.utils.make_grid(next(data)))
Существуют и более сложные способы аугментации. Ниже приведена пара примеров таких способов.
Mixup — это "смешение" признаков двух объектов в определенных пропорциях. Mixup можно представить с помощью простого уравнения:
$\text{New image} = \alpha * \text{image}_1 + (1-\alpha) * \text{image}_2$
В ряде случаев возможно расширение набора данных путем синтеза новых данных.
Например, при создании системы распознавания текста можно генерировать новые образцы путем набора распознаваемых фраз или символов различными шрифтами на различных фонах и с добавлением каких-либо шумов и искажений.
В ряде областей для синтеза новых образов могут создаваться 3D-модели распознаваемых объектов. Например, в работе от Microsoft Fake It Till You Make It: Face analysis in the wild using synthetic data alone анализ лиц людей производился на синтетических 3D-моделях лиц. Датасет доступен на GitHub.
Также созданием новых образов, похожих на имеющиеся в датасете, можно заниматься при помощи генеративных моделей. Примером генеративных моделей является GAN (Generative Adversarial Network). Мы познакомимся с такими моделями в одной из следующих лекций.
Кроме методов, реализованных в PyTorch, существуют и специализированные библиотеки для аугментации изображений, в которых реализованы дополнительные возможности (например, наложение теней, бликов или пятен воды на изображение).
Например:
Важно: при выборе методов аугментации имеет смысл использовать только те, которые будут в реальной жизни.
Например, нет смысла:
Рассмотрим несколько примеров аугментаций аудио. С полным списком можно ознакомиться здесь: [git] audiomentations.
Импортируем библиотеку и посмотрим на пример
os.chdir("/content")
!pip install -q audiomentations
!wget -q https://edunet.kea.su/repo/EduNet-web_dependencies/L11/audio_example.wav
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.6/70.6 kB 5.9 MB/s eta 0:00:00
from IPython.display import Audio
# Get input audio
input_audio = "/content/audio_example.wav"
display(Audio(input_audio))
import librosa
data, sr = librosa.load("/content/audio_example.wav") # sr - sampling rate
from audiomentations import AddGaussianSNR
augment = AddGaussianSNR(min_snr_in_db=3, max_snr_in_db=7, p=1)
# Augment/transform the audio data
augmented_data = augment(samples=data, sample_rate=sr)
display(Audio(augmented_data, rate=sr))
Сравним волновые картины и спектрограммы
from scipy.signal import spectrogram
def produce_plots(input_audio_arr, aug_audio, sr):
f, t, Sxx_in = spectrogram(
input_audio_arr, fs=sr
) # Compute spectrogram for the original signal (f - frequency, t - time)
f, t, Sxx_aug = spectrogram(aug_audio, fs=sr)
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(20, 5))
ax[0, 0].plot(input_audio_arr)
ax[0, 0].set_xlim(0, len(input_audio_arr))
ax[0, 0].set_xticks([])
ax[0, 0].set_title("Original audio")
ax[0, 1].plot(aug_audio)
ax[0, 1].set_xlim(0, len(input_audio_arr))
ax[0, 1].set_xticks([])
ax[0, 1].set_title("Augmented audio")
ax[1, 0].imshow(
np.log(Sxx_in),
extent=[t.min(), t.max(), f.min(), f.max()],
aspect="auto",
cmap="inferno",
)
ax[1, 0].set_ylabel("Frequecny, Hz")
ax[1, 0].set_xlabel("Time,s")
ax[1, 1].imshow(
np.log(Sxx_aug, where=Sxx_aug > 0),
extent=[t.min(), t.max(), f.min(), f.max()],
aspect="auto",
cmap="inferno",
)
ax[1, 1].set_ylabel("Frequecny, Hz")
ax[1, 1].set_xlabel("Time,s")
plt.subplots_adjust(hspace=0)
plt.show()
produce_plots(data, augmented_data, sr)
from audiomentations import TimeStretch
augment = TimeStretch(min_rate=0.8, max_rate=1.5, p=1)
augmented_data = augment(data, sample_rate=sr)
display(Audio(augmented_data, rate=sr))
produce_plots(data, augmented_data, sr)
Изменение тональности:
from audiomentations import PitchShift
augment = PitchShift(min_semitones=1, max_semitones=12, p=1)
augmented_data = augment(data, sample_rate=sr)
display(Audio(augmented_data, rate=sr))
Как и в случае с картинками, мы можем совмещать несколько аугментаций вместе
from audiomentations import Compose, AddGaussianNoise, Shift
augment = Compose(
[
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1),
TimeStretch(min_rate=0.8, max_rate=1.25, p=1),
PitchShift(min_semitones=-4, max_semitones=4, p=1),
Shift(min_fraction=-0.5, max_fraction=0.5, p=1),
]
)
augmented_data = augment(data, sample_rate=sr)
display(Audio(augmented_data, rate=sr))
Посмотрим на то, что получилось:
produce_plots(data, augmented_data, sr)
Дополнительные библиотеки для аугментации звука (и волновых функций в целом):
Теперь рассмотрим несколько примеров аугментаций текста. С полным списком можно ознакомиться здесь: [git] библиотеки.
!pip install -q nlpaug
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.5/410.5 kB 19.0 MB/s eta 0:00:00
# Define input text
text = "Hello, future of AI for Science! How are you today?"
print(f"input text: {text}")
input text: Hello, future of AI for Science! How are you today?
Заменой на похоже выглядящие:
import nlpaug.augmenter.char as nac
augment = nac.OcrAug()
augmented_text = augment.augment(text)
print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original: Hello, future of AI for Science! How are you today? Augmented Texts: ['Hel1u, fotore of AI for 8cience! How are you today?']
С опечатками, которые учитывают расположение символов на клавиатуре:
augment = nac.KeyboardAug()
augmented_text = augment.augment(text)
print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original: Hello, future of AI for Science! How are you today? Augmented Texts: ['nel/o, fufuDe of AI for ZcienX2! How are you g8day?']
С орфографическими ошибками:
import nlpaug.augmenter.word as naw
augment = naw.SpellingAug()
augmented_text = augment.augment(text, n=3)
print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original: Hello, future of AI for Science! How are you today? Augmented Texts: ['Hello, futur og AI for Science! Hot are ypi today?', 'Hello, furtuer f AI for Science! Wow are you today?', 'Hello, future for AI fom Scince! Hou are you today?']
С использованием модели для предсказания новых слов в зависимости от контекста:
!pip install -q transformers
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.4/7.4 MB 84.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 32.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 108.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 66.9 MB/s eta 0:00:00
from IPython.display import clear_output
# model_type: word2vec, glove or fasttext
augment = naw.ContextualWordEmbsAug(model_path="bert-base-uncased", action="insert")
augmented_text = augment.augment(text)
clear_output()
print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original: Hello, future of AI for Science! How are you today? Augmented Texts: ['big hello, future of ai sci for all science! now how are you today?']
Мы можем перевести текстовые данные на какой-либо язык, а затем перевести их обратно на язык оригинала. Это может помочь сгенерировать текстовые данные с разными словами, сохраняя при этом контекст текстовых данных.
!pip -q install sacremoses
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 880.6/880.6 kB 23.4 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Building wheel for sacremoses (setup.py) ... done
back_translation_aug = naw.BackTranslationAug(
from_model_name="facebook/wmt19-en-de", to_model_name="facebook/wmt19-de-en"
)
augmented_text = back_translation_aug.augment(text)
clear_output()
print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original: Hello, future of AI for Science! How are you today? Augmented Texts: ['Hello, the future of AI for science! How are you doing today?']
Instance Crossover Augmentation — составление новых объектов класса из отдельных предложений того же класса. Например, есть два объекта одного класса "положительный отзыв":
Тогда можно составить новый объект того же класса из их частей:
Важно: при любых аугментациях текста на уровне предложений есть шанс создать странные и нелогичные объекты, поэтому использовать их следует с особой осторожностью.
Дополнительные библиотеки для аугментации текста:
Как обучить нейросеть на своих данных, когда их мало?
Для такой типовой задачи, как классификация изображений, можно воспользоваться одной из существующих архитектур (AlexNet, VGG, Inception, ResNet и т.д.) и просто обучить нейросеть на своих данных. Реализации таких сетей с помощью различных фреймворков уже существуют, так что на данном этапе можно использовать одну из них как черный ящик, не вникая в принцип её работы. Например, в PyTorch есть множество уже реализованных известных архитектур: torchvision.models.
Однако глубокие нейронные сети требуют больших объемов данных для успешного обучения. И, зачастую, в нашей частной задаче недостаточно данных для того, чтобы хорошо обучить нейросеть с нуля. Transfer learning решает эту проблему. По сути мы пытаемся использовать опыт, полученный нейронной сетью при обучении на некоторой задаче $T_1$, чтобы решать схожую задачу $T_2$.
К примеру, transfer learning можно использовать при решении задачи классификации изображений на небольшом наборе данных. Как уже ранее обсуждалось, при обработке изображений свёрточные нейронные сети в первых слоях "реагируют" на некие простые пространственные шаблоны (к примеру, углы), после чего комбинируют их в сложные осмысленные формы (к примеру, глаза или носы). Вся эта информация извлекается из изображения, на её основе создаются сложные представления данных, которые в результате классифицируются линейной моделью.
Идея заключается в том, что если изначально обучить модель на некоторой сложной и довольно общей задаче, то можно надеяться, что она (как минимум часть ее слоев), в общем случае, будет извлекать важную информацию из изображений, и полученные представления можно будет успешно использовать для классификации линейной моделью.
Таким образом, берем часть модели, которая, по нашему представлению, отвечает за выделение хороших признаков (часто — все слои, кроме последнего) — feature extractor. Присоединяем к этой части один или несколько дополнительных слоёв для решения уже новой задачи. И учим только эти слои. Cлои feature extractor не учим — они "заморожены".
Понятно, что не все фильтры модели будут использованы максимально эффективно — к примеру, если мы работаем с изображениями, связанными с едой, возможно, не все фильтры на скрытых слоях предобученной на ImageNet модели окажутся полезны для нашей задачи. Почему бы не попробовать не только обучить новый классификатор, но и дообучить некоторые промежуточные слои? При использовании этого подхода мы при обучении дополнительно "настраиваем" и промежуточные слои, называется он fine-tuning.
При fine-tuning используют меньший learning rate, чем при обучении нейросети с нуля: мы знаем, что по крайней мере часть весов нейросети выполняет свою задачу хорошо, и не хотим испортить это быстрыми изменениями.
Кроме этого, можно делать комбинации этих методов: сначала учить только последние добавленные нами слои сети, затем учить еще и самые близкие к ним, и после этого учить уже все веса нейросети вместе. То есть мы можем определить свою стратегию fine-tuning.
Иногда fine-tuning считается синонимом Transfer learning, в этом случае часть от предтренированной сети называют backbone ("позвоночник"), а добавленную часть — head ("голова").
Последовательно рассмотрим шаги, необходимые для реализации подхода transfer learning.
Шаг 1. Получение предварительно обученной модели
Первым шагом является выбор предварительно обученной модели, которую мы хотели бы использовать в качестве основы для обучения. Основным предположением является то, что признаки, которые умеет выделять из данных предобученная модель, хорошо подойдут для решения нашей частной задачи. Поэтому эффект от Transfer learning будет тем лучше, чем более схожими будут домены в нашей задаче и в задаче, на которой предварительно обучалась модель.
Для задач обработки изображений очень часто используются модели, предобученные на ImageNet. Такой подход распространен, однако, если ваша задача связана, например, с обработкой снимков клеток под микроскопом, то модель, предобученная на более близком домене (тоже на снимках клеток, пусть и совсем других), может быть лучшим начальным решением.
Шаг 2. Заморозка предобученных слоев
Мы предполагаем, что первые слои модели уже хорошо натренированы выделять какие-то абстрактные признаки из данных. Поэтому мы не хотим "сломать" их, особенно если начнем на этих признаках обучать новые слои, которые инициализируются случайно: на первых шагах обучения ошибка будет большой и мы можем сильно изменить "хорошие" предобученные веса.
Поэтому требуется "заморозить" предобученные веса. На практике заморозка означает отключение подсчета градиентов. Таким образом при последующем обучении параметры с отключенным подсчетом градиентов не будут обновляться.
Шаг 3. Добавление новых обучаемых слоев
В отличие от начальных слоев, которые выделяют достаточно общие признаки из данных, более близкие к выходу слои предобученной модели сильно специфичны конкретно под ту задачу, на которую она обучалась. Для моделей, предобученных на ImageNet, последний слой заточен конкретно под предсказание 1000 классов из этого набора данных. Кроме этого, последние слои могут не подходить под новую задачу архитектурно: в новой задаче может быть меньше классов, 10 вместо 1000. Поэтому, требуется заменить последние один или несколько слоев предобученной модели на новые, подходящие под нашу задачу. При этом, естественно, веса в этих слоях будут инициализированы случайно. Именно эти слои мы и будем обучать на следующем шаге.
Шаг 4. Обучение новых слоев
Все, что нам теперь нужно — обучить новые слои на наших данных. При этом замороженные слои используются лишь как экстрактор высокоуровневых признаков. Обучение такой модели существенно ничем не отличается от обучения любой другой модели: используется обучающая и валидационная выборка, контролируется изменение функции потерь и функционала качества.
Шаг 5. Тонкая настройка модели (fine-tuning)
После того, как мы обучили новые слои модели, и они уже как-то решают задачу, мы можем разморозить ранее замороженные веса, чтобы тонко настроить их под нашу задачу, в надежде, что это позволит еще немного повысить качество.
Нужно быть осторожным на этом этапе, использовать learning rate на порядок или два меньший, чем при основном обучении, и одновременно с этим следить за возникновением переобучения. Переобучение при fine-tuning может возникать из-за того, что мы резко увеличиваем количество настраиваемых параметров модели, но при этом наш датасет остается небольшим, и мощная модель может начать заучивать обучающие данные.
Давайте рассмотрим пример практической реализации такого подхода (код переработан из этой статьи).
Загрузим датасет EuroSAT и удалим из него 90% файлов. EuroSAT — датасет для классификации спутниковых снимков по типам местности: лес, река, жилая застройка и т. п.
import random
import torch
import numpy as np
def set_random_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
set_random_seed(42)
import os
from random import sample
!wget -qN https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/EuroSAT.zip # http://madm.dfki.de/files/sentinel/EuroSAT.zip
!unzip -qn EuroSAT.zip
os.chdir("/content")
path = "/content/2750/"
for folder in os.listdir(path):
files = os.listdir(path + folder)
for file in sample(files, int(len(files) * 0.9)):
os.remove(path + folder + "/" + file)
Определим аугментации. Для примера будем использовать родные аугментации из библиотеки Torchvision
from torchvision import transforms
# Applying Transforms to the Data
img_transforms = {
"train": transforms.Compose(
[
transforms.Resize(size=224), # as in ImageNet
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
]
),
# No augmentations on valid data!
"valid": transforms.Compose(
[
transforms.Resize(size=224),
transforms.ToTensor(),
]
),
# No augmentations on test data!
"test": transforms.Compose(
[
transforms.Resize(size=224),
transforms.ToTensor(),
]
),
}
Создадим datasets
from torchvision import datasets
from copy import deepcopy
dataset = datasets.ImageFolder(root=path)
# split to train/valid/test
train_set, valid_set, test_set = torch.utils.data.random_split(
dataset, [int(len(dataset) * 0.8), int(len(dataset) * 0.1), int(len(dataset) * 0.1)]
)
train_set.dataset = deepcopy(dataset)
valid_set.dataset = deepcopy(dataset)
test_set.dataset = deepcopy(dataset)
# define augmentations
train_set.dataset.transform = img_transforms["train"]
valid_set.dataset.transform = img_transforms["valid"]
test_set.dataset.transform = img_transforms["test"]
print(f"Train size: {len(train_set)}")
print(f"Valid size: {len(valid_set)}")
print(f"Test size: {len(test_set)}")
Train size: 2160 Valid size: 270 Test size: 270
from torch.utils.data import DataLoader
# Batch size
batch_size = 64
# Number of classes
num_classes = len(dataset.classes)
# Get a mapping of the indices to the class names, in order to see the output classes of the test images.
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
# Size of Data, to be used for calculating Average Loss and Accuracy
train_data_size, valid_data_size = len(train_set), len(valid_set)
# Create iterators for the Data loaded using DataLoader module
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
print("indexes to class: ")
idx_to_class
indexes to class:
{0: 'AnnualCrop', 1: 'Forest', 2: 'HerbaceousVegetation', 3: 'Highway', 4: 'Industrial', 5: 'Pasture', 6: 'PermanentCrop', 7: 'Residential', 8: 'River', 9: 'SeaLake'}
В наборе данных не так уж и много изображений. При обучении с нуля нейросеть скорее всего не достигнет высокой точности.
Загрузим MobileNet v2 без весов и попробуем обучить "с нуля", то есть с весов, инициализированных случайно.
from torchvision import models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = models.mobilenet_v2(weights=None)
print(model)
MobileNetV2( (features): Sequential( (0): Conv2dNormActivation( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False) (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (4): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False) (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (7): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False) (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (8): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (9): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (10): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (11): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False) (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (12): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (13): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (14): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False) (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (15): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (16): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (17): InvertedResidual( (conv): Sequential( (0): Conv2dNormActivation( (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (1): Conv2dNormActivation( (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False) (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (18): Conv2dNormActivation( (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace=True) ) ) (classifier): Sequential( (0): Dropout(p=0.2, inplace=False) (1): Linear(in_features=1280, out_features=1000, bias=True) ) )
Последний слой MobileNet дает на выходе предсказания для 1000 классов, а в нашем датасете классов всего 10. Поэтому мы должны изменить выход сети так, чтобы он выдавал 10 предсказаний. Поэтому мы заменяем последний слой модели MobileNet слоем с num_classes
нейронами, равным числу классов в нашем датасете.
То есть мы "сказали" нашей модели распознавать не 1000, а только num_classes
классов.
# Change the final layer of MobileNet Model for Transfer Learning
import torch.nn as nn
# change out classes, from 1000 to 10
model.classifier[1] = nn.Linear(1280, num_classes)
print(model.classifier)
Sequential( (0): Dropout(p=0.2, inplace=False) (1): Linear(in_features=1280, out_features=10, bias=True) )
Затем мы определяем функцию потерь и оптимизатор, которые будут использоваться для обучения.
import torch.optim as optim
# Define Optimizer and Loss Function
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
print(optimizer)
Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) capturable: False differentiable: False eps: 1e-08 foreach: None fused: None lr: 0.0003 maximize: False weight_decay: 0 )
Для тренировки и валидации нашей модели напишем отдельную функцию.
import time
def train_and_validate(model, criterion, optimizer, num_epochs=25, save_state=False):
"""
Function to train and validate
Parameters
:param model: Model to train and validate
:param criterion: Loss Criterion to minimize
:param optimizer: Optimizer for computing gradients
:param epochs: Number of epochs (default=25)
Returns
model: Trained Model with best validation accuracy
history: (dict object): Having training loss, accuracy and validation loss, accuracy
"""
start = time.time()
history = []
best_acc = 0.0
for epoch in range(num_epochs):
epoch_start = time.time()
print("Epoch: {}/{}".format(epoch + 1, num_epochs))
# Set to training mode
model.train()
# Loss and Accuracy within the epoch
train_loss = 0.0
train_acc = 0.0
valid_loss = 0.0
valid_acc = 0.0
train_correct = 0
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad() # Clean existing gradients
outputs = model(
inputs
) # Forward pass - compute outputs on input data using the model
loss = criterion(outputs, labels) # Compute loss
loss.backward() # Backpropagate the gradients
optimizer.step() # Update the parameters
# Compute the total loss for the batch and add it to train_loss
train_loss += loss.item() * inputs.size(0)
# Compute correct predictions
train_correct += (torch.argmax(outputs, dim=-1) == labels).float().sum()
# Compute the mean train accuracy
train_accuracy = 100 * train_correct / (len(train_loader) * batch_size)
val_correct = 0
# Validation - No gradient tracking needed
with torch.no_grad():
model.eval() # Set to evaluation mode
# Validation loop
for j, (inputs, labels) in enumerate(valid_loader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(
inputs
) # Forward pass - compute outputs on input data using the model
loss = criterion(outputs, labels) # Compute loss
valid_loss += loss.item() * inputs.size(
0
) # Compute the total loss for the batch and add it to valid_loss
val_correct += (torch.argmax(outputs, dim=-1) == labels).float().sum()
# Compute mean val accuracy
val_accuracy = 100 * val_correct / (len(valid_loader) * batch_size)
# Find average training loss and training accuracy
avg_train_loss = train_loss / (len(train_loader) * batch_size)
# Find average training loss and training accuracy
avg_valid_loss = valid_loss / (len(valid_loader) * batch_size)
history.append(
[
avg_train_loss,
avg_valid_loss,
train_accuracy.detach().cpu(),
val_accuracy.detach().cpu(),
]
)
epoch_end = time.time()
print(
"Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(
epoch + 1,
avg_train_loss,
train_accuracy.detach().cpu(),
avg_valid_loss,
val_accuracy.detach().cpu(),
epoch_end - epoch_start,
)
)
# Saving state for fine_tuning (because we may overfit)
if save_state:
os.makedirs("check_points", exist_ok=True)
torch.save(model.state_dict(), f"check_points/fine_tuning_{epoch + 1}.pth")
return model, history
Теперь обучим нашу модель:
num_epochs = 20
trained_model, history = train_and_validate(
model.to(device), criterion, optimizer, num_epochs
)
torch.save(history, "history_fresh.pt")
Epoch: 1/20 Epoch : 001, Training: Loss: 1.6833, Accuracy: 39.1544%, Validation : Loss : 2.4529, Accuracy: 8.7500%, Time: 20.9827s Epoch: 2/20 Epoch : 002, Training: Loss: 1.1396, Accuracy: 57.6287%, Validation : Loss : 3.7425, Accuracy: 8.7500%, Time: 12.9362s Epoch: 3/20 Epoch : 003, Training: Loss: 1.0403, Accuracy: 62.0404%, Validation : Loss : 0.9845, Accuracy: 46.8750%, Time: 12.9204s Epoch: 4/20 Epoch : 004, Training: Loss: 0.9564, Accuracy: 64.5680%, Validation : Loss : 0.9148, Accuracy: 50.0000%, Time: 12.9076s Epoch: 5/20 Epoch : 005, Training: Loss: 0.8677, Accuracy: 68.4743%, Validation : Loss : 0.9954, Accuracy: 48.4375%, Time: 13.1667s Epoch: 6/20 Epoch : 006, Training: Loss: 0.7869, Accuracy: 70.8640%, Validation : Loss : 0.7944, Accuracy: 56.2500%, Time: 13.1206s Epoch: 7/20 Epoch : 007, Training: Loss: 0.7766, Accuracy: 71.0478%, Validation : Loss : 0.8605, Accuracy: 55.0000%, Time: 13.2928s Epoch: 8/20 Epoch : 008, Training: Loss: 0.7428, Accuracy: 72.8401%, Validation : Loss : 0.6109, Accuracy: 63.1250%, Time: 13.5449s Epoch: 9/20 Epoch : 009, Training: Loss: 0.6979, Accuracy: 73.9890%, Validation : Loss : 0.7091, Accuracy: 58.1250%, Time: 13.0479s Epoch: 10/20 Epoch : 010, Training: Loss: 0.7157, Accuracy: 73.7132%, Validation : Loss : 0.6613, Accuracy: 58.7500%, Time: 13.0548s Epoch: 11/20 Epoch : 011, Training: Loss: 0.7081, Accuracy: 74.6324%, Validation : Loss : 0.6852, Accuracy: 58.1250%, Time: 13.1791s Epoch: 12/20 Epoch : 012, Training: Loss: 0.6620, Accuracy: 76.0110%, Validation : Loss : 0.6173, Accuracy: 62.1875%, Time: 13.1287s Epoch: 13/20 Epoch : 013, Training: Loss: 0.6407, Accuracy: 75.6893%, Validation : Loss : 0.5484, Accuracy: 65.0000%, Time: 13.2165s Epoch: 14/20 Epoch : 014, Training: Loss: 0.6198, Accuracy: 77.3897%, Validation : Loss : 0.5837, Accuracy: 62.8125%, Time: 13.1706s Epoch: 15/20 Epoch : 015, Training: Loss: 0.5599, Accuracy: 78.9522%, Validation : Loss : 0.5946, Accuracy: 62.5000%, Time: 13.2250s Epoch: 16/20 Epoch : 016, Training: Loss: 0.5354, Accuracy: 80.3768%, Validation : Loss : 0.5517, Accuracy: 63.7500%, Time: 13.2174s Epoch: 17/20 Epoch : 017, Training: Loss: 0.5385, Accuracy: 80.1011%, Validation : Loss : 0.5519, Accuracy: 65.3125%, Time: 13.3673s Epoch: 18/20 Epoch : 018, Training: Loss: 0.5141, Accuracy: 80.5607%, Validation : Loss : 0.5671, Accuracy: 65.3125%, Time: 13.4340s Epoch: 19/20 Epoch : 019, Training: Loss: 0.5267, Accuracy: 80.1011%, Validation : Loss : 0.6413, Accuracy: 61.8750%, Time: 13.2895s Epoch: 20/20 Epoch : 020, Training: Loss: 0.4901, Accuracy: 82.3070%, Validation : Loss : 0.6876, Accuracy: 60.3125%, Time: 13.3084s
Посмотрим на графики:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
fig.suptitle("Fresh learning", fontsize=14)
history = np.array(history)
ax[0].plot(history[:, :2])
ax[0].legend(["Train Loss", "Val Loss"])
ax[1].plot(history[:, 2:])
ax[1].legend(["Train Accuracy", "Val Accuracy"])
ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()
Точность на валидационной выборке не превысила 66%. Посмотрим, сможем ли мы добиться большей точности при использовании предобученной модели.
Теперь давайте попробуем использовать transfer learning.
Загрузим предобученную на ImageNet модель MobileNet v2:
del model
model = models.mobilenet_v2(weights="MobileNet_V2_Weights.DEFAULT")
Downloading: "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-7ebf99e0.pth 100%|██████████| 13.6M/13.6M [00:00<00:00, 116MB/s]
В данном случае мы не дообучаем скрытые слои нашей модели, поэтому отключаем подсчёт градиентов ("замораживаем" параметры).
# Freeze model parameters
for param in model.parameters():
param.requires_grad = False
Нам снова нужно изменить выход сети так, чтобы он выдавал 10 классов вместо 1000.
Мы могли бы изменить количество выходов сети, просто подменив последний линейный слой, как в примере обучения с нуля:
model.classifier[1] = nn.Linear(1280, num_classes)
Но нужно понимать, что мы не ограничены архитектурой готовой сети, и можем как подменять слои, так и добавлять новые. Поэтому в целях демонстрации мы заменим выходной слой исходной сети на два слоя: первый мы добавим "подменой" модуля, а затем добавим активацию и новый выходной слой с num_classes
выходами с помощью метода add_module()
класса Sequential
.
Когда мы подменяем или добавляем слои, по умолчанию подсчет градиентов на них будет включен, и таким образом мы добьемся, что учиться будут только новые слои, веса которых инициализируются случайно.
# Change the final layers of MobileNet Model for Transfer Learning
model.classifier[1] = nn.Linear(
1280, 500
) # replace last module to our custom, e.g. with 500 neurons
model.classifier.add_module("2", nn.ReLU()) # add activation
model.classifier.add_module(
"3", nn.Linear(500, num_classes)
) # add new output layer with 10 out classes
print(model.classifier)
Sequential( (0): Dropout(p=0.2, inplace=False) (1): Linear(in_features=1280, out_features=500, bias=True) (2): ReLU() (3): Linear(in_features=500, out_features=10, bias=True) )
# Define Optimizer and Loss Function
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
num_epochs = 20
trained_model, history = train_and_validate(
model.to(device), criterion, optimizer, num_epochs
)
torch.save(history, "history_transfer_learning.pt")
Epoch: 1/20 Epoch : 001, Training: Loss: 1.7921, Accuracy: 51.2868%, Validation : Loss : 1.3201, Accuracy: 52.1875%, Time: 8.7978s Epoch: 2/20 Epoch : 002, Training: Loss: 1.0062, Accuracy: 73.7592%, Validation : Loss : 0.7506, Accuracy: 67.1875%, Time: 8.5419s Epoch: 3/20 Epoch : 003, Training: Loss: 0.7081, Accuracy: 80.7445%, Validation : Loss : 0.5708, Accuracy: 70.0000%, Time: 9.4897s Epoch: 4/20 Epoch : 004, Training: Loss: 0.6006, Accuracy: 80.9743%, Validation : Loss : 0.4894, Accuracy: 72.5000%, Time: 9.0147s Epoch: 5/20 Epoch : 005, Training: Loss: 0.5416, Accuracy: 82.1691%, Validation : Loss : 0.4478, Accuracy: 72.8125%, Time: 8.1678s Epoch: 6/20 Epoch : 006, Training: Loss: 0.5156, Accuracy: 82.2610%, Validation : Loss : 0.3888, Accuracy: 74.6875%, Time: 8.9877s Epoch: 7/20 Epoch : 007, Training: Loss: 0.4730, Accuracy: 83.6397%, Validation : Loss : 0.3852, Accuracy: 72.8125%, Time: 8.5624s Epoch: 8/20 Epoch : 008, Training: Loss: 0.4319, Accuracy: 85.3401%, Validation : Loss : 0.3676, Accuracy: 74.0625%, Time: 8.3820s Epoch: 9/20 Epoch : 009, Training: Loss: 0.4202, Accuracy: 85.5239%, Validation : Loss : 0.3469, Accuracy: 73.4375%, Time: 8.8643s Epoch: 10/20 Epoch : 010, Training: Loss: 0.3945, Accuracy: 85.7996%, Validation : Loss : 0.3489, Accuracy: 73.1250%, Time: 8.3387s Epoch: 11/20 Epoch : 011, Training: Loss: 0.3992, Accuracy: 86.4430%, Validation : Loss : 0.3235, Accuracy: 75.3125%, Time: 8.6160s Epoch: 12/20 Epoch : 012, Training: Loss: 0.3893, Accuracy: 85.7077%, Validation : Loss : 0.3286, Accuracy: 75.6250%, Time: 9.0667s Epoch: 13/20 Epoch : 013, Training: Loss: 0.3859, Accuracy: 86.4890%, Validation : Loss : 0.3055, Accuracy: 75.9375%, Time: 8.0769s Epoch: 14/20 Epoch : 014, Training: Loss: 0.3672, Accuracy: 87.2243%, Validation : Loss : 0.3039, Accuracy: 74.6875%, Time: 8.7466s Epoch: 15/20 Epoch : 015, Training: Loss: 0.3547, Accuracy: 87.1783%, Validation : Loss : 0.2932, Accuracy: 76.2500%, Time: 8.9641s Epoch: 16/20 Epoch : 016, Training: Loss: 0.3558, Accuracy: 87.1783%, Validation : Loss : 0.2980, Accuracy: 75.3125%, Time: 8.0300s Epoch: 17/20 Epoch : 017, Training: Loss: 0.3423, Accuracy: 87.7298%, Validation : Loss : 0.2877, Accuracy: 75.6250%, Time: 8.9482s Epoch: 18/20 Epoch : 018, Training: Loss: 0.3026, Accuracy: 89.8438%, Validation : Loss : 0.2859, Accuracy: 76.5625%, Time: 8.5148s Epoch: 19/20 Epoch : 019, Training: Loss: 0.3376, Accuracy: 87.3162%, Validation : Loss : 0.2783, Accuracy: 76.2500%, Time: 7.8372s Epoch: 20/20 Epoch : 020, Training: Loss: 0.3159, Accuracy: 88.3732%, Validation : Loss : 0.2793, Accuracy: 77.1875%, Time: 8.7078s
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
fig.suptitle("Transfer learning", fontsize=14)
history = np.array(history)
ax[0].plot(history[:, :2])
ax[0].legend(["Train Loss", "Val Loss"])
ax[1].plot(history[:, 2:])
ax[1].legend(["Train Accuracy", "Val Accuracy"])
ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()
Сравним между собой обучение с нуля и обучение с предобученными весами.
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
fig.suptitle("Fresh Learning (FL) vs Transfer Learning (TL)", fontsize=14)
history_fresh = np.array(torch.load("history_fresh.pt"))
history_transfer_learning = np.array(torch.load("history_transfer_learning.pt"))
ax[0].plot(history_fresh[:, :2], linestyle="--")
ax[0].set_prop_cycle("color", ["tab:blue", "tab:orange"])
ax[0].plot(history_transfer_learning[:, :2])
ax[0].legend(["Train Loss (FL)", "Val Loss (FL)", "Train Loss (TL)", "Val Loss (TL)"])
ax[1].plot(history_fresh[:, 2:], linestyle="--")
ax[1].set_prop_cycle("color", ["tab:blue", "tab:orange"])
ax[1].plot(history_transfer_learning[:, 2:])
ax[1].legend(
[
"Train Accuracy (FL)",
"Val Accuracy (FL)",
"Train Accuracy (TL)",
"Val Accuracy (TL)",
]
)
ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()
При использовании предобученных весов процесс обучения идет более плавно и модель выдает бо́льшую точность. На валидационной выборке мы получили точность около 74%.
Посмотрим, сможем ли мы еще немного повысить точность путем тонкой донастройки всех весов сети.
Проведём процедуру fine-tuning. В предыдущем варианте с transfer learning обучался только последний слой, добавленный вручную. Давайте проверим это, выведя те слои, в которых включён градиент.
for name, param in model.named_parameters():
print(name, param.requires_grad)
features.0.0.weight False features.0.1.weight False features.0.1.bias False features.1.conv.0.0.weight False features.1.conv.0.1.weight False features.1.conv.0.1.bias False features.1.conv.1.weight False features.1.conv.2.weight False features.1.conv.2.bias False features.2.conv.0.0.weight False features.2.conv.0.1.weight False features.2.conv.0.1.bias False features.2.conv.1.0.weight False features.2.conv.1.1.weight False features.2.conv.1.1.bias False features.2.conv.2.weight False features.2.conv.3.weight False features.2.conv.3.bias False features.3.conv.0.0.weight False features.3.conv.0.1.weight False features.3.conv.0.1.bias False features.3.conv.1.0.weight False features.3.conv.1.1.weight False features.3.conv.1.1.bias False features.3.conv.2.weight False features.3.conv.3.weight False features.3.conv.3.bias False features.4.conv.0.0.weight False features.4.conv.0.1.weight False features.4.conv.0.1.bias False features.4.conv.1.0.weight False features.4.conv.1.1.weight False features.4.conv.1.1.bias False features.4.conv.2.weight False features.4.conv.3.weight False features.4.conv.3.bias False features.5.conv.0.0.weight False features.5.conv.0.1.weight False features.5.conv.0.1.bias False features.5.conv.1.0.weight False features.5.conv.1.1.weight False features.5.conv.1.1.bias False features.5.conv.2.weight False features.5.conv.3.weight False features.5.conv.3.bias False features.6.conv.0.0.weight False features.6.conv.0.1.weight False features.6.conv.0.1.bias False features.6.conv.1.0.weight False features.6.conv.1.1.weight False features.6.conv.1.1.bias False features.6.conv.2.weight False features.6.conv.3.weight False features.6.conv.3.bias False features.7.conv.0.0.weight False features.7.conv.0.1.weight False features.7.conv.0.1.bias False features.7.conv.1.0.weight False features.7.conv.1.1.weight False features.7.conv.1.1.bias False features.7.conv.2.weight False features.7.conv.3.weight False features.7.conv.3.bias False features.8.conv.0.0.weight False features.8.conv.0.1.weight False features.8.conv.0.1.bias False features.8.conv.1.0.weight False features.8.conv.1.1.weight False features.8.conv.1.1.bias False features.8.conv.2.weight False features.8.conv.3.weight False features.8.conv.3.bias False features.9.conv.0.0.weight False features.9.conv.0.1.weight False features.9.conv.0.1.bias False features.9.conv.1.0.weight False features.9.conv.1.1.weight False features.9.conv.1.1.bias False features.9.conv.2.weight False features.9.conv.3.weight False features.9.conv.3.bias False features.10.conv.0.0.weight False features.10.conv.0.1.weight False features.10.conv.0.1.bias False features.10.conv.1.0.weight False features.10.conv.1.1.weight False features.10.conv.1.1.bias False features.10.conv.2.weight False features.10.conv.3.weight False features.10.conv.3.bias False features.11.conv.0.0.weight False features.11.conv.0.1.weight False features.11.conv.0.1.bias False features.11.conv.1.0.weight False features.11.conv.1.1.weight False features.11.conv.1.1.bias False features.11.conv.2.weight False features.11.conv.3.weight False features.11.conv.3.bias False features.12.conv.0.0.weight False features.12.conv.0.1.weight False features.12.conv.0.1.bias False features.12.conv.1.0.weight False features.12.conv.1.1.weight False features.12.conv.1.1.bias False features.12.conv.2.weight False features.12.conv.3.weight False features.12.conv.3.bias False features.13.conv.0.0.weight False features.13.conv.0.1.weight False features.13.conv.0.1.bias False features.13.conv.1.0.weight False features.13.conv.1.1.weight False features.13.conv.1.1.bias False features.13.conv.2.weight False features.13.conv.3.weight False features.13.conv.3.bias False features.14.conv.0.0.weight False features.14.conv.0.1.weight False features.14.conv.0.1.bias False features.14.conv.1.0.weight False features.14.conv.1.1.weight False features.14.conv.1.1.bias False features.14.conv.2.weight False features.14.conv.3.weight False features.14.conv.3.bias False features.15.conv.0.0.weight False features.15.conv.0.1.weight False features.15.conv.0.1.bias False features.15.conv.1.0.weight False features.15.conv.1.1.weight False features.15.conv.1.1.bias False features.15.conv.2.weight False features.15.conv.3.weight False features.15.conv.3.bias False features.16.conv.0.0.weight False features.16.conv.0.1.weight False features.16.conv.0.1.bias False features.16.conv.1.0.weight False features.16.conv.1.1.weight False features.16.conv.1.1.bias False features.16.conv.2.weight False features.16.conv.3.weight False features.16.conv.3.bias False features.17.conv.0.0.weight False features.17.conv.0.1.weight False features.17.conv.0.1.bias False features.17.conv.1.0.weight False features.17.conv.1.1.weight False features.17.conv.1.1.bias False features.17.conv.2.weight False features.17.conv.3.weight False features.17.conv.3.bias False features.18.0.weight False features.18.1.weight False features.18.1.bias False classifier.1.weight True classifier.1.bias True classifier.3.weight True classifier.3.bias True
Мы оставим дообученную голову нейронной сети и продолжим обучение всей сети с уменьшением темпа обучения.
Разморозим параметры. criterion
остаётся тот же, в optimizer
уменьшим параметр lr
на порядок.
# Unfreeze model parameters
for param in model.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=3e-5)
Пройдём дополнительные 20 эпох и построим графики.
num_epochs = 20
trained_model, history = train_and_validate(
model.to(device), criterion, optimizer, num_epochs, save_state=True
)
torch.save(history, "history_finetuning.pt")
Epoch: 1/20 Epoch : 001, Training: Loss: 0.2955, Accuracy: 90.1195%, Validation : Loss : 0.2437, Accuracy: 76.5625%, Time: 13.6207s Epoch: 2/20 Epoch : 002, Training: Loss: 0.2402, Accuracy: 91.8199%, Validation : Loss : 0.2150, Accuracy: 77.5000%, Time: 13.3512s Epoch: 3/20 Epoch : 003, Training: Loss: 0.2102, Accuracy: 92.2794%, Validation : Loss : 0.2052, Accuracy: 77.1875%, Time: 13.9947s Epoch: 4/20 Epoch : 004, Training: Loss: 0.2057, Accuracy: 92.0496%, Validation : Loss : 0.1850, Accuracy: 78.4375%, Time: 13.3778s Epoch: 5/20 Epoch : 005, Training: Loss: 0.2118, Accuracy: 92.5092%, Validation : Loss : 0.1663, Accuracy: 80.0000%, Time: 13.4142s Epoch: 6/20 Epoch : 006, Training: Loss: 0.1611, Accuracy: 93.9338%, Validation : Loss : 0.1563, Accuracy: 79.3750%, Time: 13.5197s Epoch: 7/20 Epoch : 007, Training: Loss: 0.1707, Accuracy: 93.1526%, Validation : Loss : 0.1380, Accuracy: 80.0000%, Time: 13.3344s Epoch: 8/20 Epoch : 008, Training: Loss: 0.1602, Accuracy: 93.8419%, Validation : Loss : 0.1335, Accuracy: 80.6250%, Time: 13.3374s Epoch: 9/20 Epoch : 009, Training: Loss: 0.1428, Accuracy: 94.2096%, Validation : Loss : 0.1271, Accuracy: 80.6250%, Time: 13.4488s Epoch: 10/20 Epoch : 010, Training: Loss: 0.1313, Accuracy: 94.7151%, Validation : Loss : 0.1210, Accuracy: 80.9375%, Time: 13.4156s Epoch: 11/20 Epoch : 011, Training: Loss: 0.1328, Accuracy: 94.5772%, Validation : Loss : 0.1122, Accuracy: 81.2500%, Time: 13.4008s Epoch: 12/20 Epoch : 012, Training: Loss: 0.1143, Accuracy: 95.4504%, Validation : Loss : 0.1081, Accuracy: 80.9375%, Time: 13.4811s Epoch: 13/20 Epoch : 013, Training: Loss: 0.1203, Accuracy: 95.0368%, Validation : Loss : 0.1093, Accuracy: 80.3125%, Time: 13.3948s Epoch: 14/20 Epoch : 014, Training: Loss: 0.1087, Accuracy: 95.4963%, Validation : Loss : 0.1054, Accuracy: 81.2500%, Time: 13.3669s Epoch: 15/20 Epoch : 015, Training: Loss: 0.0959, Accuracy: 95.5882%, Validation : Loss : 0.1016, Accuracy: 80.6250%, Time: 13.4678s Epoch: 16/20 Epoch : 016, Training: Loss: 0.0965, Accuracy: 96.2776%, Validation : Loss : 0.0960, Accuracy: 81.2500%, Time: 13.4115s Epoch: 17/20 Epoch : 017, Training: Loss: 0.1007, Accuracy: 95.5882%, Validation : Loss : 0.0966, Accuracy: 81.5625%, Time: 14.0048s Epoch: 18/20 Epoch : 018, Training: Loss: 0.0932, Accuracy: 95.8640%, Validation : Loss : 0.0927, Accuracy: 81.2500%, Time: 13.2979s Epoch: 19/20 Epoch : 019, Training: Loss: 0.0806, Accuracy: 96.8290%, Validation : Loss : 0.0928, Accuracy: 81.5625%, Time: 13.4152s Epoch: 20/20 Epoch : 020, Training: Loss: 0.0736, Accuracy: 96.6452%, Validation : Loss : 0.0879, Accuracy: 82.1875%, Time: 13.2985s
fig, ax = plt.subplots(ncols=2, figsize=(16, 5))
fig.suptitle("Transfer Learning (TL) AND Finetuning (FT)", fontsize=14)
history_transfer_learning = np.array(torch.load("history_transfer_learning.pt"))
history_finetuning = np.array(torch.load("history_finetuning.pt"))
train_val_loss = np.concatenate(
(history_transfer_learning[:, :2], history_finetuning[:, :2]), axis=0
)
ax[0].plot(train_val_loss)
ax[0].vlines(19, -0.1, 2.1, color="tab:green", linewidth=2, linestyle="--")
ax[0].legend(["Train Loss", "Val Loss", "TL/FT boundary"])
train_val_acc = np.concatenate(
(history_transfer_learning[:, 2:], history_finetuning[:, 2:]), axis=0
)
ax[1].plot(train_val_acc)
ax[1].vlines(19, -5, 105, color="tab:green", linewidth=2, linestyle="--")
ax[1].legend(["Train Accuracy", "Val Accuracy", "TL/FT boundary"])
ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()
Есть ли эффект от fine-tuning? После дообучения ещё на 20 эпохах мы наблюдаем следующие эффекты:
При fine-tuning модель может быть склонна к переобучению, так как мы обучаем сложную модель с большим числом параметров на небольшом количестве данных. Поэтому мы используем learning rate на порядок меньший, чем при обычном обучении. Для контроля переобучения следует следить за метриками и ошибкой на валидационной выборке.
Лучшее качество на валидационных данных мы получили на 38 эпохе. При fine-tuning мы сохраняли состояния нейросети на каждой эпохе. Возьмём состояние с 38 эпохи как наиболее оптимальное.
trained_model.load_state_dict(
torch.load("check_points/fine_tuning_18.pth")
) # 38 = 20 (TL) + 18 (FT)
trained_model.eval();
Посмотрим на предсказания
def predict(model, test_img_name, device):
"""
Function to predict the class of a single test image
Parameters
:param model: Model to test
:param test_img_name: Test image
"""
transform = img_transforms["test"]
test_img = torch.tensor(np.asarray(test_img_name))
test_img = transforms.ToPILImage()(test_img)
plt.imshow(test_img)
test_img_tensor = test_img_name.unsqueeze(0).to(device)
with torch.no_grad():
model.eval()
# Model outputs is logits
out = model(test_img_tensor).to(device)
probs = torch.softmax(out, dim=1).to(device)
topk, topclass = probs.topk(3, dim=1)
for i in range(3):
print(
"Predcition",
i + 1,
":",
idx_to_class[topclass.cpu().numpy()[0][i]],
", Score: ",
round(topk.cpu().numpy()[0][i], 2),
)
print("Shoud be %s\n" % idx_to_class[0])
predict(
trained_model.to(device),
test_set[np.where([x[1] == 0 for x in test_set])[0][0]][0],
device,
)
Shoud be AnnualCrop Predcition 1 : AnnualCrop , Score: 1.0 Predcition 2 : PermanentCrop , Score: 0.0 Predcition 3 : SeaLake , Score: 0.0
print("Shoud be %s\n" % idx_to_class[6])
predict(
trained_model,
test_set[np.where([x[1] == 6 for x in test_set])[0][0]][0],
device,
)
Shoud be PermanentCrop Predcition 1 : PermanentCrop , Score: 0.58 Predcition 2 : AnnualCrop , Score: 0.42 Predcition 3 : SeaLake , Score: 0.0
print("Shoud be %s\n" % idx_to_class[8])
predict(
trained_model,
test_set[np.where([x[1] == 8 for x in test_set])[0][0]][0],
device,
)
Shoud be River Predcition 1 : River , Score: 1.0 Predcition 2 : Highway , Score: 0.0 Predcition 3 : AnnualCrop , Score: 0.0
Мы увидели, как использовать предварительно обученную модель на 1000 классов ImageNet для нашей задачи на 10 классов.
Мы сравнили качество обучения с нуля, transfer learning и fine-tuning и научились добиваться максимального качества с помощью этих принципов.
На практике не забывайте о характерной опасности fine-tuning — переобучении. Используйте низкий learning rate и отслеживайте Loss и показатели качества — возможно, вам будет достаточно небольшого количества эпох.
Существуют задачи, где не представляется возможным разбить данные на классы так, чтобы в каждом классе было достаточно много объектов.
Рассмотрим, например, задачу распознавания лиц.
На вход системе подается фото лица человека. Требуется сопоставить его с другим изображением или изображениями, например, хранящимися в БД, и таким образом идентифицировать человека на фотографии.
На первый взгляд кажется, что это задача классификации.
Все изображения одного человека будем считать относящимися к одному классу, и модель будет этот класс предсказывать.
Для небольшой организации, в которой всего несколько десятков сотрудников такой подход может сработать. При этом возникнут проблемы:
Чтобы обучить такую систему, нам сначала потребуется много (сотни) разных изображений каждого сотрудника.
Когда человек присоединяется к организации или покидает ее, приходится менять структуру модели и обучать ее заново.
Это практически невозможно для крупных организаций, где набор и увольнение происходит почти каждую неделю. И в принципе невозможно для города масштаба Москвы или Лондона, в котором миллионы жителей и сотни тысяч приезжих.
Поэтому используется другой подход. Вместо того, чтобы классифицировать изображения, модель учится выделять ключевые признаки и на их основе строить компактный вектор, достаточно точно описывающий лицо.
В англоязычной литературе такие вектора признаков называются embedding, и мы тоже будем использовать это обозначение.
Может возникнуть вопрос: не потеряем ли мы важную информацию, сжав изображение в несколько сотен чисел?
Чтобы ответить на него, вспомним, как работает фоторобот.
Для получения фотореалистичного изображения лица достаточно нескольких ключевых признаков: глаза, волосы, рот, нос... Каждый из них кодируется максимум несколькими сотнями целых значений.
Значит, вектора-признака из 128 вещественных чисел будет более чем достаточно. Правда интерпретировать значения, которые закодирует в него нейросеть, будет не столь просто.
Если нам удастся обучить модель кодировать в embedding признаки, важные для сравнения, то мы сможем сравнивать векторы между собой.
Если расстояние между векторами для лиц, которые похожи друг на друга, будут маленькими, а у непохожих, наоборот, большими, то мы сможем экспериментально подобрать порог $d$ и, сравнивая с ним расстояние между двумя векторами, принимать решение: принадлежат ли они одному человеку или нет.
Можно оценивать не расстояние, а степень схожести similarity. В этом случае неравенства поменяют знак, но логика останется прежней
Теперь, чтобы идентифицировать человека, требуется только одно изображение его лица. Эмбеддинг этого изображения можно сравнить с эмбеддингами других лиц из БД, используя k-NN или иной метод кластеризации.
Такая модель не учится классифицировать изображение напрямую по какому-либо из выходных классов. Она учится выделять признаки, важные при сравнении.
Такой подход решает обе проблемы, о которых мы говорили выше:
Какая архитектура должна быть у модели, генерирующей векторы признаков?
Можно было бы использовать обычную сеть, обученную для задачи классификации, и затем удалить из нее один или несколько последних слоёв.
Активации последнего слоя представляют собой отклики на некие высокоуровневые признаки, потенциально важные для классификации, и их можно интерпретировать как embedding.
from torchvision.models import alexnet
import torch
face1 = torch.randn((3, 224, 224))
face2 = torch.randn((3, 224, 224))
model = alexnet(weights="AlexNet_Weights.DEFAULT")
# remove classification layer
model.fc = model.classifier[6] = torch.nn.Identity()
# get embeddings
embedding1 = model(face1.unsqueeze(0))
embedding2 = model(face2.unsqueeze(0))
diff = torch.nn.functional.pairwise_distance(embedding1, embedding2)
print("L2 distance: ", diff.item())
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth 100%|██████████| 233M/233M [00:01<00:00, 186MB/s]
L2 distance: 30.66541862487793
Такой подход будет работать. Однако можно заметно улучшить точность, используя функцию потерь, которая оценивает именно качество сравнения, а не классификации.
Рассмотрим подход, основанный на методологии, описанной в статье Siamese Neural Networks for One-shot Image Recognition (Koch et al., 2015).
Два входных изображения ($x_1$ и $x_2$) проходят через одну и ту же сверточную сеть, на выходе для каждого изображения генерируется вектор признаков фиксированной длины $h_1$ и $h_2$.
Модель обучается генерировать близкие вектора для изображений одного объекта и далекие для разных.
Оценивая расстояние между двумя векторами признаков, которое будет малым для одних и тех же объектов и большим для различных, мы сможем оценить их сходство.
Это центральная идея сиамских сетей.
Какую функцию потерь использовать для обучения такой сети?
Очевидно, loss function должна будет учитывать не один выход, а как минимум два.
Популярной на сегодняшний день является Triplet loss
, которой требуется три embedding вместо двух.
Чтобы сгенерировать три эмбеддинга, модель должна получать на вход три изображения.
Первые два должны относиться к одному и тому же объекту (человеку), а третье — к другому.
Таким образом, триплет состоит из опорного ("якорного" anchor
), положительного (positive
) и отрицательного (negative
) образцов.
Описание в статье FaceNet: A Unified Embedding for Face Recognition and Clustering
Сама функция потерь будет выглядеть следующим образом:
$$TripletLoss = \sum_{1}^{N} L_i(x_i^{a},x_i^{p},x_i^{n})$$$$L_i(x_i^{a},x_i^{p},x_i^{n})=max(0,\left\| f(x_i^{a}) -f(x_i^{p}) \right\|_2^{2} - \left\| f(x_i^{a}) -f(x_i^{n}) \right\|_2^{2} + margin)$$Где:
$x_i^{a}$ — базовое изображение (anchor),
$x_i^{p}$ — изображение того же объекта (positive),
$x_i^{n}$ — изображение другого объекта (negative),
$f(x)$ — нормированный выход модели (embedding) для входа $x$,
$\left\| x \right\|_2$ — это L2 (Euclidean norm), соответственно $\left\| a \right\|_2^{2}$ — это L2 в квадрате,
$margin$ — это константа или минимальный "зазор", на который расстояние до эмбеддинга негативного объекта обязано превосходить расстояние до позитивного (идея такая же, что в SVM Loss) .
В ходе обучения с Triplet Loss расстояние между эмбеддингами опорного и позитивного объектов уменьшается, а между эмбеддингами опорного и отрицательного — увеличивается.
Важным дополнением является то, что embedding-и нормируются. В результате нормировки каждый вектор-признак будет иметь единичную длину.
Теперь мы можем рассматривать embedding-и как точки на $n$-мерной сфере с радиусом $1$.
Это удобно, так как все расстояния между embedding будут лежать в интервале $[0 \dots 2]$, и нам будет проще подобрать порог для сравнения.
Кроме того, можно использовать другие меры расстояния, например, косинусное расстояние, которое определяется углом между векторами, лежит в интервале $[-1 \dots 1]$ и соответствует расстоянию между точками на поверхности сферы.
В статье авторы минимизируют Евклидово расстояние, но подход будет работать и для других метрик сходства, например, косинусного расстояния.
В PyTorch есть две реализации TripletLoss
TripletMarginLoss — минимизирует $L_p$ норму
from torch import nn
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
loss = triplet_loss(anchor, positive, negative)
print(loss)
tensor(1.0878, grad_fn=<MeanBackward0>)
TripletMarginWithDistanceLoss — позволяет задать произвольную функцию расстояния.
import torch.nn.functional as F
triplet_loss = nn.TripletMarginWithDistanceLoss(
margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
)
loss = triplet_loss(anchor, positive, negative)
print(loss)
tensor(0.9909, grad_fn=<MeanBackward0>)
Другие функции потерь для сиамских сетей:
Исторически первой появилась Contrastive Loss
, о ней подробнее в статье Dimensionality Reduction by Learning an Invariant Mapping (Hadsell et al., 2005)
В PyTorch есть реализация CosineEmbeddingLoss, она позволяет обучать модель на парах изображений, минимизировав косинусное расстояние между embedding.
Загрузим небольшой фрагмент датасета с лицами. Внутри архива фото лиц сгруппированы по папкам
faces/
├── training/
| ├── s1/
| | ├── 1.pgm
| | ├ ...
| | └── 9.pgm
| ├ ... (excluding 5...7)
| └── s40/
| ├── 1.pgm
| ├ ...
| └── 9.pgm
└── testing/
├── s5/
| ├── 1.pgm
| ├ ...
| └── 9.pgm
├ ...
└── s7/
├── 1.pgm
├ ...
└── 9.pgm
В каждой папке фото лица одного и того же человека.
!wget -qN https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/small_face_dataset.zip
!unzip -qn small_face_dataset.zip
Чтобы результаты воспроизводились, зафиксируем SEED
import numpy as np
import random
def set_random_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
set_random_seed(42)
Для TripletLoss потребуются три изображения: anchor, positive, negative, и метод get_item должен возвращать их нам. Первые два должны принадлежать одному человеку, а третье — другому.
from torch.utils.data import Dataset
from glob import glob
from PIL import Image
class SiameseNetworkDataset(Dataset):
def __init__(self, dir=None, transform=None, splitter="/"):
self.dir = dir
self.splitter = splitter
self.transform = transform
self.files = glob(f"{self.dir}/**/*.pgm", recursive=True)
self.data = self.build_index()
def build_index(self):
index = {}
for f in self.files:
id = self.path2id(f)
if not id in index:
index[id] = []
index[id].append(f)
return index
def path2id(self, path):
return path.replace(self.dir, "").split(self.splitter)[0]
def __getitem__(self, index):
anchor_path = self.files[index]
positive_path = self.find_positive(anchor_path)
negative_path = self.find_negative(anchor_path)
# Loading the images
anchor = Image.open(anchor_path)
positive = Image.open(positive_path)
negative = Image.open(negative_path)
if self.transform is not None: # Apply image transformations
anchor = self.transform(anchor)
positive = self.transform(positive)
negative = self.transform(negative)
return anchor, positive, negative
def find_positive(self, path):
id = self.path2id(path)
all_exept_my = self.data[id].copy()
all_exept_my.remove(path)
return random.choice(all_exept_my)
def find_negative(self, path):
all_exept_my_ids = list(self.data.keys())
id = self.path2id(path)
all_exept_my_ids.remove(id)
selected_id = random.choice(all_exept_my_ids)
return random.choice(self.data[selected_id])
def __len__(self):
return len(self.files)
Выведем несколько изображений, чтобы убедиться, что класс датасета функционирует должным образом.
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
# Create dataset instance
siamese_dataset = SiameseNetworkDataset(
"faces/training/",
transform=transforms.Compose(
[
transforms.Resize((105, 105)),
transforms.ToTensor(),
]
),
)
# Create dataloader & extract batch of data from it
vis_dataloader = DataLoader(siamese_dataset, batch_size=8, shuffle=True)
dataiter = iter(vis_dataloader)
example_batch = next(dataiter) # anc, pos, neg
# Show batch contents
concatenated = torch.cat((example_batch[0], example_batch[1], example_batch[2]), 0)
grid = torchvision.utils.make_grid(concatenated)
plt.axis("off")
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.gcf().set_size_inches(20, 60)
plt.show()
В каждом столбце тройка изображений. Первое и второе принадлежат одному человеку, третье — другому.
Нас устроит любая модель для работы с изображениями. Например, ResNet18.
Все, что от нас требуется, это:
Пожалуй, единственный вопрос — это размерность последнего слоя. В промышленных системах распознавания лиц, которые тренируются на датасетах из миллионов изображений, используются embedding размерностью от 128 до 512.
Для демонстрационной задачи нам должно хватить 32 значений. Количество выходов последнего линейного слоя установим равным 32.
from torchvision.models import resnet18
class SiameseNet(nn.Module):
def __init__(self, latent_dim):
super().__init__()
self.model = resnet18(weights=None)
# Because we use grayscale images reduce input channel count to one
self.model.conv1 = nn.Conv2d(
1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
)
# Replace ImageNet 1000 class classifier with 64- out linear layer
self.model.fc = nn.Linear(self.model.fc.in_features, latent_dim)
def _forward(self, x):
out = self.model(x)
# normalize embedding to unit vector
out = torch.nn.functional.normalize(out)
return out
def forward(self, anchor, positive, negative):
output1 = self._forward(anchor)
output2 = self._forward(positive)
output3 = self._forward(negative)
return output1, output2, output3
Загрузчики данных не отличаются от загрузчиков для обычной сети. Единственное отличие — это добавление аугментации в виде случайного отражения по вертикали к обучающим данным.
# Apply augmentations on train data
img_trans_train = transforms.Compose(
[
transforms.Resize((105, 105)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
img_trans_test = transforms.Compose(
[transforms.Resize((105, 105)), transforms.ToTensor()]
)
train_dataset = SiameseNetworkDataset("faces/training/", transform=img_trans_train)
val_dataset = SiameseNetworkDataset("faces/testing/", transform=img_trans_test)
batch_size = 300
train_loader = DataLoader(
train_dataset, num_workers=2, batch_size=batch_size, shuffle=True
)
val_loader = DataLoader(val_dataset, num_workers=2, batch_size=1, shuffle=False)
Отличие от сетей для классификации в том, что у модели 3 выхода, и все их надо передать в loss. При этом нет меток в явном виде. Определить, какой embedding относится к позитивному образцу, а какой — к негативному, можно только порядком их следования.
def train(num_epochs, model, criterion, optimizer, train_loader):
loss_history = []
model.train()
for epoch in range(0, num_epochs):
train_loss = 0
for i, batch in enumerate(train_loader, 0):
anc, pos, neg = batch
output_anc, output_pos, output_neg = model(
anc.to(device), pos.to(device), neg.to(device)
)
loss = criterion(output_anc, output_pos, output_neg)
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss += loss.detach().cpu().item()
loss_history.append(train_loss / len(train_loader))
last_epoch_loss = torch.tensor(loss_history[-1])
print("Epoch {} with {:.4f} loss".format(epoch, last_epoch_loss))
return loss_history, last_epoch_loss
В качестве функции расстояния в TripletLoss возьмем косинусное расстояние (величина, обратная к косинусной близости).
import torch.optim as optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
latent_dim = 32
model = SiameseNet(latent_dim).to(device)
criterion = nn.TripletMarginWithDistanceLoss(
margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 9
loss_history, _ = train(num_epochs, model, criterion, optimizer, train_loader)
Epoch 0 with 0.8090 loss Epoch 1 with 0.4584 loss Epoch 2 with 0.4083 loss Epoch 3 with 0.3571 loss Epoch 4 with 0.3335 loss Epoch 5 with 0.3602 loss Epoch 6 with 0.3168 loss Epoch 7 with 0.2326 loss Epoch 8 with 0.2895 loss
Выведем график loss
plt.plot(range(1, len(loss_history) + 1), loss_history)
plt.ylabel("loss")
plt.xlabel("num of epochs")
plt.grid()
plt.show()
Для начала выведем тройки изображений из проверочного датасета и посмотрим на косинусную близость (схожесть) для позитивных и негативных пар. Если модель обучилась, схожесть для позитивных пар будет больше, чем для негативных.
# Helper method for visualization
def show(img, text=None):
img_np = img.numpy()
plt.axis("off")
plt.text(75, 120, text, fontweight="bold")
plt.imshow(np.transpose(img_np, (1, 2, 0))) # [CxHxW] -> [HxWxC] for imshow
plt.show()
def plot_imgs(model, test_loader):
similarity_pos = []
similarity_neg = []
model.eval()
with torch.inference_mode():
for i, batch in enumerate(test_loader, 0):
anc, pos, neg = batch
output_anc, output_pos, output_neg = model(
anc.to(device), pos.to(device), neg.to(device)
)
# compute euc. distance
sim_pos = F.cosine_similarity(output_anc, output_pos).item()
sim_neg = F.cosine_similarity(output_anc, output_neg).item()
similarity_pos.append(sim_pos)
similarity_neg.append(sim_neg)
if not i % 5:
concatenated = torch.cat((anc, pos, neg))
result = "OK" if sim_neg < sim_pos else "BAD"
show(
torchvision.utils.make_grid(concatenated),
f"Positive / negative similarities: {sim_pos:.3f} / {sim_neg:.3f} - {result}",
)
return similarity_pos, similarity_neg
set_random_seed(42)
similarity_pos, similarity_neg = plot_imgs(model, val_loader)
Но такая оценка субъективна, давайте посмотрим на распределение схожестей по категориям:
import seaborn as sns
similarities = {"The same person": similarity_pos, "Another person": similarity_neg}
ax = sns.histplot(similarities, bins=20)
ax.set(xlabel="Pairwise similarity")
plt.show()
Видно, что схожесть между двумя фото одного и того же человека в среднем больше, чем схожесть между фотографиями разных людей.
Если бы мы проектировали систему распознавания лиц, нужно было бы выбрать порог, чтобы сравнивать с ним схожесть и принимать решение о том, верифицировать фото как подлинное или нет.
Соответственно, для нашего игрушечного датасета такой порог следует выбирать $≈0.75$. При этом мы будем иметь некоторое количество ошибок и ложно распознавать постороннего человека.
Часто, когда мы пишем и обучаем сети (будь то с нуля или с помощью transfer learning), мы вынуждены угадывать гиперпараметры (lr, betas и т.д). В случае с learning rate нам есть от чего оттолкнуться (маленький lr для transfer learning), но все же такой подход не кажется оптимальным.
Для оптимизации гиперпараметров существуют готовые решения, которые используют различные методы black-box оптимизации. Разберем одну из наиболее популярных библиотек — Optuna.
!pip install --quiet optuna
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 390.6/390.6 kB 13.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.5/224.5 kB 22.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.7/78.7 kB 11.5 MB/s eta 0:00:00
Давайте оптимизируем learning rate
и latent_dim
. Для того, чтобы код не выполнялся очень долго, укажем небольшой диапазон параметров и небольшое число эпох.
import optuna
from optuna.samplers import RandomSampler
# define function which will optimized
def objective(trial):
# boundaries for the optimizer's
lr = trial.suggest_float("lr", 1e-4, 1e-2)
latent_dim = trial.suggest_int("latent_dim", 8, 64, step=8)
# create new model(and all parameters) every iteration
model = SiameseNet(latent_dim).to(device) # latent_dim regulates by optuna
criterion = nn.TripletMarginWithDistanceLoss(
margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
)
optimizer = optim.Adam(
model.parameters(), lr=lr
) # learning rate regulates by optuna
# To save time, we will take only 3 epochs
train_loader = DataLoader(
train_dataset, num_workers=2, batch_size=batch_size, shuffle=True
)
_, last_epoch_loss = train(3, model, criterion, optimizer, train_loader)
return last_epoch_loss
# Create "exploration"
study = optuna.create_study(
direction="minimize", study_name="Optimizer", sampler=RandomSampler(42)
)
study.optimize(
objective, n_trials=10
) # The more iterations, the higher the chances of catching the most optimal hyperparameters
[I 2023-07-18 20:44:47,744] A new study created in memory with name: Optimizer
Epoch 0 with 0.8298 loss Epoch 1 with 0.5833 loss
[I 2023-07-18 20:44:53,687] Trial 0 finished with value: 0.5804647207260132 and parameters: {'lr': 0.003807947176588889, 'latent_dim': 64}. Best is trial 0 with value: 0.5804647207260132.
Epoch 2 with 0.5805 loss Epoch 0 with 0.8369 loss Epoch 1 with 0.6785 loss
[I 2023-07-18 20:45:00,230] Trial 1 finished with value: 0.7354487180709839 and parameters: {'lr': 0.007346740023932911, 'latent_dim': 40}. Best is trial 0 with value: 0.5804647207260132.
Epoch 2 with 0.7354 loss Epoch 0 with 0.8462 loss Epoch 1 with 0.4905 loss
[I 2023-07-18 20:45:06,189] Trial 2 finished with value: 0.36206525564193726 and parameters: {'lr': 0.0016445845403801217, 'latent_dim': 16}. Best is trial 2 with value: 0.36206525564193726.
Epoch 2 with 0.3621 loss Epoch 0 with 0.7979 loss Epoch 1 with 0.5225 loss
[I 2023-07-18 20:45:12,623] Trial 3 finished with value: 0.3257625102996826 and parameters: {'lr': 0.0006750277604651748, 'latent_dim': 56}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.3258 loss Epoch 0 with 0.8698 loss Epoch 1 with 0.6307 loss
[I 2023-07-18 20:45:18,821] Trial 4 finished with value: 0.5650991201400757 and parameters: {'lr': 0.006051038616257768, 'latent_dim': 48}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.5651 loss Epoch 0 with 0.9011 loss Epoch 1 with 0.5483 loss
[I 2023-07-18 20:45:24,987] Trial 5 finished with value: 0.48651862144470215 and parameters: {'lr': 0.00030378649352844425, 'latent_dim': 64}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.4865 loss Epoch 0 with 0.8338 loss Epoch 1 with 0.7046 loss
[I 2023-07-18 20:45:31,435] Trial 6 finished with value: 0.6294255256652832 and parameters: {'lr': 0.008341182143924175, 'latent_dim': 16}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.6294 loss Epoch 0 with 0.7068 loss Epoch 1 with 0.5217 loss
[I 2023-07-18 20:45:37,403] Trial 7 finished with value: 0.44763171672821045 and parameters: {'lr': 0.0019000671753502962, 'latent_dim': 16}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.4476 loss Epoch 0 with 0.8477 loss Epoch 1 with 0.5504 loss
[I 2023-07-18 20:45:44,057] Trial 8 finished with value: 0.39682716131210327 and parameters: {'lr': 0.0031119982052994237, 'latent_dim': 40}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.3968 loss Epoch 0 with 0.8262 loss Epoch 1 with 0.7000 loss
[I 2023-07-18 20:45:49,978] Trial 9 finished with value: 0.6107327938079834 and parameters: {'lr': 0.004376255684556947, 'latent_dim': 24}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.6107 loss
Как видите, такой упрощенный подбор даже двух гиперпараметров занимает много времени.
# show best params
study.best_params
{'lr': 0.0006750277604651748, 'latent_dim': 56}
Давайте посмотрим на историю оптимизации наших параметров:
optuna.visualization.plot_optimization_history(study)
Что ж, проверим, станет ли реально лучше. Обучим сеть с лучшими latent_dim
и lr
, определенными с помощью Optuna.
set_random_seed(42)
model = SiameseNet(study.best_params["latent_dim"]).to(
device
) # take latent_dim, which choosen by Optuna
criterion = nn.TripletMarginWithDistanceLoss(
margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
)
optimizer = optim.Adam(
model.parameters(), lr=study.best_params["lr"]
) # take lr, which choosen by Optuna
num_epochs = 9
train_loader = DataLoader(
train_dataset, num_workers=2, batch_size=batch_size, shuffle=True
)
l_optim, _ = train(num_epochs, model, criterion, optimizer, train_loader)
Epoch 0 with 0.7609 loss Epoch 1 with 0.5261 loss Epoch 2 with 0.4040 loss Epoch 3 with 0.3303 loss Epoch 4 with 0.3403 loss Epoch 5 with 0.3373 loss Epoch 6 with 0.2967 loss Epoch 7 with 0.2273 loss Epoch 8 with 0.2934 loss
plt.plot(range(1, len(loss_history) + 1), loss_history, label="no optimization")
plt.plot(range(1, len(l_optim) + 1), l_optim, label="optimal params")
plt.ylabel("loss")
plt.xlabel("num of epochs")
plt.grid()
plt.legend()
plt.show()
set_random_seed(42)
similarity_pos, similarity_neg = plot_imgs(model, val_loader)
similarities_optim = {
"The same person": similarity_pos,
"Another person": similarity_neg,
}
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
sns.histplot(similarities, bins=20, alpha=0.5, ax=axes[0])
sns.histplot(similarities_optim, bins=20, alpha=0.5, ax=axes[1])
axes[0].set(title="No optimization")
axes[1].set(title="Optimization with Optuna")
axes[0].set(xlabel="Pairwise similarity")
axes[1].set(xlabel="Pairwise similarity")
plt.show()
К сожалению, нельзя сказать, что после оптимизации гиперпараметров распределение схожестей стало лучше. При проведении порога по схожести количество ошибок первого и второго рода будет больше, чем до оптимизации.
Это может объясняться тем, что в этой задаче 10 итераций для подбора гиперпараметров все-таки маловато. Зато мы познакомились с API Optuna.
Заключение
Мы затронули проблемы, которые возникают при обучении на реальных данных.
Одна из основных проблем — малые датасеты. Для того, чтобы обучить нейронную сеть на небольшом датасете, можно использовать:
Однако необходимо помнить, что ни один из этих методов не защитит от ситуации, когда реальные данные будут сильно отличаться от тренировочных.
В случае когда у нас не только мало данных, но еще и очень большое (возможно, неизвестное) число классов, можно воспользоваться Metric Learning. В этом случае нейронная сеть обучается не классифицировать изображения, а, наоборот, находить различия между классом и новыми данными. Для этого используются нейронные сети, относящиеся к классу сиамских нейронных сетей.
Также мы разобрали автоматическую оптимизацию гиперпараметров с помощью API Optuna и показали, как пользоваться этим инструментом для оптимизации гиперпараметров модели и процедуры обучения.
Литература
Обучение на реальных данных
How to avoid machine learning pitfalls: a guide for academic researchers (Lones, 2021)
Understanding data augmentation for classification: when to warp? (Wong et al., 2016)
Learning from class-imbalanced data: Review of methods and applications (Haixiang et al., 2017)
Как решить проблему маленького количества данных?
Dealing with very small datasets
Несбалансированные данные
Kaggle Tutorual: Tackling Class imbalance
SMOTE explained for noobs - Synthetic Minority Over-sampling TEchnique line by line
Блог пост про 8 тактик борьбы с несбалансированными классами в наборе данных машинного обучения
Метрики, разработаные для работы с несбалансированными классами.
Transfer Learning
Image Classification using Transfer Learning in PyTorch
How To Do Transfer Learning For Computer Vision | PyTorch Tutorial
Transfer learning for Computer Vision Tutorial
Python Pytorch Tutorials # 2 Transfer Learning : Inference with ImageNet Models
PyTorch - The Basics of Transfer Learning with TorchVision and AlexNet
Augmentation
A survey on Image Data Augmentation for Deep Learning (Shorten and Khoshgoftaar, 2019)
Data augmentation for improving deep learning in image classification problem
Few-shot learning
FaceNet: A Unified Embedding for Face Recognition and Clustering (Schroff et al., 2015)
One-Shot Learning for Face Recognition
Siamese Neural Networks for One-shot Image Recognition (Koch et al., 2015)
Dimensionality Reduction by Learning an Invariant Mapping (Hadsell et al., 2005)
Language Models are Few-Shot Learners (Brown et al., 2020)
Hyperparameter optimization
Optuna tutorial for hyperparameter optimization
Optuna: Get the Best out of your Hyperparameters – Easy Tutorial