Прежде чем перейти к статье, хочу вам представить, экономическую онлайн игру Brave Knights, в которой вы можете играть и зарабатывать. Регистируйтесь, играйте и зарабатывайте!
Введение
Всем привет. Это мой первый пост и первый обзор на работу. В двух словах опишу, чем это я тут занимался.
Цель проекта заключалась в распознавании достопримечательностей на фотографиях при помощи машинного обучения, а именно свёрточных нейросетей. Данная тема была выбрана из следующих соображений:
у автора уже был некий опыт работы с задачами компьютерного зрения
задача звучала так, как будто её можно сделать очень быстро и не прикладывать большое количество усилий, и, что немаловажно, вычислительных ресурсов (все сетки обучались в колабе или на кагле)
задача может иметь какое-то практическое применение (ну, в теории...)
Сначала он планировался как исключительно учебный проект, но потом я проникся его идеей и решил доработать его до такого состояния, до которого могу.
Далее, я буду рассказывать о том, как я подходил к решению этой задачи, и при этом буду стараться идти по коду из ноутбука, в котором и происходила вся магия, при этом стараться пояснять какие-то свои действия. Возможно, это поможет кому-нибудь преодолеть боязнь "чистого листа" и увидеть, что подобного рода вещи делаются действительно просто!
Инструменты
Ну и первым делом расскажу об инструментах, которые использовались при реализации проекта.
Colab/Kaggle: использовались для того, чтобы обучать сети на ГПУ.
Weights And Biases: сервис, в который я сохранял модели, их описания, добавлял лоссы, значения метрик, параметры обучения, препроцессинга. В общем, вел полный учет. С данными можно ознакомиться по ссылке. В процессе написания кода немного изменялся раздел метадаты, который, по сути, содержит параметры обучения и препроцессинга. В разделе с файлами вы сможете ознакомиться с описанием сети (как устроены её слои), скачать обученные веса сети, а также глянуть на значение лоссов и метрик.
Данные для обучения
Ну, наверное, стоило бы начать с выбора данных для обучения нейронной сети. Для этого я покопался в датасетах на кагле (тык), и вот такой сайт мне еще приглянулся.
Собственно, как выяснилось, существует соревнование от Google, связанное как раз таки с распознаванием достопримечательностей. Здесь появилась первая проблема: датасет весит 100 гб. Понимая, что сетки в дальнейшем я буду учить не на своей пекарне, от данного варианта пришлось отказаться. Полистав еще, я остановился на этом датасете. В нем содержится 210 классов и примерно по 50 фотокарточек на каждый из классов. Картинки все разного размера, снятые с разных ракурсов, с разного расстояния. В общем, датасет совсем не рафинированный, а пока что я работал только с такими. Ну, самому данные размечать не надо, за это уже лайк! Приведу вам парочку особо удачных фотографий:
Хранение и обработка данных (ч.1)
В данном разделе я бы хотел рассказать о том, как данные хранились и обрабатывались.
Первым делом, мы проверим картинки на то, сколько каналов они содержат. Привычнее всего все таки работать с изображениями с тремя каналами (RGB). Но помимо этого формата в данном датасете нам встречаются и черно-белые картинки, и фотографии формата RGBA. Но, к счастью, таких картинок крайне мало (19), поэтому удалим их без угрызений совести.
Для хранения данных я написал несколько классов, которые наследуются от torch.utils.data.Dataset
. При реализации таких классов необходимым условием является переопределение методов __getitem__
и __len__
(то есть добавить классу возможность получать элемент по индексу, и возвращать длину экземпляра класса). Ну, этим я и занялся.
Датасет с быстрым доступом (FastDataset)
Первое, что пришло в голову: давайте просто считывать изображения, приводить их к одному размеру, переводить их в тензоры pytorch
, и хранить тензоры. Далее, когда мы хотим перебрать элементы датасета, мы просто достаем из памяти тензоры, не выполняя больше никакой обработки. Супер, теперь мы сохранили все данные, немножко подождали во время инициализации, но зато доступ моментальный(почти). Казалось бы, что может пойти не так... Но ответ, на самом деле, очевиден: хранить обработанные данные - удовольствие недешевое, и за него приходится платить чеканной монетой памятью. Что же делать...
Датасет с медленным доступом (CustomDataset)
Второе, что пришло в голову: а давайте мы просто будем хранить список, содержащий пути до наших картинок. Таким образом, расходы памяти становятся в разы меньше. Но при таком подходе мы жертвуем временем, за которое происходит обход. Ведь при хранении данных в виде списка путей при каждом обращении мы должны считать изображение по его пути, применить операции ресайзинга и приведения к тензорам, и только после этого мы можем работать с полученным объектом. Долго, да, но ничего не поделать.
Обучение сети
В данном разделе мы немного отойдем от ноутбука.
Разбиение данных
Итак, в нашем арсенале уже имеется два вида датасетов. Давайте уже что-нибудь обучим. Для этого нам нужно написать, наверное, цикл обучения сети, в котором также будем рассчитывать метрики на валидационной выборке для грамотного подбора гиперпараметров сети. Но для того, чтоб таковая выборка появилась, надо научиться разбивать данные. Для этого в каждом классе я реализовал метод разбиения исходного датасета на тренировочный и валидационный. Обернул это все простенькими функциями, и на выходе получил удобно скармливаемые сети экземпляра torch.DataLoader
.
Обучение
При обучении использовался оптимизатор Adam из модуля pytorch
, и в качестве функции потерь была выбрана nn.CrossEntropyLoss
.
Сначала я пробовал писать и обучать совсем простые сети, которые состояли из двух частей: сверточной части, в которой использовались свертки(шок) и пуллинги; полносвязной части, в которой использовались линейные слои и чуть-чуть дропауты (на wandb это - нулевая версия CNN). Стало понятно, что нужно усложнять архитектуру. Добавил слои батч-нормализации, и качество очень приятно подскочило. До этого относился к этому с пренебрежением, потому что не очень понимал, как оно работает (да и сейчас тоже). В общем, используя метод проб и ошибок, удалось поднять качество значение метрики F1
на валидационной выборке до 93%. Тогда я подумал, что цель достигнута, и получилось отделаться малой кровью, но не тут то было. И просто для того, чтобы убедиться в том, что все действительно хорошо, решил погуглить про метрику, которую я использовал. Оказалось, что считалось совсем не то, что я ожидал, и когда я все исправил, значение метрики на валидационной выборке вернулось к 31%, а вот на тренировочной выборке было 96%. Вот это уже по нашему! Сетка хорошо так переобучается. Давайте решать проблему.
Хранение и обработка данных (ч.2)
Первая идея, которая меня посетила: скорее всего, сетка просто не может научиться на 45 изображениях, некоторые из которых еще и не самого лучшего качества. Что можно с этим сделать? Ну давайте применим аугментацию. Не знаю, какой контингент читает эту рукопись, так что дам краткое пояснение. Аугментация, по сути - увеличение объема данных, за счёт которого можно обучить сетку. Давайте попробуем искусственно расширить множество уже имеющихся картинок путем применения к ним неких трансформаций.
Идея следующая: давайте к каждому существующему изображению будем применять набор преобразований, например, поворот на 180 градусов, или осуществлять небольшой поворот изображения.
Мы смогли расширить наш датасет аж в 7 раз! Давайте используем это для обучения сетки.
Далее, я реализовал еще два класса, по аналогии с датасетами из ч.1: с быстрым доступом AugmentedFastDataset
и медленным доступом AugmentedCustomDataset
. Проблема возникла моментально: уже на данном этапе, при применении 7 различных видов трансформаций, датасет с быстрым доступом сжирал всю память, и все падало. Соответственно, пришлось использовать его менее быструю, но более экономичную (в плане памяти) версию.
Ну и что мы видим (посмотреть можно CNN.v9): модель все равно очень сильно переобучается. Что же еще можно такого придумать...
И пришла в голову следующая идея: зачем применять по одной трансформации(назову так операцию изменения исходного изображения) за раз? Можно ведь применять их последовательно. Тогда, делая различные комбинации, мы сможем еще больше расширить датасет. Давайте попробуем реализовать эту задумку в классе AdvancedCustomDataset
. Кратко опишу процесс: в конструктор класса мы передаем аргумент ex_amount
, которая отвечает за то, сколько экземпляров для каждого класса мы хотим получить. Далее, проходимся по каждому классу, и до тех пор, пока не получим нужное число изображений, применяем случайный набор трансформаций к случайному изображению. Ниже, можно увидеть пример того, как работает данная задумка.
Также, произошли еще некоторые минорные изменения, связанные с заменой некоторых функций на их аналоги из других модулей. Причина проста: так как датасет сильно расширился, и доступ к элементам медленный, на один проход уходит уйма времени. Поэтому приятно было бы сэкономить на таких мелочах пару-тройку минут.
открытие изображение раньше производилось при помощи библиотеки
PIL
. Как показали сравнения, открытие изображений при помощи библиотекиcv2
работает гораздо быстрее. Поэтому, в отличии от остальных классов, в последнем датасете используется аналог изcv2
операции по изменению изображений были взяты из модуля
torchvision.transforms
. Как выяснилось позже, аналогичные функции из модуляalbumentations
работают быстрее.
Ну что, данных наделали, метрики наладили, пора учиться. Коль теперь я был волен выбирать, сколько картинок для каждого класса я хочу, я выставил значение 2000. И после длительного процесса обучения и валидации получаем модель со значением F1
= 60% на валидационной выборке. Уже хорошо, я считаю.
Немножко про FineTuning
Ну что же, какое-то приемлемое качество мы уже получили, давайте теперь отвлечемся от самописной архитектуры, и попробуем переобучить уже существующую сеть. В качестве такой сети я взял модель VGG13
с батч-нормализацией с предобученными весами. Далее, заморозил всю свёрточную часть, немного поигрался с классификатором, и поставил это все дело учиться. Получилось еще лучше, чем было до этого: метрика на валидационной выборке равна 70% (тык).
Послесловие
Итак, что мы получили на выходе: две сети, которые работают хорошо, и действительно качественно распознают изображения. Допускаю даже то, что та самая тридцати процентная ошибка возникает из-за кривых фотографий в датасете (примеры приводил выше).
Попытался оформить все это дело в мини-проект, который можно скачать с гитхаба.
Прошу писать любые замечания, связанные с написанным, буду рад набраться опыта!