Обучение на реальных данных

Проблемы при работе с реальной задачей машинного обучения¶

В реальных задачах, особенно если вы работаете над новой проблемой, вы столкнетесь с широким спектром трудностей. Приведем часть из них и сопроводим примером на основе задач нахождения клеток крови на фотографии мазка крови.

Злокачественные клетки в мазке периферической крови

Source: Wikipedia

  • нехватка данных — фотографий мазков крови может быть недостаточно для построения сложной модели с нуля

  • недостаток размеченных данных — возможно, существует достаточно большое количество фотографий мазков крови (например, в историях болезни), но очень малая часть из них размечена

  • некачественная разметка — мазок крови могли доверить анализировать студенту-практиканту. Размечать его мог вообще человек не из профессии — например, хотевший таким образом увеличить обучающую выборку для модели на конкурс Kaggle. Даже в широко известных MNIST, CIFAR-10 и ImageNet есть ошибки в разметке (примеры)

Source: Label Errors in ML Test Sets

  • низкое качество данных — не все фотографии будут в хорошем разрешении. Да и не все образцы были правильно подготовлены. А еще могут попадаться фотографии от пациентов, больных чем-нибудь экзотическим, в результате чего их клетки выглядят сильно отличающимися от обычных. Или же в крови пациента просто плавают паразиты

Серповидная клеточная анемия приводит к аномальным эритроцитам

Source: Wikipedia

А так выглядит мазок крови при сонной болезни

Source: Wikipedia

  • несбалансированность датасета — клетки крови встречаются в разных пропорциях. Какие-то классы могут быть плохо представлены (минорные классы). Например, если в вашем датасете будет всего 10 фотографий, на которых присутствуют базофилы, то нейросети будет очень заманчиво вообще не пытаться найти базофилы (всего 10 ошибок).

  • ковариантный сдвиг — явление, когда признаки тренировочной и тестовой выборок распределены по-разному. Ковариантный сдвиг может стать серьезной проблемой для практического применения моделей.

Модель учится сопоставлять целевые значения признакам. В такой ситуации модель не в состоянии делать адекватные предсказания на тесте, так как во время обучения она не видела области пространства, в которой расположены тестовые объекты. Источники ошибок, приводящих к ковариантному сдвигу, обсуждались ранее в лекции №7).

Практический совет: для быстрого обнаружения ковариантного сдвига можно обучить модель, которая будет предсказывать, относится ли объект к train или test выборке. Если модель легко делит данные, то имеет смысл визуализировать значения признаков, по которым она это делает.

  • полные дубликаты — в данных могут быть полные дубликаты. Кто-то до вас агрегировал фотографии из разных источников, и либо вы не обратили на это внимание, либо он забыл об этом сказать. Такие данные надо сразу помечать и использовать только после предварительного размышления, т.к. они могут мешать вам и на этапе обучения модели, и на итоговой валидации ее качества (если один и тот же объект попадет и в обучение, и в валидацию).

  • неполные дубликаты — в данных могут быть данные от одного и того же пациента. Кажется, что если это разные мазки крови, то всё нормально. На самом деле и это уменьшает количество информации, которую может извлечь нейросеть из данных. С такими данными также нужно аккуратно работать и не допускать попадания одного пациента и в обучение, и в тест.

  • малое число источников данных — проблема, родственная предыдущей. В вашем датасете могут быть данные только от одного микроскопа или одной модели микроскопа. Могут быть данные, снятые только одним специалистом, или в одной больнице, или только у взрослых (фотографий мазков детей нет). Это также может влиять на способность вашего алгоритма обобщать полученное решение и требует пристального внимания.

Все это приводит к целому спектру проблем, из которых самой типичной будет переобучение модели — какую бы простую модель вы не взяли, она все равно будет выучивать искажения вашего датасета.

Общие подходы при работе с реальными данными¶

Нехватка данных¶

"Если у вас мало данных, попробуйте найти еще данные для вашей задачи" — совет приводится во многих инструкциях по борьбе с малым количеством данных и может быть воспринят с юмором. Однако часто для вашей задачи действительно существуют данные, собранные другими людьми. Также часто можно найти данные, которые очень похожи на ваши, и их можно использовать в обучении, но, например, учитывать с меньшим весом. Даже 20 дополнительных примеров могут сильно облегчить ситуацию.

Вы также можете использовать данные, которые не совсем похожи на ваши, в качестве внешней валидации. Тем самым вы можете разбивать свой изначальный датасет только для кросс-валидации и не выделять отдельную часть для теста. Тестом послужат как раз "не совсем" похожие данные.

Также можно использовать две техники, которые мы сегодня рассмотрим подробнее: Аугментация и Transfer Learning.

Дисбаланс классов¶

Прежде всего надо убедиться, что датасет сбалансирован:

In [2]:
import torch
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine


def show_class_balance(y, classes):
    _, counts = torch.unique(torch.tensor(y), return_counts=True)
    plt.bar(classes, counts)
    plt.ylabel("n_samples")
    plt.ylim([0, 75])
    plt.show()


wine = load_wine()
classes = wine.target_names

show_class_balance(wine.target, classes)

Разница в 10–20% будет незначительна, поэтому для наглядности мы искусственно разбалансируем наш датасет при помощи метода make_imbalance

In [3]:
from imblearn.datasets import make_imbalance

x, y = make_imbalance(
    wine.data, wine.target, sampling_strategy={0: 10, 1: 70, 2: 40}, random_state=42
)
show_class_balance(y, classes)

Изменение баланса класса сэмплированием¶

Если в данных недостаток именно конкретного класса, то можно бороться с этим при помощи разных способов сэмплирования.

Важно понимать, что в большинстве случаев данные, полученные таким способом, должны использоваться в качестве обучающего набора, но ни в коем случае не в качестве валидации или теста.

Дублирование примеров меньшего класса (oversampling)¶

Мы можем увеличить число объектов меньшего класса за счет дублирования.

Дублирование примеров меньшего класса

В этом случае наша модель будет "вынуждена" обращать внимание на минорный класс.

Такой Resampling может быть выполнен с помощью пакета imbalanced-learn, как показано ниже:

In [4]:
from imblearn.over_sampling import RandomOverSampler

ros = RandomOverSampler(random_state=0)
x_ros, y_ros = ros.fit_resample(x, y)

show_class_balance(y_ros, classes)

Уменьшение числа примеров большего класса (undersampling)¶

Аналогично, можно взять для обучения не всех представителей большего класса.

Удаление примеров преобладающего класса

Это также вынуждает модель обращать внимание на оба класса. Минус подхода очевиден: мы можем выбросить важных представителей большего класса, ответственных за существенное улучшение генерализации, и из-за этого качество модели существенно ухудшится. Можно пытаться выбрасывать объекты большего класса как-то по-умному. Например, кластеризовать объекты большего класса и брать по заданному количеству объектов из каждого класса.

In [5]:
from imblearn.under_sampling import RandomUnderSampler

rus = RandomUnderSampler(random_state=42)
x_res, y_res = rus.fit_resample(x, y)

show_class_balance(y_res, classes)

Ансамбли + undersampling¶

Можно использовать ансамбли вместе с undersampling. В этом случае мы можем, к примеру, делать сэмплирование только большего класса, а объекты минорного класса оставлять как есть.

Или просто сэмплировать объекты и того, и другого класса в равном количестве.

Балансирование представленности объектов в батчах¶

В случае нейросетей можно балансировать встречаемость каждого класса не на уровне датасета, а на уровне батча. Например, собирать каждый батч таким образом, чтобы в нем было поровну всех классов.

Это может улучшать сходимость даже в случае небольшого дисбаланса или его отсутствия, т.к. мы будем избегать шаги обучения нейросети, в которых она просто не увидела какого-то класса в силу случайных причин.

В PyTorch эту функциональность можно получить, используя класс WeightedRandomSampler . Для его инициализации требуется рассчитать вес каждого класса. Сумма весов не обязана быть равной единице.

In [6]:
# https://pytorch.org/docs/stable/generated/torch.unique.html
_, counts = torch.unique(torch.tensor(y), return_counts=True)
weights = counts.max() / counts
print("Classes: ", classes)
print("Weights: ", weights)
Classes:  ['class_0' 'class_1' 'class_2']
Weights:  tensor([7.0000, 1.0000, 1.7500])

Теперь создаем объект WeightedRandomSampler, в конструктор подаем два аргумента:

  • список весов для каждого элемента в датасете;
  • количество элементов (можно использовать не весь датасет).
In [7]:
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler

tensor_x = torch.Tensor(x)  # transform to torch tensor
tensor_y = torch.Tensor(y)
dataset = TensorDataset(tensor_x, tensor_y)

batch_size = 8

weight_for_sample = []  # Every sample must have a weight
for l in y:
    weight_for_sample.append(weights[l].item())

sampler = WeightedRandomSampler(torch.tensor(weight_for_sample), len(dataset))
loader = DataLoader(dataset, batch_size=32, drop_last=True, sampler=sampler)

Посмотрим на распределение элементов разных классов по батчам.

In [8]:
batch_labels = []
for data, labels in loader:
    print(
        "Labels:",
        labels.int().tolist(),
        "Classes in batch:",
        torch.unique(labels, return_counts=True)[1].tolist(),
    )
    batch_labels.append(labels.tolist())

show_class_balance(batch_labels, classes)
Labels: [1, 0, 2, 2, 1, 2, 0, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 0, 1, 2, 0, 2, 1, 2, 2, 1, 0, 0, 0, 1, 2] Classes in batch: [7, 12, 13]
Labels: [1, 0, 1, 0, 0, 0, 1, 0, 2, 0, 0, 1, 1, 1, 0, 1, 1, 0, 2, 0, 1, 0, 0, 2, 2, 2, 0, 1, 1, 1, 1, 1] Classes in batch: [13, 14, 5]
Labels: [0, 1, 1, 2, 1, 0, 2, 1, 0, 2, 1, 1, 1, 2, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 1, 1, 2, 1, 0, 0, 2] Classes in batch: [7, 14, 11]

Результаты будут отличаться от запуска к запуску. Но видно, что в батчах объекты каждого класса встречаются почти равномерно.

Стоит отметить, что нужно быть осторожным со взвешиванием объектов в батчах и контролировать состав батчей. Дело в том, что при существенном дисбалансе веса при объектах минорного класса могут оказываться на несколько порядков больше, чем при объектах мажорного класса. Данные веса преобразуются в вероятности для сэмплирования, и может случиться так, что вероятности при объектах мажорного класса станут численно неотличимы от нуля. Тем самым можно получить обратный эффект: батчи будут состоять исключительно из объектов минорного класса. В таком случае нужно намеренно ограничивать веса.

Генерация синтетических данных¶

Другой подход к решению этой проблемы — создание синтетических данных. Делать это можно по-разному.

SMOTE¶

Synthetic Minority Over-sampling Technique (SMOTE) позволяет генерировать синтетические данные за счет реальных объектов из минорного класса.

Алгоритм работает следующим образом:

  1. Для случайной точки из минорного класса выбираем $k$ ближайших соседей из того же класса.
  2. Для первого соседа проводим отрезок, соединяющий его и выбранную точку. На этом отрезке случайно выбираем точку.
  3. Эта точка — новый синтетический объект минорного класса.
  4. Повторяем процедуру для оставшихся соседей.

Число соседей, как и число раз, которое мы запускаем описанную выше процедуру, можно регулировать.

In [9]:
from imblearn.over_sampling import SMOTE

oversample = SMOTE()
x_smote, y_smote = oversample.fit_resample(x, y)

show_class_balance(y_smote, classes)

Количество объектов каждого класса, которое должно получиться после генерации, можно задать явно:

In [10]:
over = SMOTE(sampling_strategy={0: 20, 1: 70, 2: 70})
x_smote, y_smote = over.fit_resample(x, y)

show_class_balance(y_smote, classes)

Подробнее про использование пакета можно прочесть в статье: SMOTE for Imbalanced Classification with Python

Изменение функции потерь¶

Модели в машинном обучении “ленивы”. При работе с несбалансированными классами модель будет чаще сталкиваться с доминирующим классом и вместо того, чтобы разбираться в признаках объектов, может начать ориентироваться на статистическое распределение классов.

Пример: датасет, в котором 95% объектов относятся к классу 1 и 5% к классу 0. Модель может выучиться всегда относить объекты к классу 1, и в 95% случаях она будет права.

Веса классов¶

Чтобы это поправить, можно изменить функцию потерь так, чтобы больше штрафовать модель за ошибки в минорных классах. Для этого можно:

  • Добавлять веса в функцию потерь для компенсации дисбаланса классов. Во многих функциях потерь в PyTorch (например, CrossEntropyLoss) есть параметр weight, который имеет по умолчанию значение None. В него можно передать тензор весов, соответствующий размеру вектора целевых значений, и получить взвешенную функцию ошибок.

Посмотрим, как это работает. Допустим, мы получили от нейросети неверные предсказания:

второй объект должен относиться к классу 1, а не 0

In [11]:
scores = torch.tensor([[30.0, 2.0], [30.0, 2.0]])  # Scores for batch of two samples
target = torch.tensor([0, 1])  # Second sample belongs to class 1
# but logit for class 0 greater: 30 > 2. So it was misclassified

Подсчитаем Cross-Entropy Loss без весов:

$$Loss = \frac{1}{2} \biggr[- \ln\frac{e^{30}}{e^{30}+e^{2}} - \ln\frac{e^{2}}{e^{30}+e^{2}}\biggr]\approx 14.0 $$
In [12]:
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(scores, target)
print(f"Loss = {loss.item():.2f}")
Loss = 14.00

Если у нас есть два класса с соотношением 4:1, можно задать веса weight = [0.2, 0.8]. И, так как сеть ошиблась на классе с большим весом, ошибка вырастет:

$$Loss = \biggr[-0.2 \ln\frac{e^{30}}{e^{30}+e^{2}} -0.8 \ln\frac{e^{2}}{e^{30}+e^{2}}\biggr]\approx 22.4 $$
In [13]:
weights = torch.tensor([0.2, 0.8], dtype=torch.float32)
criterion = torch.nn.CrossEntropyLoss(weight=weights)
loss = criterion(scores, target)
print(f"Loss = {loss.item():.2f}")
Loss = 22.40

Сумма весов может быть не равна единице:

In [14]:
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([2.0, 8.0]))
loss = criterion(scores, target)
print(f"Loss = {loss.item():.2f}")
Loss = 22.40
  • Иногда качество модели можно улучшить, взяв квадратные корни от полученных таким образом весов (немного снижает штрафы за ошибки на редких классах).

  • Несмотря на интуитивно понятную логику работы способа, он не всегда дает значительный эффект. Тем не менее, на практике стоит пробовать экспериментировать с этим способом наряду с прочими техниками борьбы с дисбалансом.

  • Можно менять форму функции потерь. В 2017 году для работы с несбалансированными классами был предложен Focal Loss. Это — модификация Cross-Entropy Loss, которая модифицирует ее форму для различных классов. С тех пор появились различные модификации функции ошибок, с которыми можно и нужно экспериментировать.

Focal Loss¶

Focal Loss — это функция потерь, используемая в нейронных сетях для решения проблемы классификации сложных объектов (hard examples).

Приведем пример: пусть модель должна классифицировать фрукты на два класса: яблоки и груши. В наборе данных есть много явных представителей того и иного класса: зеленые яблоки и желтые груши. Для модели они будут простыми: она на них не ошибается, и предсказываемая вероятность истинного класса для них велика и равна $0.9$.

В то же время в наборе данных есть малое количество зеленых груш: они по форме похожи на груши, а по цвету — на яблоки. Говоря более общими словами, зеленые груши ближе к границе классов в пространстве признаков. Для модели такие примеры могут оказаться сложными, и она будет на них ошибаться. Пусть для зеленой груши вероятность истинного класса равна $0.2$, то есть модель ошиблась и предсказала класс "яблоко" с вероятностью $0.8$.

Проблема состоит в том, что сумма большого количества малых ошибок на простых объектах может перевешивать сумму малого количества ошибок потерь на сложных объектах. Поэтому модель будет плохо учиться верно классифицировать сложные объекты: ей будет "лень" исправлять незначительные ошибки.

Для решения этой проблемы была предложена специальная функция потерь — Focal Loss. Она немного модифицирует кросс-энтропию для придания большей значимости ошибкам на сложных объектах.

Focal Loss была предложена в статье Focal Loss for Dense Object Detection (Lin et al., 2017) изначально для задачи детектирования объектов на изображениях. Она определяется так:

$$\large\text{FL}(p_t) = -(1 - p_t)^\gamma\text{log}(p_t)$$

Здесь $p_t$ — предсказанная вероятность истинного класса, а $\gamma$ — настраиваемый гиперпараметр.

Focal Loss уменьшает потери на уверенно классифицируемых примерах (где $p_t>0.5$) и больше фокусируется на сложных примерах, которые классифицированы неправильно. Параметр $\gamma$ управляет относительной важностью неправильно классифицируемых примеров. Более высокое значение $\gamma$ увеличивает важность неправильно классифицированных примеров. В экспериментах авторы показали, что параметр $\gamma=2$ показывал себя наилучшим образом в их задаче.

При $\gamma=0$ Focal Loss становится равной Cross-Entropy Loss, которая выражается как обратный логарифм вероятности истинного класса:

$$\large\text{CE}(p_t)=-\text{log}(p_t)$$

Разберем на нашем примере с яблоками и грушами.

Мы имеем 20 простых объектов с вероятностью истинного класса $0.9$: 10 яблок и 10 груш, и один сложный объект с вероятностью истинного класса $0.2$.

$\large{\text{CE} = \overbrace{\sum^{20}-\text{log}(0.9)}^{\large\color{#3C8031}{\text{loss(easy apples and pears)=2.11}}} + \overbrace{(-\text{log}(0.2))}^{\large\color{#F26035}{\text{loss(hard pear)=1.61}}} \approx 3.72}$

$\large{\text{FL}(\gamma=2) = \overbrace{\sum^{20}-\color{#AF3235}{\underbrace{(1-0.9)^2}_{0.01}}\text{log}(0.9)}^{\large\color{#3C8031}{\text{loss(easy apples and pears)=0.02}}} + \overbrace{(-\color{#AF3235}{\underbrace{(1-0.2)^2}_{0.64}}\text{log}(0.2))}^{\large\color{#F26035}{\text{loss(hard pear)=1.03}}} \approx 1.05}$

Фактически, потери для уверенно классифицированных объектов дополнительно занижаются.

Этот эффект достигается путем домножения на коэффициент: $ \large(1-p_{t})^\gamma$

Пока модель ошибается, $p_{t}$ — мала, и значение выражения в скобках соответственно близко к 1.

Когда модель обучилась, значение $p_{t}$ становится близким к 1, а разность в скобках становится маленьким числом, которое возводится в степень $ \gamma \ge 0 $. Таким образом, домножение на это небольшое число нивелирует вклад верно классифицированных объектов.

Это позволяет модели сосредоточиться (сфокусироваться, отсюда и название) на изучении сложных объектов (hard examples).

В примере коэффициент $(1-p_t)^\gamma$ в Focal Loss в 100 раз занизил потери при уверенной классификации простых яблок и груш, и потери при неверной классификации сложной груши стали преобладать.

Source: Focal Loss for Dense Object Detection (Lin et al., 2018)

Давайте посчитаем для различных значений $γ$, сколько понадобится примеров с небольшой ошибкой (высокой вероятностью истинного класса, равной $0.9$), чтобы получить суммарный Focal Loss примерно такой же, как у одного примера с большой ошибкой (низкой вероятностью истинного класса, равной $0.2$).

In [15]:
import numpy as np


def cross_entropy(prob_true):
    return -np.log(prob_true)


def focal_loss(prob_true, gamma=2):
    return (1 - prob_true) ** gamma * cross_entropy(prob_true)


p1 = 0.9  # probability of easy examples predictions
p2 = 0.2  # probability of hard examples predictions
gammas = [0, 0.5, 1, 2, 5, 10, 15]

print(
    f"For probability of easy examples predictions {p1} and probability of hard examples predictions {p2}\n"
)

for gamma in gammas:
    fl1 = focal_loss(p1, gamma)
    fl2 = focal_loss(p2, gamma)

    print(
        f"gamma = {gamma},".ljust(15),
        f"for an equal loss with a problematic prediction, almost correct ones are required {int(fl2 / fl1)}",
    )
For probability of easy examples predictions 0.9 and probability of hard examples predictions 0.2

gamma = 0,      for an equal loss with a problematic prediction, almost correct ones are required 15
gamma = 0.5,    for an equal loss with a problematic prediction, almost correct ones are required 43
gamma = 1,      for an equal loss with a problematic prediction, almost correct ones are required 122
gamma = 2,      for an equal loss with a problematic prediction, almost correct ones are required 977
gamma = 5,      for an equal loss with a problematic prediction, almost correct ones are required 500548
gamma = 10,     for an equal loss with a problematic prediction, almost correct ones are required 16401977428
gamma = 15,     for an equal loss with a problematic prediction, almost correct ones are required 537459996388583

Как видно, при увеличении значения $\gamma$ можно достичь значительного роста "важности" примеров с высокой ошибкой, что по сути позволяет модели обращать внимание на "hard examples".

Этот пример также показывает опасность Focal Loss: если мы имеем ошибки в разметке, то при большом $\gamma$ можно начать очень сильно наказывать модель за ошибки на неверно размеченных примерах, что может привести к переобучению под ошибки в разметке.

Focal Loss может применяться также и в задачах с дисбалансом классов. В этом смысле объекты преобладающего класса могут считаться простыми, а объекты минорного класса — сложными.

Однако для работы с дисбалансом в Focal Loss могут быть добавлены веса для классов. Тогда формула будет выглядеть так:

$$\large\text{FL}(p_t) = -\alpha_t(1 - p_t)^\gamma\text{log}(p_t)$$

Здесь $\alpha_t$ — вес для истинного класса, имеющий такой же смысл, как параметр weight в Cross-Entropy Loss.

Focal Loss не реализована в PyTorch нативно, но существуют сторонние совместимые реализации. Посмотрим, как воспользоваться одной из них.

In [16]:
import random


def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


set_random_seed(42)
In [17]:
#!wget https://raw.githubusercontent.com/AdeelH/pytorch-multi-class-focal-loss/master/focal_loss.py
!wget -q https://edunet.kea.su/repo/EduNet-web_dependencies/L11/focal_loss.py
In [18]:
import torch
from torch import nn
from focal_loss import FocalLoss


criterion = FocalLoss(alpha=None, gamma=2.0)

model_output = torch.rand(3, 3)  # model output is logits, as in CELoss
print(f"model_output:\n {model_output}")

target = torch.empty(3, dtype=torch.long).random_(3)
print(f"target: {target}")

loss_fl = criterion(model_output, target)
print(f"loss_fl: {loss_fl}")
model_output:
 tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408]])
target: tensor([2, 1, 1])
loss_fl: 0.6864498257637024

Убедимся, что сторонняя реализация вычисляет то, что нужно, и вычислим значение вручную. В первую очередь нужно перевести model_output из логитов в вероятности с помощью softmax.

In [19]:
probs = torch.nn.functional.softmax(model_output, dim=1)

print(f"probabilities after softmax:\n {probs}")
probabilities after softmax:
 tensor([[0.3788, 0.3914, 0.2299],
        [0.4415, 0.2500, 0.3085],
        [0.2131, 0.3646, 0.4224]])

Теперь вручную рассчитаем значение функции потерь.

In [20]:
def focal_loss(prob_true, gamma=2):
    return -((1 - prob_true) ** gamma) * np.log(prob_true)


hand_calculated_loss = 0

for i in range(3):
    hand_calculated_loss += focal_loss(probs[i, target[i]])

hand_calculated_loss /= 3  # average by number of samples

print(f"hand-calculated focal loss:    {hand_calculated_loss.item()}")
print(f"library-calculated focal loss: {loss_fl}")
print(
    f"Are results almost equal? {torch.isclose(loss_fl, hand_calculated_loss).item()}"
)
hand-calculated focal loss:    0.6864497661590576
library-calculated focal loss: 0.6864498257637024
Are results almost equal? True

Действительно, при расчете вручную получили то же значение, что и при расчете с помощью сторонней реализации.

Обнаружение аномалий¶

В случае сильно несбалансированных наборов данных стоит задуматься, могут ли такие примеры рассматриваться как аномалия (выброс) или нет. Если такое событие и впрямь может считаться аномальным, мы можем использовать такие модели, как OneClassSVM, методы кластеризации или методы обнаружения гауссовских аномалий.

Эти методы требуют изменения взгляда на задачу: мы будем рассматривать аномалии как отдельный класс выбросов. Это может помочь нам найти новые способы разделения и классификации.

Если продолжать пример с фруктами, то задача обнаружения аномалий возникла бы, если бы мы предполагали, что среди яблок и груш может вдруг возникнуть мандарин, или любой другой фрукт, и нам бы нужно было не отнести его к одному из известных классов, а пометить как отдельный, отличающийся класс.

Проблемой при работе с аномалиями является то, что аномальных значений может быть очень мало или вообще не быть. В таком случае алгоритм учит паттерны нормального поведения и реагирует на отличия от паттернов.

Разберем примеры обнаружения аномалий с помощью трех алгоритмов из библиотеки Scikit-Learn (там можно найти еще много различных алгоритмов).

Создадим датасет из двух кластеров и случайных значений.

In [21]:
rng = np.random.RandomState(42)

# Train
x = 0.3 * rng.randn(100, 2)  # 100 2D points
x_train = np.r_[x + 2, x - 2]  # split into two clusters

# Test norlmal
x = 0.3 * rng.randn(20, 2)  # 20 2D points
x_test_norlmal = np.r_[x + 2, x - 2]  # split into two clusters

# Test outliers
x_test_outliers = rng.uniform(low=-4, high=4, size=(20, 2))

Напишем функцию визуализации, которая будет изображать созданный датасет на рисунке слева, а результат поиска аномалий — на рисунке справа.

In [22]:
def plot_outliers(x_train, x_test_norlmal, x_test_outliers, model=None):
    fig, (plt_data, plt_model) = plt.subplots(1, 2, figsize=(12, 6))

    plt_data.set_title("Created Dataset (real labels)")
    plot_train = plt_data.scatter(
        x_train[:, 0], x_train[:, 1], c="white", s=40, edgecolor="k"
    )
    plot_test_normal = plt_data.scatter(
        x_test_norlmal[:, 0], x_test_norlmal[:, 1], c="green", s=40, edgecolor="k"
    )
    plot_test_outliers = plt_data.scatter(
        x_test_outliers[:, 0], x_test_outliers[:, 1], c="red", s=40, edgecolor="k"
    )

    plt_data.set_xlim((-5, 5))
    plt_data.set_ylim((-5, 5))

    plt_data.legend(
        [plot_train, plot_test_normal, plot_test_outliers],
        ["train", "test normal", "test outliers"],
        loc="lower right",
    )

    if model:
        plt_model.set_title("Model Results")
        # plot decision function
        xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))
        Z = model.decision_function(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)

        plt_model.contourf(xx, yy, Z, cmap=plt.cm.Blues_r)

        # plot prediction
        full_data = np.concatenate((x_train, x_test_norlmal, x_test_outliers), axis=0)
        predicted = model.predict(full_data)

        anom_index = np.where(predicted == -1)
        anom_values = full_data[anom_index]

        plot_all_data = plt_model.scatter(
            full_data[:, 0], full_data[:, 1], c="white", s=40, edgecolor="k"
        )

        plot_anom_data = plt_model.scatter(
            anom_values[:, 0], anom_values[:, 1], c="red", s=40, marker="x"
        )
        plt_model.legend(
            [plot_all_data, plot_anom_data],
            ["normal", "outliers"],
            loc="lower right",
        )
    plt.show()

Посмотрим, как работает на этих данных алгоритм OneClassSVM.

Идея алгоритма состоит в поиске функции, которая положительна для областей с высокой плотностью и отрицательна для областей с малой плотностью. Подробнее об алгоритме можно прочитать в оригинальной статье.

In [23]:
from sklearn.svm import OneClassSVM

gamma = 2.0
contamination = 0.05

model = OneClassSVM(gamma=gamma, kernel="rbf", nu=contamination)
model.fit(x_train)

plot_outliers(x_train, x_test_norlmal, x_test_outliers, model)

Посмотрим, как на этих же данных работает алгоритм IsolationForest.

IsolationForest состоит из деревьев, которые «изолируют» (пытаются отделить от остальной выборки) наблюдения, случайным образом выбирая признак и случайное значение порога для этого признака (между max и min значениями признака). Такой алгоритм чаще и проще отделяет значения аномалии. Если построить по такому принципу множество деревьев, то значения, которые чаще других отделяются раньше, будут аномалиями.

In [24]:
from sklearn.ensemble import IsolationForest

n_estimators = 200
contamination = 0.05

model = IsolationForest(
    n_estimators=n_estimators, contamination=contamination, random_state=rng
)
model.fit(x_train)

plot_outliers(x_train, x_test_norlmal, x_test_outliers, model)

Последним алгоритмом, на который мы посмотрим, будет LocalOutlierFactor.

В нем используется метод k-NN. Расстояние до ближайших соседей используется для оценки расположения точек. Если соседи далеко, то точка с большой вероятностью является аномалией.

In [25]:
from sklearn.neighbors import LocalOutlierFactor

n_neighbors = 10
contamination = 0.05

model = LocalOutlierFactor(
    n_neighbors=n_neighbors, novelty=True, contamination=contamination
)
model.fit(x_train)

plot_outliers(x_train, x_test_norlmal, x_test_outliers, model)

Метрики качества на несбалансированных данных¶

Обращайте внимание на то, какие метрики вы используете. При решении задачи классификации часто используется accuracy (точность), равная доле правильно классифицированных объектов. Эта метрика позволяет адекватно оценить результат классификации в случае сбалансированных классов. В случае дисбаланса классов данная метрика может выдать обманчиво хороший результат.

Пример: датасет, в котором 95% объектов относятся к классу 0 и 5% к классу 1.

In [26]:
from sklearn.datasets import make_classification
from collections import Counter


x, y = make_classification(
    n_samples=1000,
    n_features=2,
    n_redundant=0,
    n_clusters_per_class=1,
    weights=[0.95],
    flip_y=0,
    random_state=42,
)

counter = Counter(y)
print("Class distribution ", Counter(y))

for label, _ in counter.items():
    row_ix = np.where(y == label)[0]
    plt.scatter(x[row_ix, 0], x[row_ix, 1], label=str(label))
plt.legend()
plt.show()
Class distribution  Counter({0: 950, 1: 50})

И модель, которая для всех данных выдает класс 0,

In [27]:
class DummyModel:
    def predict(self, x):
        return np.zeros(x.shape[0])  # always predict class 0

Такая модель будет иметь $accuracy = 0.95$, хотя не выдает никакой полезной информации:

In [28]:
from sklearn.metrics import accuracy_score

dummy_model = DummyModel()
y_pred = dummy_model.predict(x)

accuracy = accuracy_score(y, y_pred)
print("Accuracy", accuracy)
Accuracy 0.95

Для несбалансированных данных лучше выбирать F1 Score, MCC (Matthews correlation coefficient, коэффициент корреляции Мэтьюса) или balanced accuracy (среднее между recall разных классов).

In [29]:
from sklearn.metrics import f1_score, matthews_corrcoef, balanced_accuracy_score

print("F1", f1_score(y, y_pred))
print("MCC", matthews_corrcoef(y, y_pred))
print("Balanced accuracy", balanced_accuracy_score(y, y_pred))
F1 0.0
MCC 0.0
Balanced accuracy 0.5

Все эти метрики оказываются равны нулю, что отражает отсутствие связи предсказаний с данными на входе модели.

Аугментация¶

Другой способ побороть маленькое количество данных для обучения — аугментация.

Аугмента́ция (от лат. augmentatio — увеличение, расширение) — увеличение выборки обучающих данных через модификацию существующих данных.

Модели глубокого обучения обычно требуют большого количества данных для обучения. В целом, чем больше данных, тем лучше для обучения модели. В то же время получение огромных объемов данных сопряжено со своими проблемами (например, с нехваткой размеченных данных или с трудозатратами, сопряженными с разметкой).

Вместо того, чтобы тратить дни на сбор данных вручную, мы можем использовать методы аугментации для автоматической генерации новых примеров из уже имеющихся.

Помимо увеличения размеченных датасетов, многие методы self-supervised learning построены на использовании разных аугментаций одного и того же сэмпла.

Примеры аугментаций изображения

Важный момент: при обучении модели мы используем разбиение данных на train-val-test. Аугментации стоит применять только на train. Почему так? Конечная цель обучения нейросети — это применение на реальных данных, которые сеть не видела. Поэтому для адекватной оценки качества модели валидационные и тестовые данные изменять не нужно.

В любом случае, test должен быть отделен от данных еще до того, как они попали в DataLoader или нейросеть.

Другое дело, что аугментации на тесте можно использовать как метод ансамблирования в случае классификации. Можно взять sample → создать несколько его копий → по-разному их аугментировать → предсказать класс на каждой из этих аугментированных копий → а потом выбрать наиболее вероятный класс голосованием (такой функционал реализован, например, в YOLOv5, о которой речь пойдет в следующих лекциях).

Изображения¶

Загрузим и отобразим пример картинки. Картинку отмасштабируем, чтобы она не занимала весь экран.

In [30]:
# setting random seed for reproducible illustrations
set_random_seed(42)

URL = "https://edunet.kea.su/repo/EduNet-web_dependencies/L11/capybara_image.jpg"
!wget -q $URL -O test.jpg
In [31]:
from IPython.display import display
from PIL import Image
from torchvision import transforms

input_img = Image.open("/content/test.jpg")
input_img = transforms.Resize(size=300)(input_img)
display(input_img)

Рассмотрим несколько примеров аугментаций картинок. С полным списком можно ознакомиться на сайте [doc] документации torchvision.

Random Rotation¶

Трансформация transforms.Random Rotation принимает параметр degrees — диапазон углов, из которого выбирается случайный угол для поворота изображения.

Создадим переменную transform, в которую добавим нашу аугментацию, и применим ее к исходному изображению. Затем запустим следующую ячейку несколько раз подряд

In [32]:
import matplotlib.pyplot as plt


def plot_augmented_img(transform, input_img):
    fig, ax = plt.subplots(1, 2, figsize=(15, 15))
    augmented_img = transform(input_img)
    ax[0].imshow(input_img)
    ax[0].set_title("Original img")
    ax[0].axis("off")

    ax[1].imshow(augmented_img)
    ax[1].set_title("Augmented img")
    ax[1].axis("off")
    plt.show()


transform = transforms.RandomRotation(degrees=(0, 180))

plot_augmented_img(transform, input_img)

Gaussian Blur¶

transforms.GaussianBlur размывает изображение с помощью фильтра Гаусса.

In [33]:
transform = transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))

plot_augmented_img(transform, input_img)

Random Erasing¶

transforms.RandomErasing стирает на изображении произвольный прямоугольник. Она имеет параметр p — вероятность, с которой данная трансформация вообще применится к изображению.

Данная трансформация работает только с torch.Tensor, поэтому предварительно нужно применить трансформацию ToTensor, а затем ToPILImage, чтобы воспользоваться нашей функцией для отображения.

In [34]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.RandomErasing(p=1), transforms.ToPILImage()]
)

plot_augmented_img(transform, input_img)

Не лишним будет заметить, что некоторые трансформации могут существенно исказить изображение. Например, здесь, RandomErasing практически полностью стерла основной объект на снимке — капибару. Такая грубая аугментация может только навредить процессу обучения, и на практике нужно быть осторожным.

RandomErasing также имеет параметр scale — диапазон соотношения стираемой области к входному изображению. Попробуем уменьшить этот диапазон относительно значения по умолчанию, чтобы избежать нежелательного эффекта стирания капибары.

In [35]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.RandomErasing(p=1, scale=(0.02, 0.1)),
        transforms.ToPILImage(),
    ]
)

plot_augmented_img(transform, input_img)

ColorJitter¶

transforms.ColorJitter случайным образом меняет яркость, контрастность, насыщенность и оттенок изображения.

In [36]:
transform = transforms.ColorJitter(brightness=0.5, hue=0.3)

plot_augmented_img(transform, input_img)

Совмещаем несколько аугментаций вместе¶

Для этого будем использовать метод transforms.Compose. Нам нужно будет создать list со всеми аугментациями, которые будут применены последовательно.

In [37]:
transform = transforms.Compose(
    [
        transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
        transforms.RandomPerspective(distortion_scale=0.5, p=1.0),
        transforms.ColorJitter(brightness=0.5, hue=0.3),
    ]
)

plot_augmented_img(transform, input_img)

Совмещение нескольких аугментаций случайным образом¶

Random Apply¶

Для того, чтобы применять аугментации случайным образом, можно воспользоваться методом transforms.RandomApply, который на вход принимает список аугментаций и вероятность p, с которой каждая аугментация будет применена.

In [38]:
transform = transforms.RandomApply(
    transforms=[
        transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
        transforms.RandomPerspective(distortion_scale=0.5),
        transforms.ColorJitter(brightness=0.5, hue=0.3),
    ],
    p=0.9,
)

plot_augmented_img(transform, input_img)
Random Choice¶

В других случаях может быть полезен метод transforms.RandomChoice, который на вход принимает список аугментаций transforms, выбирает из него одну случайную аугментацию и применяет ее к изображению. Необязательным параметром является список вероятностей p, который указывает, с какой вероятностью каждая из аугментаций может быть выбрана из списка (по умолчанию каждая может быть выбрана равновероятно).

In [39]:
transform = transforms.RandomChoice(
    transforms=[
        transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
        transforms.RandomPerspective(distortion_scale=0.5, p=1.0),
        transforms.ColorJitter(brightness=0.5, hue=0.3),
    ],
    p=[0.2, 0.4, 0.6],
)

plot_augmented_img(transform, input_img)
plot_augmented_img(transform, input_img)
plot_augmented_img(transform, input_img)

Пример создания собственной аугментации¶

Иногда может оказаться, что среди широкого спектра реализованных аугментаций нет такой, какую вы хотели бы применить к своим данным. В таком случае ее можно описать в виде класса и использовать наравне с реализованными в библиотеке.

Главное, что необходимо описать при создании класса — метод __call__. Он должен принимать изображение (оно может быть представлено в формате PIL.Image, np.array или torch.Tensor), делать с ним интересующие нас видоизменения и возвращать измененное изображение.

Рассмотрим пример добавления на изображение шума "соль и перец". Наш метод аугментации будет и принимать на вход, и возвращать PIL.Image.

In [40]:
from PIL import Image
import numpy as np


class SaltAndPepperNoise:
    """
    Add a "salt and pepper" noise to the PIL image
    __call__ method returns PIL Image with noise
    """

    def __init__(self, p=0.01):
        self.p = p  # noise level

    def __call__(self, pil_image):
        np_image = np.array(pil_image)

        # create random mask for "salt" and "pepper" pixels
        salt_ind = np.random.choice(
            a=[True, False], size=np_image.shape[:2], p=[self.p, 1 - self.p]
        )
        pepper_ind = np.random.choice(
            a=[True, False], size=np_image.shape[:2], p=[self.p, 1 - self.p]
        )

        # add "salt" and "pepper"
        np_image[salt_ind] = 255
        np_image[pepper_ind] = 0

        return Image.fromarray(np_image)
In [41]:
transform = SaltAndPepperNoise(p=0.03)

plot_augmented_img(transform, input_img)

Аугментация внутри Dataset¶

Возьмем папку с картинками.

In [42]:
import os
from zipfile import ZipFile

os.chdir("/content")
# download files
!wget -q --no-check-certificate 'https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/for_transforms.Compose.zip' -O data.zip
with ZipFile(
    "data.zip", "r"
) as folder:  # Create a ZipFile Object and load sample.zip in it
    folder.extractall()  # Extract all the contents of zip file in current directory
In [43]:
os.chdir("/content/for_transforms.Compose")
img_list = os.listdir()
print(img_list)
['horse4.jpg', 'horse2.jpg', 'bicornis5.jpg', 'bicornis1.jpg', 'bicornis3.jpg', 'bicornis2.jpg', 'horse1.jpg', 'horse5.jpg']

Напишем класс Dataset

In [44]:
from torch.utils.data import Dataset


class AugmentationDataset(Dataset):
    def __init__(self, img_list, transforms=None):
        self.img_list = img_list
        self.transforms = transforms

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, i):
        img = plt.imread(self.img_list[i])
        img = Image.fromarray(img).convert("RGB")
        img = np.array(img).astype(np.uint8)

        if self.transforms is not None:
            img = self.transforms(img)
        return img

Напишем вспомогательную функцию для отображения картинок. Напомним, что в PyTorch размерность каналов идет в первом, а не в последнем измерении тензора, описывающего картинку: Channels x Height x Width. Для отображения при помощи Matplotlib необходимо перевести массив в формат Height x Width x Channels.

In [45]:
def show_img(img):
    plt.figure(figsize=(40, 38))
    img_np = img.numpy()
    plt.imshow(np.transpose(img_np, (1, 2, 0)))  # [CxHxW] -> [HxWxC] for imshow
    plt.show()

Создадим list с аугментациями, которые мы хотим применить. Чтобы загрузить аугментации в PyTorch, нам необходимо эти картинки преобразовать в тензоры. Для этого воспользуемся стандартным преобразованием transforms.ToTensor()

In [46]:
tensor_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((164, 164)),
        transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
        transforms.RandomPerspective(distortion_scale=0.5),
        transforms.ToTensor(),
    ]
)

Теперь обернем все в DataLoader и отобразим

In [47]:
from torch.utils.data import DataLoader
import torchvision

Augmentation_dataloader = DataLoader(
    AugmentationDataset(img_list, tensor_transform), batch_size=8, shuffle=True
)

data = iter(Augmentation_dataloader)
show_img(torchvision.utils.make_grid(next(data)))

Нестандартные способы аугментации¶

Существуют и более сложные способы аугментации. Ниже приведена пара примеров таких способов.

Mixup¶

Mixup — это "смешение" признаков двух объектов в определенных пропорциях. Mixup можно представить с помощью простого уравнения:

$\text{New image} = \alpha * \text{image}_1 + (1-\alpha) * \text{image}_2$

Подробнее в статьях:

mixup: Beyond Empirical Risk Minimization

On Mixup Training

Аугментация при помощи генерации данных¶

В ряде случаев возможно расширение набора данных путем синтеза новых данных.

Например, при создании системы распознавания текста можно генерировать новые образцы путем набора распознаваемых фраз или символов различными шрифтами на различных фонах и с добавлением каких-либо шумов и искажений.

В ряде областей для синтеза новых образов могут создаваться 3D-модели распознаваемых объектов. Например, в работе от Microsoft Fake It Till You Make It: Face analysis in the wild using synthetic data alone анализ лиц людей производился на синтетических 3D-моделях лиц. Датасет доступен на GitHub.

Также созданием новых образов, похожих на имеющиеся в датасете, можно заниматься при помощи генеративных моделей. Примером генеративных моделей является GAN (Generative Adversarial Network). Мы познакомимся с такими моделями в одной из следующих лекций.

Source: Progressive growing of GANS for improved quality, stability, and variation (Karras et al., 2018)

Аугментация в реальных задачах¶

Кроме методов, реализованных в PyTorch, существуют и специализированные библиотеки для аугментации изображений, в которых реализованы дополнительные возможности (например, наложение теней, бликов или пятен воды на изображение).

Например:

  • Albumentations
  • imgaug
  • AugLy

Важно: при выборе методов аугментации имеет смысл использовать только те, которые будут в реальной жизни.

Например, нет смысла:

  • делать перевод изображения в черно-белое, если предполагается, что весь входящий поток будет цветным
  • отражать человека вверх ногами, если мы не предполагаем его таким распознавать

Аудио¶

Рассмотрим несколько примеров аугментаций аудио. С полным списком можно ознакомиться здесь: [git] audiomentations.

Импортируем библиотеку и посмотрим на пример

In [48]:
os.chdir("/content")

!pip install -q audiomentations

!wget -q https://edunet.kea.su/repo/EduNet-web_dependencies/L11/audio_example.wav
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 70.6/70.6 kB 5.9 MB/s eta 0:00:00
In [49]:
from IPython.display import Audio

# Get input audio
input_audio = "/content/audio_example.wav"

display(Audio(input_audio))
Your browser does not support the audio element.
In [50]:
import librosa

data, sr = librosa.load("/content/audio_example.wav")  # sr - sampling rate

Background Noise¶

In [51]:
from audiomentations import AddGaussianSNR

augment = AddGaussianSNR(min_snr_in_db=3, max_snr_in_db=7, p=1)

# Augment/transform the audio data
augmented_data = augment(samples=data, sample_rate=sr)

display(Audio(augmented_data, rate=sr))
Your browser does not support the audio element.

Сравним волновые картины и спектрограммы

In [52]:
from scipy.signal import spectrogram


def produce_plots(input_audio_arr, aug_audio, sr):
    f, t, Sxx_in = spectrogram(
        input_audio_arr, fs=sr
    )  # Compute spectrogram for the original signal (f - frequency, t - time)
    f, t, Sxx_aug = spectrogram(aug_audio, fs=sr)

    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(20, 5))

    ax[0, 0].plot(input_audio_arr)
    ax[0, 0].set_xlim(0, len(input_audio_arr))
    ax[0, 0].set_xticks([])
    ax[0, 0].set_title("Original audio")

    ax[0, 1].plot(aug_audio)
    ax[0, 1].set_xlim(0, len(input_audio_arr))
    ax[0, 1].set_xticks([])
    ax[0, 1].set_title("Augmented  audio")

    ax[1, 0].imshow(
        np.log(Sxx_in),
        extent=[t.min(), t.max(), f.min(), f.max()],
        aspect="auto",
        cmap="inferno",
    )
    ax[1, 0].set_ylabel("Frequecny, Hz")
    ax[1, 0].set_xlabel("Time,s")

    ax[1, 1].imshow(
        np.log(Sxx_aug, where=Sxx_aug > 0),
        extent=[t.min(), t.max(), f.min(), f.max()],
        aspect="auto",
        cmap="inferno",
    )
    ax[1, 1].set_ylabel("Frequecny, Hz")
    ax[1, 1].set_xlabel("Time,s")

    plt.subplots_adjust(hspace=0)
    plt.show()


produce_plots(data, augmented_data, sr)

Time Stretch¶

In [53]:
from audiomentations import TimeStretch

augment = TimeStretch(min_rate=0.8, max_rate=1.5, p=1)
augmented_data = augment(data, sample_rate=sr)

display(Audio(augmented_data, rate=sr))
Your browser does not support the audio element.
In [54]:
produce_plots(data, augmented_data, sr)

Pitch Shift¶

Изменение тональности:

In [55]:
from audiomentations import PitchShift

augment = PitchShift(min_semitones=1, max_semitones=12, p=1)
augmented_data = augment(data, sample_rate=sr)

display(Audio(augmented_data, rate=sr))
Your browser does not support the audio element.

Совмещаем несколько аугментаций вместе¶

Как и в случае с картинками, мы можем совмещать несколько аугментаций вместе

In [56]:
from audiomentations import Compose, AddGaussianNoise, Shift

augment = Compose(
    [
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1),
        TimeStretch(min_rate=0.8, max_rate=1.25, p=1),
        PitchShift(min_semitones=-4, max_semitones=4, p=1),
        Shift(min_fraction=-0.5, max_fraction=0.5, p=1),
    ]
)

augmented_data = augment(data, sample_rate=sr)

display(Audio(augmented_data, rate=sr))
Your browser does not support the audio element.

Посмотрим на то, что получилось:

In [57]:
produce_plots(data, augmented_data, sr)

Дополнительные библиотеки для аугментации звука (и волновых функций в целом):

  • torchaudio
  • torch-audiomentations
  • AugLy

Текст¶

Теперь рассмотрим несколько примеров аугментаций текста. С полным списком можно ознакомиться здесь: [git] библиотеки.

In [58]:
!pip install -q nlpaug
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.5/410.5 kB 19.0 MB/s eta 0:00:00
In [59]:
# Define input text
text = "Hello, future of AI for Science! How are you today?"
print(f"input text: {text}")
input text: Hello, future of AI for Science! How are you today?

Аугментация символов¶

Заменой на похоже выглядящие:

In [60]:
import nlpaug.augmenter.char as nac

augment = nac.OcrAug()
augmented_text = augment.augment(text)

print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original:
Hello, future of AI for Science! How are you today?
Augmented Texts:
['Hel1u, fotore of AI for 8cience! How are you today?']

С опечатками, которые учитывают расположение символов на клавиатуре:

In [61]:
augment = nac.KeyboardAug()
augmented_text = augment.augment(text)

print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original:
Hello, future of AI for Science! How are you today?
Augmented Texts:
['nel/o, fufuDe of AI for ZcienX2! How are you g8day?']

Аугментация слов¶

С орфографическими ошибками:

In [62]:
import nlpaug.augmenter.word as naw

augment = naw.SpellingAug()
augmented_text = augment.augment(text, n=3)

print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original:
Hello, future of AI for Science! How are you today?
Augmented Texts:
['Hello, futur og AI for Science! Hot are ypi today?', 'Hello, furtuer f AI for Science! Wow are you today?', 'Hello, future for AI fom Scince! Hou are you today?']

С использованием модели для предсказания новых слов в зависимости от контекста:

In [63]:
!pip install -q transformers
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.4/7.4 MB 84.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 32.9 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 108.0 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 66.9 MB/s eta 0:00:00
In [64]:
from IPython.display import clear_output


# model_type: word2vec, glove or fasttext
augment = naw.ContextualWordEmbsAug(model_path="bert-base-uncased", action="insert")
augmented_text = augment.augment(text)

clear_output()
print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original:
Hello, future of AI for Science! How are you today?
Augmented Texts:
['big hello, future of ai sci for all science! now how are you today?']

Аугментация предложений¶

Мы можем перевести текстовые данные на какой-либо язык, а затем перевести их обратно на язык оригинала. Это может помочь сгенерировать текстовые данные с разными словами, сохраняя при этом контекст текстовых данных.

In [65]:
!pip -q install sacremoses
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 880.6/880.6 kB 23.4 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
  Building wheel for sacremoses (setup.py) ... done
In [66]:
back_translation_aug = naw.BackTranslationAug(
    from_model_name="facebook/wmt19-en-de", to_model_name="facebook/wmt19-de-en"
)
augmented_text = back_translation_aug.augment(text)

clear_output()
print(f"Original:\n{text}")
print(f"Augmented Texts:\n{augmented_text}")
Original:
Hello, future of AI for Science! How are you today?
Augmented Texts:
['Hello, the future of AI for science! How are you doing today?']

Instance Crossover Augmentation — составление новых объектов класса из отдельных предложений того же класса. Например, есть два объекта одного класса "положительный отзыв":

  • очень удобное приложение. Мне понравилось им пользоваться
  • класс! Отличный интерфейс

Тогда можно составить новый объект того же класса из их частей:

  • очень удобное приложение! Отличный интерфейс

Важно: при любых аугментациях текста на уровне предложений есть шанс создать странные и нелогичные объекты, поэтому использовать их следует с особой осторожностью.

Дополнительные библиотеки для аугментации текста:

  • TextAugment
  • AugLy

Обзор методов аугментации текста с примерами

Transfer learning¶

Как обучить нейросеть на своих данных, когда их мало?

Для такой типовой задачи, как классификация изображений, можно воспользоваться одной из существующих архитектур (AlexNet, VGG, Inception, ResNet и т.д.) и просто обучить нейросеть на своих данных. Реализации таких сетей с помощью различных фреймворков уже существуют, так что на данном этапе можно использовать одну из них как черный ящик, не вникая в принцип её работы. Например, в PyTorch есть множество уже реализованных известных архитектур: torchvision.models.

Однако глубокие нейронные сети требуют больших объемов данных для успешного обучения. И, зачастую, в нашей частной задаче недостаточно данных для того, чтобы хорошо обучить нейросеть с нуля. Transfer learning решает эту проблему. По сути мы пытаемся использовать опыт, полученный нейронной сетью при обучении на некоторой задаче $T_1$, чтобы решать схожую задачу $T_2$.

К примеру, transfer learning можно использовать при решении задачи классификации изображений на небольшом наборе данных. Как уже ранее обсуждалось, при обработке изображений свёрточные нейронные сети в первых слоях "реагируют" на некие простые пространственные шаблоны (к примеру, углы), после чего комбинируют их в сложные осмысленные формы (к примеру, глаза или носы). Вся эта информация извлекается из изображения, на её основе создаются сложные представления данных, которые в результате классифицируются линейной моделью.

Идея заключается в том, что если изначально обучить модель на некоторой сложной и довольно общей задаче, то можно надеяться, что она (как минимум часть ее слоев), в общем случае, будет извлекать важную информацию из изображений, и полученные представления можно будет успешно использовать для классификации линейной моделью.

Таким образом, берем часть модели, которая, по нашему представлению, отвечает за выделение хороших признаков (часто — все слои, кроме последнего) — feature extractor. Присоединяем к этой части один или несколько дополнительных слоёв для решения уже новой задачи. И учим только эти слои. Cлои feature extractor не учим — они "заморожены".

Понятно, что не все фильтры модели будут использованы максимально эффективно — к примеру, если мы работаем с изображениями, связанными с едой, возможно, не все фильтры на скрытых слоях предобученной на ImageNet модели окажутся полезны для нашей задачи. Почему бы не попробовать не только обучить новый классификатор, но и дообучить некоторые промежуточные слои? При использовании этого подхода мы при обучении дополнительно "настраиваем" и промежуточные слои, называется он fine-tuning.

При fine-tuning используют меньший learning rate, чем при обучении нейросети с нуля: мы знаем, что по крайней мере часть весов нейросети выполняет свою задачу хорошо, и не хотим испортить это быстрыми изменениями.

Кроме этого, можно делать комбинации этих методов: сначала учить только последние добавленные нами слои сети, затем учить еще и самые близкие к ним, и после этого учить уже все веса нейросети вместе. То есть мы можем определить свою стратегию fine-tuning.

Иногда fine-tuning считается синонимом Transfer learning, в этом случае часть от предтренированной сети называют backbone ("позвоночник"), а добавленную часть — head ("голова").

Порядок действий при transfer learning¶

Последовательно рассмотрим шаги, необходимые для реализации подхода transfer learning.

Шаг 1. Получение предварительно обученной модели

Первым шагом является выбор предварительно обученной модели, которую мы хотели бы использовать в качестве основы для обучения. Основным предположением является то, что признаки, которые умеет выделять из данных предобученная модель, хорошо подойдут для решения нашей частной задачи. Поэтому эффект от Transfer learning будет тем лучше, чем более схожими будут домены в нашей задаче и в задаче, на которой предварительно обучалась модель.

Для задач обработки изображений очень часто используются модели, предобученные на ImageNet. Такой подход распространен, однако, если ваша задача связана, например, с обработкой снимков клеток под микроскопом, то модель, предобученная на более близком домене (тоже на снимках клеток, пусть и совсем других), может быть лучшим начальным решением.

Шаг 2. Заморозка предобученных слоев

Мы предполагаем, что первые слои модели уже хорошо натренированы выделять какие-то абстрактные признаки из данных. Поэтому мы не хотим "сломать" их, особенно если начнем на этих признаках обучать новые слои, которые инициализируются случайно: на первых шагах обучения ошибка будет большой и мы можем сильно изменить "хорошие" предобученные веса.

Поэтому требуется "заморозить" предобученные веса. На практике заморозка означает отключение подсчета градиентов. Таким образом при последующем обучении параметры с отключенным подсчетом градиентов не будут обновляться.

Шаг 3. Добавление новых обучаемых слоев

В отличие от начальных слоев, которые выделяют достаточно общие признаки из данных, более близкие к выходу слои предобученной модели сильно специфичны конкретно под ту задачу, на которую она обучалась. Для моделей, предобученных на ImageNet, последний слой заточен конкретно под предсказание 1000 классов из этого набора данных. Кроме этого, последние слои могут не подходить под новую задачу архитектурно: в новой задаче может быть меньше классов, 10 вместо 1000. Поэтому, требуется заменить последние один или несколько слоев предобученной модели на новые, подходящие под нашу задачу. При этом, естественно, веса в этих слоях будут инициализированы случайно. Именно эти слои мы и будем обучать на следующем шаге.

Шаг 4. Обучение новых слоев

Все, что нам теперь нужно — обучить новые слои на наших данных. При этом замороженные слои используются лишь как экстрактор высокоуровневых признаков. Обучение такой модели существенно ничем не отличается от обучения любой другой модели: используется обучающая и валидационная выборка, контролируется изменение функции потерь и функционала качества.

Шаг 5. Тонкая настройка модели (fine-tuning)

После того, как мы обучили новые слои модели, и они уже как-то решают задачу, мы можем разморозить ранее замороженные веса, чтобы тонко настроить их под нашу задачу, в надежде, что это позволит еще немного повысить качество.

Нужно быть осторожным на этом этапе, использовать learning rate на порядок или два меньший, чем при основном обучении, и одновременно с этим следить за возникновением переобучения. Переобучение при fine-tuning может возникать из-за того, что мы резко увеличиваем количество настраиваемых параметров модели, но при этом наш датасет остается небольшим, и мощная модель может начать заучивать обучающие данные.

Практический пример transfer learning¶

Давайте рассмотрим пример практической реализации такого подхода (код переработан из этой статьи).

Загрузим датасет EuroSAT и удалим из него 90% файлов. EuroSAT — датасет для классификации спутниковых снимков по типам местности: лес, река, жилая застройка и т. п.

In [67]:
import random
import torch
import numpy as np


def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


set_random_seed(42)
In [68]:
import os
from random import sample

!wget -qN https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/EuroSAT.zip # http://madm.dfki.de/files/sentinel/EuroSAT.zip
!unzip -qn EuroSAT.zip

os.chdir("/content")
path = "/content/2750/"

for folder in os.listdir(path):
    files = os.listdir(path + folder)
    for file in sample(files, int(len(files) * 0.9)):
        os.remove(path + folder + "/" + file)

Определим аугментации. Для примера будем использовать родные аугментации из библиотеки Torchvision

In [69]:
from torchvision import transforms

# Applying Transforms to the Data
img_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize(size=224),  # as in ImageNet
            transforms.RandomRotation(degrees=15),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ]
    ),
    # No augmentations on valid data!
    "valid": transforms.Compose(
        [
            transforms.Resize(size=224),
            transforms.ToTensor(),
        ]
    ),
    # No augmentations on test data!
    "test": transforms.Compose(
        [
            transforms.Resize(size=224),
            transforms.ToTensor(),
        ]
    ),
}

Создадим datasets

In [70]:
from torchvision import datasets
from copy import deepcopy

dataset = datasets.ImageFolder(root=path)
# split to train/valid/test
train_set, valid_set, test_set = torch.utils.data.random_split(
    dataset, [int(len(dataset) * 0.8), int(len(dataset) * 0.1), int(len(dataset) * 0.1)]
)

train_set.dataset = deepcopy(dataset)
valid_set.dataset = deepcopy(dataset)
test_set.dataset = deepcopy(dataset)

# define augmentations
train_set.dataset.transform = img_transforms["train"]
valid_set.dataset.transform = img_transforms["valid"]
test_set.dataset.transform = img_transforms["test"]

print(f"Train size: {len(train_set)}")
print(f"Valid size: {len(valid_set)}")
print(f"Test size: {len(test_set)}")
Train size: 2160
Valid size: 270
Test size: 270
In [71]:
from torch.utils.data import DataLoader

# Batch size
batch_size = 64

# Number of classes
num_classes = len(dataset.classes)

# Get a mapping of the indices to the class names, in order to see the output classes of the test images.
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}

# Size of Data, to be used for calculating Average Loss and Accuracy
train_data_size, valid_data_size = len(train_set), len(valid_set)

# Create iterators for the Data loaded using DataLoader module
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
print("indexes to class: ")
idx_to_class
indexes to class: 
Out[71]:
{0: 'AnnualCrop',
 1: 'Forest',
 2: 'HerbaceousVegetation',
 3: 'Highway',
 4: 'Industrial',
 5: 'Pasture',
 6: 'PermanentCrop',
 7: 'Residential',
 8: 'River',
 9: 'SeaLake'}

В наборе данных не так уж и много изображений. При обучении с нуля нейросеть скорее всего не достигнет высокой точности.

Обучение готовой архитектуры "с нуля"¶

Загрузим MobileNet v2 без весов и попробуем обучить "с нуля", то есть с весов, инициализированных случайно.

In [72]:
from torchvision import models

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.mobilenet_v2(weights=None)
print(model)
MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (4): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (7): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (8): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (9): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (10): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (11): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (12): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (13): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (14): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (15): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (16): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (17): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (18): Conv2dNormActivation(
      (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
  )
  (classifier): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=1280, out_features=1000, bias=True)
  )
)

Последний слой MobileNet дает на выходе предсказания для 1000 классов, а в нашем датасете классов всего 10. Поэтому мы должны изменить выход сети так, чтобы он выдавал 10 предсказаний. Поэтому мы заменяем последний слой модели MobileNet слоем с num_classes нейронами, равным числу классов в нашем датасете.

То есть мы "сказали" нашей модели распознавать не 1000, а только num_classes классов.

In [73]:
# Change the final layer of MobileNet Model for Transfer Learning
import torch.nn as nn

# change out classes, from 1000 to 10
model.classifier[1] = nn.Linear(1280, num_classes)
print(model.classifier)
Sequential(
  (0): Dropout(p=0.2, inplace=False)
  (1): Linear(in_features=1280, out_features=10, bias=True)
)

Затем мы определяем функцию потерь и оптимизатор, которые будут использоваться для обучения.

In [74]:
import torch.optim as optim

# Define Optimizer and Loss Function
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
print(optimizer)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0003
    maximize: False
    weight_decay: 0
)

Для тренировки и валидации нашей модели напишем отдельную функцию.

In [75]:
import time


def train_and_validate(model, criterion, optimizer, num_epochs=25, save_state=False):
    """
    Function to train and validate
    Parameters
        :param model: Model to train and validate
        :param criterion: Loss Criterion to minimize
        :param optimizer: Optimizer for computing gradients
        :param epochs: Number of epochs (default=25)

    Returns
        model: Trained Model with best validation accuracy
        history: (dict object): Having training loss, accuracy and validation loss, accuracy
    """

    start = time.time()
    history = []
    best_acc = 0.0

    for epoch in range(num_epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch + 1, num_epochs))

        # Set to training mode
        model.train()

        # Loss and Accuracy within the epoch
        train_loss = 0.0
        train_acc = 0.0

        valid_loss = 0.0
        valid_acc = 0.0

        train_correct = 0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()  # Clean existing gradients
            outputs = model(
                inputs
            )  # Forward pass - compute outputs on input data using the model
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backpropagate the gradients
            optimizer.step()  # Update the parameters

            # Compute the total loss for the batch and add it to train_loss
            train_loss += loss.item() * inputs.size(0)
            # Compute correct predictions
            train_correct += (torch.argmax(outputs, dim=-1) == labels).float().sum()

        # Compute the mean train accuracy
        train_accuracy = 100 * train_correct / (len(train_loader) * batch_size)

        val_correct = 0
        # Validation - No gradient tracking needed
        with torch.no_grad():
            model.eval()  # Set to evaluation mode

            # Validation loop
            for j, (inputs, labels) in enumerate(valid_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(
                    inputs
                )  # Forward pass - compute outputs on input data using the model
                loss = criterion(outputs, labels)  # Compute loss
                valid_loss += loss.item() * inputs.size(
                    0
                )  # Compute the total loss for the batch and add it to valid_loss

                val_correct += (torch.argmax(outputs, dim=-1) == labels).float().sum()

        # Compute mean val accuracy
        val_accuracy = 100 * val_correct / (len(valid_loader) * batch_size)

        # Find average training loss and training accuracy
        avg_train_loss = train_loss / (len(train_loader) * batch_size)

        # Find average training loss and training accuracy
        avg_valid_loss = valid_loss / (len(valid_loader) * batch_size)

        history.append(
            [
                avg_train_loss,
                avg_valid_loss,
                train_accuracy.detach().cpu(),
                val_accuracy.detach().cpu(),
            ]
        )

        epoch_end = time.time()

        print(
            "Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(
                epoch + 1,
                avg_train_loss,
                train_accuracy.detach().cpu(),
                avg_valid_loss,
                val_accuracy.detach().cpu(),
                epoch_end - epoch_start,
            )
        )
        # Saving state for fine_tuning (because we may overfit)
        if save_state:
            os.makedirs("check_points", exist_ok=True)
            torch.save(model.state_dict(), f"check_points/fine_tuning_{epoch + 1}.pth")

    return model, history

Теперь обучим нашу модель:

In [76]:
num_epochs = 20
trained_model, history = train_and_validate(
    model.to(device), criterion, optimizer, num_epochs
)

torch.save(history, "history_fresh.pt")
Epoch: 1/20
Epoch : 001, Training: Loss: 1.6833, Accuracy: 39.1544%, 
		Validation : Loss : 2.4529, Accuracy: 8.7500%, Time: 20.9827s
Epoch: 2/20
Epoch : 002, Training: Loss: 1.1396, Accuracy: 57.6287%, 
		Validation : Loss : 3.7425, Accuracy: 8.7500%, Time: 12.9362s
Epoch: 3/20
Epoch : 003, Training: Loss: 1.0403, Accuracy: 62.0404%, 
		Validation : Loss : 0.9845, Accuracy: 46.8750%, Time: 12.9204s
Epoch: 4/20
Epoch : 004, Training: Loss: 0.9564, Accuracy: 64.5680%, 
		Validation : Loss : 0.9148, Accuracy: 50.0000%, Time: 12.9076s
Epoch: 5/20
Epoch : 005, Training: Loss: 0.8677, Accuracy: 68.4743%, 
		Validation : Loss : 0.9954, Accuracy: 48.4375%, Time: 13.1667s
Epoch: 6/20
Epoch : 006, Training: Loss: 0.7869, Accuracy: 70.8640%, 
		Validation : Loss : 0.7944, Accuracy: 56.2500%, Time: 13.1206s
Epoch: 7/20
Epoch : 007, Training: Loss: 0.7766, Accuracy: 71.0478%, 
		Validation : Loss : 0.8605, Accuracy: 55.0000%, Time: 13.2928s
Epoch: 8/20
Epoch : 008, Training: Loss: 0.7428, Accuracy: 72.8401%, 
		Validation : Loss : 0.6109, Accuracy: 63.1250%, Time: 13.5449s
Epoch: 9/20
Epoch : 009, Training: Loss: 0.6979, Accuracy: 73.9890%, 
		Validation : Loss : 0.7091, Accuracy: 58.1250%, Time: 13.0479s
Epoch: 10/20
Epoch : 010, Training: Loss: 0.7157, Accuracy: 73.7132%, 
		Validation : Loss : 0.6613, Accuracy: 58.7500%, Time: 13.0548s
Epoch: 11/20
Epoch : 011, Training: Loss: 0.7081, Accuracy: 74.6324%, 
		Validation : Loss : 0.6852, Accuracy: 58.1250%, Time: 13.1791s
Epoch: 12/20
Epoch : 012, Training: Loss: 0.6620, Accuracy: 76.0110%, 
		Validation : Loss : 0.6173, Accuracy: 62.1875%, Time: 13.1287s
Epoch: 13/20
Epoch : 013, Training: Loss: 0.6407, Accuracy: 75.6893%, 
		Validation : Loss : 0.5484, Accuracy: 65.0000%, Time: 13.2165s
Epoch: 14/20
Epoch : 014, Training: Loss: 0.6198, Accuracy: 77.3897%, 
		Validation : Loss : 0.5837, Accuracy: 62.8125%, Time: 13.1706s
Epoch: 15/20
Epoch : 015, Training: Loss: 0.5599, Accuracy: 78.9522%, 
		Validation : Loss : 0.5946, Accuracy: 62.5000%, Time: 13.2250s
Epoch: 16/20
Epoch : 016, Training: Loss: 0.5354, Accuracy: 80.3768%, 
		Validation : Loss : 0.5517, Accuracy: 63.7500%, Time: 13.2174s
Epoch: 17/20
Epoch : 017, Training: Loss: 0.5385, Accuracy: 80.1011%, 
		Validation : Loss : 0.5519, Accuracy: 65.3125%, Time: 13.3673s
Epoch: 18/20
Epoch : 018, Training: Loss: 0.5141, Accuracy: 80.5607%, 
		Validation : Loss : 0.5671, Accuracy: 65.3125%, Time: 13.4340s
Epoch: 19/20
Epoch : 019, Training: Loss: 0.5267, Accuracy: 80.1011%, 
		Validation : Loss : 0.6413, Accuracy: 61.8750%, Time: 13.2895s
Epoch: 20/20
Epoch : 020, Training: Loss: 0.4901, Accuracy: 82.3070%, 
		Validation : Loss : 0.6876, Accuracy: 60.3125%, Time: 13.3084s

Посмотрим на графики:

In [77]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
fig.suptitle("Fresh learning", fontsize=14)

history = np.array(history)
ax[0].plot(history[:, :2])
ax[0].legend(["Train Loss", "Val Loss"])
ax[1].plot(history[:, 2:])
ax[1].legend(["Train Accuracy", "Val Accuracy"])
ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()

Точность на валидационной выборке не превысила 66%. Посмотрим, сможем ли мы добиться большей точности при использовании предобученной модели.

Обучение готовой архитектуры с предобученными весами¶

Обучение классификационной "головы"¶

Теперь давайте попробуем использовать transfer learning.

Загрузим предобученную на ImageNet модель MobileNet v2:

In [78]:
del model
model = models.mobilenet_v2(weights="MobileNet_V2_Weights.DEFAULT")
Downloading: "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-7ebf99e0.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 116MB/s]

В данном случае мы не дообучаем скрытые слои нашей модели, поэтому отключаем подсчёт градиентов ("замораживаем" параметры).

In [79]:
# Freeze model parameters
for param in model.parameters():
    param.requires_grad = False

Нам снова нужно изменить выход сети так, чтобы он выдавал 10 классов вместо 1000.

Мы могли бы изменить количество выходов сети, просто подменив последний линейный слой, как в примере обучения с нуля:

model.classifier[1] = nn.Linear(1280, num_classes)

Но нужно понимать, что мы не ограничены архитектурой готовой сети, и можем как подменять слои, так и добавлять новые. Поэтому в целях демонстрации мы заменим выходной слой исходной сети на два слоя: первый мы добавим "подменой" модуля, а затем добавим активацию и новый выходной слой с num_classes выходами с помощью метода add_module() класса Sequential.

Когда мы подменяем или добавляем слои, по умолчанию подсчет градиентов на них будет включен, и таким образом мы добьемся, что учиться будут только новые слои, веса которых инициализируются случайно.

In [80]:
# Change the final layers of MobileNet Model for Transfer Learning

model.classifier[1] = nn.Linear(
    1280, 500
)  # replace last module to our custom, e.g. with 500 neurons
model.classifier.add_module("2", nn.ReLU())  # add activation
model.classifier.add_module(
    "3", nn.Linear(500, num_classes)
)  # add new output layer with 10  out classes

print(model.classifier)
Sequential(
  (0): Dropout(p=0.2, inplace=False)
  (1): Linear(in_features=1280, out_features=500, bias=True)
  (2): ReLU()
  (3): Linear(in_features=500, out_features=10, bias=True)
)
In [81]:
# Define Optimizer and Loss Function
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
In [82]:
num_epochs = 20
trained_model, history = train_and_validate(
    model.to(device), criterion, optimizer, num_epochs
)

torch.save(history, "history_transfer_learning.pt")
Epoch: 1/20
Epoch : 001, Training: Loss: 1.7921, Accuracy: 51.2868%, 
		Validation : Loss : 1.3201, Accuracy: 52.1875%, Time: 8.7978s
Epoch: 2/20
Epoch : 002, Training: Loss: 1.0062, Accuracy: 73.7592%, 
		Validation : Loss : 0.7506, Accuracy: 67.1875%, Time: 8.5419s
Epoch: 3/20
Epoch : 003, Training: Loss: 0.7081, Accuracy: 80.7445%, 
		Validation : Loss : 0.5708, Accuracy: 70.0000%, Time: 9.4897s
Epoch: 4/20
Epoch : 004, Training: Loss: 0.6006, Accuracy: 80.9743%, 
		Validation : Loss : 0.4894, Accuracy: 72.5000%, Time: 9.0147s
Epoch: 5/20
Epoch : 005, Training: Loss: 0.5416, Accuracy: 82.1691%, 
		Validation : Loss : 0.4478, Accuracy: 72.8125%, Time: 8.1678s
Epoch: 6/20
Epoch : 006, Training: Loss: 0.5156, Accuracy: 82.2610%, 
		Validation : Loss : 0.3888, Accuracy: 74.6875%, Time: 8.9877s
Epoch: 7/20
Epoch : 007, Training: Loss: 0.4730, Accuracy: 83.6397%, 
		Validation : Loss : 0.3852, Accuracy: 72.8125%, Time: 8.5624s
Epoch: 8/20
Epoch : 008, Training: Loss: 0.4319, Accuracy: 85.3401%, 
		Validation : Loss : 0.3676, Accuracy: 74.0625%, Time: 8.3820s
Epoch: 9/20
Epoch : 009, Training: Loss: 0.4202, Accuracy: 85.5239%, 
		Validation : Loss : 0.3469, Accuracy: 73.4375%, Time: 8.8643s
Epoch: 10/20
Epoch : 010, Training: Loss: 0.3945, Accuracy: 85.7996%, 
		Validation : Loss : 0.3489, Accuracy: 73.1250%, Time: 8.3387s
Epoch: 11/20
Epoch : 011, Training: Loss: 0.3992, Accuracy: 86.4430%, 
		Validation : Loss : 0.3235, Accuracy: 75.3125%, Time: 8.6160s
Epoch: 12/20
Epoch : 012, Training: Loss: 0.3893, Accuracy: 85.7077%, 
		Validation : Loss : 0.3286, Accuracy: 75.6250%, Time: 9.0667s
Epoch: 13/20
Epoch : 013, Training: Loss: 0.3859, Accuracy: 86.4890%, 
		Validation : Loss : 0.3055, Accuracy: 75.9375%, Time: 8.0769s
Epoch: 14/20
Epoch : 014, Training: Loss: 0.3672, Accuracy: 87.2243%, 
		Validation : Loss : 0.3039, Accuracy: 74.6875%, Time: 8.7466s
Epoch: 15/20
Epoch : 015, Training: Loss: 0.3547, Accuracy: 87.1783%, 
		Validation : Loss : 0.2932, Accuracy: 76.2500%, Time: 8.9641s
Epoch: 16/20
Epoch : 016, Training: Loss: 0.3558, Accuracy: 87.1783%, 
		Validation : Loss : 0.2980, Accuracy: 75.3125%, Time: 8.0300s
Epoch: 17/20
Epoch : 017, Training: Loss: 0.3423, Accuracy: 87.7298%, 
		Validation : Loss : 0.2877, Accuracy: 75.6250%, Time: 8.9482s
Epoch: 18/20
Epoch : 018, Training: Loss: 0.3026, Accuracy: 89.8438%, 
		Validation : Loss : 0.2859, Accuracy: 76.5625%, Time: 8.5148s
Epoch: 19/20
Epoch : 019, Training: Loss: 0.3376, Accuracy: 87.3162%, 
		Validation : Loss : 0.2783, Accuracy: 76.2500%, Time: 7.8372s
Epoch: 20/20
Epoch : 020, Training: Loss: 0.3159, Accuracy: 88.3732%, 
		Validation : Loss : 0.2793, Accuracy: 77.1875%, Time: 8.7078s
In [83]:
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
fig.suptitle("Transfer learning", fontsize=14)

history = np.array(history)
ax[0].plot(history[:, :2])
ax[0].legend(["Train Loss", "Val Loss"])
ax[1].plot(history[:, 2:])
ax[1].legend(["Train Accuracy", "Val Accuracy"])
ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()

Сравним между собой обучение с нуля и обучение с предобученными весами.

In [84]:
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
fig.suptitle("Fresh Learning (FL) vs Transfer Learning (TL)", fontsize=14)

history_fresh = np.array(torch.load("history_fresh.pt"))
history_transfer_learning = np.array(torch.load("history_transfer_learning.pt"))

ax[0].plot(history_fresh[:, :2], linestyle="--")
ax[0].set_prop_cycle("color", ["tab:blue", "tab:orange"])
ax[0].plot(history_transfer_learning[:, :2])
ax[0].legend(["Train Loss (FL)", "Val Loss (FL)", "Train Loss (TL)", "Val Loss (TL)"])

ax[1].plot(history_fresh[:, 2:], linestyle="--")
ax[1].set_prop_cycle("color", ["tab:blue", "tab:orange"])
ax[1].plot(history_transfer_learning[:, 2:])
ax[1].legend(
    [
        "Train Accuracy (FL)",
        "Val Accuracy (FL)",
        "Train Accuracy (TL)",
        "Val Accuracy (TL)",
    ]
)
ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()

При использовании предобученных весов процесс обучения идет более плавно и модель выдает бо́льшую точность. На валидационной выборке мы получили точность около 74%.

Дообучение всех слоев (Fine-tuning)¶

Посмотрим, сможем ли мы еще немного повысить точность путем тонкой донастройки всех весов сети.

Проведём процедуру fine-tuning. В предыдущем варианте с transfer learning обучался только последний слой, добавленный вручную. Давайте проверим это, выведя те слои, в которых включён градиент.

In [85]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)
features.0.0.weight False
features.0.1.weight False
features.0.1.bias False
features.1.conv.0.0.weight False
features.1.conv.0.1.weight False
features.1.conv.0.1.bias False
features.1.conv.1.weight False
features.1.conv.2.weight False
features.1.conv.2.bias False
features.2.conv.0.0.weight False
features.2.conv.0.1.weight False
features.2.conv.0.1.bias False
features.2.conv.1.0.weight False
features.2.conv.1.1.weight False
features.2.conv.1.1.bias False
features.2.conv.2.weight False
features.2.conv.3.weight False
features.2.conv.3.bias False
features.3.conv.0.0.weight False
features.3.conv.0.1.weight False
features.3.conv.0.1.bias False
features.3.conv.1.0.weight False
features.3.conv.1.1.weight False
features.3.conv.1.1.bias False
features.3.conv.2.weight False
features.3.conv.3.weight False
features.3.conv.3.bias False
features.4.conv.0.0.weight False
features.4.conv.0.1.weight False
features.4.conv.0.1.bias False
features.4.conv.1.0.weight False
features.4.conv.1.1.weight False
features.4.conv.1.1.bias False
features.4.conv.2.weight False
features.4.conv.3.weight False
features.4.conv.3.bias False
features.5.conv.0.0.weight False
features.5.conv.0.1.weight False
features.5.conv.0.1.bias False
features.5.conv.1.0.weight False
features.5.conv.1.1.weight False
features.5.conv.1.1.bias False
features.5.conv.2.weight False
features.5.conv.3.weight False
features.5.conv.3.bias False
features.6.conv.0.0.weight False
features.6.conv.0.1.weight False
features.6.conv.0.1.bias False
features.6.conv.1.0.weight False
features.6.conv.1.1.weight False
features.6.conv.1.1.bias False
features.6.conv.2.weight False
features.6.conv.3.weight False
features.6.conv.3.bias False
features.7.conv.0.0.weight False
features.7.conv.0.1.weight False
features.7.conv.0.1.bias False
features.7.conv.1.0.weight False
features.7.conv.1.1.weight False
features.7.conv.1.1.bias False
features.7.conv.2.weight False
features.7.conv.3.weight False
features.7.conv.3.bias False
features.8.conv.0.0.weight False
features.8.conv.0.1.weight False
features.8.conv.0.1.bias False
features.8.conv.1.0.weight False
features.8.conv.1.1.weight False
features.8.conv.1.1.bias False
features.8.conv.2.weight False
features.8.conv.3.weight False
features.8.conv.3.bias False
features.9.conv.0.0.weight False
features.9.conv.0.1.weight False
features.9.conv.0.1.bias False
features.9.conv.1.0.weight False
features.9.conv.1.1.weight False
features.9.conv.1.1.bias False
features.9.conv.2.weight False
features.9.conv.3.weight False
features.9.conv.3.bias False
features.10.conv.0.0.weight False
features.10.conv.0.1.weight False
features.10.conv.0.1.bias False
features.10.conv.1.0.weight False
features.10.conv.1.1.weight False
features.10.conv.1.1.bias False
features.10.conv.2.weight False
features.10.conv.3.weight False
features.10.conv.3.bias False
features.11.conv.0.0.weight False
features.11.conv.0.1.weight False
features.11.conv.0.1.bias False
features.11.conv.1.0.weight False
features.11.conv.1.1.weight False
features.11.conv.1.1.bias False
features.11.conv.2.weight False
features.11.conv.3.weight False
features.11.conv.3.bias False
features.12.conv.0.0.weight False
features.12.conv.0.1.weight False
features.12.conv.0.1.bias False
features.12.conv.1.0.weight False
features.12.conv.1.1.weight False
features.12.conv.1.1.bias False
features.12.conv.2.weight False
features.12.conv.3.weight False
features.12.conv.3.bias False
features.13.conv.0.0.weight False
features.13.conv.0.1.weight False
features.13.conv.0.1.bias False
features.13.conv.1.0.weight False
features.13.conv.1.1.weight False
features.13.conv.1.1.bias False
features.13.conv.2.weight False
features.13.conv.3.weight False
features.13.conv.3.bias False
features.14.conv.0.0.weight False
features.14.conv.0.1.weight False
features.14.conv.0.1.bias False
features.14.conv.1.0.weight False
features.14.conv.1.1.weight False
features.14.conv.1.1.bias False
features.14.conv.2.weight False
features.14.conv.3.weight False
features.14.conv.3.bias False
features.15.conv.0.0.weight False
features.15.conv.0.1.weight False
features.15.conv.0.1.bias False
features.15.conv.1.0.weight False
features.15.conv.1.1.weight False
features.15.conv.1.1.bias False
features.15.conv.2.weight False
features.15.conv.3.weight False
features.15.conv.3.bias False
features.16.conv.0.0.weight False
features.16.conv.0.1.weight False
features.16.conv.0.1.bias False
features.16.conv.1.0.weight False
features.16.conv.1.1.weight False
features.16.conv.1.1.bias False
features.16.conv.2.weight False
features.16.conv.3.weight False
features.16.conv.3.bias False
features.17.conv.0.0.weight False
features.17.conv.0.1.weight False
features.17.conv.0.1.bias False
features.17.conv.1.0.weight False
features.17.conv.1.1.weight False
features.17.conv.1.1.bias False
features.17.conv.2.weight False
features.17.conv.3.weight False
features.17.conv.3.bias False
features.18.0.weight False
features.18.1.weight False
features.18.1.bias False
classifier.1.weight True
classifier.1.bias True
classifier.3.weight True
classifier.3.bias True

Мы оставим дообученную голову нейронной сети и продолжим обучение всей сети с уменьшением темпа обучения.

Разморозим параметры. criterion остаётся тот же, в optimizer уменьшим параметр lr на порядок.

In [86]:
# Unfreeze model parameters
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=3e-5)

Пройдём дополнительные 20 эпох и построим графики.

In [87]:
num_epochs = 20
trained_model, history = train_and_validate(
    model.to(device), criterion, optimizer, num_epochs, save_state=True
)

torch.save(history, "history_finetuning.pt")
Epoch: 1/20
Epoch : 001, Training: Loss: 0.2955, Accuracy: 90.1195%, 
		Validation : Loss : 0.2437, Accuracy: 76.5625%, Time: 13.6207s
Epoch: 2/20
Epoch : 002, Training: Loss: 0.2402, Accuracy: 91.8199%, 
		Validation : Loss : 0.2150, Accuracy: 77.5000%, Time: 13.3512s
Epoch: 3/20
Epoch : 003, Training: Loss: 0.2102, Accuracy: 92.2794%, 
		Validation : Loss : 0.2052, Accuracy: 77.1875%, Time: 13.9947s
Epoch: 4/20
Epoch : 004, Training: Loss: 0.2057, Accuracy: 92.0496%, 
		Validation : Loss : 0.1850, Accuracy: 78.4375%, Time: 13.3778s
Epoch: 5/20
Epoch : 005, Training: Loss: 0.2118, Accuracy: 92.5092%, 
		Validation : Loss : 0.1663, Accuracy: 80.0000%, Time: 13.4142s
Epoch: 6/20
Epoch : 006, Training: Loss: 0.1611, Accuracy: 93.9338%, 
		Validation : Loss : 0.1563, Accuracy: 79.3750%, Time: 13.5197s
Epoch: 7/20
Epoch : 007, Training: Loss: 0.1707, Accuracy: 93.1526%, 
		Validation : Loss : 0.1380, Accuracy: 80.0000%, Time: 13.3344s
Epoch: 8/20
Epoch : 008, Training: Loss: 0.1602, Accuracy: 93.8419%, 
		Validation : Loss : 0.1335, Accuracy: 80.6250%, Time: 13.3374s
Epoch: 9/20
Epoch : 009, Training: Loss: 0.1428, Accuracy: 94.2096%, 
		Validation : Loss : 0.1271, Accuracy: 80.6250%, Time: 13.4488s
Epoch: 10/20
Epoch : 010, Training: Loss: 0.1313, Accuracy: 94.7151%, 
		Validation : Loss : 0.1210, Accuracy: 80.9375%, Time: 13.4156s
Epoch: 11/20
Epoch : 011, Training: Loss: 0.1328, Accuracy: 94.5772%, 
		Validation : Loss : 0.1122, Accuracy: 81.2500%, Time: 13.4008s
Epoch: 12/20
Epoch : 012, Training: Loss: 0.1143, Accuracy: 95.4504%, 
		Validation : Loss : 0.1081, Accuracy: 80.9375%, Time: 13.4811s
Epoch: 13/20
Epoch : 013, Training: Loss: 0.1203, Accuracy: 95.0368%, 
		Validation : Loss : 0.1093, Accuracy: 80.3125%, Time: 13.3948s
Epoch: 14/20
Epoch : 014, Training: Loss: 0.1087, Accuracy: 95.4963%, 
		Validation : Loss : 0.1054, Accuracy: 81.2500%, Time: 13.3669s
Epoch: 15/20
Epoch : 015, Training: Loss: 0.0959, Accuracy: 95.5882%, 
		Validation : Loss : 0.1016, Accuracy: 80.6250%, Time: 13.4678s
Epoch: 16/20
Epoch : 016, Training: Loss: 0.0965, Accuracy: 96.2776%, 
		Validation : Loss : 0.0960, Accuracy: 81.2500%, Time: 13.4115s
Epoch: 17/20
Epoch : 017, Training: Loss: 0.1007, Accuracy: 95.5882%, 
		Validation : Loss : 0.0966, Accuracy: 81.5625%, Time: 14.0048s
Epoch: 18/20
Epoch : 018, Training: Loss: 0.0932, Accuracy: 95.8640%, 
		Validation : Loss : 0.0927, Accuracy: 81.2500%, Time: 13.2979s
Epoch: 19/20
Epoch : 019, Training: Loss: 0.0806, Accuracy: 96.8290%, 
		Validation : Loss : 0.0928, Accuracy: 81.5625%, Time: 13.4152s
Epoch: 20/20
Epoch : 020, Training: Loss: 0.0736, Accuracy: 96.6452%, 
		Validation : Loss : 0.0879, Accuracy: 82.1875%, Time: 13.2985s
In [88]:
fig, ax = plt.subplots(ncols=2, figsize=(16, 5))
fig.suptitle("Transfer Learning (TL) AND Finetuning (FT)", fontsize=14)

history_transfer_learning = np.array(torch.load("history_transfer_learning.pt"))
history_finetuning = np.array(torch.load("history_finetuning.pt"))

train_val_loss = np.concatenate(
    (history_transfer_learning[:, :2], history_finetuning[:, :2]), axis=0
)
ax[0].plot(train_val_loss)
ax[0].vlines(19, -0.1, 2.1, color="tab:green", linewidth=2, linestyle="--")
ax[0].legend(["Train Loss", "Val Loss", "TL/FT boundary"])

train_val_acc = np.concatenate(
    (history_transfer_learning[:, 2:], history_finetuning[:, 2:]), axis=0
)

ax[1].plot(train_val_acc)
ax[1].vlines(19, -5, 105, color="tab:green", linewidth=2, linestyle="--")
ax[1].legend(["Train Accuracy", "Val Accuracy", "TL/FT boundary"])

ax[0].set_xlabel("Epoch Number")
ax[1].set_xlabel("Epoch Number")
ax[0].set_ylabel("Loss")
ax[1].set_ylabel("Accuracy")
plt.savefig("loss_curve.png")
ax[0].grid()
ax[1].grid()
plt.show()

Есть ли эффект от fine-tuning? После дообучения ещё на 20 эпохах мы наблюдаем следующие эффекты:

  • Loss дополнительно снизился, хотя до fine-tuning он стремился к выходу на плато
  • точность на валидации превысила 80%, то есть мы получили дополнительно около 6% точности.

При fine-tuning модель может быть склонна к переобучению, так как мы обучаем сложную модель с большим числом параметров на небольшом количестве данных. Поэтому мы используем learning rate на порядок меньший, чем при обычном обучении. Для контроля переобучения следует следить за метриками и ошибкой на валидационной выборке.

Лучшее качество на валидационных данных мы получили на 38 эпохе. При fine-tuning мы сохраняли состояния нейросети на каждой эпохе. Возьмём состояние с 38 эпохи как наиболее оптимальное.

In [89]:
trained_model.load_state_dict(
    torch.load("check_points/fine_tuning_18.pth")
)  # 38 = 20 (TL) + 18 (FT)
trained_model.eval();

Посмотрим на предсказания

In [90]:
def predict(model, test_img_name, device):
    """
    Function to predict the class of a single test image
    Parameters
        :param model: Model to test
        :param test_img_name: Test image

    """

    transform = img_transforms["test"]
    test_img = torch.tensor(np.asarray(test_img_name))
    test_img = transforms.ToPILImage()(test_img)
    plt.imshow(test_img)

    test_img_tensor = test_img_name.unsqueeze(0).to(device)

    with torch.no_grad():
        model.eval()
        # Model outputs is logits
        out = model(test_img_tensor).to(device)
        probs = torch.softmax(out, dim=1).to(device)
        topk, topclass = probs.topk(3, dim=1)
        for i in range(3):
            print(
                "Predcition",
                i + 1,
                ":",
                idx_to_class[topclass.cpu().numpy()[0][i]],
                ", Score: ",
                round(topk.cpu().numpy()[0][i], 2),
            )
In [91]:
print("Shoud be %s\n" % idx_to_class[0])
predict(
    trained_model.to(device),
    test_set[np.where([x[1] == 0 for x in test_set])[0][0]][0],
    device,
)
Shoud be AnnualCrop

Predcition 1 : AnnualCrop , Score:  1.0
Predcition 2 : PermanentCrop , Score:  0.0
Predcition 3 : SeaLake , Score:  0.0
In [92]:
print("Shoud be %s\n" % idx_to_class[6])
predict(
    trained_model,
    test_set[np.where([x[1] == 6 for x in test_set])[0][0]][0],
    device,
)
Shoud be PermanentCrop

Predcition 1 : PermanentCrop , Score:  0.58
Predcition 2 : AnnualCrop , Score:  0.42
Predcition 3 : SeaLake , Score:  0.0
In [93]:
print("Shoud be %s\n" % idx_to_class[8])
predict(
    trained_model,
    test_set[np.where([x[1] == 8 for x in test_set])[0][0]][0],
    device,
)
Shoud be River

Predcition 1 : River , Score:  1.0
Predcition 2 : Highway , Score:  0.0
Predcition 3 : AnnualCrop , Score:  0.0
  • Мы увидели, как использовать предварительно обученную модель на 1000 классов ImageNet для нашей задачи на 10 классов.

  • Мы сравнили качество обучения с нуля, transfer learning и fine-tuning и научились добиваться максимального качества с помощью этих принципов.

На практике не забывайте о характерной опасности fine-tuning — переобучении. Используйте низкий learning rate и отслеживайте Loss и показатели качества — возможно, вам будет достаточно небольшого количества эпох.

Metric learning¶

Существуют задачи, где не представляется возможным разбить данные на классы так, чтобы в каждом классе было достаточно много объектов.

Рассмотрим, например, задачу распознавания лиц.

На вход системе подается фото лица человека. Требуется сопоставить его с другим изображением или изображениями, например, хранящимися в БД, и таким образом идентифицировать человека на фотографии.

На первый взгляд кажется, что это задача классификации.

Все изображения одного человека будем считать относящимися к одному классу, и модель будет этот класс предсказывать.

Для небольшой организации, в которой всего несколько десятков сотрудников такой подход может сработать. При этом возникнут проблемы:

  1. Чтобы обучить такую ​​систему, нам сначала потребуется много (сотни) разных изображений каждого сотрудника.

  2. Когда человек присоединяется к организации или покидает ее, приходится менять структуру модели и обучать ее заново.

Это практически невозможно для крупных организаций, где набор и увольнение происходит почти каждую неделю. И в принципе невозможно для города масштаба Москвы или Лондона, в котором миллионы жителей и сотни тысяч приезжих.

Формирование векторов-признаков (embedding)¶

Поэтому используется другой подход. Вместо того, чтобы классифицировать изображения, модель учится выделять ключевые признаки и на их основе строить компактный вектор, достаточно точно описывающий лицо.

В англоязычной литературе такие вектора признаков называются embedding, и мы тоже будем использовать это обозначение.

Может возникнуть вопрос: не потеряем ли мы важную информацию, сжав изображение в несколько сотен чисел?

Чтобы ответить на него, вспомним, как работает фоторобот.

Для получения фотореалистичного изображения лица достаточно нескольких ключевых признаков: глаза, волосы, рот, нос... Каждый из них кодируется максимум несколькими сотнями целых значений.

Значит, вектора-признака из 128 вещественных чисел будет более чем достаточно. Правда интерпретировать значения, которые закодирует в него нейросеть, будет не столь просто.

Если нам удастся обучить модель кодировать в embedding признаки, важные для сравнения, то мы сможем сравнивать векторы между собой.

Если расстояние между векторами для лиц, которые похожи друг на друга, будут маленькими, а у непохожих, наоборот, большими, то мы сможем экспериментально подобрать порог $d$ и, сравнивая с ним расстояние между двумя векторами, принимать решение: принадлежат ли они одному человеку или нет.

Можно оценивать не расстояние, а степень схожести similarity. В этом случае неравенства поменяют знак, но логика останется прежней

Теперь, чтобы идентифицировать человека, требуется только одно изображение его лица. Эмбеддинг этого изображения можно сравнить с эмбеддингами других лиц из БД, используя k-NN или иной метод кластеризации.

Такая модель не учится классифицировать изображение напрямую по какому-либо из выходных классов. Она учится выделять признаки, важные при сравнении.

Такой подход решает обе проблемы, о которых мы говорили выше:

  • Для обучения такой сети нам не требуется много экземпляров объектов одного класса, а достаточно лишь нескольких.
  • Но самое большое преимущество в простоте ее обучения в случае появления новых объектов.

Сиамская сеть (Siamese Network)¶

Какая архитектура должна быть у модели, генерирующей векторы признаков?

Можно было бы использовать обычную сеть, обученную для задачи классификации, и затем удалить из нее один или несколько последних слоёв.

Активации последнего слоя представляют собой отклики на некие высокоуровневые признаки, потенциально важные для классификации, и их можно интерпретировать как embedding.

In [94]:
from torchvision.models import alexnet
import torch

face1 = torch.randn((3, 224, 224))
face2 = torch.randn((3, 224, 224))

model = alexnet(weights="AlexNet_Weights.DEFAULT")
# remove classification layer
model.fc = model.classifier[6] = torch.nn.Identity()

# get embeddings
embedding1 = model(face1.unsqueeze(0))
embedding2 = model(face2.unsqueeze(0))

diff = torch.nn.functional.pairwise_distance(embedding1, embedding2)
print("L2 distance: ", diff.item())
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 186MB/s]
L2 distance:  30.66541862487793

Такой подход будет работать. Однако можно заметно улучшить точность, используя функцию потерь, которая оценивает именно качество сравнения, а не классификации.

Рассмотрим подход, основанный на методологии, описанной в статье Siamese Neural Networks for One-shot Image Recognition (Koch et al., 2015).

Используются две копии одной и той же сети, отсюда и название Siamese Networks.

Два входных изображения ($x_1$ и $x_2$) проходят через одну и ту же сверточную сеть, на выходе для каждого изображения генерируется вектор признаков фиксированной длины $h_1$ и $h_2$.

Модель обучается генерировать близкие вектора для изображений одного объекта и далекие для разных.

Оценивая расстояние между двумя векторами признаков, которое будет малым для одних и тех же объектов и большим для различных, мы сможем оценить их сходство.

Это центральная идея сиамских сетей.

Triplet Loss¶

Какую функцию потерь использовать для обучения такой сети?

Очевидно, loss function должна будет учитывать не один выход, а как минимум два.

Популярной на сегодняшний день является Triplet loss, которой требуется три embedding вместо двух.

Чтобы сгенерировать три эмбеддинга, модель должна получать на вход три изображения.

Первые два должны относиться к одному и тому же объекту (человеку), а третье — к другому.

Таким образом, триплет состоит из опорного ("якорного" anchor), положительного (positive) и отрицательного (negative) образцов.

Описание в статье FaceNet: A Unified Embedding for Face Recognition and Clustering

Сама функция потерь будет выглядеть следующим образом:

$$TripletLoss = \sum_{1}^{N} L_i(x_i^{a},x_i^{p},x_i^{n})$$$$L_i(x_i^{a},x_i^{p},x_i^{n})=max(0,\left\| f(x_i^{a}) -f(x_i^{p}) \right\|_2^{2} - \left\| f(x_i^{a}) -f(x_i^{n}) \right\|_2^{2} + margin)$$

Где:

$x_i^{a}$ — базовое изображение (anchor),

$x_i^{p}$ — изображение того же объекта (positive),

$x_i^{n}$ — изображение другого объекта (negative),

$f(x)$ — нормированный выход модели (embedding) для входа $x$,

$\left\| x \right\|_2$ — это L2 (Euclidean norm), соответственно $\left\| a \right\|_2^{2}$ — это L2 в квадрате,

$margin$ — это константа или минимальный "зазор", на который расстояние до эмбеддинга негативного объекта обязано превосходить расстояние до позитивного (идея такая же, что в SVM Loss) .

$$\large L = max(d(a,p)-d(a,n)+margin,0)$$

В ходе обучения с Triplet Loss расстояние между эмбеддингами опорного и позитивного объектов уменьшается, а между эмбеддингами опорного и отрицательного — увеличивается.

Важным дополнением является то, что embedding-и нормируются. В результате нормировки каждый вектор-признак будет иметь единичную длину.

Теперь мы можем рассматривать embedding-и как точки на $n$-мерной сфере с радиусом $1$.

Это удобно, так как все расстояния между embedding будут лежать в интервале $[0 \dots 2]$, и нам будет проще подобрать порог для сравнения.

Кроме того, можно использовать другие меры расстояния, например, косинусное расстояние, которое определяется углом между векторами, лежит в интервале $[-1 \dots 1]$ и соответствует расстоянию между точками на поверхности сферы.

В статье авторы минимизируют Евклидово расстояние, но подход будет работать и для других метрик сходства, например, косинусного расстояния.

В PyTorch есть две реализации TripletLoss

TripletMarginLoss — минимизирует $L_p$ норму

In [95]:
from torch import nn

triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
loss = triplet_loss(anchor, positive, negative)
print(loss)
tensor(1.0878, grad_fn=<MeanBackward0>)

TripletMarginWithDistanceLoss — позволяет задать произвольную функцию расстояния.

In [96]:
import torch.nn.functional as F

triplet_loss = nn.TripletMarginWithDistanceLoss(
    margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
)
loss = triplet_loss(anchor, positive, negative)
print(loss)
tensor(0.9909, grad_fn=<MeanBackward0>)

Другие функции потерь для сиамских сетей:

Исторически первой появилась Contrastive Loss, о ней подробнее в статье Dimensionality Reduction by Learning an Invariant Mapping (Hadsell et al., 2005)

В PyTorch есть реализация CosineEmbeddingLoss, она позволяет обучать модель на парах изображений, минимизировав косинусное расстояние между embedding.

Реализация сиамской сети¶

Загрузка данных¶

Загрузим небольшой фрагмент датасета с лицами. Внутри архива фото лиц сгруппированы по папкам

faces/
├── training/
|   ├── s1/
|   |   ├── 1.pgm
|   |   ├ ...
|   |   └── 9.pgm
|   ├ ... (excluding 5...7)
|   └── s40/
|       ├── 1.pgm
|       ├ ...
|       └── 9.pgm
└── testing/
    ├── s5/
    |   ├── 1.pgm
    |   ├ ...
    |   └── 9.pgm
    ├ ...
    └── s7/
        ├── 1.pgm
        ├ ...
        └── 9.pgm

В каждой папке фото лица одного и того же человека.

In [97]:
!wget -qN https://edunet.kea.su/repo/EduNet-web_dependencies/datasets/small_face_dataset.zip
!unzip -qn small_face_dataset.zip

Чтобы результаты воспроизводились, зафиксируем SEED

In [98]:
import numpy as np
import random


def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


set_random_seed(42)

Dataset for TripletLoss¶

Для TripletLoss потребуются три изображения: anchor, positive, negative, и метод get_item должен возвращать их нам. Первые два должны принадлежать одному человеку, а третье — другому.

In [99]:
from torch.utils.data import Dataset
from glob import glob
from PIL import Image


class SiameseNetworkDataset(Dataset):
    def __init__(self, dir=None, transform=None, splitter="/"):
        self.dir = dir
        self.splitter = splitter
        self.transform = transform
        self.files = glob(f"{self.dir}/**/*.pgm", recursive=True)
        self.data = self.build_index()

    def build_index(self):
        index = {}
        for f in self.files:
            id = self.path2id(f)
            if not id in index:
                index[id] = []
            index[id].append(f)
        return index

    def path2id(self, path):
        return path.replace(self.dir, "").split(self.splitter)[0]

    def __getitem__(self, index):
        anchor_path = self.files[index]
        positive_path = self.find_positive(anchor_path)
        negative_path = self.find_negative(anchor_path)

        # Loading the images
        anchor = Image.open(anchor_path)
        positive = Image.open(positive_path)
        negative = Image.open(negative_path)

        if self.transform is not None:  # Apply image transformations
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative

    def find_positive(self, path):
        id = self.path2id(path)
        all_exept_my = self.data[id].copy()
        all_exept_my.remove(path)
        return random.choice(all_exept_my)

    def find_negative(self, path):
        all_exept_my_ids = list(self.data.keys())
        id = self.path2id(path)
        all_exept_my_ids.remove(id)
        selected_id = random.choice(all_exept_my_ids)
        return random.choice(self.data[selected_id])

    def __len__(self):
        return len(self.files)

Выведем несколько изображений, чтобы убедиться, что класс датасета функционирует должным образом.

In [100]:
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

# Create dataset instance
siamese_dataset = SiameseNetworkDataset(
    "faces/training/",
    transform=transforms.Compose(
        [
            transforms.Resize((105, 105)),
            transforms.ToTensor(),
        ]
    ),
)

# Create dataloader & extract batch of data from it
vis_dataloader = DataLoader(siamese_dataset, batch_size=8, shuffle=True)
dataiter = iter(vis_dataloader)
example_batch = next(dataiter)  # anc, pos, neg

# Show batch contents
concatenated = torch.cat((example_batch[0], example_batch[1], example_batch[2]), 0)
grid = torchvision.utils.make_grid(concatenated)

plt.axis("off")
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.gcf().set_size_inches(20, 60)
plt.show()

В каждом столбце тройка изображений. Первое и второе принадлежат одному человеку, третье — другому.

Создание модели¶

Нас устроит любая модель для работы с изображениями. Например, ResNet18.

Все, что от нас требуется, это:

  • заменить последний слой
  • отправлять на анализ три изображения вместо одного. Соответственно на выходе тоже будут три вектора признаков (embedding)

Пожалуй, единственный вопрос — это размерность последнего слоя. В промышленных системах распознавания лиц, которые тренируются на датасетах из миллионов изображений, используются embedding размерностью от 128 до 512.

Для демонстрационной задачи нам должно хватить 32 значений. Количество выходов последнего линейного слоя установим равным 32.

In [101]:
from torchvision.models import resnet18


class SiameseNet(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.model = resnet18(weights=None)
        # Because we use grayscale images reduce input channel count to one
        self.model.conv1 = nn.Conv2d(
            1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
        )
        # Replace ImageNet 1000 class classifier with 64- out linear layer
        self.model.fc = nn.Linear(self.model.fc.in_features, latent_dim)

    def _forward(self, x):
        out = self.model(x)
        # normalize embedding to unit vector
        out = torch.nn.functional.normalize(out)
        return out

    def forward(self, anchor, positive, negative):
        output1 = self._forward(anchor)
        output2 = self._forward(positive)
        output3 = self._forward(negative)

        return output1, output2, output3

Dataloaders¶

Загрузчики данных не отличаются от загрузчиков для обычной сети. Единственное отличие — это добавление аугментации в виде случайного отражения по вертикали к обучающим данным.

In [102]:
# Apply augmentations on train data
img_trans_train = transforms.Compose(
    [
        transforms.Resize((105, 105)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

img_trans_test = transforms.Compose(
    [transforms.Resize((105, 105)), transforms.ToTensor()]
)

train_dataset = SiameseNetworkDataset("faces/training/", transform=img_trans_train)
val_dataset = SiameseNetworkDataset("faces/testing/", transform=img_trans_test)

batch_size = 300
train_loader = DataLoader(
    train_dataset, num_workers=2, batch_size=batch_size, shuffle=True
)
val_loader = DataLoader(val_dataset, num_workers=2, batch_size=1, shuffle=False)

Обучение¶

Отличие от сетей для классификации в том, что у модели 3 выхода, и все их надо передать в loss. При этом нет меток в явном виде. Определить, какой embedding относится к позитивному образцу, а какой — к негативному, можно только порядком их следования.

In [103]:
def train(num_epochs, model, criterion, optimizer, train_loader):
    loss_history = []
    model.train()
    for epoch in range(0, num_epochs):
        train_loss = 0
        for i, batch in enumerate(train_loader, 0):
            anc, pos, neg = batch
            output_anc, output_pos, output_neg = model(
                anc.to(device), pos.to(device), neg.to(device)
            )
            loss = criterion(output_anc, output_pos, output_neg)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            train_loss += loss.detach().cpu().item()

        loss_history.append(train_loss / len(train_loader))
        last_epoch_loss = torch.tensor(loss_history[-1])
        print("Epoch {} with {:.4f} loss".format(epoch, last_epoch_loss))

    return loss_history, last_epoch_loss

В качестве функции расстояния в TripletLoss возьмем косинусное расстояние (величина, обратная к косинусной близости).

In [104]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
latent_dim = 32
model = SiameseNet(latent_dim).to(device)
criterion = nn.TripletMarginWithDistanceLoss(
    margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
)

optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 9
loss_history, _ = train(num_epochs, model, criterion, optimizer, train_loader)
Epoch 0 with 0.8090 loss
Epoch 1 with 0.4584 loss
Epoch 2 with 0.4083 loss
Epoch 3 with 0.3571 loss
Epoch 4 with 0.3335 loss
Epoch 5 with 0.3602 loss
Epoch 6 with 0.3168 loss
Epoch 7 with 0.2326 loss
Epoch 8 with 0.2895 loss

Выведем график loss

In [105]:
plt.plot(range(1, len(loss_history) + 1), loss_history)
plt.ylabel("loss")
plt.xlabel("num of epochs")
plt.grid()
plt.show()

Проверка¶

Для начала выведем тройки изображений из проверочного датасета и посмотрим на косинусную близость (схожесть) для позитивных и негативных пар. Если модель обучилась, схожесть для позитивных пар будет больше, чем для негативных.

In [106]:
# Helper method for visualization
def show(img, text=None):
    img_np = img.numpy()
    plt.axis("off")
    plt.text(75, 120, text, fontweight="bold")
    plt.imshow(np.transpose(img_np, (1, 2, 0)))  # [CxHxW] -> [HxWxC] for imshow
    plt.show()


def plot_imgs(model, test_loader):
    similarity_pos = []
    similarity_neg = []
    model.eval()
    with torch.inference_mode():
        for i, batch in enumerate(test_loader, 0):
            anc, pos, neg = batch
            output_anc, output_pos, output_neg = model(
                anc.to(device), pos.to(device), neg.to(device)
            )
            # compute euc. distance
            sim_pos = F.cosine_similarity(output_anc, output_pos).item()
            sim_neg = F.cosine_similarity(output_anc, output_neg).item()

            similarity_pos.append(sim_pos)
            similarity_neg.append(sim_neg)

            if not i % 5:
                concatenated = torch.cat((anc, pos, neg))
                result = "OK" if sim_neg < sim_pos else "BAD"
                show(
                    torchvision.utils.make_grid(concatenated),
                    f"Positive / negative similarities: {sim_pos:.3f} / {sim_neg:.3f} - {result}",
                )

    return similarity_pos, similarity_neg


set_random_seed(42)
similarity_pos, similarity_neg = plot_imgs(model, val_loader)

Но такая оценка субъективна, давайте посмотрим на распределение схожестей по категориям:

In [107]:
import seaborn as sns

similarities = {"The same person": similarity_pos, "Another person": similarity_neg}

ax = sns.histplot(similarities, bins=20)
ax.set(xlabel="Pairwise similarity")
plt.show()

Видно, что схожесть между двумя фото одного и того же человека в среднем больше, чем схожесть между фотографиями разных людей.

Если бы мы проектировали систему распознавания лиц, нужно было бы выбрать порог, чтобы сравнивать с ним схожесть и принимать решение о том, верифицировать фото как подлинное или нет.

Соответственно, для нашего игрушечного датасета такой порог следует выбирать $≈0.75$. При этом мы будем иметь некоторое количество ошибок и ложно распознавать постороннего человека.

Оптимизация гиперпараметров¶

Часто, когда мы пишем и обучаем сети (будь то с нуля или с помощью transfer learning), мы вынуждены угадывать гиперпараметры (lr, betas и т.д). В случае с learning rate нам есть от чего оттолкнуться (маленький lr для transfer learning), но все же такой подход не кажется оптимальным.

Для оптимизации гиперпараметров существуют готовые решения, которые используют различные методы black-box оптимизации. Разберем одну из наиболее популярных библиотек — Optuna.

In [108]:
!pip install --quiet optuna
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 390.6/390.6 kB 13.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.5/224.5 kB 22.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.7/78.7 kB 11.5 MB/s eta 0:00:00

Давайте оптимизируем learning rate и latent_dim. Для того, чтобы код не выполнялся очень долго, укажем небольшой диапазон параметров и небольшое число эпох.

In [109]:
import optuna
from optuna.samplers import RandomSampler


# define function which will optimized
def objective(trial):
    # boundaries for the optimizer's
    lr = trial.suggest_float("lr", 1e-4, 1e-2)
    latent_dim = trial.suggest_int("latent_dim", 8, 64, step=8)

    # create new model(and all parameters) every iteration
    model = SiameseNet(latent_dim).to(device)  # latent_dim regulates by optuna
    criterion = nn.TripletMarginWithDistanceLoss(
        margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
    )

    optimizer = optim.Adam(
        model.parameters(), lr=lr
    )  # learning rate regulates by optuna

    # To save time, we will take only 3 epochs
    train_loader = DataLoader(
        train_dataset, num_workers=2, batch_size=batch_size, shuffle=True
    )
    _, last_epoch_loss = train(3, model, criterion, optimizer, train_loader)
    return last_epoch_loss


# Create "exploration"
study = optuna.create_study(
    direction="minimize", study_name="Optimizer", sampler=RandomSampler(42)
)

study.optimize(
    objective, n_trials=10
)  # The more iterations, the higher the chances of catching the most optimal hyperparameters
[I 2023-07-18 20:44:47,744] A new study created in memory with name: Optimizer
Epoch 0 with 0.8298 loss
Epoch 1 with 0.5833 loss
[I 2023-07-18 20:44:53,687] Trial 0 finished with value: 0.5804647207260132 and parameters: {'lr': 0.003807947176588889, 'latent_dim': 64}. Best is trial 0 with value: 0.5804647207260132.
Epoch 2 with 0.5805 loss
Epoch 0 with 0.8369 loss
Epoch 1 with 0.6785 loss
[I 2023-07-18 20:45:00,230] Trial 1 finished with value: 0.7354487180709839 and parameters: {'lr': 0.007346740023932911, 'latent_dim': 40}. Best is trial 0 with value: 0.5804647207260132.
Epoch 2 with 0.7354 loss
Epoch 0 with 0.8462 loss
Epoch 1 with 0.4905 loss
[I 2023-07-18 20:45:06,189] Trial 2 finished with value: 0.36206525564193726 and parameters: {'lr': 0.0016445845403801217, 'latent_dim': 16}. Best is trial 2 with value: 0.36206525564193726.
Epoch 2 with 0.3621 loss
Epoch 0 with 0.7979 loss
Epoch 1 with 0.5225 loss
[I 2023-07-18 20:45:12,623] Trial 3 finished with value: 0.3257625102996826 and parameters: {'lr': 0.0006750277604651748, 'latent_dim': 56}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.3258 loss
Epoch 0 with 0.8698 loss
Epoch 1 with 0.6307 loss
[I 2023-07-18 20:45:18,821] Trial 4 finished with value: 0.5650991201400757 and parameters: {'lr': 0.006051038616257768, 'latent_dim': 48}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.5651 loss
Epoch 0 with 0.9011 loss
Epoch 1 with 0.5483 loss
[I 2023-07-18 20:45:24,987] Trial 5 finished with value: 0.48651862144470215 and parameters: {'lr': 0.00030378649352844425, 'latent_dim': 64}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.4865 loss
Epoch 0 with 0.8338 loss
Epoch 1 with 0.7046 loss
[I 2023-07-18 20:45:31,435] Trial 6 finished with value: 0.6294255256652832 and parameters: {'lr': 0.008341182143924175, 'latent_dim': 16}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.6294 loss
Epoch 0 with 0.7068 loss
Epoch 1 with 0.5217 loss
[I 2023-07-18 20:45:37,403] Trial 7 finished with value: 0.44763171672821045 and parameters: {'lr': 0.0019000671753502962, 'latent_dim': 16}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.4476 loss
Epoch 0 with 0.8477 loss
Epoch 1 with 0.5504 loss
[I 2023-07-18 20:45:44,057] Trial 8 finished with value: 0.39682716131210327 and parameters: {'lr': 0.0031119982052994237, 'latent_dim': 40}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.3968 loss
Epoch 0 with 0.8262 loss
Epoch 1 with 0.7000 loss
[I 2023-07-18 20:45:49,978] Trial 9 finished with value: 0.6107327938079834 and parameters: {'lr': 0.004376255684556947, 'latent_dim': 24}. Best is trial 3 with value: 0.3257625102996826.
Epoch 2 with 0.6107 loss

Как видите, такой упрощенный подбор даже двух гиперпараметров занимает много времени.

In [110]:
# show best params
study.best_params
Out[110]:
{'lr': 0.0006750277604651748, 'latent_dim': 56}

Давайте посмотрим на историю оптимизации наших параметров:

In [111]:
optuna.visualization.plot_optimization_history(study)

Что ж, проверим, станет ли реально лучше. Обучим сеть с лучшими latent_dim и lr, определенными с помощью Optuna.

In [112]:
set_random_seed(42)

model = SiameseNet(study.best_params["latent_dim"]).to(
    device
)  # take latent_dim, which choosen by Optuna
criterion = nn.TripletMarginWithDistanceLoss(
    margin=1.0, distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)
)
optimizer = optim.Adam(
    model.parameters(), lr=study.best_params["lr"]
)  # take lr, which choosen by Optuna

num_epochs = 9
train_loader = DataLoader(
    train_dataset, num_workers=2, batch_size=batch_size, shuffle=True
)
l_optim, _ = train(num_epochs, model, criterion, optimizer, train_loader)
Epoch 0 with 0.7609 loss
Epoch 1 with 0.5261 loss
Epoch 2 with 0.4040 loss
Epoch 3 with 0.3303 loss
Epoch 4 with 0.3403 loss
Epoch 5 with 0.3373 loss
Epoch 6 with 0.2967 loss
Epoch 7 with 0.2273 loss
Epoch 8 with 0.2934 loss
In [113]:
plt.plot(range(1, len(loss_history) + 1), loss_history, label="no optimization")
plt.plot(range(1, len(l_optim) + 1), l_optim, label="optimal params")
plt.ylabel("loss")
plt.xlabel("num of epochs")
plt.grid()
plt.legend()
plt.show()
In [114]:
set_random_seed(42)
similarity_pos, similarity_neg = plot_imgs(model, val_loader)
In [115]:
similarities_optim = {
    "The same person": similarity_pos,
    "Another person": similarity_neg,
}

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

sns.histplot(similarities, bins=20, alpha=0.5, ax=axes[0])
sns.histplot(similarities_optim, bins=20, alpha=0.5, ax=axes[1])

axes[0].set(title="No optimization")
axes[1].set(title="Optimization with Optuna")
axes[0].set(xlabel="Pairwise similarity")
axes[1].set(xlabel="Pairwise similarity")

plt.show()

К сожалению, нельзя сказать, что после оптимизации гиперпараметров распределение схожестей стало лучше. При проведении порога по схожести количество ошибок первого и второго рода будет больше, чем до оптимизации.

Это может объясняться тем, что в этой задаче 10 итераций для подбора гиперпараметров все-таки маловато. Зато мы познакомились с API Optuna.

Заключение

Мы затронули проблемы, которые возникают при обучении на реальных данных.

Одна из основных проблем — малые датасеты. Для того, чтобы обучить нейронную сеть на небольшом датасете, можно использовать:

  • аугментации;
  • Transfer learning.

Однако необходимо помнить, что ни один из этих методов не защитит от ситуации, когда реальные данные будут сильно отличаться от тренировочных.

В случае когда у нас не только мало данных, но еще и очень большое (возможно, неизвестное) число классов, можно воспользоваться Metric Learning. В этом случае нейронная сеть обучается не классифицировать изображения, а, наоборот, находить различия между классом и новыми данными. Для этого используются нейронные сети, относящиеся к классу сиамских нейронных сетей.

Также мы разобрали автоматическую оптимизацию гиперпараметров с помощью API Optuna и показали, как пользоваться этим инструментом для оптимизации гиперпараметров модели и процедуры обучения.

Литература

Обучение на реальных данных

How to avoid machine learning pitfalls: a guide for academic researchers (Lones, 2021)

Understanding data augmentation for classification: when to warp? (Wong et al., 2016)

Learning from class-imbalanced data: Review of methods and applications (Haixiang et al., 2017)

Как решить проблему маленького количества данных?

Dealing with very small datasets

Несбалансированные данные

Kaggle Tutorual: Tackling Class imbalance

SMOTE explained for noobs - Synthetic Minority Over-sampling TEchnique line by line

Блог пост про 8 тактик борьбы с несбалансированными классами в наборе данных машинного обучения

Метрики, разработаные для работы с несбалансированными классами.

Transfer Learning

Image Classification using Transfer Learning in PyTorch

How To Do Transfer Learning For Computer Vision | PyTorch Tutorial

Transfer learning for Computer Vision Tutorial

Python Pytorch Tutorials # 2 Transfer Learning : Inference with ImageNet Models

PyTorch - The Basics of Transfer Learning with TorchVision and AlexNet

Augmentation

A survey on Image Data Augmentation for Deep Learning (Shorten and Khoshgoftaar, 2019)

Data augmentation for improving deep learning in image classification problem

Few-shot learning

FaceNet: A Unified Embedding for Face Recognition and Clustering (Schroff et al., 2015)

One-Shot Learning for Face Recognition

Ищем знакомые лица

Siamese Neural Networks for One-shot Image Recognition (Koch et al., 2015)

Dimensionality Reduction by Learning an Invariant Mapping (Hadsell et al., 2005)

Language Models are Few-Shot Learners (Brown et al., 2020)

Hyperparameter optimization

Optuna tutorial for hyperparameter optimization

Optuna: Get the Best out of your Hyperparameters – Easy Tutorial