How to plot a survival tree in python

878 views Asked by At

I have developed a survival tree using Survival tree in sksurv.

estimator = SurvivalTree().fit(X_train, y_train)

I get the following when I do getstate

{'max_depth': 21,
 'node_count': 135,
 'nodes': array([(  1, 134, 59,  2.62050003e+02, inf, 282, 282.),
        (  2,  53, 53,  1.55930004e+01, inf, 279, 279.),
        (  3,  50, 20,  1.73500000e+02, inf, 117, 117.),
        (  4,   5, 39,  9.45000000e+01, inf, 109, 109.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        (  6,  11, 12,  4.40250015e+00, inf, 106, 106.),
        (  7,   8, 59,  1.35000002e+00, inf,  19,  19.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        (  9,  10, 50,  2.85999990e+00, inf,  16,  16.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,  13,  13.),
        ( 12,  15, 16,  3.20000008e-01, inf,  87,  87.),
        ( 13,  14, 56,  2.25000000e+00, inf,  17,  17.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,  13,  13.),
        ( 16,  49, 42,  5.60000014e+00, inf,  70,  70.),
        ( 17,  18, 11,  5.50000000e+00, inf,  67,  67.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 19,  48,  0,  1.71500000e+02, inf,  64,  64.),
        ( 20,  33, 56,  4.91345000e+00, inf,  61,  61.),
        ( 21,  30, 33,  3.50000000e+00, inf,  31,  31.),
        ( 22,  25, 12,  4.59264994e+00, inf,  25,  25.),
        ( 23,  24, 43,  3.95000000e+01, inf,   7,   7.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( 26,  29, 18,  1.50000000e+00, inf,  18,  18.),
        ( 27,  28, 13,  9.55000019e+00, inf,  14,  14.),
        ( -1,  -1, -2, -2.00000000e+00, inf,  11,  11.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( 31,  32, 17,  2.49999994e-02, inf,   6,   6.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 34,  47, 60,  6.30495000e+00, inf,  30,  30.),
        ( 35,  38, 21,  2.66000004e+01, inf,  27,  27.),
        ( 36,  37, 45,  3.85000002e+00, inf,   6,   6.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 39,  40, 15,  1.00500000e+02, inf,  21,  21.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 41,  42, 47,  6.45000005e+00, inf,  18,  18.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( 43,  44, 30,  2.96500006e+01, inf,  14,  14.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( 45,  46, 46,  1.41149998e+00, inf,   9,   9.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 51,  52, 52,  1.64999998e+00, inf,   8,   8.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 54, 113, 52,  2.19599998e+00, inf, 162, 162.),
        ( 55, 110, 37,  9.66499996e+00, inf, 122, 122.),
        ( 56,  57,  1,  2.45000005e+00, inf, 115, 115.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( 58,  59, 47,  5.25000000e+00, inf, 110, 110.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( 60,  75, 11,  1.13499999e+01, inf, 105, 105.),
        ( 61,  72, 59,  3.49500008e+01, inf,  41,  41.),
        ( 62,  71,  5,  1.70000000e+01, inf,  32,  32.),
        ( 63,  70, 53,  2.04500008e+01, inf,  29,  29.),
        ( 64,  69,  3,  1.35000000e+01, inf,  25,  25.),
        ( 65,  68, 34,  6.30000025e-01, inf,  21,  21.),
        ( 66,  67, 20,  1.16000000e+02, inf,  18,  18.),
        ( -1,  -1, -2, -2.00000000e+00, inf,  13,  13.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 73,  74, 13,  9.05000019e+00, inf,   9,   9.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( 76,  77,  1,  2.75000000e+00, inf,  64,  64.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 78, 109, 17,  4.39999998e-01, inf,  61,  61.),
        ( 79,  80, 49,  9.55000013e-01, inf,  58,  58.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( 81, 106, 34,  1.22229999e+00, inf,  53,  53.),
        ( 82, 105, 61,  1.37000000e+02, inf,  45,  45.),
        ( 83,  84, 54,  1.35500000e+02, inf,  41,  41.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( 85, 100, 47,  6.66149998e+00, inf,  36,  36.),
        ( 86,  87, 21,  2.34500008e+01, inf,  27,  27.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 88,  89, 38,  9.50000000e+00, inf,  24,  24.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 90,  99, 52,  1.96899998e+00, inf,  21,  21.),
        ( 91,  98, 19,  3.98680695e+02, inf,  17,  17.),
        ( 92,  93, 56,  4.43900013e+00, inf,  13,  13.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 94,  95, 55,  1.01465553e+00, inf,  10,  10.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( 96,  97, 57,  8.23000014e-01, inf,   7,   7.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        (101, 104, 44,  9.55000019e+00, inf,   9,   9.),
        (102, 103, 44,  9.26199961e+00, inf,   6,   6.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        (107, 108, 44,  9.81599998e+00, inf,   8,   8.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        (111, 112, 39,  9.57099991e+01, inf,   7,   7.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        (114, 115, 29,  1.25000000e+00, inf,  40,  40.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        (116, 131, 47,  6.70000005e+00, inf,  35,  35.),
        (117, 118, 37,  2.59999998e-01, inf,  28,  28.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        (119, 122, 29,  1.72220004e+00, inf,  25,  25.),
        (120, 121, 42,  3.29999995e+00, inf,   7,   7.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        (123, 124, 52,  2.77400005e+00, inf,  18,  18.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   5,   5.),
        (125, 126, 19,  1.59000000e+02, inf,  13,  13.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        (127, 128, 11,  1.05000000e+01, inf,  10,  10.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        (129, 130,  3,  1.11900001e+01, inf,   7,   7.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        (132, 133,  4,  4.25000000e+01, inf,   7,   7.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   4,   4.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.),
        ( -1,  -1, -2, -2.00000000e+00, inf,   3,   3.)],
       dtype=[('left_child', '<i8'), ('right_child', '<i8'), ('feature', '<i8'), ('threshold', '<f8'), ('impurity', '<f8'), ('n_node_samples', '<i8'), ('weighted_n_node_samples', '<f8')]),
 'values': array([[[0.0035461 , 0.9964539 ],
         [0.01068896, 0.98933637],
         [0.01429906, 0.98576476],
         ...,
         [1.43303768, 0.23599046],
         [1.47651594, 0.22573   ],
         [1.52413499, 0.21498096]],
 
        [[0.        , 1.        ],
         [0.00719424, 0.99280576],
         [0.01083061, 0.98919555],
         ...,
         [1.42215641, 0.23858111],
         [1.46563467, 0.22820802],
         [1.51325371, 0.21734097]],
 
        [[0.        , 1.        ],
         [0.        , 1.        ],
         [0.        , 1.        ],
         ...,
         [1.01336515, 0.35838703],
         [1.01336515, 0.35838703],
         [1.01336515, 0.35838703]],
 
        ...,
 
        [[0.        , 1.        ],
         [0.        , 1.        ],
         [0.        , 1.        ],
         ...,
         [2.08333333, 0.        ],
         [2.08333333, 0.        ],
         [2.08333333, 0.        ]],
 
        [[0.        , 1.        ],
         [0.        , 1.        ],
         [0.        , 1.        ],
         ...,
         [1.83333333, 0.        ],
         [1.83333333, 0.        ],
         [1.83333333, 0.        ]],
 
        [[0.33333333, 0.66666667],
         [0.33333333, 0.66666667],
         [0.33333333, 0.66666667],
         ...,
         [1.83333333, 0.        ],
         [1.83333333, 0.        ],
         [1.83333333, 0.        ]]])}

But when I use tree.plot_tree I get an error stating the following:

AttributeError: 'SurvivalTree' object has no attribute 'criterion'

I tried graphviz too and get the same error.

So basically I want the survival tree to be plotted for better interpretation.

1

There are 1 answers

0
sebp On BEST ANSWER

To plot a fitted SurvivalTree, you'd need a patched version of plot_tree, which you can obtain from this gist. Then use

plot_tree(survival_tree,
          feature_names=feature_names,
          impurity=False,
          label="none")