Close

Сохранение весов нейронной сети

Процесс обучения нейросети часто происходит часами и днями, и не будет лишним периодически сохранять веса во избежание потери времени.
Для сохранения нам понадобится пакет h5py:

 

Теперь импортируем класс ModelCheckpoint и указываем имя файла в котором сохранятся веса.

Здесь указывается то, что мы измеряем и каким образом. В нашем примере, в случае роста(mode=’max’) точности на тестовом сете (monitor=’acc’) модель сохранит свои веса в файл weights.hdf5.
Этот чекпоинт добавляется в список колбеков в вызове метода fit модели.

Теперь каждый раз при улучшении точности, веса нейронной сети будут сохраняться.

Загрузить веса еще проще. Достаточно вызвать метод load_weights перед вызовом compile.

После вызова compile модели можно делать предсказания на тестовом сете на расчитанных и сохраненных ранее весах.

Как сохранить саму модель, я описал здесь а сама задача классификации описана тут.

Ниже приведён полный код примера.

 

Поделиться: