Visualizing a TensorFlow graph in Jupyter doesn't work

3k views Asked by At

I saw this question about how to visualise a tensorflow graph in Jupyter notebook. I found that this answer comes from this example with just one modification in (tensor.tensor_content = bytes("<stripped %d bytes>"%size, 'utf-8') is replaced by tensor.tensor_content = "<stripped %d bytes>"%size). However if I try to rerun it on tensorflow_inception_graph.pb the visualisation doesn't work: the iframe is white and there is no nodes displayed.

I would highly appreciate if you explain me what I am doing wrong. Here there is a simple example to reproduce the problem.

Import:

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import tensorflow as tf
import numpy as np

from IPython.display import clear_output, Image, display, HTML

Create graph:

graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

x = tf.placeholder(tf.float32, shape=[None, 25, 25, 3], name='x')
y_true = tf.placeholder(tf.float32, shape=[None, 10], name='y_true')
y_true_cls = tf.argmax(y_true, dimension=1, name='y_true_cls')

print graph.get_operations()

The output:

[<tensorflow.python.framework.ops.Operation at 0x115902850>,
 <tensorflow.python.framework.ops.Operation at 0x115902690>,
 <tensorflow.python.framework.ops.Operation at 0x115902b10>,
 <tensorflow.python.framework.ops.Operation at 0x1159029d0>]

The visualisation functions:

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = bytes("<stripped %d bytes>"%size, "utf-8")
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

The result:

enter image description here

UPD I tried a simpler example:

tf.reset_default_graph()
x = tf.ones((), name="x")
y = tf.ones((), name="y")
z = tf.add(x, y, name="z")
show_graph()

But it still doesn't work. I suspect the problem is related to Javascript/HTML code that is generated:

    <script>
      function load() {
        document.getElementById(&quot;graph0.746875762596&quot;).pbtxt = 'node {\n  name: &quot;x&quot;\n  op: &quot;Const&quot;\n  attr {\n    key: &quot;dtype&quot;\n    value {\n      type: DT_FLOAT\n    }\n  }\n  attr {\n    key: &quot;value&quot;\n    value {\n      tensor {\n        dtype: DT_FLOAT\n        tensor_shape {\n        }\n        float_val: 1.0\n      }\n    }\n  }\n}\nnode {\n  name: &quot;y&quot;\n  op: &quot;Const&quot;\n  attr {\n    key: &quot;dtype&quot;\n    value {\n      type: DT_FLOAT\n    }\n  }\n  attr {\n    key: &quot;value&quot;\n    value {\n      tensor {\n        dtype: DT_FLOAT\n        tensor_shape {\n        }\n        float_val: 1.0\n      }\n    }\n  }\n}\nnode {\n  name: &quot;z&quot;\n  op: &quot;Add&quot;\n  input: &quot;x&quot;\n  input: &quot;y&quot;\n  attr {\n    key: &quot;T&quot;\n    value {\n      type: DT_FLOAT\n    }\n  }\n}\n';
      }
    </script>
    <link rel=&quot;import&quot; href=&quot;https://tensorboard.appspot.com/tf-graph-basic.build.html&quot; onload=load()>
    <div style=&quot;height:600px&quot;>
      <tf-graph-basic id=&quot;graph0.746875762596&quot;></tf-graph-basic>
    </div>

Maybe something with &quot and '?

2

There are 2 answers

1
Pasky On BEST ANSWER

The reason behind the failure is imports (<link rel="import" ...) are only supported under Chrome failing in Firefox and Safari, and there's no sight to be adopted by others until WebComponents definition arrives. So, you'd better run Jupyter in Chrome.

If you're against Chrome, there's good news. You can use a Polyfill (a piece of code that implements a feature on web browsers that do not support the feature) to make it work:

<script src="//cdnjs.cloudflare.com/ajax/libs/polymer/0.3.3/platform.js"></script>

I've tested it in Firefox and Safari and it works, but not perfectly fine. It is a little bit slower to load the Polypill and the graph canvas is reduced to one inch wide (I don't know why, TensorBoard internals). Then I realised platform.js has been deprecated but new implementations incorporates new errors (unhandled events and XML parsing).

Following is the modified code:

# TensorFlow Graph visualizer code
import numpy as np
from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script src="//cdnjs.cloudflare.com/ajax/libs/polymer/0.3.3/platform.js"></script>
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

Notice there's only one line added at the beginning of the code = """ block. It must be there because it's required by the Polyfill.

Original source code can be found here. You can ask him to evolve it within Google to cover other browsers than Chrome, but I don't think it will happen.

5
Yaroslav Bulatov On

Here's the version I'm using right now

You should be able to do something like this:

enter image description here