꽃분류데이터셋.py

꽃 이미지 5종 분류 - CNN + 데이터 증강

핵심 개념

  • image_dataset_from_directory 활용
  • 데이터 증강 (RandomFlip, RandomRotation, RandomZoom)
  • Conv2D + MaxPooling2D 구조
  • 다중 분류 (5개 클래스)
  • 모델 및 학습 히스토리 저장
#https://www.tensorflow.org/tutorials/images/classification

import os, shutil
import pickle
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from tensorflow.keras import optimizers

import numpy as np
import os
import random
import PIL.Image as pilimg
import imghdr
import pandas as pd
import tensorflow as tf

def study():
    batch_size = 32  #한번에 불러오는 이미지 개수
    img_height = 180 #이미지의 높이
    img_width  = 180  #이미지의 넓이

    # 데이터 증강
    data_augmentation = keras.Sequential(
        [
            layers.RandomFlip("horizontal",
                        input_shape=(img_height, img_width, 3)),
            layers.RandomRotation(0.1),
            layers.RandomZoom(0.1),
        ]
    )

    # 모델 구축
    model = models.Sequential()

    model.add(data_augmentation)
    model.add(layers.Rescaling(1./255)) #스케일링

    model.add(layers.Conv2D(32, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dropout(0.2)) #과대적합 방지
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dense(128, activation='relu'))
    model.add(layers.Dense(5, activation='softmax')) #5개로 분류

    model.compile(optimizer='adam',
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])

    train_dir = "./dataset/flowers/train"
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        train_dir,
        validation_split=0.2,
        subset="training",  #훈련용
        seed=123,
        image_size=(img_height, img_width),
        label_mode="int",
        batch_size=batch_size)

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        train_dir,
        validation_split=0.2,
        subset="validation", #검증용
        seed=123,
        image_size=(img_height, img_width),
        label_mode="int",
        batch_size=batch_size)

    epochs=10

    history = model.fit(
        train_ds,
        epochs=epochs,
        validation_data=val_ds
    )

    # 모델 저장하기
    model.save('flowers_model.keras')
    f = open("flowers_hist.hist", "wb")
    pickle.dump( history.history, file=f)
    f.close()

def drawChart():
    f = open("flowers_hist.hist", "rb")
    history = pickle.load(f)
    f.close()
    print( history.keys())
    acc = history['accuracy']
    val_acc = history['val_accuracy']
    loss = history['loss']
    val_loss = history['val_loss']

    epochs = range(len(acc))

    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'b', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend()

    plt.figure()
    plt.plot(epochs, loss, 'bo', label='Training loss')
    plt.plot(epochs, val_loss, 'b', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend()

    plt.show()

if __name__=="__main__":
    while(True):
        print("1. 기본학습   ")
        print("2. 차트")
        print("3. 예측하기 ")
        print("4. 평가하기 ")
        sel = input("선택 : ")
        if sel=="1":
            study()
        elif sel=="2":
            drawChart()
        elif sel=="3":
            Predict()
        elif sel=="4":
            Evaluate()
        else:
            break