def optimize_graph(graph, output_graph):
"""Takes a Python Graph object and optimizes the graph.
Args:
graph: tf.Graph tensorflow dataflow graph
"""
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.optimizers[:] = [
'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning',
'constfold', 'arithmetic', 'dependency'
]
meta_graph = tf.train.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
optimized_graph = tf_optimizer.OptimizeGraph(
rewriter_config, meta_graph, cluster=get_cluster())
extract_weights(optimized_graph, output_graph)
return optimize_graph
But I have no idea what exactly is happening here.
It would be nice if some official documentation is provided about this optimization.