bart.tree#
Module: bart.tree
#
Inheritance diagram for ISLP.bart.tree
:
Classes#
BaseNode
#
LeafNode
#
SplitNode
#
Tree
#
- class ISLP.bart.tree.Tree(tree_id=0, num_observations=0)#
Bases:
object
Full binary tree
A full binary tree is a tree where each node has exactly zero or two children. This structure is used as the basic component of the Bayesian Additive Regression Tree (BART)
- Parameters:
- tree_idint, optional
- num_observationsint, optional
- Attributes:
- tree_structuredict
A dictionary that represents the nodes stored in breadth-first order, based in the array method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). The dictionary’s keys are integers that represent the nodes position. The dictionary’s values are objects of type SplitNode or LeafNode that represent the nodes of the tree itself.
- num_nodesint
Total number of nodes.
- idx_leaf_nodeslist
List with the index of the leaf nodes of the tree.
- idx_prunable_split_nodeslist
List with the index of the prunable splitting nodes of the tree. A splitting node is prunable if both its children are leaf nodes.
- tree_idint
Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART.
- num_observationsint
Number of observations used to fit BART.
Methods
grow_tree
(index_leaf_node, new_split_node, ...)Grow the tree from a particular node.
init_tree
(tree_id, leaf_node_value, ...)- Parameters:
Predict output of tree for an unobserved point x.
copy
delete_node
get_node
predict_output
set_node
- __init__(tree_id=0, num_observations=0)#
- copy()#
- delete_node(index)#
- get_node(index)#
- grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)#
Grow the tree from a particular node.
- Parameters:
- index_leaf_nodeint
- new_split_nodeSplitNode
- new_left_nodeLeafNode
- new_right_nodeLeafNode
- static init_tree(tree_id, leaf_node_value, idx_data_points)#
- Parameters:
- tree_id
- leaf_node_value
- idx_data_points
- Returns:
- predict_out_of_sample(X)#
Predict output of tree for an unobserved point x.
- Parameters:
- Xnumpy array
Unobserved point
- Returns:
- float
Value of the leaf value where the unobserved point lies.
- predict_output()#
- set_node(index, node)#