Skip to content

Commit

Permalink
Extend json report for regression trees
Browse files Browse the repository at this point in the history
calcul d'une datagridstats de la target discretise
passage de datagridstats dans la DTDecisionTreeSpec
ecriture de datagridstats quand la target est continuous

				"targetPartition": {
					"variable": "PetalLength",
					"type": "Numerical",
					"partitionType": "Intervals",
					"partition": [
						[1,2.4],
						[2.4,4.75],
						[4.75,4.85],
						[4.85,6.9]
					],
					"frequencies": [38,27,3,37]
				},
  • Loading branch information
n-voisine committed Jan 22, 2025
1 parent fc99324 commit c7e88bb
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 16 deletions.
93 changes: 80 additions & 13 deletions src/Learning/DTForest/DTDecisionTreeCreationTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// This software is distributed under the BSD 3-Clause-clear License, the text of which is available
// at https://spdx.org/licenses/BSD-3-Clause-Clear.html or see the "LICENSE" file for more details.

#include "DTDecisionTreeCreationTask.h"
#include "DTDecisionTreeCreationTask.h"
#include "DTDecisionTree.h"
#include "DTDecisionTreeSpec.h"
Expand All @@ -11,7 +10,6 @@
#include "DTDecisionBinaryTreeCost.h"
#include "DTStat.h"
#include "DTForestAttributeSelection.h"
#include "DTDecisionTreeSpec.h"
#include "DTGlobalTag.h"
#include "DTDiscretizerMODL.h"
#include "DTGrouperMODL.h"
Expand Down Expand Up @@ -645,6 +643,7 @@ boolean DTDecisionTreeCreationTask::SlaveProcess()
ALString sMessage;
KWAttribute* attribute = NULL;
KWAttributeStats* attributeStats = NULL;
KWDataGridStats targetStats;
ObjectArray oaObjects;
ObjectArray oatupletable;
ObjectArray oaOrigineAttributs;
Expand Down Expand Up @@ -856,12 +855,13 @@ boolean DTDecisionTreeCreationTask::SlaveProcess()
InitializeMODLDiscretization(
&slaveTupleTableLoader, slaveLearningSpec,
*input_cvIntervalValues.GetConstContinuousVector(),
attributegenerator->GetIndex());
attributegenerator->GetIndex(), &targetStats);
else
InitializeBinaryEQFDiscretization(
&slaveTupleTableLoader, slaveLearningSpec,
*input_cvIntervalValues.GetConstContinuousVector(),
input_ivSplitValues.GetAt(attributegenerator->GetIndex()));
input_ivSplitValues.GetAt(attributegenerator->GetIndex()),
&targetStats);

blOrigine.Initialize(slaveLearningSpec, &slaveTupleTableLoader,
&oaObjects);
Expand Down Expand Up @@ -934,6 +934,23 @@ boolean DTDecisionTreeCreationTask::SlaveProcess()
// Creation de la spec de l'arbre a partir de l'arbre calcule
reportTreeSpec = new DTDecisionTreeSpec;
reportTreeSpec->InitFromDecisionTree(dttree);
reportTreeSpec->SetTargetType(
shared_learningSpec.GetLearningSpec()->GetTargetAttributeType());
//prise en compte de la discretisation de la cible pour la regression
if (bRegressionWithMODLDiscretization)
{
reportTreeSpec->SetTargetStats(targetStats.Clone());
reportTreeSpec->SetTargetMin(
cast(KWDescriptiveContinuousStats*,
shared_learningSpec.GetLearningSpec()
->GetTargetDescriptiveStats())
->GetMin());
reportTreeSpec->SetTargetMax(
cast(KWDescriptiveContinuousStats*,
shared_learningSpec.GetLearningSpec()
->GetTargetDescriptiveStats())
->GetMax());
}
// detection des doublons
key = reportTreeSpec->ComputeHashValue();
// filtre les arbre qui
Expand Down Expand Up @@ -2202,7 +2219,8 @@ KWLearningSpec* DTDecisionTreeCreationTask::InitializeRegressionLearningSpec(con
newClass->SetName(newClass->GetName() + "_classification");
newTarget = new KWAttribute;

newTarget->SetName(learningSpec->GetTargetAttributeName() + "_categorical");
newTarget->SetName(
learningSpec->GetClass()->BuildAttributeName(learningSpec->GetTargetAttributeName() + "_categorical"));
newTarget->SetType(KWType::Symbol);
newClass->InsertAttribute(newTarget);
KWClassDomain::GetCurrentDomain()->InsertClass(newClass);
Expand Down Expand Up @@ -2266,7 +2284,8 @@ void DTDecisionTreeCreationTask::InitializeEqualFreqDiscretization(KWTupleTableL

void DTDecisionTreeCreationTask::InitializeMODLDiscretization(KWTupleTableLoader* tupleTableLoader,
KWLearningSpec* learningSpec,
const ContinuousVector& cvIntervalValues, int nSplitIndex)
const ContinuousVector& cvIntervalValues, int nSplitIndex,
KWDataGridStats* targetStat)
{
// on transforme la cible continue en cible categorielle, en effectuant au prealable une dicretisation MODL sur
// la cible continue
Expand All @@ -2276,12 +2295,14 @@ void DTDecisionTreeCreationTask::InitializeMODLDiscretization(KWTupleTableLoader
DTBaseLoader bl;
SymbolVector* svTargetValues = NULL;

require(targetStat != NULL);
assert(learningSpec != NULL);
assert(tupleTableLoader != NULL);
assert(randomForestParameter.GetDiscretizationTargetMethod() == DTForestParameter::DISCRETIZATION_MODL);

svTargetValues = MODLDiscretizeContinuousTarget(
tupleTableLoader, randomForestParameter.GetMaxIntervalsNumberForTarget(), cvIntervalValues, nSplitIndex);
svTargetValues =
MODLDiscretizeContinuousTarget(tupleTableLoader, randomForestParameter.GetMaxIntervalsNumberForTarget(),
cvIntervalValues, nSplitIndex, targetStat);
assert(svTargetValues != NULL);

tupleTableLoader->SetInputExtraAttributeName(learningSpec->GetTargetAttributeName());
Expand Down Expand Up @@ -2314,7 +2335,7 @@ void DTDecisionTreeCreationTask::InitializeMODLDiscretization(KWTupleTableLoader
void DTDecisionTreeCreationTask::InitializeBinaryEQFDiscretization(KWTupleTableLoader* tupleTableLoader,
KWLearningSpec* learningSpec,
const ContinuousVector& cvIntervalValues,
int nSplitIndex)
int nSplitIndex, KWDataGridStats* targetStat)
{
// on transforme la cible continue en cible categorielle, en effectuant au prealable une dicretisation MODL sur
// la cible continue
Expand All @@ -2329,8 +2350,9 @@ void DTDecisionTreeCreationTask::InitializeBinaryEQFDiscretization(KWTupleTableL
assert(randomForestParameter.GetDiscretizationTargetMethod() ==
DTForestParameter::DISCRETIZATION_BINARY_EQUAL_FREQUENCY);

svTargetValues = MODLDiscretizeContinuousTarget(
tupleTableLoader, randomForestParameter.GetMaxIntervalsNumberForTarget(), cvIntervalValues, nSplitIndex);
svTargetValues =
MODLDiscretizeContinuousTarget(tupleTableLoader, randomForestParameter.GetMaxIntervalsNumberForTarget(),
cvIntervalValues, nSplitIndex, targetStat);
assert(svTargetValues != NULL);

tupleTableLoader->SetInputExtraAttributeName(learningSpec->GetTargetAttributeName());
Expand Down Expand Up @@ -2419,20 +2441,27 @@ SymbolVector* DTDecisionTreeCreationTask::EqualFreqDiscretizeContinuousTarget(KW
SymbolVector* DTDecisionTreeCreationTask::MODLDiscretizeContinuousTarget(KWTupleTableLoader* tupleTableLoader,
const int nMaxIntervalsNumber,
const ContinuousVector& cvInput,
int nSplitIndex) const
int nSplitIndex,
KWDataGridStats* targetStat) const
{
SymbolVector* svTargetValues = NULL;
SymbolVector svTargetIn;
ContinuousVector cvChosenIntervals;
IntVector ivFrequencyIntervals;
ContinuousVector cvInputIntervalValues;
ObjectDictionary odIntervals;
Object o;
int nBound, nPartNumber;
boolean bDisplayValues = false;
Continuous cValue;
ALString s;
KWDGSAttributeDiscretization* attribute;
int oldseed;

//initilisation de random seed 2001 + index de l'arbre
require(cvInput.GetSize() > 0);
require(targetStat != NULL);

//initialisation de random seed 2001 + index de l'arbre
oldseed = GetRandomSeed();
SetRandomSeed(2001 + nSplitIndex);

Expand Down Expand Up @@ -2466,9 +2495,27 @@ SymbolVector* DTDecisionTreeCreationTask::MODLDiscretizeContinuousTarget(KWTuple
if (bDisplayValues)
cvChosenIntervals.Write(cout);

//initialisation de target Stat d'un arbre
targetStat->DeleteAll();

// Creation de l'attribut
attribute = new KWDGSAttributeDiscretization;
attribute->SetAttributeName(shared_learningSpec.GetLearningSpec()->GetTargetAttributeName());

svTargetValues = new SymbolVector;
if (randomForestParameter.GetDiscretizationTargetMethod() == DTForestParameter::DISCRETIZATION_MODL)
{

// Creation des bornes des intervalles
nPartNumber = cvChosenIntervals.GetSize() - 1;
attribute->SetInitialValueNumber(nPartNumber);
attribute->SetGranularizedValueNumber(nPartNumber);
attribute->SetPartNumber(nPartNumber);
for (nBound = 1; nBound < nPartNumber; nBound++)
attribute->SetIntervalBoundAt(nBound - 1, cvChosenIntervals.GetAt(nBound));
ensure(attribute->Check());
ivFrequencyIntervals.SetSize(nPartNumber);

for (int nInterval = 0; nInterval < cvChosenIntervals.GetSize(); nInterval++)
{
s = "I" + ALString(IntToString(nInterval));
Expand All @@ -2486,13 +2533,25 @@ SymbolVector* DTDecisionTreeCreationTask::MODLDiscretizeContinuousTarget(KWTuple
cValue < cvChosenIntervals.GetAt(nInterval + 1)))
{
svTargetValues->Add(svTargetIn.GetAt(nInterval));
ivFrequencyIntervals.SetAt(nInterval,
ivFrequencyIntervals.GetAt(nInterval) + 1);
break;
}
}
}
}
else
{
// Creation des bornes des intervalles
nPartNumber = 2;
attribute->SetInitialValueNumber(nPartNumber);
attribute->SetGranularizedValueNumber(nPartNumber);
attribute->SetPartNumber(nPartNumber);
for (nBound = 1; nBound < nPartNumber; nBound++)
attribute->SetIntervalBoundAt(nBound, cvChosenIntervals.GetAt(nSplitIndex));
ensure(attribute->Check());
ivFrequencyIntervals.SetSize(nPartNumber);

Symbol sI0("I0");
Symbol sI1("I1");

Expand All @@ -2504,16 +2563,24 @@ SymbolVector* DTDecisionTreeCreationTask::MODLDiscretizeContinuousTarget(KWTuple
if (cValue <= cvChosenIntervals.GetAt(nSplitIndex))
{
svTargetValues->Add(sI0);
ivFrequencyIntervals.SetAt(0, ivFrequencyIntervals.GetAt(0) + 1);
}
else
{
svTargetValues->Add(sI1);
ivFrequencyIntervals.SetAt(1, ivFrequencyIntervals.GetAt(1) + 1);
}
}
}
assert(svTargetValues != NULL);
assert(tupleTableLoader->GetInputExtraAttributeContinuousValues()->GetSize() == svTargetValues->GetSize());

//initialisation du gridstat
targetStat->AddAttribute(attribute);
targetStat->CreateAllCells();
for (int nInterval = 0; nInterval < nPartNumber; nInterval++)
targetStat->SetUnivariateCellFrequencyAt(nInterval, ivFrequencyIntervals.GetAt(nInterval));

// restitution de l'etat initial :
SetRandomSeed(oldseed);

Expand Down
9 changes: 6 additions & 3 deletions src/Learning/DTForest/DTDecisionTreeCreationTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,17 @@ class DTDecisionTreeCreationTask : public KDDataPreparationAttributeCreationTask

// Discretisation 'MODL' d'une target continue
SymbolVector* MODLDiscretizeContinuousTarget(KWTupleTableLoader*, int nMaxIntervalsNumber,
const ContinuousVector& cvIntervalValues, int nSplitIndex) const;
const ContinuousVector& cvIntervalValues, int nSplitIndex,
KWDataGridStats* targetStat) const;

// transforme une regression en classification en effectuant au prealable une discretisation MODL
void InitializeMODLDiscretization(KWTupleTableLoader*, KWLearningSpec*,
const ContinuousVector& cvIntervalValues, int nSplitIndex);
const ContinuousVector& cvIntervalValues, int nSplitIndex,
KWDataGridStats* targetStat);

void InitializeBinaryEQFDiscretization(KWTupleTableLoader*, KWLearningSpec*,
const ContinuousVector& cvIntervalValues, int nSplitIndex);
const ContinuousVector& cvIntervalValues, int nSplitIndex,
KWDataGridStats* targetStat);

///////////////////////////////////////////////////////////////////////////////////////////////////
// Reimplementation des methodes virtuelles de tache
Expand Down
Loading

0 comments on commit c7e88bb

Please sign in to comment.