LLM как оптимизатор для задачи линейной регрессии

Моя цель - предложение широкого ассортимента товаров и услуг на постоянно высоком качестве обслуживания по самым выгодным ценам.

В сентябре 2023 года инженеры из гугла выпустили статью об использовании LLM для различных задач оптимизации. Там нет кода или ссылки на репозиторий, чтобы можно было самому поиграть, поэтому я написал простой оптимизатор с помощью языковой модели (Mistral-7B-Instruct) для задачи линейной регрессии.

Коротко о линейной регрессии

Линейная регрессия — это модель зависимости одной переменной от другой (или нескольких) с линейной функцией зависимости. Она позволяет предсказывать значение одной переменной на основании другой или нескольких.

Решить задачу линейной регрессии с одной переменной - значит нарисовать линию, которая будет максимально точно соответствовать существующим наблюдениям. Линия - это уравнение, подставив в которое значение X, мы получим предсказанное значение Y:

Линейная регрессия
Линейная регрессия

Чтобы оценить, насколько хорошо наша линия подходит под имеющиеся наблюдения, используют различные методы. Самый известный - метод наименьших квадратов (МНК). С его помощью мы определяем насколько далеко реальные наблюдения отдалены от нашей линии. Задача - минимизировать эти расстояния.

Функцию, которая рассчитывает расстояния, называют функцией потерь (loss function или cost function). И мы хотим её минимизировать.

Задача линейной регрессии имеет аналитическое решение. Когда с помощью манипуляций с производными мы получаем явную формулу и находим точное решение (правильную линию). Но если переменных и наблюдений слишком много, то аналитическое решение может быть вычислительно-затратным или даже невозможным.

Тогда на помощь приходят итерационные методы. Самый известный - градиентный спуск.

Изменение функции потерь (cost) в зависимости от значения  переменной (w) 
Изменение функции потерь (cost) в зависимости от значения переменной (w) 

Во время градиентного спуска мы как бы проверяем: если я немного увеличу значение переменной w, то будет ли моя линия лучше подходить под имеющиеся наблюдения? Если да, то я немного увеличиваю w, если нет - уменьшаю. И так двигаюсь до тех пор, пока не окажусь в оптимальном минимуме.

Думаю, что-то похожее будет делать наша языковая модель, когда будет подбирать идеальные коэффициенты на основании только текстовых инструкций.

Больше про линейную регрессию - тут.
Про градиентный спуск - тут.
Функцию потерь - тут.

Оптимизируем с помощью LLM

Пайплайн:

  1. Создадим набор данных со значениями y, x;

  2. Случайно инициируем веса (w, b) для нашей линии y_pred = w*x + b;

  3. Передадим модели инструкцию, в которой скажем, какое значение принимает наша функция потерь при заданных w, b. И попросим её изменить w, b таким образом, чтобы уменьшить функцию потерь. (Модель не будет знать, какую функцию мы оптимизируем. Мы будем подавать ей только значения: w, b, loss);

  4. Возьмём предложенные моделью w, b, посчитаем для них loss и снова подадим модели. (Сначала на входе у модели будет всего один пример - случайно инициированные веса, а затем к нему буду добавляться примеры, которые она сама придумала, но не больше 10 штук);

  5. Дождёмся, когда 3 последних значения loss функции станут меньше 1 и примем это за оптимальное решение.

Загружаем модель Mistral-7B-Instruct-v0.1 с Hugging Face:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda"

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",
    device_map=device,
    torch_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

Создадим набор данных:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

x = np.arange(0, 6, 0.5) # создаём истинные значения для x
y = 3*x + np.random.randint(-1, 2, 12) # создаём истинные значения для y + шум

# инициируем случайные веса для нашей линии y_pred = w*x + b
# во время оптимизации мы будем менять веса w, b, рассчитывать y_pred
# и сравнивать их с истинными значениями "y", определёнными выше
w = np.random.uniform(-5, 5) 
b = np.random.uniform(-5, 5)

Построим график:

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.plot(x, w*x + b, c='red')
ax.set_xlabel('x')
ax.set_ylabel('y');
Синие точки - истинные значения y при заданном x.Красная линия - наш ответ со случайно инициированными весами (w, b). Выглядит не очень.
Синие точки - истинные значения y при заданном x.
Красная линия - наш ответ со случайно инициированными весами (w, b).
Выглядит не очень.

Напишем несколько функций для парсинга ответов LLM, расчёта loss:

def is_number_isdigit(s): # парсинг str ответа от LLM
    n1 = s[0].replace('.','',1).replace('-','',1).strip().isdigit()
    n2 = s[1].replace('.','',1).replace('-','',1).strip().isdigit()
    return n1 * n2

  
# останавливаем оптимизацию, когда последние "last_nums" значений loss < 1
def check_last_solutions(loss_list, last_nums):
    if len(loss_list) >= last_nums:
        last = loss_list[-last_nums:]
        return all(num < 1 for num in last)

      
def loss_calc(y, w, x, b):
    return ((y - w*x + b)**2).mean() # функция потерь МНК

  
loss = loss_calc(y, w, x, b) # рассчитаем первый loss для случайных (w, b)

d = {'loss': [loss], 'w': [w], 'b': [b]}
loss_list = [loss] # соберём все loss для построения графика в конце

df = pd.DataFrame(data=d) # датасет c предложеными моделью весами (w, b) и loss
df.sort_values(by=['loss'], ascending=False, inplace=True)

Посмотрим loss со случайно инициированными w, b:

df
Output:

loss	         w	        b
404.096928	-2.683655	1.586905

Создаём промт:

# num_sol - максимальное кол-во наблюдений в промте
def create_prompt_bias(num_sol): 
    meta_prompt_start = f'''Now you will help me minimize a function with two input variables w, b. I have some (w, b) pairs and the function values at those points.
The pairs are arranged in descending order based on their function values, where lower values are better.\n\n'''

    solutions = ''
    if num_sol > len(df.loss):
        num_sol = len(df.loss)

    for i in range(num_sol):
        solutions += f'''input:\nw={df.w.iloc[-num_sol + i]:.3f}, b={df.b.iloc[-num_sol + i]:.3f}\nvalue:\n{df.loss.iloc[-num_sol + i]:.3f}\n\n''' 
    
    meta_prompt_end = f'''Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than
any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values.

w, b ='''

    return meta_prompt_start + solutions + meta_prompt_end
# Вот так будет выглядеть промт для двух решений. 
# Значения сотрируются по loss(value) по убыванию.

Now you will help me minimize a function with two input variables w, b. I have some (w, b) pairs and the function values at those points.
The pairs are arranged in descending order based on their function values, where lower values are better.

input:
w=-0.456, b=0.357
value:
135.314

input:
w=0.700, b=0.450
value:
63.494

Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than
any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values.

w, b =

Запускаем цикл оптимизации:

num_solutions = 10 # кол-во наблюдений, которое будем подавать в промт

for i in range(500):
    
    text = create_prompt(num_solutions)

    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    model.to(device)

    generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=15,
            temperature=0.8,
            do_sample=True,
            pad_token_id=50256
            )

    output = tokenizer.batch_decode(generated_ids)[0]

    response = output.split("w, b =")[1].strip()
    
    if "\n" in response:
        response = response.split("\n")[0].strip()

    if "," in response:
        numbers = response.split(',')
    
    if is_number_isdigit(numbers):
        w, b = float(numbers[0].strip()), float(numbers[1].strip())
        loss = loss_calc(y, w, x, b)
        loss_list.append(loss)
        new_row = {'loss': loss, 'w': w, 'b': b}
        new_row_df = pd.DataFrame(new_row, index=[0])
        df = pd.concat([df, new_row_df], ignore_index=True)
        df.sort_values(by='loss', ascending=False, inplace=True)

    if i % 20 == 0: # принтуем каждый 20-ый шаг 
        print(f'{w=} {b=} loss={loss:.3f}')

    if check_last_solutions(loss_list, 3):
        break
Output:

w=-100.0 b=1.0 loss=112593.792
w=-1.5 b=0.9 loss=245.704
w=2.2 b=1.1 loss=15.197
w=-2.0 b=-1.0 loss=246.792
w=3.5 b=1.2 loss=0.809

Посмотрим последние 10 значений loss:

print(*loss_list[-10:], sep='\n')
44.41708333333333
28.161666666666665
26.42833333333333
21.763333333333335
46.583333333333336
20.939537499999997
20.939537499999997
0.80875
0.80875
0.6437500000000002

А вот так теперь выглядит наша прямая:

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.plot(x, w*x + b, c='red');
Похоже, модель справилась
Похоже, модель справилась

Посмотрим на снижение loos во время оптимизации (ограничил значения 700 единицами, потому что в процессе тренировки было несколько выбросов со значениями больше миллиона).

fig, ax = plt.subplots()
print(f'number of step = {len(loss_list)}')
ax.plot([x for x in loss_list if x < 700]);
Для оптимизации потребовалось чуть больше 60-ти шагов
Для оптимизации потребовалось чуть больше 60-ти шагов

Интересное наблюдение. Температура (temperature), параметр, который отвечает за вариативность ответов модели, играет в нашем случае своеобразную роль шага для градиентного спуска. Чем ниже температура, тем медленнее снижается loss, но в то же время реже встречаются выбросы. И наоборот - чем выше температура, тем более уверенные "шаги" делает модель, быстрее сходится, но и часто отдаёт большие выбросы.

Вот так, например, выглядит снижение loss при temperature=0.5 через каждые 20 итераций:

w=-0.001 b=2.73 loss=153.029
w=0.0 b=2.73 loss=152.950
w=1.0 b=2.73 loss=83.893
w=0.333 b=1.73 loss=108.150
w=0.5 b=1.73 loss=97.242
w=0.94 b=1.73 loss=71.318
w=0.97 b=1.75 loss=69.999
w=0.995 b=1.715 loss=68.143
w=0.999 b=1.719 loss=67.990
w=0.999 b=1.719 loss=67.990
w=0.999 b=1.719 loss=67.990
w=0.999 b=1.719 loss=67.990
w=0.719 b=0.2 loss=61.172
w=0.75 b=0.15 loss=58.963
w=0.852 b=0.1 loss=53.394
w=0.905 b=0.095 loss=50.863
w=0.918 b=0.078 loss=50.063
w=0.922 b=0.068 loss=49.762
w=0.931 b=0.063 loss=49.294
w=0.936 b=0.056 loss=48.985
w=0.935 b=0.057 loss=49.042
w=0.939 b=0.054 loss=48.826
w=0.939 b=0.054 loss=48.826
w=0.946 b=0.051 loss=48.475
w=0.934 b=0.043 loss=48.922

P. S.

Не стоит рассматривать языковую модель, как реальный инструмент для оптимизации в таких задачах. Для решения задачи линейно регрессии существуют куда более простые, быстрые и менее затратные методы (для запуска Mistral-7B-Instruct в формате bfloat16 требуется видеокарта с памятью как минимум 16Gb).

Но в целом тенденция выглядит немного пугающей. Даже относительно небольшие LLM становятся всё более "умными", а люди находят им всё новые применения. Например, в статье, на которую я ссылался вначале, авторы предлагают метод оптимизации промтов - а это уже реальная заявка на то, чтобы отобрать работу у промт инженеров (ну или по крайней мере внести существенные коррективы в их обучение).

Репозиторий с кодом на GitHub - https://github.com/akocherovskiy/LLM_as_optimizer

Google Colab, где можно запустить код на бесплатной Т4 - LLM_as_optimizer

Источник: https://habr.com/ru/articles/767650/


Интересные статьи

Интересные статьи

Для качественного технического обслуживания и ремонта необходимо заранее знать о возможных неисправностях, а также об остаточном ресурсе трансформаторного оборудования. Необходимо разработать модель, ...
Управление продуктом, проектом и персоналом – это важная часть стартапа, от которой зависит его рост и развитие. Общение с командой и внешней аудиторией требует четкости и единого понимания процессов....
Видеоблогер Конор Хекстра использовал разные языки программирования, чтобы решить одну и ту же задачу. Попутно выяснилось, что у Фортрана полно поклонников.
Привет, Habr! На связи отдел аналитики данных X5 Tech.Сегодня мы поговорим об очень интересном разделе прикладной математики — оптимизации.
В этой статье мы поговорим о математике градиентного спуска, почему при обучении нейронных сетей применяется стохастический градиентный спуск и о вариации SGD (Stochastic Gradient Descent...