diff --git a/mnist-distributed/mirrored_mnist.py b/mnist-distributed/mirrored_mnist.py index f410e7f6..9b0fc879 100644 --- a/mnist-distributed/mirrored_mnist.py +++ b/mnist-distributed/mirrored_mnist.py @@ -6,11 +6,14 @@ import numpy as np import tensorflow as tf import argparse +import ast MODEL_DIR = "/model/" input_shape = (28, 28, 1) num_classes = 10 +TF_CONFIG = os.environ.get('TF_CONFIG') + def data_loader(hyperparams): f = gzip.open('/mnist/mnist.pkl.gz', 'rb') dataset = pickle.load(f, encoding='bytes') @@ -30,7 +33,8 @@ def data_loader(hyperparams): ) def model_with_strategy(learning_rate): - strategy = tf.distribute.MirroredStrategy() + gpus = tf.config.list_logical_devices('GPU') + strategy = tf.distribute.MirroredStrategy(gpus) with strategy.scope(): model = keras.Sequential( [ @@ -55,22 +59,23 @@ def __init__(self, learning_rate=0.01, batch_size=64, epochs=2): self.train_dataset, self.test_dataset = data_loader(hyperparams) def train(self): - model = model_with_strategy(self.learning_rate) - #steps per epoch are reduced here to train on limited resources - #you are free to remove this argument - history = model.fit(self.train_dataset, - epochs=self.epochs, - shuffle=True, - steps_per_epoch=30, - validation_data=self.test_dataset, - validation_steps=30, - verbose=0) - - tf.saved_model.save(model,MODEL_DIR + str(1)) - val_losses = history.history['val_loss'] - val_accuracies = history.history['val_accuracy'] - for epoch, val_loss, val_accuracy in zip(range(self.epochs), val_losses, val_accuracies): - print("epoch {}:\nval_loss={:.2f}\nval_accuracy={:.2f}\n".format(epoch + 1, val_loss, val_accuracy)) + if TF_CONFIG and ast.literal_eval(TF_CONFIG)['task']['type'] != 'ps': + model = model_with_strategy(self.learning_rate) + #steps per epoch are reduced here to train on limited resources + #you are free to remove this argument + history = model.fit(self.train_dataset, + epochs=self.epochs, + shuffle=True, + steps_per_epoch=30, + validation_data=self.test_dataset, + validation_steps=30, + verbose=0) + + tf.saved_model.save(model,MODEL_DIR + str(1)) + val_losses = history.history['val_loss'] + val_accuracies = history.history['val_accuracy'] + for epoch, val_loss, val_accuracy in zip(range(self.epochs), val_losses, val_accuracies): + print("epoch {}:\nval_loss={:.2f}\nval_accuracy={:.2f}\n".format(epoch + 1, val_loss, val_accuracy)) if __name__ == "__main__":