T
TensorFlow
@tensorflowblog1.5K подп.
2.3Kпросмотров
11 сентября 2024 г.
📷 ФотоScore: 2.6K
🌟SALSA: Стабильная адаптация линейного поиска Armijo. SALSA (Stable Armijo Line Search Adaptation) — метод, разработанный для оптимизации Learning Rate (LR) во время обучения. Основная концепция метода построена вокруг выполнения линейного поиска для определения наилучшего возможного LR для каждого шага обучения, что дает быструю сходимость и улучшенное обобщение. Чтобы уменьшить вычислительную нагрузку, Salsa предлагает пошаговый миниатюрный линейный поиск. В нем LR постепенно увеличивается с каждым шагом, а критерий линейного поиска постоянно переоценивается. Дополнительно, Salsa включает экспоненциальное сглаживание в процесс линейного поиска и устанавливает два экспоненциальных скользящих средних для скорости обучения. Это помогает стабилизировать оптимизацию и уменьшить нестабильность от мини-пакетирования. Экспериментальные результаты показывают, что Salsa превосходит другие методы оптимизации: 50% сокращение final loss и 1,25 average rank в языковых и графических задачах. Вычислительные издержки Salsa всего на 3% выше, чем у базового LR метода, что можно воспринимать как незначительным увеличением, учитывая показатели производительности. Salsa достаточно универсален, чтобы использоваться с различными оптимизаторами, и особенно эффективен при обучении современных архитектур, которые чувствительны к скорости обучения. ▶️Локальный запуск: # Clone repository: git clone https://github.com/TheMody/No-learning-rates-needed-Introducing-SALSA-Stable-Armijo-Line-Search-Adaptation.git # Create & activate env: conda env create -f environment.yml conda activate sls3 # Install dependencies: pip install pytorch numpy transformers datasets tensorflow-datasets wandb # NOTE: custom optimizer is in \salsa\SaLSA.py,comparison version are in \salsa\adam_sls.py: from salsa.SaLSA import SaLSA self.optimizer = SaLSA(model.parameters()) # NOTE: typical pytorch forward pass needs to be changed to: def closure(backwards = False): y_pred = model(x) loss = criterion(y_pred, y) if backwards: loss.backward() return loss optimizer.zero_grad() loss = optimizer.step(closure = closure) 📌Лицензирование :  MIT License 🟡Arxiv 🟡Датасет Cifar-10 🟡Youtube video 🖥Github [ Stars: 11 | Issues: 0 | Forks: 0] @ai_machinelearning_big_data #AI #LLM #ML #Train #SALSA
2.3K
просмотров
2326
символов
Да
эмодзи
Да
медиа

Другие посты @tensorflowblog

Все посты канала →
🌟SALSA: Стабильная адаптация линейного поиска Armijo. SALSA — @tensorflowblog | PostSniper