Дообучение модели MobileNet V3

MobileNet V3 ИИ дообучение

Приобщаемся к современным технологиям 🙂 не просто задаём вопросы в чате с ИИ-агентом, а пробуем кастомизировать (дообучить) уже обученную ML-модель, чтоб пощупать данный подход и возможно в будущем применять для решения каких либо прикладных задач.

Будем дообучать на «своих» данных(dataset-е) уже обученную модель MobileNetV3 распознания изображений, а именно научим MobileNet V3 на изображениях автомобилей определять марку и модель.

Выбрал именно авто, потому что такой dataset попался под руку на hugging face.

Я это проделал в Jyputer notebook, но можно и в обычном python-файлике. Использовал фреймворк TensorFlow 2.10+.

Это эксперимент не имеет практического применения для меня, это просто опробование технологии дообучения уже предобученой модели.

Ну приступим…

Установка зависимостей

pip install datasets tensorflow keras scikit-learn

Подготовка данных для до обучения

Загружаем dataset stanford_cars c HuggingFace

from datasets import load_dataset

#Папка где будем хранить dataset (кеш)
custom_cache_dir = "/tf/notebook/additional_train/cars_dataset"

# Загрузим датасет для дообучения и проверки
dataset = load_dataset("tanganke/stanford_cars", cache_dir=custom_cache_dir, split="train")

Подготовка данных в формате TensorFlow

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, applications
import numpy as np
import os
from sklearn.model_selection import train_test_split

os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
input_shape=(224, 224, 3)

# Функция для ресайза изображений в 224x224
def resize_image(img):
    img = img.convert('RGB')  # Убедимся, что изображение в RGB
    return np.array(img.resize((224, 224)))  # Ресайз под MobileNetV3

# Преобразуем в numpy массивы (для данных обучения)
images = np.array([resize_image(img) for img in dataset['image']])  # Теперь все изображения 224x224x3
labels = np.array(dataset['label'])

# Нормализация и разделение данных на обучающие и проверочные
X_train, X_val, y_train, y_val = train_test_split(images, labels, test_size=0.2, random_state=42)

# Преобразуем метки в one-hot encoding
num_classes = len(np.unique(labels))
y_train = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)
y_val = tf.keras.utils.to_categorical(y_val, num_classes=num_classes)

ImageDataGenerator для аугментации изображений

ImageDataGenerator — это мощный инструмент в библиотеке Keras (TensorFlow), предназначенный, чтобы на лету производить аугментацию изображений во время обучения нейронных сетей.

Аугментация данных (от англ. data augmentation) — это техника искусственного увеличения размера обучающей выборки путём применения различных преобразований к исходным данным, сохраняющих их смысловое содержание, но изменяющих их форму или представление.

Для изображений и генерация дополнительных изображений на основе исходного, т.е. добавление деформаций и изменений в копии изображения.

ImageDataGenerator позволяет:

  • Генерировать новые варианты изображений из исходного набора данных
  • Автоматически применять различные преобразования к изображениям
  • Увеличивать разнообразие обучающей выборки без ручного создания новых изображений
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v3.preprocess_input,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v3.preprocess_input
)

class_names = dataset.features['label'].names

Создание модели на базе MobileNet V3

Компиляция модели

from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras.optimizers import Adam

# Загрузка предобученной MobileNetV3 Large (можно попробовать Small для меньшего размера)
base_model = applications.MobileNetV3Small(
    input_shape=input_shape,
    include_top=False,
    weights='imagenet'  # используем веса imageNet
)

# Замораживаем базовую модель (опционально)
#base_model.trainable = False
base_model.trainable = True
for layer in base_model.layers[:20]:
    layer.trainable = False

inputs = keras.Input(shape=input_shape)
x = base_model(inputs, training=False)
x = GlobalAveragePooling2D()(x)
x = Dense(8096, activation='relu')(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)

model = Model(inputs, outputs)

model.compile(
        optimizer=Adam(learning_rate=0.0001),
        loss='categorical_crossentropy',
        metrics=['accuracy']

)

Обучение модели

import datetime
import os

os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'

epochs = 50
batch_size = 32

model.fit(
    train_datagen.flow(X_train, y_train, batch_size=batch_size),
    validation_data=val_datagen.flow(X_val, y_val),
    epochs=epochs,
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint('/tf/notebook/additional_train/best_model.keras', save_best_only=True)
    ]
)

model.summary()
model.save('/tf/notebook/additional_train/cars_model.keras')

Проверка модели

from PIL import Image
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import matplotlib.pyplot as plt

model_path = '/tf/notebook/additional_train/cars_model.keras'
model = tf.keras.models.load_model(model_path)

# 2. Функция для предобработки изображения
def preprocess_image(image_path, target_size=(224, 224)):
    """
    Загружает изображение любого размера, преобразует к нужному формату для MobileNetV3
    """
    img = load_img(image_path, target_size=target_size)

    # Конвертация в RGB (на случай если CMYK/grayscale)
    if img.mode != 'RGB':
        img = img.convert('RGB')

    # Нормализация и добавление batch-размерности
    img = tf.keras.applications.mobilenet_v3.preprocess_input(img)
    img = np.expand_dims(img, axis=0)

    return img

# 3. Пример использования
def predict_car(model, image_path, class_names=None):
    """
    Предсказывает марку автомобиля на изображении
    """
    # Загрузка и предобработка
    img = preprocess_image(image_path)

    # Предсказание
    predictions = model.predict(img)
    predicted_class = np.argmax(predictions[0])
    confidence = np.max(predictions[0])

    # Визуализация
    plt.imshow(Image.open(image_path))
    plt.axis('off')

    if class_names:
        plt.title(f"Predicted: {class_names[predicted_class]}\nConfidence: {confidence:.2%}\nClass: {predicted_class}")
    else:
        plt.title(f"Class: {predicted_class}\nConfidence: {confidence:.2%}")

    plt.show()

    return predicted_class, confidence

# 4. Запуск предсказания (пример)
test_image_path = "/tf/notebook/additional_train/test_imgs/"  # путь к изображениям

pred_class, confidence = predict_car(model, test_image_path + 'Fisker_Karma_sedan_2012.jpg', class_names)
pred_class, confidence = predict_car(model, test_image_path + 'honda_accord.jpg', class_names)
pred_class, confidence = predict_car(model, test_image_path + 'hammer.jpg', class_names)
pred_class, confidence = predict_car(model, test_image_path + 'bmw2012.jpg', class_names) 

Примеры работы

Оставить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *