Всем привет.
Характерной тенденцией последних нескольких лет в глубоком обучении является проникновение трансформера в различные сферы деятельности, где только можно и нельзя (но если очень хочется, то можно) применить нейронные сети. Универсальность архитектуры позволяет работать с самыми разнообразными данными, предварительно превращая их в последовательность токенов, будь то текст, картинки, аудио, видео или даже состояние среды.
Но за невероятную мощь и гибкость архитектуры приходится платить значительной вычислительной сложностью и расходом памяти, ибо сие многоголовое чудище ненасытно в отношении памяти, особенно для длинных последовательностей, что ограничивает применимость моделей на практике. Да и даже при наличии серьезных вычислительных ресурсов обучение моделей на серьезных задачах - дело отнюдь не быстрое.
В недалеком прошлом вышла целая плеяда работ посвященных удешевлению дорогой операции внимания посредством построения различных приближений, сводящих квадратичную по длине последовательности вычислительную сложность и расход памяти к субквадратичной за счет приближения матрицами более низкого ранга, хэшированием, разреженного внимания, локального внимания, комбинированного и вагон и маленькая тележка других идей. Многие подходы показали себя довольно неплохо, давая небольшую потерю в качестве относительно исходного vanilla attention, но все-таки внимание в его первозданном виде было и остается наиболее распространенным.
И на днях вышла работа Flash Attention, где был предложен способ существенно ускорить вычисление attention на GPU, причем никак не меняя конечный результат. То есть делается то же самое, что и при стандартном вычислении внимания, но по-другому.
Ключевая идея
Ключевым ингредиентом успеха Flash Attention является наблюдение того, что основное время уходит не сколько на сами вычисления, ибо современные карточки с легионом CUDA ядер и всякими наворотами выдают тучу терафлопс, а на обращения к памяти, подрузку матриц и тензоров.
Память видеокарты имеет следующую иерархию - есть быстрая кэш память, которой немного, всего пара десятков мегабайт, и относительно медленная, но которой может быть много HBM (high bandwidth memory). Казалось бы, что величина пропускной способности порядка 1 Тб/с это довольно много, но в случае обучения или инференса модели трансформера, HBM память не поспевает за вычислениями (память преследовала его, но он оказался быстрее).
Как вычисляется старый добрый attention?
В стандартной реализации матрицы (query), (key), (value) загружаются поблочно из HBM памяти и промежуточные результаты вычислений - матрицы и - загружаются из HBM памяти в кэш и обратно. Традиционный подход, таким образом требует порядка обращений к памяти.
А что делает Flash Attention?
Вычисление маленькими блоками
Входные данные - матрицы - нарезаются на блоки некоторого размера, такого, чтобы все влезало в кэш.
Затем перемножение матриц проводится поблочно (довольно древняя тема). И , оказывается, тоже можно вычислять блок за блоком, причем достаточно дополнительно хранить сумму экспонент (нормализационный фактор) от входных данных для каждого блока и максимальное значение входов блока.
Пересчет промежуточных матриц
В стандартной реализации приходится хранить довольно увесистые матрицы и размера в памяти для того чтобы вычислить градиенты весов при обратном проходе. Но зная нормализационные факторы в и выходные градиенты слоя, можно легко вычислить градиенты по ключам (как показано в приложении к статье). Такой подход несколько увеличивает количество вычислений, но так как основное время все равно уходило на обращения к памяти, имеем выигрыш по времени выполнения.
Слияние ядер
Кроме того, поблочная процедура вычисления позволяет выполнять все операции (перемножения матриц, , dropout) разом в одном СUDA kernel, в отличие от стандартной имплементации, где приходилось бы на каждую из операций вызывать отдельное ядро, что приводило к дополнительным накладным расходам.
Алгоритм целиком
Flash attention вычисляет операцию attention за (вычислительная сложность обычного attention) с использованием дополнительной памяти.
Блочно-разреженный Flash Attention
Дополнительного ускорения можно добиться, если при вычислении внимания отбросить некоторые из нарезанных блоков (то есть не вычислять). То есть в некотором смысле структурированный прунинг на уровне активаций. Подобрав удачным или правильным образом маску можно уменьшить число вычислений, существенно не теряя в качестве.
Эксперименты
Авторы протестировали свое детище на трех известных бенчмарках:
Обучение BERTа на Википедии
Обучение GPT-2 на OpenWebtext
Работа с длинными последовательностями на Long-range Arena
Как можно заметить, FlashAttention бьет даже довольно качественную и оптимизированную реализацию BERT от Nvidia. Еще более заметная разница между временем обучения для GPT-2 для двух стандартных реализаций и у Flash Attention. На Long-range arena Flash Attention выступает успешнее всех конкурентов с более дешевым вниманием, будучи при этом еще и быстрее. Блочно-разреженное внимание дает еще некоторый прирост в скорости, не теряя по сути в качестве.
Далее авторы сравнивают время прямого и обратного прохода по сети и расход памяти по сравнению с разными реализациями стандартного и "облегченного" внимания. При достаточно больших длинах последовательной Flash Attention оказывается быстрее и еще экономнее по памяти.
И в качестве вишенки на торте авторы впервые в истории смогли с хоть каким-то успехом решить задачи Path-X (картинка 128x128), Path-256 (256x256) с качеством выше выдаваемого алгоритмом гадания на кофейной гуще. Собственно, задача заключается в следующем - найти, существует ли путь между двумя белыми точками на черно-белой картинке или нет. Казалось бы, что тут сложного, справится и маленький ребенок. Ребенок-то справится, а вот последовательность длины порядка 10000 не влезет даже в очень емкую карточку в традиционном подходе, да и из представления картинки в виде одномерной последовательности черно-белых пикселей извлечь глобальный контекст не так-то просто. Все предыдущие подходы либо падали по памяти, либо выдавали качество уровня случайного классификатора (т.е 50%). Тут же удалось выбить ~60%.
Все эксперименты из статьи были запущены на одной машине с Нвидиевской A100.
Будущие направления
В последующей работе авторы предполагают обобщить метод на случай обучения на нескольких GPU. Кроме учета времени передачи данных от кэша к HBM появляется еще дополнительно время передачи данных между разными карточками и машинами, которое еще на порядок (порядки) продолжительнее. Кроме того, и поточечную нелинейность в трансформере (применение FFN (feed forward net) независимо к каждому токену) можно оптимизировать перейдя к блочным умножениям.
Итог
Кажется, что данный результат позволит расширить и оптимизировать применение трансформеров в различных областях и сократить время на обучение моделей, что влечет за собой экономию денег, элеткроэнергии и выбросов CO2 (куда без зеленой повестки). Интересно, насколько быстро данная идея будет воплощена в стандартных фрейворках глубокого обучения. Полученные результаты выглядят очень сильно, но в зависимости от архитектуры и задачи выигрыш будет более или менее заметен.
Ссылки и литература
Сама статья
Гитхаб-репозиторий
Блог на твиттере от создателей