2.3Kпросмотров
11 сентября 2024 г.
📷 ФотоScore: 2.6K
🌟SALSA: Стабильная адаптация линейного поиска Armijo. SALSA (Stable Armijo Line Search Adaptation) — метод, разработанный для оптимизации Learning Rate (LR) во время обучения. Основная концепция метода построена вокруг выполнения линейного поиска для определения наилучшего возможного LR для каждого шага обучения, что дает быструю сходимость и улучшенное обобщение. Чтобы уменьшить вычислительную нагрузку, Salsa предлагает пошаговый миниатюрный линейный поиск. В нем LR постепенно увеличивается с каждым шагом, а критерий линейного поиска постоянно переоценивается. Дополнительно, Salsa включает экспоненциальное сглаживание в процесс линейного поиска и устанавливает два экспоненциальных скользящих средних для скорости обучения. Это помогает стабилизировать оптимизацию и уменьшить нестабильность от мини-пакетирования. Экспериментальные результаты показывают, что Salsa превосходит другие методы оптимизации: 50% сокращение final loss и 1,25 average rank в языковых и графических задачах. Вычислительные издержки Salsa всего на 3% выше, чем у базового LR метода, что можно воспринимать как незначительным увеличением, учитывая показатели производительности. Salsa достаточно универсален, чтобы использоваться с различными оптимизаторами, и особенно эффективен при обучении современных архитектур, которые чувствительны к скорости обучения. ▶️Локальный запуск: # Clone repository:
git clone https://github.com/TheMody/No-learning-rates-needed-Introducing-SALSA-Stable-Armijo-Line-Search-Adaptation.git # Create & activate env:
conda env create -f environment.yml
conda activate sls3 # Install dependencies:
pip install pytorch numpy transformers datasets tensorflow-datasets wandb # NOTE: custom optimizer is in \salsa\SaLSA.py,comparison version are in \salsa\adam_sls.py:
from salsa.SaLSA import SaLSA
self.optimizer = SaLSA(model.parameters()) # NOTE: typical pytorch forward pass needs to be changed to:
def closure(backwards = False): y_pred = model(x) loss = criterion(y_pred, y) if backwards: loss.backward() return loss
optimizer.zero_grad()
loss = optimizer.step(closure = closure) 📌Лицензирование : MIT License 🟡Arxiv
🟡Датасет Cifar-10
🟡Youtube video
🖥Github [ Stars: 11 | Issues: 0 | Forks: 0] @ai_machinelearning_big_data #AI #LLM #ML #Train #SALSA