KAN нейросеть решение задачи классификации

Автор статьи, участник AI-Лаборатории УИИ О.В. Еклашева
Я - новичок, который интересуется нейросетями. Краем уха я услышала, что появилась новая сеть, построенная на принципах теоремы Колмогорова- Арнольда. В новой сети, если совсем по простому, функции активации располагаются не после слоя нейронной сети, а перед входом в нейрон. Причем функции могут быть разными для каждого нейрона. В принципе сейчас по всем непонятным вопросам я, как и многие, сначала лезу в гугл.

Выдача гугл по поиску «нейросеть kan» на 28 мая 2024 года

  1. Первая в выдаче статья на Harb.com https://habr.com/ru/news/811619/ “Исследователи разработали принципиально новую архитектуру нейросетей, которая работает лучше перцептрона” суть которой заключается в том, что представлена новая архитектура глубокого обучения. И дана ссылка на оригинал статьи на английском языке https://arxiv.org/pdf/2404.19756
  2. Третья в выдаче статья на Harb.com https://habr.com/ru/articles/812147/ “На практике пробуем KAN – принципиально новую архитектуру нейросетей”
3. Четвертая в выдаче на сайте TAdvister.ru статья Kolmogorov-Arnold Networks (KAN)
4. Пятая в выдаче на сайте VC.ru статья Новый убийца нейросетей? Сеть Колмогорова- ...(KAN)
5. Шестая в выдаче на сайте datasecrets.ru статья Принципиально новую архитектуру

И так далее
При внимательном чтении замечаешь, что все статьи опираются на оригинал статьи на английском языке https://arxiv.org/pdf/2404.19756 , содержат рисунки взятые из этой статьи и, что прекрасно, ссылаются на github создателей этой сети.
Есть также перевод этой статьи на русский https://datasecrets.ru/articles/9 .
Сама статья написана восемью учеными (Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljaci, Thomas Y. Hou, Max Tegmark), представляющими Массачусетский технологический институт, Калифорнийский технологический институт, Северо-Восточный университет и Институт искусственного интеллекта и фундаментальных взаимодействий NSF . И на их github хранится вся нужная информация и новые модули, которые позволяют попробовать поработать с новой сетью.

На мой взгляд, лучший способ что-то понять - это самостоятельно сделать что-то руками. Я зашла на githab KindXiaoming/pykan, который писал про KAN и скачала ноут для работы, но работа сразу застопорилась. Строка !pip install kan , с которой началась программа, вроде бы выполнялась, но потом колаб не находил модель KAN ради которого, все и затевалось. Проведя некоторое время в бесплодных попытках заставить работать программу, я написала этому человеку по электронной почте и, о чудо, практически сразу же получила ответ. Отказывается нужно писать !pip install pykan . С такой строкой программа начала работать.

Давайте попробуем разобрать в чем принципиальное отличие архитектуры KAN от архитектуры многослойного персептрона. Если говорить совсем для новичков, то в случае многослойного персептрона сеть состоит из слоев, слой из нейронов и есть связи, которые соединяют нейроны различных слоев. На этих связях располагаются веса. Именно эти веса подбирает механизм обратного распространения ошибки. И внутри каждого слоя находится функция активации, как правило линейная и одинаковая для всего слоя. Каждый нейрон складывает входящие в него веса, и потом передает значение, измененное функцией активации. Итак, есть ядра, веса на ребрах между ядрами и функция активации в ядре. Эта простая конструкция находится в основе систем компьютерного зрения и больших языковых моделей, на основе которых создана ChatGPT, создающая картинки, музыку и программный код, сочиняющая стихи и прочее.

Создатели KAN переместили функцию активации из нейронов на ребра, и теперь механизм обратного распространения ошибки подбирает не просто число, а функцию.
Это, с одной стороны, позволило сделать архитектуру сети намного меньше. В сети KAN намного меньше и слоев и нейронов, с другой повысилась объяснимость моделей. В KAN можно посмотреть функцию, лежащую на каждом ребре.

Объяснимость моделей очень важна, чтобы наглядно наблюдать влияние различных факторов на результат. Когда я изучала эконометрику, нам говорили, что когда выбирают среди двух моделей, пусть первая имеет не слишком хорошие показатели качества, но имеет параметры, которые можно интерпретировать, а другая - хорошие показатели качества модели, но не интерпретируемые параметры, нужно использовать обе. Одну для интерпретации, а другую - для прогнозирования. Сейчас сети, построенные на основе многослойного персептрона, воспринимаются как черный ящик, в котором непонятно что находится, но который более менее успешно справляется с задачей.

Рассмотрим как пользоваться новой технологией на примере решения классической задачи классификации. В качестве данных взята задача классификации грибов с сайта Kaggle.com (https://www.kaggle.com/datasets/prishasawhney/mushroom-dataset) - первая в выдаче с высоким удобством использования и делением на два класса. Данный датасет содержит 9 столбцов - 8 независимых переменных: Cap Diameter (диаметр шляпки), Cap Shape (форма шляпки), Gill Attachment (способ прикрепления нижней части шляпки), Gill Color (цвет нижней части шляпки), Stem Height (высота ножки), Stem Width (ширина ножки), Stem Color (цвет ножки), Season (сезон) и 1 целевая функция - Target Class - Is it edible or not? (целевой класс - съедобно или нет). Целевой класс содержит два значения — 0 или 1, где 0 относится к съедобному, а 1 — к ядовитому. Все данные - числовые. Датасет состоит из 54035 наблюдения, пропусков в данных нет.
Настраиваем взаимодействие между github создателей и python
Инсталлируем модуль pykan
Импортируем из kan модели и создаем модель KAN, другими словами задаем три числа: первое число - количество входов, оно должно равняться числу входных переменных, второе число - количество промежуточных слоев и третье число - количество выходных переменных. В данном случае есть 8 независимых переменных и одна целевая функция. И давайте возьмем 4 промежуточных слоя, кубический сплайн (k=3) и 3 промежуточных интервала (grid=3)
Подключаем библиотеки, которые будем использовать
Получаем доступ к гугл-диску. Я скачала файл с данными на свой диск
Получаем исходные данные из файла
Разделяем данные на независимые переменные (все кроме последнего столбца) и целевую функцию (последний столбец)
Разделяем выборки на тренировочную и тестовую
Нормализуем данные и преобразуем их в тензоры
Готовим датасет для подачи в модель. Для модели KAN данные - это словарь с определенными ключами. Сначала создаем пустой датасет, потом ставит в соответствие определенным ключам нужные значения.
Обучаем модель. На вход подаем подготовленный датасет, задаем специальный оптимизатор LBFGS, который подбирает функции активации на ребрах, соединяющих нейроны различных слоев. Параметр steps определяет количество эпох, которая модель обучается.
После этого модель довольно долго обучается (2 минут 1 секунда) и выдает весьма средние результаты своей работы. Ошибка на тренировочной выборке 0,451, ошибка на тестовой выборке - 0,450. Низкая точность распознавания классов. Вероятность 0,45 съесть ядовитый гриб пугает.
Можно посмотреть график модели
Данная модель имеет такой график. Видимо все ребра не нарисованы, чтобы не загромождать рисунок.
Также можно посмотреть вид функции активации, найденной в результате обучения модели, используя функцию
Эта ячейка выполнялась очень долго, в данном случае - почти 60 минут и выдала следующий результат. Первое число наверное означает номер ряда, второе число - номер нейрона в этом ряду и третье число - номер нейрона в следующем ряду, с которым есть связь, потом есть вид функции, например тангенс, синус и модуль или корень квадратный, и r2, то есть доля вариации, которую объясняет данная функция.
Еще можно посмотреть символьный вид функции
Вот формула
Также можно рассчитать предсказания.
Предсказания получаются в виде вероятности, что гриб окажется ядовитым
Округлим, чтобы можно было построить матрицу confusion
При классическом построении архитектуры многослойного персептрона при 40 эпохах затрачивается примерно такое время и достигается значительно лучшее качество
Эта ячейка выполнялась очень долго, в данном случае - почти 60 минут и выдала следующий результат. Первое число наверное означает номер ряда, второе число - номер нейрона в этом ряду и третье число - номер нейрона в следующем ряду, с которым есть связь, потом есть вид функции, например тангенс, синус и модуль или корень квадратный, и r2, то есть доля вариации, которую объясняет данная функция.
Вот график обучения. Максимальное качество модели 0,97. Доля верных ответов растет и на обучающем наборе и на тестовом. Видно, что предел обучения еще не достигнут.
Для классической модели многослойного персептрона confusion матрица имеет вид
Сравним решение задач, по времени работы (я закомментировала ячейки, корорые отвечают за построение функции для нейросети KAN, так они выполняются очень долго).
Вставив ячейки, после выполнения обучения обеих моделей, можно узнать использования ресурсов . А также я измерила время выполнения обучения обеих сетей.
Для нейросети KAN
Время выполнения: 116.54228949546814 секунд
Загрузка CPU: 3.0%
Количество физических ядер CPU: 1
Частота CPU: 2199.998 MHz
Использование памяти процессом: 1833586688 bytes

Для многослойного персептрона
Время выполнения: 94.9920506477356 секунд
Загрузка CPU: 2.5%
Количество физических ядер CPU: 1
Частота CPU: 2199.998 MHz
Использование памяти процессом: 1846673408 bytes

В данном случае архитектура с многослойным персептроном всухую выиграла у нейросети KAN. Сеть со стандартным многослойным персептроном оказалась быстрее, требовала меньших ресурсов и точнее в разы, и это даже не говоря о том, что количество эпох можно было ограничить 15, после этой эпохи особого роста точности нет.
Я даже расстроилась, поскольку я болела за KAN. Хотя, наверное, не стоило проверять работу новой модели на задачах, с которыми блестяще справляются уже имеющиеся модели. Нужно применять их в ситуации сложности и нелинейности. Для тех задач, с которыми плохо справляется архитектура с многослойным персептроном. Задним умом понимаешь, что все то логично. В нейросети KAN функции активации нелинейные, а с помощью нелинейных функции очень не просто описать линейную функцию. Наверно стоит попробовать использование нейросети KAN для анализа аудиосигналов, причем даже не для выделения основного звука, а для анализа фона. Но это тема уже для другого исследования. Продолжение следует…

Made on
Tilda