Thanks for the quick response. I will do profiling on the chief node. I also doubt I might have messed up big time with PS, so I will paste PS-specific code here.
def dist_train(args):
os.environ["GRPC_FAIL_FAST"] = "use_caller"
cluster_resolver = TFConfigClusterResolver()
if cluster_resolver.task_type in ('ps', 'worker'):
logging.info(
"[{}] Start {}({})...".format(datetime.now(), cluster_resolver.task_type, cluster_resolver.task_id))
server = tf.distribute.Server(
cluster_resolver.cluster_spec(),
job_name=cluster_resolver.task_type,
task_index=cluster_resolver.task_id,
protocol=cluster_resolver.rpc_layer or "grpc",
start=True)
print(cluster_resolver.cluster_spec())
server.join()
if cluster_resolver.task_type == 'chief':
NUM_PS = len(cluster_resolver.cluster_spec().as_dict().get("ps", ()))
# wait for workers to be ready
time.sleep(5)
logging.info(
"[{}] Start {}({})...".format(datetime.now(), cluster_resolver.task_type, cluster_resolver.task_id))
variable_partitioner = (
tf.distribute.experimental.partitioners.FixedShardsPartitioner(
num_shards=NUM_PS))
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver,
variable_partitioner=variable_partitioner)
with strategy.scope():
keras_model = make_model(args)