Прежде чем перейти к статье, хочу вам представить, экономическую онлайн игру Brave Knights, в которой вы можете играть и зарабатывать. Регистируйтесь, играйте и зарабатывайте!
Большие языковые модели это конечно хорошо, но иногда требуется использовать что-то маленькое и быстрое.
Постановка задачи
Дистилляция будет проводиться для модели BERT, обученной на задачу бинарной классификации. В качестве данных был выбран открытый корпус русскоязычных твитов. Вдохновлялся двумя статьями: по дистилляции данных из BERT в BiLSTM, и собственно по дистилляции BERT. Нового ничего не добавлю, хочется все причесать и сделать пошаговый туториал для простого использования. Весь код на github.
План работ
Baseline 1: TF-IDF + RandomForest
Baseline 2: BiLSTM
Дистилляция BERT > BiLSTM
Дистилляция BERT > tinyBERT
TF-IDF + RandomForest
Все стандартно: нижний регистр, лемматизация, удаление стоп-слов. Полученные вектора классифицируем RandomForest. Получаем F1 чуть больше 0.75.
Как обучить TF-IDF + RF
import re
import pandas as pd
from pymystem3 import Mystem
# get data
data = pd.read_csv('data.csv')
texts = list(data['comment'])
labels = list(map(int, data['toxic'].values))
# clean texts
texts = [re.sub('[^а-яё ]', ' ', str(t).lower()) for t in texts]
texts = [re.sub(r" +", " ", t).strip() for t in texts]
# lemmatize
mstm = Mystem()
normalized = [''.join(mstm.lemmatize(t)[:-1]) for t in texts]
# remove stopwords
with open('./stopwords.txt') as f:
stopwords = [line.rstrip('\n') for line in f]
def drop_stop(text):
tokens = text.split(' ')
tokens = [t for t in tokens if t not in stopwords]
return ' '.join(tokens)
normalized = [drop_stop(text) for text in normalized]
# new dataset
df = pd.DataFrame()
df['text'] = texts
df['norm'] = normalized
df['label'] = labels
# train-valid-test-split
from sklearn.model_selection import train_test_split
train, test = train_test_split(df, test_size=0.3, random_state=42)
valid, test = train_test_split(test, test_size=0.5, random_state=42)
# tf-idf
from sklearn.feature_extraction.text import TfidfVectorizer
model_tfidf = TfidfVectorizer(max_features=5000)
train_tfidf = model_tfidf.fit_transform(train['norm'].values)
valid_tfidf = model_tfidf.transform(valid['norm'].values)
test_tfidf = model_tfidf.transform(test['norm'].values)
# RF
from sklearn.ensemble import RandomForestClassifier
cls = RandomForestClassifier(random_state=42)
cls.fit(train_tfidf, train['label'].values)
# prediction
predictions = cls.predict(test_tfidf)
# score
from sklearn.metrics import f1_score
f1_score(predictions, test['label'].values)
BiLSTM
Попробуем улучшить бэйзлайн с помощью нейросетевого подхода. Все стандартно: учим токенизатор, учим сетку. В качестве базовой архитектуры берем BiLSTM. Получаем F1 чуть больше 0.79. Небольшой, но прирост есть.
Как обучить BiLSTM
# get data
import pandas as pd
train = pd.read_csv('train.csv')
valid = pd.read_csv('valid.csv')
test = pd.read_csv('test.csv')
# create tokenizer
from tokenizers import Tokenizer
from tokenizers import ByteLevelBPETokenizer
from tokenizers.pre_tokenizers import Whitespace
tokenizer = ByteLevelBPETokenizer()
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_id=0, pad_token='<pad>')
texts_path = 'texts.txt'
with open(texts_path, 'w') as f:
for text in list(train['text'].values):
f.write("%s\n" % text)
tokenizer.train(
files=[texts_path],
vocab_size=5_000,
min_frequency=2,
special_tokens=['<pad>', '<unk>']
)
# create dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, tokens, labels, max_len):
self.tokens = tokens
self.labels = labels
self.max_len = max_len
def __len__(self):
return len(self.tokens)
def __getitem__(self, idx):
label = self.labels[idx]
label = torch.tensor(label)
tokens = self.tokens[idx]
out = torch.zeros(self.max_len, dtype=torch.long)
out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len]
return out, label
max_len = 64
BATCH_SIZE = 16
train_labels = list(train['label'])
train_tokens = [tokenizer.encode(text).ids for text in list(train['text'])]
train_dataset = CustomDataset(train_tokens, train_labels, max_len)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_labels = list(test['label'])
test_tokens = [tokenizer.encode(text).ids for text in list(test['text'])]
test_dataset = CustomDataset(test_tokens, test_labels, max_len)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# create BiLSTM
class LSTM_classifier(nn.Module):
def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):
super().__init__()
self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
self.dropout_layer = nn.Dropout(dropout)
self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)
self.batchnorm = nn.BatchNorm1d(linear_dim)
self.relu = nn.ReLU()
self.out_layer = nn.Linear(linear_dim, n_classes)
def forward(self, inputs):
batch_size = inputs.size(0)
embeddings = self.embedding_layer(inputs)
lstm_out, (ht, ct) = self.lstm_layer(embeddings)
out = ht.transpose(0, 1)
out = out.reshape(batch_size, -1)
out = self.fc_layer(out)
out = self.batchnorm(out)
out = self.relu(out)
out = self.dropout_layer(out)
out = self.out_layer(out)
out = torch.squeeze(out, 1)
out = torch.sigmoid(out)
return out
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def eval_nn(model, data_loader):
predicted = []
labels = []
model.eval()
with torch.no_grad():
for data in data_loader:
x, y = data
x = x.to(device)
outputs = model(x)
_, predict = torch.max(outputs.data, 1)
predict = predict.cpu().detach().numpy().tolist()
predicted += predict
labels += y
score = f1_score(labels, predicted, average='binary')
return score
def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20):
best_score = 0
for epoch in range(epochs):
model.train()
for inputs, labels in tqdm(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
predict = model(inputs)
loss = loss_function(predict, labels)
loss.backward()
optimizer.step()
score = eval_nn(model, test_loader)
print(epoch, 'valid:', score)
if score > best_score:
torch.save(model.state_dict(),'lstm.pt')
best_score = score
return best_score
# fit NN
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)
model.apply(init_weights)
model.to(device)
optimizer = optim.AdamW(model.parameters())
loss_function = nn.CrossEntropyLoss().to(device)
train_nn(model, optimizer, loss_function, train_loader, valid_loader, device, epochs=20)
eval_nn(model, test_loader)
Учим BERT
Обучим модель-учитель. В качестве учителя я выбрал героя вышеупомянутой статьи по дистилляции - rubert-tiny от @cointegrated. Получаем F1 чуть больше 0.91. Я особо не игрался с обучением, можно думаю было получить метрику и получше, особенно если использовать большой BERT, но и так достаточно показательно. Как обучить BERT на бинарную классификацию можно глянуть в моей прошлой статье, или прямо тут:
как обучить BERT
import torch
from torch.utils.data import Dataset
class BertDataset(Dataset):
def __init__(self, texts, targets, tokenizer, max_len=512):
self.texts = texts
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
target = self.targets[idx]
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
padding='max_length',
return_attention_mask=True,
return_tensors='pt',
truncation=True
)
return {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'targets': torch.tensor(target, dtype=torch.long)
}
from tqdm import tqdm
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_recall_fscore_support
class BertClassifier:
def __init__(self, path, n_classes=2):
self.path = path
self.model = BertForSequenceClassification.from_pretrained(path)
self.tokenizer = BertTokenizer.from_pretrained(path)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.max_len = 512
self.out_features = self.model.bert.encoder.layer[1].output.dense.out_features
self.model.classifier = torch.nn.Linear(self.out_features, n_classes)
self.model.to(self.device)
def preparation(self, X_train, y_train, epochs):
# create datasets
self.train_set = BertDataset(X_train, y_train, self.tokenizer)
# create data loaders
self.train_loader = DataLoader(self.train_set, batch_size=2, shuffle=True)
# helpers initialization
self.optimizer = AdamW(
self.model.parameters(),
lr=2e-5,
weight_decay=0.005,
correct_bias=True
)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=500,
num_training_steps=len(self.train_loader) * epochs
)
self.loss_fn = torch.nn.CrossEntropyLoss().to(self.device)
def fit(self):
self.model = self.model.train()
losses = []
correct_predictions = 0
for data in tqdm(self.train_loader):
input_ids = data["input_ids"].to(self.device)
attention_mask = data["attention_mask"].to(self.device)
targets = data["targets"].to(self.device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
preds = torch.argmax(outputs.logits, dim=1)
loss = self.loss_fn(outputs.logits, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
train_acc = correct_predictions.double() / len(self.train_set)
train_loss = np.mean(losses)
return train_acc, train_loss
def train(self, X_train, y_train, X_valid, y_valid, X_test, y_test, epochs=1):
print('*' * 10)
print(f'Model: {self.path}')
self.preparation(X_train, y_train, epochs)
for epoch in range(epochs):
print(f'Epoch {epoch + 1}/{epochs}')
train_acc, train_loss = self.fit()
print(f'Train loss {train_loss} accuracy {train_acc}')
predictions_valid = [self.predict(x) for x in X_valid]
precision, recall, f1score = precision_recall_fscore_support(y_valid, predictions_valid, average='macro')[:3]
print('Valid:')
print(f'precision: {precision}, recall: {recall}, f1score: {f1score}')
predictions_test = [self.predict(x) for x in X_test]
precision, recall, f1score = precision_recall_fscore_support(y_test, predictions_test, average='macro')[:3]
print('Test:')
print(f'precision: {precision}, recall: {recall}, f1score: {f1score}')
print('*' * 10)
def predict(self, text):
self.model = self.model.eval()
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt',
)
out = {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten()
}
input_ids = out["input_ids"].to(self.device)
attention_mask = out["attention_mask"].to(self.device)
outputs = self.model(
input_ids=input_ids.unsqueeze(0),
attention_mask=attention_mask.unsqueeze(0)
)
prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]
return prediction
import pandas as pd
train = pd.read_csv('train.csv')
valid = pd.read_csv('valid.csv')
test = pd.read_csv('test.csv')
classifier = BertClassifier(
path='cointegrated/rubert-tiny',
n_classes=2
)
classifier.train(
X_train=list(train['text']),
y_train=list(train['label']),
X_valid=list(valid['text']),
y_valid=list(valid['label']),
X_test=list(test['text']),
y_test=list(test['label']),
epochs=1
)
path = './trainer'
classifier.model.save_pretrained(path)
classifier.tokenizer.save_pretrained(path)
Дистилляция BERT > BiLSTM
Основная идея - приближение BiLSTM-учеником выхода BERT-учителя. Для этого при обучении используем функцию ошибки MSE. Можно использовать совместно с обучением на метках и CrossEntropyLoss. Подробнее можно почитать в статье по ссылке. На моих тестовых данных дистилляция докинула всего пару процентов: F1 чуть больше 0.82.
Код дистилляции
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from tokenizers import ByteLevelBPETokenizer
from tokenizers.pre_tokenizers import Whitespace
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import f1_score
import numpy as np
import pandas as pd
### data
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')
### tokenizer: train
tokenizer = ByteLevelBPETokenizer()
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_id=0, pad_token='<pad>')
texts_path = 'texts.txt'
with open(texts_path, 'w') as f:
for text in list(train['text'].values):
f.write("%s\n" % text)
tokenizer.train(
files=[texts_path],
vocab_size=5_000,
min_frequency=2,
special_tokens=['<pad>', '<unk>']
)
### load BERT tokenizer
tokenizer_bert = BertTokenizer.from_pretrained('./rubert-tiny')
### dataset
class CustomDataset(Dataset):
def __init__(self, tokens, labels, max_len):
self.tokens = tokens
self.labels = labels
self.max_len = max_len
def __len__(self):
return len(self.tokens)
def __getitem__(self, idx):
label = self.labels[idx]
label = torch.tensor(label)
tokens = self.tokens[idx]
out = torch.zeros(self.max_len, dtype=torch.long)
out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len]
return out, label
max_len = 64
BATCH_SIZE = 16
train_labels = list(train['label'])
train_tokens = [tokenizer.encode(str(text)).ids for text in list(train['text'])]
train_dataset = CustomDataset(train_tokens, train_labels, max_len)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_labels = list(test['label'])
test_tokens = [tokenizer.encode(str(text)).ids for text in list(test['text'])]
test_dataset = CustomDataset(test_tokens, test_labels, max_len)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
class LSTM_classifier(nn.Module):
def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):
super().__init__()
self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
self.dropout_layer = nn.Dropout(dropout)
self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)
self.batchnorm = nn.BatchNorm1d(linear_dim)
self.relu = nn.ReLU()
self.out_layer = nn.Linear(linear_dim, n_classes)
def forward(self, inputs):
batch_size = inputs.size(0)
embeddings = self.embedding_layer(inputs)
lstm_out, (ht, ct) = self.lstm_layer(embeddings)
out = ht.transpose(0, 1)
out = out.reshape(batch_size, -1)
out = self.fc_layer(out)
out = self.batchnorm(out)
out = self.relu(out)
out = self.dropout_layer(out)
out = self.out_layer(out)
out = torch.squeeze(out, 1)
out = torch.sigmoid(out)
return out
########
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def eval_nn(model, data_loader):
predicted = []
labels = []
model.eval()
with torch.no_grad():
for data in data_loader:
x, y = data
x = x.to(device)
outputs = model(x)
_, predict = torch.max(outputs.data, 1)
predict = predict.cpu().detach().numpy().tolist()
predicted += predict
labels += y
score = f1_score(labels, predicted, average='binary')
return score
def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20):
best_score = 0
for epoch in range(epochs):
model.train()
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
predict = model(inputs)
loss = loss_function(predict, labels)
loss.backward()
optimizer.step()
score = eval_nn(model, test_loader)
print(epoch, 'valid:', score)
if score > best_score:
torch.save(model.state_dict(), 'lstm.pt')
best_score = score
return best_score
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)
model.apply(init_weights);
model.to(device);
optimizer = optim.AdamW(model.parameters())
loss_function = nn.CrossEntropyLoss().to(device)
train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=3)
#####
class DistillDataset(Dataset):
def __init__(self, texts, labels, tokenizer_bert, tokenizer_lstm, max_len):
self.texts = texts
self.labels = labels
self.tokenizer_bert = tokenizer_bert
self.tokenizer_lstm = tokenizer_lstm
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
label = torch.tensor(label)
# lstm
tokens_lstm = self.tokenizer_lstm.encode(str(text)).ids
out_lstm = torch.zeros(self.max_len, dtype=torch.long)
out_lstm[:len(tokens_lstm)] = torch.tensor(tokens_lstm, dtype=torch.long)[:self.max_len]
# bert
encoding = self.tokenizer_bert.encode_plus(
str(text),
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt',
)
out_bert = {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten()
}
return out_lstm, out_bert, label
train_dataset_distill = DistillDataset(
list(train['text']),
list(train['label']),
tokenizer_bert,
tokenizer,
max_len
)
train_loader_distill = DataLoader(train_dataset_distill, batch_size=BATCH_SIZE, shuffle=True)
### BERT-teacher model
class BertTrainer:
def __init__(self, path_model, n_classes=2):
self.model = BertForSequenceClassification.from_pretrained(path_model, num_labels=n_classes)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.max_len = 512
self.model.to(self.device)
self.model = self.model.eval()
def predict(self, inputs):
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
return outputs.logits
teacher = BertTrainer('./rubert-tiny')
### BiLSTM-student model
class CustomLSTM(nn.Module):
def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):
super().__init__()
self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
self.dropout_layer = nn.Dropout(dropout)
self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)
self.batchnorm = nn.BatchNorm1d(linear_dim)
self.relu = nn.ReLU()
self.out_layer = nn.Linear(linear_dim, n_classes)
def forward(self, inputs):
batch_size = inputs.size(0)
embeddings = self.embedding_layer(inputs)
lstm_out, (ht, ct) = self.lstm_layer(embeddings)
out = ht.transpose(0, 1)
out = out.reshape(batch_size, -1)
out = self.fc_layer(out)
out = self.batchnorm(out)
out = self.relu(out)
out = self.dropout_layer(out)
out = self.out_layer(out)
# out = torch.squeeze(out, 1)
# out = torch.sigmoid(out)
return out
def loss_function(output, teacher_prob, real_label, a=0.5):
criterion_mse = torch.nn.MSELoss()
criterion_ce = torch.nn.CrossEntropyLoss()
return a * criterion_ce(output, real_label) + (1 - a) * criterion_mse(output, teacher_prob)
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def eval_nn(model, data_loader):
predicted = []
labels = []
model.eval()
with torch.no_grad():
for data in data_loader:
x, y = data
x = x.to(device)
outputs = model(x)
_, predict = torch.max(outputs.data, 1)
predict = predict.cpu().detach().numpy().tolist()
predicted += predict
labels += y
score = f1_score(labels, predicted, average='binary')
return labels, predicted, score
def train_distill(model, teacher, optimizer, loss_function, distill_loader, train_loader, test_loader, device, epochs=30, alpha=0.5):
best_score = 0
score_list = []
for epoch in range(epochs):
model.train()
for inputs, inputs_teacher, labels in distill_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
predict = model(inputs)
teacher_predict = teacher.predict(inputs_teacher)
loss = loss_function(predict, teacher_predict, labels, alpha)
loss.backward()
optimizer.step()
score_train = round(eval_nn(model, train_loader)[2], 3)
score_test = round(eval_nn(model, test_loader)[2], 3)
score_list.append((score_train, score_test))
print(epoch, score_train, score_test)
if score_test > best_score:
best_score = score_test
best_model = model
torch.save(best_model.state_dict(), f'./results/lstm_{best_score}.pt')
return best_model, best_score, score_list
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vocab_size = tokenizer.get_vocab_size()
vocab_size
score_alpha = []
for alpha in [0, 0.25, 0.5, 0.75, 1]:
model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)
model.apply(init_weights)
model.to(device)
optimizer = optim.AdamW(model.parameters())
_, _, score_list = train_distill(model, teacher, optimizer, loss_function, train_loader_distill, train_loader, test_loader, device, 30, alpha)
score_alpha.append(score_list)
import matplotlib.pyplot as plt
import numpy as np
a_list = [1, 0.75, 0.5, 0.25, 0]
for i, score in enumerate(score_alpha):
_, score_test = list(zip(*score))
plt.plot(score_test, label=f'{a_list[i]}')
plt.grid(True)
plt.legend()
plt.show()
Дистилляция BERT > tinyBERT
Основная идея, как и в прошлом пункте - приближать учеником поведение учителя. Есть много вариантов что и как приближать, я взял всего два:
Приближать [CLS]-токен по MSE
Дистилляция распределения токенов по дивергенции Кульбака-Лейблера
Дополнительно в процессе обучения решаем задачу MLM - предсказание замаскированных токенов. Уменьшение размера модели осуществляется за счет сокращения словаря и уменьшения количества голов внимания, а также количества и размерности скрытых слоев.
Обучение итогового классификатора в итоге делится на 2 этапа:
Обучение языковой модели
Обучение головы для классификации
Я применял дистилляцию только для первого этапа, голову для классификации учил уже непосредственно на дистиллированной модели. Думаю можно было накинуть и вариант с MSE как в примере с BiLSTM, но оставил эти эксперименты на потом.
Ключевые моменты реализации:
Сокращение словаря:
from transformers import BertTokenizerFast, BertForPreTraining, BertModel, BertConfig
from collections import Counter
from tqdm.auto import tqdm, trange
import pandas as pd
train = pd.read_csv('train.csv')
X_train=list(train['text'])
tokenizer = BertTokenizerFast.from_pretrained('./rubert-tiny')
cnt = Counter()
for text in tqdm(X_train):
cnt.update(tokenizer(str(text))['input_ids'])
resulting_vocab = {
tokenizer.vocab[k] for k in tokenizer.special_tokens_map.values()
}
for k, v in cnt.items():
if v > 5:
resulting_vocab.add(k)
resulting_vocab = sorted(resulting_vocab)
tokenizer.save_pretrained('./bert_distill');
inv_voc = {idx: word for word, idx in tokenizer.vocab.items()}
with open('./bert_distill/vocab.txt', 'w', encoding='utf-8') as f:
for idx in resulting_vocab:
f.write(inv_voc[idx] + '\n')
Инициализация весов
config = BertConfig(
emb_size=256,
hidden_size=256,
intermediate_size=256,
max_position_embeddings=512,
num_attention_heads=8,
num_hidden_layers=3,
vocab_size=tokenizer_distill.vocab_size
)
model = BertForPreTraining(config)
model.save_pretrained('./bert_distill')
from transformers import BertModel
# load model without CLS-head
teacher = BertForPreTraining.from_pretrained('./rubert-tiny')
tokenizer_teacher = BertTokenizerFast.from_pretrained('./rubert-tiny')
# copy input embeddings accordingly with resulting_vocab
model.bert.embeddings.word_embeddings.weight.data = teacher.bert.embeddings.word_embeddings.weight.data[resulting_vocab, :256].clone()
model.bert.embeddings.position_embeddings.weight.data = teacher.bert.embeddings.position_embeddings.weight.data[:, :256].clone()
# copy output embeddings
model.cls.predictions.decoder.weight.data = teacher.cls.predictions.decoder.weight.data[resulting_vocab, :256].clone()
MLM-loss
inputs = tokenizer_distill(texts, return_tensors='pt', padding=True, truncation=True, max_length=16)
inputs = preprocess_inputs(inputs, tokenizer_distill, data_collator)
outputs = model(**inputs, output_hidden_states=True)
loss += nn.CrossEntropyLoss(
outputs.prediction_logits.view(-1, model.config.vocab_size),
inputs['labels'].view(-1)
)
KL-loss
def loss_kl(inputs, outputs, model, teacher, vocab_mapping, temperature=1.0):
new_inputs = torch.tensor(
[[vocab_mapping[i] for i in row] for row in inputs['input_ids']]
).to(inputs['input_ids'].device)
with torch.no_grad():
teacher_out = teacher(
input_ids=new_inputs,
token_type_ids=inputs['token_type_ids'],
attention_mask=inputs['attention_mask']
)
# the whole batch, all tokens after the [cls], the whole dimension
kd_loss = torch.nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(outputs.prediction_logits[:, 1:, :] / temperature, dim=1),
F.softmax(teacher_out.prediction_logits[:, 1:, vocab_mapping] / temperature, dim=1)
) / outputs.prediction_logits.shape[-1]
return kd_loss
MSE-loss
input_teacher = {k: v for k, v in tokenizer_teacher(
texts,
return_tensors='pt',
padding=True,
max_length=16,
truncation=True
).items()}
with torch.no_grad():
out_teacher = teacher_mse(**input_teacher)
embeddings_teacher_norm = torch.nn.functional.normalize(out_teacher.pooler_output)
input_distill = {k: v for k, v in tokenizer_distill(
texts,
return_tensors='pt',
padding=True,
max_length=16,
truncation=True
).items()}
out = model(**input_distill, output_hidden_states=True)
embeddings = model.bert.pooler(out.hidden_states[-1])
embeddings_norm = torch.nn.functional.normalize(adapter_emb(embeddings))
loss = torch.nn.MSELoss(embeddings_norm, embeddings_teacher_norm)
Размер итоговой модели составил 16 Мб, метрика F1 0.86. Учил модель я 12 часов на макбук эйр 19 года с i5 и 8 Гб оперативной памяти. Думаю, если погонять подольше, то и результат будет получше.
Код и данные для обучения представлены на github, замечания, дополнения и исправления приветствуются.