bart.tree#

Module: bart.tree#

Inheritance diagram for ISLP.bart.tree:

digraph inheritance1492763d9a { bgcolor=transparent; rankdir=LR; size="8.0, 12.0"; "bart.tree.BaseNode" [URL="#ISLP.bart.tree.BaseNode",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top"]; "bart.tree.LeafNode" [URL="#ISLP.bart.tree.LeafNode",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top"]; "bart.tree.BaseNode" -> "bart.tree.LeafNode" [arrowsize=0.5,style="setlinewidth(0.5)"]; "bart.tree.SplitNode" [URL="#ISLP.bart.tree.SplitNode",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top"]; "bart.tree.BaseNode" -> "bart.tree.SplitNode" [arrowsize=0.5,style="setlinewidth(0.5)"]; "bart.tree.Tree" [URL="#ISLP.bart.tree.Tree",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Full binary tree"]; }

Classes#

BaseNode#

class ISLP.bart.tree.BaseNode(index)#

Bases: object

Methods

get_idx_left_child

get_idx_parent_node

get_idx_right_child

__init__(index)#
get_idx_left_child()#
get_idx_parent_node()#
get_idx_right_child()#

LeafNode#

class ISLP.bart.tree.LeafNode(index, value, idx_data_points)#

Bases: BaseNode

Methods

get_idx_left_child

get_idx_parent_node

get_idx_right_child

__init__(index, value, idx_data_points)#
get_idx_left_child()#
get_idx_parent_node()#
get_idx_right_child()#

SplitNode#

class ISLP.bart.tree.SplitNode(index, idx_split_variable, split_value)#

Bases: BaseNode

Methods

get_idx_left_child

get_idx_parent_node

get_idx_right_child

__init__(index, idx_split_variable, split_value)#
get_idx_left_child()#
get_idx_parent_node()#
get_idx_right_child()#

Tree#

class ISLP.bart.tree.Tree(tree_id=0, num_observations=0)#

Bases: object

Full binary tree

A full binary tree is a tree where each node has exactly zero or two children. This structure is used as the basic component of the Bayesian Additive Regression Tree (BART)

Parameters:
tree_idint, optional
num_observationsint, optional
Attributes:
tree_structuredict

A dictionary that represents the nodes stored in breadth-first order, based in the array method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). The dictionary’s keys are integers that represent the nodes position. The dictionary’s values are objects of type SplitNode or LeafNode that represent the nodes of the tree itself.

num_nodesint

Total number of nodes.

idx_leaf_nodeslist

List with the index of the leaf nodes of the tree.

idx_prunable_split_nodeslist

List with the index of the prunable splitting nodes of the tree. A splitting node is prunable if both its children are leaf nodes.

tree_idint

Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART.

num_observationsint

Number of observations used to fit BART.

Methods

grow_tree(index_leaf_node, new_split_node, ...)

Grow the tree from a particular node.

init_tree(tree_id, leaf_node_value, ...)

Parameters:

predict_out_of_sample(X)

Predict output of tree for an unobserved point x.

copy

delete_node

get_node

predict_output

set_node

__init__(tree_id=0, num_observations=0)#
copy()#
delete_node(index)#
get_node(index)#
grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)#

Grow the tree from a particular node.

Parameters:
index_leaf_nodeint
new_split_nodeSplitNode
new_left_nodeLeafNode
new_right_nodeLeafNode
static init_tree(tree_id, leaf_node_value, idx_data_points)#
Parameters:
tree_id
leaf_node_value
idx_data_points
Returns:
predict_out_of_sample(X)#

Predict output of tree for an unobserved point x.

Parameters:
Xnumpy array

Unobserved point

Returns:
float

Value of the leaf value where the unobserved point lies.

predict_output()#
set_node(index, node)#