I have developed a survival tree using Survival tree in sksurv.
estimator = SurvivalTree().fit(X_train, y_train)
I get the following when I do getstate
{'max_depth': 21,
'node_count': 135,
'nodes': array([( 1, 134, 59, 2.62050003e+02, inf, 282, 282.),
( 2, 53, 53, 1.55930004e+01, inf, 279, 279.),
( 3, 50, 20, 1.73500000e+02, inf, 117, 117.),
( 4, 5, 39, 9.45000000e+01, inf, 109, 109.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 6, 11, 12, 4.40250015e+00, inf, 106, 106.),
( 7, 8, 59, 1.35000002e+00, inf, 19, 19.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 9, 10, 50, 2.85999990e+00, inf, 16, 16.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 13, 13.),
( 12, 15, 16, 3.20000008e-01, inf, 87, 87.),
( 13, 14, 56, 2.25000000e+00, inf, 17, 17.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 13, 13.),
( 16, 49, 42, 5.60000014e+00, inf, 70, 70.),
( 17, 18, 11, 5.50000000e+00, inf, 67, 67.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 19, 48, 0, 1.71500000e+02, inf, 64, 64.),
( 20, 33, 56, 4.91345000e+00, inf, 61, 61.),
( 21, 30, 33, 3.50000000e+00, inf, 31, 31.),
( 22, 25, 12, 4.59264994e+00, inf, 25, 25.),
( 23, 24, 43, 3.95000000e+01, inf, 7, 7.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( 26, 29, 18, 1.50000000e+00, inf, 18, 18.),
( 27, 28, 13, 9.55000019e+00, inf, 14, 14.),
( -1, -1, -2, -2.00000000e+00, inf, 11, 11.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( 31, 32, 17, 2.49999994e-02, inf, 6, 6.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 34, 47, 60, 6.30495000e+00, inf, 30, 30.),
( 35, 38, 21, 2.66000004e+01, inf, 27, 27.),
( 36, 37, 45, 3.85000002e+00, inf, 6, 6.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 39, 40, 15, 1.00500000e+02, inf, 21, 21.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 41, 42, 47, 6.45000005e+00, inf, 18, 18.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( 43, 44, 30, 2.96500006e+01, inf, 14, 14.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( 45, 46, 46, 1.41149998e+00, inf, 9, 9.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 51, 52, 52, 1.64999998e+00, inf, 8, 8.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 54, 113, 52, 2.19599998e+00, inf, 162, 162.),
( 55, 110, 37, 9.66499996e+00, inf, 122, 122.),
( 56, 57, 1, 2.45000005e+00, inf, 115, 115.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( 58, 59, 47, 5.25000000e+00, inf, 110, 110.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( 60, 75, 11, 1.13499999e+01, inf, 105, 105.),
( 61, 72, 59, 3.49500008e+01, inf, 41, 41.),
( 62, 71, 5, 1.70000000e+01, inf, 32, 32.),
( 63, 70, 53, 2.04500008e+01, inf, 29, 29.),
( 64, 69, 3, 1.35000000e+01, inf, 25, 25.),
( 65, 68, 34, 6.30000025e-01, inf, 21, 21.),
( 66, 67, 20, 1.16000000e+02, inf, 18, 18.),
( -1, -1, -2, -2.00000000e+00, inf, 13, 13.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 73, 74, 13, 9.05000019e+00, inf, 9, 9.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( 76, 77, 1, 2.75000000e+00, inf, 64, 64.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 78, 109, 17, 4.39999998e-01, inf, 61, 61.),
( 79, 80, 49, 9.55000013e-01, inf, 58, 58.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( 81, 106, 34, 1.22229999e+00, inf, 53, 53.),
( 82, 105, 61, 1.37000000e+02, inf, 45, 45.),
( 83, 84, 54, 1.35500000e+02, inf, 41, 41.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( 85, 100, 47, 6.66149998e+00, inf, 36, 36.),
( 86, 87, 21, 2.34500008e+01, inf, 27, 27.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 88, 89, 38, 9.50000000e+00, inf, 24, 24.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 90, 99, 52, 1.96899998e+00, inf, 21, 21.),
( 91, 98, 19, 3.98680695e+02, inf, 17, 17.),
( 92, 93, 56, 4.43900013e+00, inf, 13, 13.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 94, 95, 55, 1.01465553e+00, inf, 10, 10.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( 96, 97, 57, 8.23000014e-01, inf, 7, 7.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
(101, 104, 44, 9.55000019e+00, inf, 9, 9.),
(102, 103, 44, 9.26199961e+00, inf, 6, 6.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
(107, 108, 44, 9.81599998e+00, inf, 8, 8.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
(111, 112, 39, 9.57099991e+01, inf, 7, 7.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
(114, 115, 29, 1.25000000e+00, inf, 40, 40.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
(116, 131, 47, 6.70000005e+00, inf, 35, 35.),
(117, 118, 37, 2.59999998e-01, inf, 28, 28.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
(119, 122, 29, 1.72220004e+00, inf, 25, 25.),
(120, 121, 42, 3.29999995e+00, inf, 7, 7.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
(123, 124, 52, 2.77400005e+00, inf, 18, 18.),
( -1, -1, -2, -2.00000000e+00, inf, 5, 5.),
(125, 126, 19, 1.59000000e+02, inf, 13, 13.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
(127, 128, 11, 1.05000000e+01, inf, 10, 10.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
(129, 130, 3, 1.11900001e+01, inf, 7, 7.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
(132, 133, 4, 4.25000000e+01, inf, 7, 7.),
( -1, -1, -2, -2.00000000e+00, inf, 4, 4.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.),
( -1, -1, -2, -2.00000000e+00, inf, 3, 3.)],
dtype=[('left_child', '<i8'), ('right_child', '<i8'), ('feature', '<i8'), ('threshold', '<f8'), ('impurity', '<f8'), ('n_node_samples', '<i8'), ('weighted_n_node_samples', '<f8')]),
'values': array([[[0.0035461 , 0.9964539 ],
[0.01068896, 0.98933637],
[0.01429906, 0.98576476],
...,
[1.43303768, 0.23599046],
[1.47651594, 0.22573 ],
[1.52413499, 0.21498096]],
[[0. , 1. ],
[0.00719424, 0.99280576],
[0.01083061, 0.98919555],
...,
[1.42215641, 0.23858111],
[1.46563467, 0.22820802],
[1.51325371, 0.21734097]],
[[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
...,
[1.01336515, 0.35838703],
[1.01336515, 0.35838703],
[1.01336515, 0.35838703]],
...,
[[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
...,
[2.08333333, 0. ],
[2.08333333, 0. ],
[2.08333333, 0. ]],
[[0. , 1. ],
[0. , 1. ],
[0. , 1. ],
...,
[1.83333333, 0. ],
[1.83333333, 0. ],
[1.83333333, 0. ]],
[[0.33333333, 0.66666667],
[0.33333333, 0.66666667],
[0.33333333, 0.66666667],
...,
[1.83333333, 0. ],
[1.83333333, 0. ],
[1.83333333, 0. ]]])}
But when I use tree.plot_tree I get an error stating the following:
AttributeError: 'SurvivalTree' object has no attribute 'criterion'
I tried graphviz too and get the same error.
So basically I want the survival tree to be plotted for better interpretation.
To plot a fitted SurvivalTree, you'd need a patched version of plot_tree, which you can obtain from this gist. Then use