前言
这次训练模型的时候出现了很重的过拟合效果,所以写一个dropout抑制下,顺便做个记录
代码
'''
过拟合:
定义:
在训练数据集上得分很高,但是在测试数据集上得分相对较低
解决方法:
使用dropout方法抑制过拟合:在训练的时候才会发挥作用,在测试的时候还是会使用所有的神经元
欠拟合:
定义:
在训练数据集上得分比较低,在测试数据集上的得分也比较低
解决方法:
增加训练的层数,提高训练的深度
So:
参数选择的原则:
1. 首先开发一个过拟合的模型
(1)添加更多的层数
(2)让每一层变得更大
(3)训练尽量多的次数
2. 然后,调参实现抑制过拟合
(1)!!!增加训练数据!!!
(2)dropout
(3)正则化
(4)图像增强
3. 继续调参直到再次过拟合
'''
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
os.system('cls')
(train_image,train_lable),(test_image,test_lable) = tf.keras.datasets.fashion_mnist.load_data()
# plt.imshow(train_image[0])
# plt.show()
train_lable_onehot = tf.keras.utils.to_categorical(train_lable)
print(train_lable_onehot)
test_lable_onehot = tf.keras.utils.to_categorical(test_lable)
print(test_lable_onehot)
train_image = train_image / 255
test_image = test_image / 255
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28))) # 隐藏层
# for i in range(20):
# model.add(tf.keras.layers.Dense(128,activation='relu')) 因为增加层数之后会增加模型的拟合能力,但过度增加层数只会让模型过拟合,这段话验证一下增加20个的效果
model.add(tf.keras.layers.Dense(128,activation='relu')) # 隐藏层,用relu进行激活
model.add(tf.keras.layers.Dropout(0.5)) # 添加Dropout层抑制过拟合
model.add(tf.keras.layers.Dense(128,activation='relu')) # 隐藏层,用relu进行激活
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128,activation='relu')) # 隐藏层,用relu进行激活
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10,activation='softmax')) # 输出层,用softmax进行激活
# model.compile(optimizer = 'adam',loss='categorical_crossentropy',metrics=['acc'])
# model.fit(train_image,train_lable_onehot,epochs = 100)
model.compile(optimizer = tf.keras.optimizers.Adam(lr=0.001),loss='categorical_crossentropy',metrics=['acc'])
history = model.fit(train_image,train_lable_onehot,epochs = 10,validation_data=(test_image,test_lable_onehot))
# print(history.history.keys())
# print(model.summary())
plt.plot(history.epoch,history.history.get('loss'),label = 'loss')
plt.plot(history.epoch,history.history.get('val_loss'),label = 'val_loss')
plt.show()
退出登录?