We were very impressed with simplicity that it offers from user perspective and we want to talk about using it as preferred method of distributed training for TensorFlow.
Compare these two code fragments.
import argparse
import sys
import tensorflow as tf
FLAGS = None
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
# Build model...
loss = ...
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps.
hooks=[tf.train.StopAtStepHook(last_step=1000000)]
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir="/tmp/train_logs",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
# mon_sess.run handles AbortedError in case of preempted PS.
mon_sess.run(train_op)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
2) MPI:
import argparse
import sys
import tensorflow as tf
import tensorflow.contrib.mpi as mpi
FLAGS = None
def main(_):
# Build model...
loss = ...
global_step = tf.contrib.framework.get_or_create_global_step()
opt = tf.train.AdagradOptimizer(0.01)
# Add MPI Distributed Optimizer
opt = mpi.DistributedOptimizer(opt)
train_op = opt.minimize(loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps.
hooks=[tf.train.StopAtStepHook(last_step=1000000)]
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(checkpoint_dir="/tmp/train_logs",
hooks=hooks,
session_creator=mpi.SessionCreator()) as mon_sess:
while not mon_sess.should_stop():
# Perform synchronous training.
mon_sess.run(train_op)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Your other flags go here.
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Code above looks exactly like what somebody would write if they were writing single-GPU code. All the distribution is done by adding mpi.DistributedOptimizer() and mpi.SessionCreator(). Latter doesn't exist in Baidu MPI because it predated MonitoredTrainingSessions, but it can be easily added along the lines of:
class MPISessionCreator(SessionCreator):
def create_session(self):
gpu_to_use = mpi.local_rank().eval()
return mpi.Session(gpu=gpu_to_use)
Additionally, NCCL 2.0 is around the corner, which should improve all-reduce performance even further.
I realize that it takes some efforts to set up MPI, but once you make that investment from platform perspective user benefit is very significant.
What does community feel about doing distributed training via all-reduce route?
--