import numpy as np import tensorflow as tf import keras from keras.optimizers import Adam, SGD from keras.callbacks import ModelCheckpoint from keras.preprocessing.image import ImageDataGenerator from optimizers import MaSS from cifar10 import load from resnet import resnet_v1, resnet_v2 import os batch_size = 64 epochs = 200 data_augmentation = True num_classes = 10 # Load Cifar-10 data (x_train, y_train), (x_test, y_test) = load() input_shape = x_train.shape[1:] # Model parameters n = 5 version = 1 if version == 1: depth = n * 6 + 2 elif version == 2: depth = n * 9 + 2 # Model name, depth and version model_type = 'ResNet%dv%d' % (depth, version) ################################################################ # Whenever learning rate reduces, restart the MaSS optimizer at the latest learned weights. # (If not reducing learning rate during training, one can have only one stage.) # #Stage 1: epoch 1-150. Learning rate: 0.1 ################################################################ # Build model model = resnet_v1(input_shape=input_shape, depth=depth) mass = MaSS(lr = 0.1, alpha = 0.05, kappa_t = 2) model.compile(loss='categorical_crossentropy', optimizer=mass, metrics=['accuracy']) model.summary() # Prepare model saving directory. save_dir = os.path.join(os.getcwd(), 'saved_models') model_name = 'cifar10_%s_model_mass.{epoch:03d}.h5' % model_type if not os.path.isdir(save_dir): os.makedirs(save_dir) filepath = os.path.join(save_dir, model_name) checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_acc', verbose=1, save_best_only=False) callbacks = [checkpoint] # Training if not data_augmentation: result1 = model.fit(x_train, y_train, batch_size=batch_size, epochs=100, validation_data=(x_test, y_test), shuffle=True, callbacks=callbacks) else: print('Using real-time data augmentation.') # This will do preprocessing and realtime data augmentation: datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset samplewise_center=False, # set each sample mean to 0 featurewise_std_normalization=False, # divide inputs by std of dataset samplewise_std_normalization=False, # divide each input by its std zca_whitening=False, # apply ZCA whitening rotation_range=0, # randomly rotate images in the range (deg 0 to 180) width_shift_range=0.1, # randomly shift images horizontally height_shift_range=0.1, # randomly shift images vertically horizontal_flip=True, # randomly flip images vertical_flip=False) # randomly flip images res1 = datagen.fit(x_train) # Fit the model on the batches generated by datagen.flow(). result1 = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), validation_data=(x_test, y_test), epochs=150, verbose=1, workers=4, callbacks=callbacks) ################################################################ # Stage 2: epoch 151-225. Learning rate: 0.01 (i.e. lr reduce by 10 after 150 epochs) ################################################################ # Build model model = resnet_v1(input_shape=input_shape, depth=depth) mass = MaSS(lr = 0.01, alpha = 0.05, kappa_t = 2) model.compile(loss='categorical_crossentropy', optimizer=mass, metrics=['accuracy']) model.summary() # Load weights model.load_weights('saved_models/cifar10_ResNet32v1_model_mass.150.h5') # Prepare model saving directory. save_dir = os.path.join(os.getcwd(), 'saved_models') model_name = 'cifar10_%s_model_mass_150+.{epoch:03d}.h5' % model_type if not os.path.isdir(save_dir): os.makedirs(save_dir) filepath = os.path.join(save_dir, model_name) checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_acc', verbose=1, save_best_only=False) callbacks = [checkpoint] # Training if not data_augmentation: result2 = model.fit(x_train, y_train, batch_size=batch_size, epochs=100, validation_data=(x_test, y_test), shuffle=True, callbacks=callbacks) else: print('Using real-time data augmentation.') # Fit the model on the batches generated by datagen.flow(). result2 = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), validation_data=(x_test, y_test), epochs=75, verbose=1, workers=4, callbacks=callbacks) ################################################################ # Stage 3: epoch 226-300. Learning rate: 0.001 (i.e. lr reduce by 10 after 225 epochs) ################################################################ # Build model model = resnet_v1(input_shape=input_shape, depth=depth) mass = MaSS(lr = 0.001, alpha = 0.05, kappa_t = 2) model.compile(loss='categorical_crossentropy', optimizer=mass, metrics=['accuracy']) # Load weights model.load_weights('saved_models/cifar10_ResNet32v1_model_mass_150+.075.h5') # Prepare model saving directory. save_dir = os.path.join(os.getcwd(), 'saved_models') model_name = 'cifar10_%s_model_mass_225+.{epoch:03d}.h5' % model_type if not os.path.isdir(save_dir): os.makedirs(save_dir) filepath = os.path.join(save_dir, model_name) checkpoint = ModelCheckpoint(filepath=filepath, monitor='val_acc', verbose=1, save_best_only=False) callbacks = [checkpoint] # Training if not data_augmentation: result3 = model.fit(x_train, y_train, batch_size=batch_size, epochs=100, validation_data=(x_test, y_test), shuffle=True, callbacks=callbacks) else: print('Using real-time data augmentation.') # Fit the model on the batches generated by datagen.flow(). result3 = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), validation_data=(x_test, y_test), epochs=75, verbose=1, workers=4, callbacks=callbacks) # Score trained model. scores = model.evaluate(x_test, y_test, verbose=1) print('Test loss:', scores[0]) print('Test accuracy:', scores[1])