Меня зовут Борисов Павел, занимаюсь ML-исследованиями. Последние месяцы ковырялся с архитектурой MoE, где эксперты подключаются поверх замороженной модели. 22 эксперимента на одной RTX 4090, ниже разбор что получилось.
Берём предобученную языковую модель и замораживаем целиком, ни один вес не меняется. К каждому MLP-слою прикручиваем маленький обучаемый модуль, «эксперт». Сверху маршрутизатор — линейный слой на 37 тысяч параметров, который для каждого токена выбирает эксперта.
По сути это MoE [1], но отличие от Mixtral [2] или DeepSeek-MoE [3] принципиальное. Там маршрутизация это часть предобучения на триллионе токенов. Тут базовая модель не трогается, эксперты подключаются как плагины.
Обучение нового эксперта — три шага:
Изоляция. Замораживаем всё кроме одного эксперта. Прогоняем тексты нужной области (математику, код, научные статьи). ~15 минут на GPU.
Интеграция. Размораживаем только маршрутизатор. Показываем тексты из всех областей, он учится направлять токены к нужному эксперту. ~15 минут.
Горячее подключение. Новая область = новый эксперт, повторяем шаги 1-2. Старых не трогаем.
Деградация при подключении нового эксперта: 0.000%. На трёх масштабах. Веса базовой модели заморожены, катастрофическое забывание [4] исключено архитектурно.
Тестировал на GPT-2 (124M), Pythia-410M и Pythia-1B:
|
Базовая модель |
Областей |
Снижение перплексии |
Отставание от идеала |
Деградация |
|---|---|---|---|---|
|
GPT-2 124M |
4 |
33.4% |
6.6% |
— |
|
Pythia-410M |
6 |
34.3% |
5.9% |
0.000% |
|
Pythia-1B |
8 |
31.2% |
3.2% |
0.000% |
Снижение перплексии ~31-34% на всех масштабах. Отставание от оракула сужается с ростом: 6.6% → 5.9% → 3.2%.
Математический эксперт специализировался лучше всех, перплексия -85.1%, междоменный разрыв 64.9x. Разговорный хуже всех (3.1%) — слишком общий стиль, мало специфики.
Вспомогательные лоссы. Четыре варианта штрафов за неравномерную нагрузку (balance, diversity, entropy, importance), все ухудшили на 11-27%. Wang et al. [5] писали о том же, DeepSeek-V3 [6] от штрафов тоже отказался.
Совместное обучение экспертов и маршрутизатора — коллапс, точность с 80.8% до 73.7%. В DeepSeek-MoE [3] пришли к похожему.
Маршрутизация по меткам. Без подсказок маршрутизатор нашёл границы точнее (6.6%) чем с явными метками (7.3%).
Вместо штрафов использовал безлоссовую балансировку [5] — обучаемое смещение для выравнивания нагрузки. 100% экспертов живы на всех масштабах.
Итого: перплексия -31%, маршрутизация 96%, деградация нулевая. Всё отлично, пока не запустил бенчмарки.
31% снижение перплексии на Pythia-1B с 8 экспертами = +0.29 п.п. на MMLU. Почти ничего.
Ладно, Pythia маленькая. Взял Qwen 2.5 3B, она даёт 74.4% на GSM8K из коробки. Обучил математического эксперта. Перплексия на математике -23.9%, междоменный разрыв 64.9x, маршрутизация 100%.
GSM8K после подключения: 65.8%. Минус 8.6 п.п.
Перепроверил три раза. Разморозка верхних слоёв, совместное обучение, двухфазная схема — всё в районе -8.4...-8.6 п.п.
В чём дело: эксперт обучился на учебниках и статьях, выучил статистику языка математики, то есть что после «решим уравнение» скорее идёт формула, а не рецепт. Перплексия от этого снижается, но GSM8K-то требует логику рассуждения, а не знание частотности слов. Hu et al. [7] и Fang et al. [8] показывали околонулевую корреляцию перплексии с бенчмарками, и вот это ровно оно.
Маршрутизатор при этом работал на 0.4% отставания от оракула. На MMLU показал +0.15 п.п. выше оракульного выбора, то есть обходил эксперта на задачах где тот вредит.
Раз проблема в данных, попробовал другой подход: обучить эксперта на пошаговых решениях самой модели вместо сырого текста. По мотивам STaR [10], только STaR дообучивает всю модель, а тут внешний эксперт поверх замороженной.
Взял 750 задач GSM8K, Qwen решил 638 правильно. Получилось 119 тысяч токенов, это в 33 раза меньше чем 4 миллиона токенов сырого текста. Формат «Вопрос/Ответ», как при инференсе.
GSM8K: 75.5%. +1.1 п.п. к базе, +9.7 п.п. к варианту с сырым текстом. При этом перплексия ухудшилась на 17.8%.
Ещё заметил что формат данных важен: «Вопрос/Ответ» (совпадает с форматом инференса) дал +2-3 п.п. по сравнению с «Задача/Решение». Для сравнения, LoRA-вариант (13.4 млн параметров вместо 67.6 млн) показал 74.0%, всего -0.4 п.п. от базы, но маршрутизации там нет.
Дальше захотелось замкнуть цикл: модель решает задачи, обучаем эксперта на решениях, модель решает лучше, обучаем снова...
|
Цикл |
Верных |
Новых |
GSM8K |
|---|---|---|---|
|
0 (исходный) |
638/750 |
— |
75.5% |
|
1 |
658/750 |
+20 |
75.5% |
|
2 |
668/750 |
+10 |
76.0% |
+20, потом +10, затухает. Но 76.0% вроде есть. Проблема в другом.
При создании эксперта через QwenWithMoE() PyTorch инициализировал веса рандомно. Seed я не фиксировал. Разброс от инициализации ~5 п.п., а эффект цикла 0.5 п.п.
После torch.manual_seed(42) и увеличения выборки до 500 задач:
|
Seed |
GSM8K |
|---|---|
|
42 |
76.4% |
|
123 |
75.2% |
|
456 |
76.0% |
|
Среднее |
75.87% ± 0.61 п.п. |
С фиксированным seed и 500 задачами:
Холодный старт (свежий эксперт каждый цикл): 76.4% → 74.6% → 74.6%. Плато.
Тёплый старт (продолжаем обучение): 76.4% → 75.0% → 71.6%.
При тёплом старте перплексия продолжала падать: 1.58 → 1.45 → 1.36. По лоссам всё хорошо, а GSM8K деградирует. Причина: между циклами 85-90% задач повторялись, эксперт на них переобучался.
То «улучшение» 75.5% → 76.0% из таблицы выше — статистический шум. На 200 задачах доверительный интервал ~5 п.п., эффект 0.5 п.п.
Попробовал сглаживание меток, минус 9 п.п. на GSM8K. По Müller et al. [11] сглаживание делает варианты ответа более равновероятными. В классификации картинок это нормально, но в математике «15 минус 7» это 8, а не «скорее 8 чем 7». Каждый промежуточный шаг рассуждения должен быть точным.
С архитектурой всё хорошо: 0.000% деградации, 96% точность маршрутизации, безлоссовая балансировка [5] вместо штрафов. Самодистилляция дала +9.7 п.п. по сравнению с обучением на сыром тексте (119 тысяч токенов собственных решений vs 4 миллиона из учебников). Замкнуть цикл самоулучшения не получилось, задачи повторяются между итерациями и эксперт переобучается. И главное — перплексия оказалась бесполезной для оценки рассуждений, она может падать когда бенчмарки деградируют и расти когда они улучшаются.
Shazeer et al. Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. arxiv:1701.06538
Jiang et al. Mixtral of Experts. 2024. arxiv:2401.04088
Dai et al. DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models. 2024. arxiv:2401.06066
Kirkpatrick et al. Overcoming Catastrophic Forgetting in Neural Networks. PNAS 2017. arxiv:1612.00796
Wang et al. Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts. 2024. arxiv:2408.15664
DeepSeek-AI. DeepSeek-V3 Technical Report. 2024. arxiv:2412.19437
Hu et al. Can Perplexity Reflect Large Language Model's Ability in Long Text Understanding? 2024. arxiv:2405.06105
Fang et al. What is Wrong with Perplexity for Long-context Language Modeling? 2024. arxiv:2410.23771
Wei et al. Chain-of-Thought Prompting Elicits Reasoning in Large Language Models. NeurIPS 2022. arxiv:2201.11903
Zelikman et al. STaR: Bootstrapping Reasoning With Reasoning. NeurIPS 2022. arxiv:2203.14465
Müller, Kornblith, Hinton. When Does Label Smoothing Help? NeurIPS 2019. arxiv:1906.02629
Код и результаты экспериментов: GitVerse | GitFlic*
Источник


