Skip to content

Commit

Permalink
[ML] Rename cross validation splitter package (elastic#59529)
Browse files Browse the repository at this point in the history
Renames and moves the cross validation splitter package.

First, the package and classes are renamed from using
"cross validation splitter" to "train test splitter".
Cross validation as a term is overloaded and encompasses
more concepts than what we are trying to do here.

Second, the package used to be under `process` but it does
not make sense to be there, it can be a top level package
under `dataframe`.
  • Loading branch information
dimitris-athanasiou authored Jul 14, 2020
1 parent ca5476e commit 1fe0ba7
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;

Expand Down Expand Up @@ -66,14 +66,14 @@ public class DataFrameDataExtractor {
private boolean isCancelled;
private boolean hasNext;
private boolean searchHasShardFailure;
private final CachedSupplier<CrossValidationSplitter> crossValidationSplitter;
private final CachedSupplier<TrainTestSplitter> trainTestSplitter;

DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
this.client = Objects.requireNonNull(client);
this.context = Objects.requireNonNull(context);
hasNext = true;
searchHasShardFailure = false;
this.crossValidationSplitter = new CachedSupplier<>(context.crossValidationSplitterFactory::create);
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
}

public Map<String, String> getHeaders() {
Expand Down Expand Up @@ -207,7 +207,7 @@ private Row createRow(SearchHit hit) {
}
}
}
boolean isTraining = extractedValues == null ? false : crossValidationSplitter.get().isTraining(extractedValues);
boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues);
return new Row(extractedValues, hit, isTraining);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.dataframe.extractor;

import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;

import java.util.List;
Expand All @@ -23,11 +23,11 @@ public class DataFrameDataExtractorContext {
final Map<String, String> headers;
final boolean includeSource;
final boolean supportsRowsWithMissingValues;
final CrossValidationSplitterFactory crossValidationSplitterFactory;
final TrainTestSplitterFactory trainTestSplitterFactory;

DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, QueryBuilder query, int scrollSize,
Map<String, String> headers, boolean includeSource, boolean supportsRowsWithMissingValues,
CrossValidationSplitterFactory crossValidationSplitterFactory) {
TrainTestSplitterFactory trainTestSplitterFactory) {
this.jobId = Objects.requireNonNull(jobId);
this.extractedFields = Objects.requireNonNull(extractedFields);
this.indices = indices.toArray(new String[indices.size()]);
Expand All @@ -36,6 +36,6 @@ public class DataFrameDataExtractorContext {
this.headers = headers;
this.includeSource = includeSource;
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory);
this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;

Expand All @@ -33,12 +33,12 @@ public class DataFrameDataExtractorFactory {
private final List<RequiredField> requiredFields;
private final Map<String, String> headers;
private final boolean supportsRowsWithMissingValues;
private final CrossValidationSplitterFactory crossValidationSplitterFactory;
private final TrainTestSplitterFactory trainTestSplitterFactory;

private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, QueryBuilder sourceQuery,
ExtractedFields extractedFields, List<RequiredField> requiredFields, Map<String, String> headers,
boolean supportsRowsWithMissingValues,
CrossValidationSplitterFactory crossValidationSplitterFactory) {
TrainTestSplitterFactory trainTestSplitterFactory) {
this.client = Objects.requireNonNull(client);
this.analyticsId = Objects.requireNonNull(analyticsId);
this.indices = Objects.requireNonNull(indices);
Expand All @@ -47,7 +47,7 @@ private DataFrameDataExtractorFactory(Client client, String analyticsId, List<St
this.requiredFields = Objects.requireNonNull(requiredFields);
this.headers = headers;
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory);
this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory);
}

public DataFrameDataExtractor newExtractor(boolean includeSource) {
Expand All @@ -60,7 +60,7 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) {
headers,
includeSource,
supportsRowsWithMissingValues,
crossValidationSplitterFactory
trainTestSplitterFactory
);
return new DataFrameDataExtractor(client, context);
}
Expand Down Expand Up @@ -89,12 +89,12 @@ public static DataFrameDataExtractorFactory createForSourceIndices(Client client
ExtractedFields extractedFields) {
return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()),
config.getSource().getParsedQuery(), extractedFields, config.getAnalysis().getRequiredFields(), config.getHeaders(),
config.getAnalysis().supportsMissingValues(), createCrossValidationSplitterFactory(client, config, extractedFields));
config.getAnalysis().supportsMissingValues(), createTrainTestSplitterFactory(client, config, extractedFields));
}

private static CrossValidationSplitterFactory createCrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields) {
return new CrossValidationSplitterFactory(client, config,
private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields) {
return new TrainTestSplitterFactory(client, config,
extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()));
}

Expand All @@ -118,7 +118,7 @@ public static void createForDestinationIndex(Client client,
DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(),
Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields,
config.getAnalysis().getRequiredFields(), config.getHeaders(), config.getAnalysis().supportsMissingValues(),
createCrossValidationSplitterFactory(client, config, extractedFields));
createTrainTestSplitterFactory(client, config, extractedFields));
listener.onResponse(extractorFactory);
},
listener::onFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
Expand All @@ -17,14 +17,14 @@
* is based on the reservoir idea. It randomly picks training docs while
* respecting the exact training percent.
*/
abstract class AbstractReservoirCrossValidationSplitter implements CrossValidationSplitter {
abstract class AbstractReservoirTrainTestSplitter implements TrainTestSplitter {

protected final int dependentVariableIndex;
private final double samplingRatio;
private final Random random;

AbstractReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed) {
AbstractReservoirTrainTestSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.samplingRatio = trainingPercent / 100.0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import java.util.List;

public class SingleClassReservoirCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
public class SingleClassReservoirTrainTestSplitter extends AbstractReservoirTrainTestSplitter {

private final SampleInfo sampleInfo;

SingleClassReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed, long classCount) {
SingleClassReservoirTrainTestSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed, long classCount) {
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
sampleInfo = new SampleInfo(classCount);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import java.util.HashMap;
import java.util.List;
Expand All @@ -14,12 +14,12 @@
* Given a dependent variable, randomly splits the dataset trying
* to preserve the proportion of each class in the training sample.
*/
public class StratifiedCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
public class StratifiedTrainTestSplitter extends AbstractReservoirTrainTestSplitter {

private final Map<String, SampleInfo> classSamples;

public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
double trainingPercent, long randomizeSeed) {
public StratifiedTrainTestSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
double trainingPercent, long randomizeSeed) {
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
this.classSamples = new HashMap<>();
classCounts.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new SampleInfo(entry.getValue())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

/**
* Processes rows in order to split the dataset in training and test subsets
*/
public interface CrossValidationSplitter {
public interface TrainTestSplitter {

boolean isTraining(String[] row);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -26,21 +26,21 @@
import java.util.Map;
import java.util.Objects;

public class CrossValidationSplitterFactory {
public class TrainTestSplitterFactory {

private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);
private static final Logger LOGGER = LogManager.getLogger(TrainTestSplitterFactory.class);

private final Client client;
private final DataFrameAnalyticsConfig config;
private final List<String> fieldNames;

public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
public TrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
this.client = Objects.requireNonNull(client);
this.config = Objects.requireNonNull(config);
this.fieldNames = Objects.requireNonNull(fieldNames);
}

public CrossValidationSplitter create() {
public TrainTestSplitter create() {
if (config.getAnalysis() instanceof Regression) {
return createSingleClassSplitter((Regression) config.getAnalysis());
}
Expand All @@ -50,7 +50,7 @@ public CrossValidationSplitter create() {
return row -> true;
}

private CrossValidationSplitter createSingleClassSplitter(Regression regression) {
private TrainTestSplitter createSingleClassSplitter(Regression regression) {
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
.setSize(0)
.setAllowPartialSearchResults(false)
Expand All @@ -60,7 +60,7 @@ private CrossValidationSplitter createSingleClassSplitter(Regression regression)
try {
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
searchRequestBuilder::get);
return new SingleClassReservoirCrossValidationSplitter(fieldNames, regression.getDependentVariable(),
return new SingleClassReservoirTrainTestSplitter(fieldNames, regression.getDependentVariable(),
regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value);
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Error searching total number of training docs", config.getId());
Expand All @@ -69,7 +69,7 @@ private CrossValidationSplitter createSingleClassSplitter(Regression regression)
}
}

private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
private TrainTestSplitter createStratifiedSplitter(Classification classification) {
String aggName = "dependent_variable_terms";
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
.setSize(0)
Expand All @@ -88,7 +88,7 @@ private CrossValidationSplitter createStratifiedSplitter(Classification classifi
classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
}

return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCounts,
return new StratifiedTrainTestSplitter(fieldNames, classification.getDependentVariable(), classCounts,
classification.getTrainingPercent(), classification.getRandomizeSeed());
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
Expand Down Expand Up @@ -67,7 +67,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
private QueryBuilder query;
private int scrollSize;
private Map<String, String> headers;
private CrossValidationSplitterFactory crossValidationSplitterFactory;
private TrainTestSplitterFactory trainTestSplitterFactory;
private ArgumentCaptor<ClearScrollRequest> capturedClearScrollRequests;
private ActionFuture<ClearScrollResponse> clearScrollFuture;

Expand All @@ -87,8 +87,8 @@ public void setUpTests() {
scrollSize = 1000;
headers = Collections.emptyMap();

crossValidationSplitterFactory = mock(CrossValidationSplitterFactory.class);
when(crossValidationSplitterFactory.create()).thenReturn(row -> true);
trainTestSplitterFactory = mock(TrainTestSplitterFactory.class);
when(trainTestSplitterFactory.create()).thenReturn(row -> true);

clearScrollFuture = mock(ActionFuture.class);
capturedClearScrollRequests = ArgumentCaptor.forClass(ClearScrollRequest.class);
Expand Down Expand Up @@ -468,7 +468,7 @@ public void testGetCategoricalFields() {

private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
headers, includeSource, supportsRowsWithMissingValues, crossValidationSplitterFactory);
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
return new TestExtractor(client, context);
}

Expand Down
Loading

0 comments on commit 1fe0ba7

Please sign in to comment.