20#ifndef TMVA_TREEINFERENCE_BRANCHLESSTREE 
   21#define TMVA_TREEINFERENCE_BRANCHLESSTREE 
   29namespace Experimental {
 
   35void RecursiveFill(
int thisIndex, 
int lastIndex, 
int treeDepth, 
int maxTreeDepth, std::vector<T> &thresholds,
 
   36                   std::vector<int> &inputs)
 
   40   if (inputs[lastIndex] == -1) {
 
   41      thresholds.at(thisIndex) = thresholds.at(lastIndex);
 
   44      if (treeDepth < maxTreeDepth)
 
   45         inputs.at(thisIndex) = -1;
 
   49   if (treeDepth < maxTreeDepth) {
 
   50      Internal::RecursiveFill<T>(2 * thisIndex + 1, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
 
   51      Internal::RecursiveFill<T>(2 * thisIndex + 2, thisIndex, treeDepth + 1, maxTreeDepth, thresholds, inputs);
 
   69   inline std::string 
GetInferenceCode(
const std::string& funcName, 
const std::string& typeName);
 
   80   for (
int level = 0; level < fTreeDepth; ++level) {
 
   83   return fThresholds[
index];
 
   96   Internal::RecursiveFill<T>(1, 0, 1, fTreeDepth, fThresholds, fInputs);
 
   97   Internal::RecursiveFill<T>(2, 0, 1, fTreeDepth, fThresholds, fInputs);
 
  100   std::replace(fInputs.begin(), fInputs.end(), -1.0, 0.0);
 
  112   std::stringstream ss;
 
  115   ss << 
"inline " << typeName << 
" " << funcName << 
"(const " << typeName << 
"* input, const int stride)";
 
  121   ss << 
"   const int inputs[" << fInputs.size() << 
"] = {";
 
  122   int last = 
static_cast<int>(fInputs.size() - 1);
 
  123   for (
int i = 0; i < last + 1; i++) {
 
  125      if (i != last) ss << 
", ";
 
  129   ss << 
"   const " << typeName << 
" thresholds[" << fThresholds.size() << 
"] = {";
 
  130   last = 
static_cast<int>(fThresholds.size() - 1);
 
  131   for (
int i = 0; i < last + 1; i++) {
 
  132      ss << fThresholds[i];
 
  133      if (i != last) ss << 
", ";
 
  138   ss << 
"   int index = 0;\n";
 
  139   for (
int level = 0; level < fTreeDepth; ++level) {
 
  140      ss << 
"   index = 2 * index + 1 + (input[inputs[index] * stride] > thresholds[index]);\n";
 
  142   ss << 
"   return thresholds[index];\n";
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
void RecursiveFill(int thisIndex, int lastIndex, int treeDepth, int maxTreeDepth, std::vector< T > &thresholds, std::vector< int > &inputs)
Fill the empty nodes of a sparse tree recursively.
create variable transformations
Branchless representation of a decision tree using topological ordering.
std::vector< int > fInputs
Cut variables / inputs.
std::vector< T > fThresholds
Cut thresholds or scores if corresponding node is a leaf.
void FillSparse()
Fill nodes of a sparse tree forming a full tree.
int fTreeDepth
Depth of the tree.
T Inference(const T *input, const int stride)
Perform inference on a single input vector.
std::string GetInferenceCode(const std::string &funcName, const std::string &typeName)
Get code for compiling the inference function of the branchless tree with the current thresholds and ...