訂閱
糾錯
加入自媒體

使用數據增強從頭開始訓練卷積神經網絡(CNN)

2022-11-24 14:21
磐創AI
關注

介紹

該文致力于處理神經網絡中的過度擬合。

過度擬合將是你主要擔心的問題,因為你僅使用 2000 個數據樣本訓練模型。存在一些有助于克服過度擬合的方法,即 dropout 和權重衰減(L2 正則化)。

我們將討論數據增強,這是計算機視覺獨有的,在使用深度學習模型解釋圖像時,數據增強在任何地方都會用到。

數據增強

學習示例不足會阻止你訓練可以泛化到新數據的模型,從而導致過度擬合。如果你有無限的數據,你的模型將暴露于當前數據分布的所有特征,從而防止過度擬合。

通過增加具有不同隨機變化的樣本來產生逼真的圖像,數據增強使用現有的訓練樣本來生成更多的訓練數據。

你的模型不應在訓練期間兩次查看同一圖像。這使模型更加通用并暴露了數據的其他特征。

Keras 可以通過使用ImageDataGenerator函數定義要應用于圖像的各種隨機變換來實現這一點。

讓我們從一個插圖開始。

####-----data augmentation configuration via ImageDataGenerator-------####

datagen = ImageDataGenerator(

rotation=40,

width_shift=0.2,

height_shift=0.2,

shear=0.2,

zoom=0.2,

horizontal_flip=True,

fill_mode='nearest')

讓我們快速回顧一下這段代碼:

· rotation:這是圖像隨機旋轉的范圍。它的容量在(0-180)度之間。

· width_shift 和 height_shift:范圍(作為總寬度或高度的一部分),在其中垂直或水平隨機翻轉圖片。

· shear:用于隨機應用剪切變換。

· zoom:用于隨機縮放圖像。

· Horizontal_flip :用于隨機水平翻轉一半圖像

· fill_mode:是用于填充新生成的像素的方法,這些像素可能在旋轉或寬度/高度變化后出現。

顯示增強圖像

####-----Let's display some randomly augmented training images-------####

from keras.preprocessing import image

fnames = [os.path.join(train_cats_dir, fname) for fname in os.listdir(train_cats_dir)]

img_path = fnames[3]

img = image.load_img(img_path, target_size=(150, 150))

x = image.img_to_array(img)

x = x.reshape((1,) + x.shape)

i = 0

for batch in datagen.flow(x, batch_size=1):

plt.figure(i)

imgplot = plt.imshow(image.array_to_img(batch[0]))

i += 1

if i % 4 == 0:

  break

plt.show()

圖:使用數據增強生成貓圖片

如果你使用數據增強設置訓練新網絡,網絡將永遠不會收到兩次相同的輸入。

然而,因為它只接收來自少量原始照片的輸入,這些輸入仍然是高度相關的;你只能重新混合已經存在的信息。

因此,這可能不足以消除過度擬合。在密集鏈接分類器之前,你應該在算法中包含一個 Dropout 層,以進一步對抗過度擬合。

實時數據增強應用

1. 醫療保健

管理數據集不是醫學成像應用的解決方案,因為獲取大量經過專業標記的樣本需要很長時間和金錢。

通過增強設計的網絡必須比類似 X 射線圖片中的預測變化更可靠和真實。但是,我們可以通過使用數據增強來增加后續插圖中的數據數量。

圖:X 射線圖像中的數據增強

2. 自動駕駛汽車

自動駕駛汽車是一個不同的使用主題,其中數據增強是有益的。

例如,CARLA旨在在物理模擬中產生靈活性和真實感。CARLA 旨在促進自動駕駛系統的結果、指導和驗證。它基于虛幻引擎 4,并提供了一個完整的模擬器環境,用于在安全的環境中測試自動駕駛技術。

當數據稀缺成為問題時,使用強化學習技術創建的模擬環境可以幫助人工智能系統的訓練和測試。對模擬環境進行建模以創建真實場景的能力為數據增強開辟了一個充滿可能性的世界。

從頭開始定義 CNN 模型

####------Defining CNN, including dropout--------####

model = models.Sequential()

model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(128, (3, 3), activation='relu'))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(128, (3, 3), activation='relu'))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Flatten())

model.add(layers.Dropout(0.5))

model.add(layers.Dense(512, activation='relu'))

model.add(layers.Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])

讓我們使用數據增強和損失函數來訓練網絡。

####-------Train CNN using data-augmentation--------#####

train_datagen = ImageDataGenerator(rescale=1./255, rotation=40, width_shift=0.2, height_shift=0.2, shear=0.2, zoom=0.2, horizontal_flip=True,)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=32, class_mode='binary')

validation_generator = test_datagen.flow_from_directory(validation_dir, target_size=(150, 150), batch_size=32, class_mode='binary')

history = model.fit_generator(train_generator, steps_per_epoch=100, epochs=100, validation_data=validation_generator, validation_steps=50)

####-------Save the model--------#####

model.save('cats_and_dogs_small_2.h5')

由于數據增強和丟失,模型不再過度擬合。因為訓練曲線和驗證曲線彼此接近。有了這個準確度,你就超過了非正則化模型 15%,達到了 82%。讓我們繪制曲線。

在訓練期間顯示損失曲線和準確度

通過使用其他正則化方法和微調網絡參數(例如每個卷積層的過濾器數量或網絡中的層數),你可以實現更高的準確度,高達 86% 或 87%。

但是,由于你要處理的數據很少,因此僅通過從頭開始訓練自己的 CNN 來達到更高的水平將是一項挑戰。

你必須采用預訓練模型作為進一步的步驟,以提高你在此挑戰中的準確性。

結論

1. 訓練數據的質量、數量和上下文本質會顯著影響深度學習模型的準確性。但開發深度學習模型的最大問題之一是缺乏數據。

2. 在生產使用方法中獲取此類數據可能既昂貴又耗時。公司使用數據增強這一低成本且高效的技術來更快地開發高精度 AI 模型,并減少對收集和準備訓練實例的依賴。

3. 本文解釋了我們如何使用數據增強技術來訓練我們的模型。當收集大量數據具有挑戰性時,會使用數據增強。正如博客中所討論的,醫療保健和無人駕駛汽車是使用這種方法的兩個最著名的領域。

       原文標題 : 使用數據增強從頭開始訓練卷積神經網絡(CNN)

聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權或其他問題,請聯系舉報。

發表評論

0條評論,0人參與

請輸入評論內容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續

暫無評論

暫無評論

    人工智能 獵頭職位 更多
    掃碼關注公眾號
    OFweek人工智能網
    獲取更多精彩內容
    文章糾錯
    x
    *文字標題:
    *糾錯內容:
    聯系郵箱:
    *驗 證 碼:

    粵公網安備 44030502002758號