40 template <
class FeatureType,
48 , num_of_features_(1000)
49 , num_of_thresholds_(10)
50 , feature_handler_(nullptr)
51 , stats_estimator_(nullptr)
55 , decision_tree_trainer_data_provider_()
56 , random_features_at_split_node_(false)
59 template <
class FeatureType,
68 template <
class FeatureType,
78 std::vector<FeatureType> features;
80 if (!random_features_at_split_node_)
81 feature_handler_->createRandomFeatures(num_of_features_, features);
87 if (decision_tree_trainer_data_provider_) {
88 std::cerr <<
"use decision_tree_trainer_data_provider_" << std::endl;
90 decision_tree_trainer_data_provider_->getDatasetAndLabels(
91 data_set_, label_data_, examples_);
92 trainDecisionTreeNode(
93 features, examples_, label_data_, max_tree_depth_, tree.
getRoot());
99 trainDecisionTreeNode(
100 features, examples_, label_data_, max_tree_depth_, tree.
getRoot());
104 template <
class FeatureType,
112 std::vector<ExampleIndex>& examples,
113 std::vector<LabelType>& label_data,
114 const std::size_t max_depth,
117 const std::size_t num_of_examples = examples.size();
118 if (num_of_examples == 0) {
120 "Reached invalid point in decision tree training: Number of examples is 0!");
124 if (max_depth == 0) {
125 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
129 if (examples.size() < min_examples_for_split_) {
130 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
134 if (random_features_at_split_node_) {
136 feature_handler_->createRandomFeatures(num_of_features_, features);
139 std::vector<float> feature_results;
140 std::vector<unsigned char> flags;
142 feature_results.reserve(num_of_examples);
143 flags.reserve(num_of_examples);
146 int best_feature_index = -1;
147 float best_feature_threshold = 0.0f;
148 float best_feature_information_gain = 0.0f;
150 const std::size_t num_of_features = features.size();
151 for (std::size_t feature_index = 0; feature_index < num_of_features;
154 feature_handler_->evaluateFeature(
155 features[feature_index], data_set_, examples, feature_results, flags);
158 if (!thresholds_.empty()) {
161 for (std::size_t threshold_index = 0; threshold_index < thresholds_.size();
164 const float information_gain =
165 stats_estimator_->computeInformationGain(data_set_,
170 thresholds_[threshold_index]);
172 if (information_gain > best_feature_information_gain) {
173 best_feature_information_gain = information_gain;
174 best_feature_index = static_cast<int>(feature_index);
175 best_feature_threshold = thresholds_[threshold_index];
180 std::vector<float> thresholds;
181 thresholds.reserve(num_of_thresholds_);
182 createThresholdsUniform(num_of_thresholds_, feature_results, thresholds);
186 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
188 const float threshold = thresholds[threshold_index];
191 const float information_gain = stats_estimator_->computeInformationGain(
192 data_set_, examples, label_data, feature_results, flags, threshold);
194 if (information_gain > best_feature_information_gain) {
195 best_feature_information_gain = information_gain;
196 best_feature_index = static_cast<int>(feature_index);
197 best_feature_threshold = threshold;
203 if (best_feature_index == -1) {
204 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
209 std::vector<unsigned char> branch_indices;
210 branch_indices.reserve(num_of_examples);
212 feature_handler_->evaluateFeature(
213 features[best_feature_index], data_set_, examples, feature_results, flags);
215 stats_estimator_->computeBranchIndices(
216 feature_results, flags, best_feature_threshold, branch_indices);
219 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
223 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
225 std::vector<std::size_t> branch_counts(num_of_branches, 0);
226 for (std::size_t example_index = 0; example_index < num_of_examples;
228 ++branch_counts[branch_indices[example_index]];
231 node.feature = features[best_feature_index];
232 node.threshold = best_feature_threshold;
233 node.sub_nodes.resize(num_of_branches);
235 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
236 if (branch_counts[branch_index] == 0) {
237 NodeType branch_node;
238 stats_estimator_->computeAndSetNodeStats(
239 data_set_, examples, label_data, branch_node);
242 node.sub_nodes[branch_index] = branch_node;
247 std::vector<LabelType> branch_labels;
248 std::vector<ExampleIndex> branch_examples;
249 branch_labels.reserve(branch_counts[branch_index]);
250 branch_examples.reserve(branch_counts[branch_index]);
252 for (std::size_t example_index = 0; example_index < num_of_examples;
254 if (branch_indices[example_index] == branch_index) {
255 branch_examples.push_back(examples[example_index]);
256 branch_labels.push_back(label_data[example_index]);
260 trainDecisionTreeNode(features,
264 node.sub_nodes[branch_index]);
269 template <
class FeatureType,
277 std::vector<float>& values,
278 std::vector<float>& thresholds)
281 float min_value = ::std::numeric_limits<float>::max();
282 float max_value = -::std::numeric_limits<float>::max();
284 const std::size_t num_of_values = values.size();
285 for (std::size_t value_index = 0; value_index < num_of_values; ++value_index) {
286 const float value = values[value_index];
288 if (value < min_value)
290 if (value > max_value)
294 const float range = max_value - min_value;
295 const float step = range / static_cast<float>(num_of_thresholds + 2);
298 thresholds.resize(num_of_thresholds);
300 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds;
302 thresholds[threshold_index] =
303 min_value + step * (static_cast<float>(threshold_index + 1));