Skip to content

Commit

Permalink
Require a field when a seed is provided to the random_score funct…
Browse files Browse the repository at this point in the history
…ion. (#25594)

We currently use fielddata on the `_id` field which is trappy, especially as we
do it implicitly. This changes the `random_score` function to use doc ids when
no seed is provided and to suggest a field when a seed is provided.

For now the change only emits a deprecation warning when no field is supplied
but this should be replaced by a strict check on 7.0.

Closes #25240
  • Loading branch information
jpountz authored Jul 19, 2017
1 parent f69decf commit f1ff7f2
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public void testFunctionScore() {
FilterFunctionBuilder[] functions = {
new FunctionScoreQueryBuilder.FilterFunctionBuilder(
matchQuery("name", "kimchy"), // <1>
randomFunction("ABCDEF")), // <2>
randomFunction()), // <2>
new FunctionScoreQueryBuilder.FilterFunctionBuilder(
exponentialDecayFunction("age", 0L, 1L)) // <3>
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package org.elasticsearch.common.lucene.search.function;

import com.carrotsearch.hppc.BitMixer;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.util.StringHelper;
Expand All @@ -33,17 +35,9 @@
*/
public class RandomScoreFunction extends ScoreFunction {

private int originalSeed;
private int saltedSeed;
private final IndexFieldData<?> uidFieldData;

/**
* Default constructor. Only useful for constructing as a placeholder, but should not be used for actual scoring.
*/
public RandomScoreFunction() {
super(CombineFunction.MULTIPLY);
uidFieldData = null;
}
private final int originalSeed;
private final int saltedSeed;
private final IndexFieldData<?> fieldData;

/**
* Creates a RandomScoreFunction.
Expand All @@ -55,33 +49,43 @@ public RandomScoreFunction() {
public RandomScoreFunction(int seed, int salt, IndexFieldData<?> uidFieldData) {
super(CombineFunction.MULTIPLY);
this.originalSeed = seed;
this.saltedSeed = seed ^ salt;
this.uidFieldData = uidFieldData;
if (uidFieldData == null) throw new NullPointerException("uid missing");
this.saltedSeed = BitMixer.mix(seed, salt);
this.fieldData = uidFieldData;
}

@Override
public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) {
AtomicFieldData leafData = uidFieldData.load(ctx);
final SortedBinaryDocValues uidByteData = leafData.getBytesValues();
if (uidByteData == null) throw new NullPointerException("failed to get uid byte data");
final SortedBinaryDocValues values;
if (fieldData != null) {
AtomicFieldData leafData = fieldData.load(ctx);
values = leafData.getBytesValues();
if (values == null) throw new NullPointerException("failed to get fielddata");
} else {
values = null;
}

return new LeafScoreFunction() {

@Override
public double score(int docId, float subQueryScore) throws IOException {
if (uidByteData.advanceExact(docId) == false) {
throw new AssertionError("Document without a _uid");
int hash;
if (values == null) {
hash = BitMixer.mix(ctx.docBase + docId);
} else if (values.advanceExact(docId)) {
hash = StringHelper.murmurhash3_x86_32(values.nextValue(), saltedSeed);
} else {
// field has no value
hash = saltedSeed;
}
int hash = StringHelper.murmurhash3_x86_32(uidByteData.nextValue(), saltedSeed);
return (hash & 0x00FFFFFF) / (float)(1 << 24); // only use the lower 24 bits to construct a float from 0.0-1.0
}

@Override
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
String field = fieldData == null ? null : fieldData.getFieldName();
return Explanation.match(
CombineFunction.toFloat(score(docId, subQueryScore.getValue())),
"random score function (seed: " + originalSeed + ")");
"random score function (seed: " + originalSeed + ", field: " + field + ")");
}
};
}
Expand All @@ -94,8 +98,8 @@ public boolean needsScores() {
@Override
protected boolean doEquals(ScoreFunction other) {
RandomScoreFunction randomScoreFunction = (RandomScoreFunction) other;
return this.originalSeed == randomScoreFunction.originalSeed &&
this.saltedSeed == randomScoreFunction.saltedSeed;
return this.originalSeed == randomScoreFunction.originalSeed
&& this.saltedSeed == randomScoreFunction.saltedSeed;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ private ParsedQuery toQuery(QueryBuilder queryBuilder, CheckedFunction<QueryBuil
}
}

public final Index index() {
public Index index() {
return indexSettings.getIndex();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
*/
package org.elasticsearch.index.query.functionscore;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.lucene.search.function.RandomScoreFunction;
import org.elasticsearch.common.lucene.search.function.ScoreFunction;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.IdFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.UidFieldMapper;
Expand All @@ -38,7 +40,11 @@
* A function that computes a random score for the matched documents
*/
public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScoreFunctionBuilder> {

private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(RandomScoreFunctionBuilder.class));

public static final String NAME = "random_score";
private String field;
private Integer seed;

public RandomScoreFunctionBuilder() {
Expand All @@ -52,6 +58,9 @@ public RandomScoreFunctionBuilder(StreamInput in) throws IOException {
if (in.readBoolean()) {
seed = in.readInt();
}
if (in.getVersion().onOrAfter(Version.V_6_0_0_alpha3)) {
field = in.readOptionalString();
}
}

@Override
Expand All @@ -62,6 +71,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (out.getVersion().onOrAfter(Version.V_6_0_0_alpha3)) {
out.writeOptionalString(field);
}
}

@Override
Expand Down Expand Up @@ -105,12 +117,33 @@ public Integer getSeed() {
return seed;
}

/**
* Set the field to be used for random number generation. This parameter is compulsory
* when a {@link #seed(int) seed} is set and ignored otherwise. Note that documents that
* have the same value for a field will get the same score.
*/
public RandomScoreFunctionBuilder setField(String field) {
this.field = field;
return this;
}

/**
* Get the field to use for random number generation.
* @see #setField(String)
*/
public String getField() {
return field;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(getName());
if (seed != null) {
builder.field("seed", seed);
}
if (field != null) {
builder.field("field", field);
}
builder.endObject();
}

Expand All @@ -126,19 +159,39 @@ protected int doHashCode() {

@Override
protected ScoreFunction doToFunction(QueryShardContext context) {
final MappedFieldType fieldType;
if (context.getIndexSettings().isSingleType()) {
fieldType = context.getMapperService().fullName(IdFieldMapper.NAME);
final int salt = (context.index().getName().hashCode() << 10) | context.getShardId();
if (seed == null) {
// DocID-based random score generation
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null);
} else {
fieldType = context.getMapperService().fullName(UidFieldMapper.NAME);
}
if (fieldType == null) {
// mapper could be null if we are on a shard with no docs yet, so this won't actually be used
return new RandomScoreFunction();
final MappedFieldType fieldType;
if (field != null) {
fieldType = context.getMapperService().fullName(field);
} else {
DEPRECATION_LOGGER.deprecated(
"As of version 7.0 Elasticsearch will require that a [field] parameter is provided when a [seed] is set");
if (context.getIndexSettings().isSingleType()) {
fieldType = context.getMapperService().fullName(IdFieldMapper.NAME);
} else {
fieldType = context.getMapperService().fullName(UidFieldMapper.NAME);
}
}
if (fieldType == null) {
if (context.getMapperService().types().isEmpty()) {
// no mappings: the index is empty anyway
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null);
}
throw new IllegalArgumentException("Field [" + field + "] is not mapped on [" + context.index() +
"] and cannot be used as a source of random numbers.");
}
int seed;
if (this.seed != null) {
seed = this.seed;
} else {
seed = hash(context.nowInMillis());
}
return new RandomScoreFunction(seed, salt, context.getForField(fieldType));
}
final int salt = (context.index().getName().hashCode() << 10) | context.getShardId();
final IndexFieldData<?> uidFieldData = context.getForField(fieldType);
return new RandomScoreFunction(this.seed == null ? hash(context.nowInMillis()) : seed, salt, uidFieldData);
}

private static int hash(long value) {
Expand Down Expand Up @@ -170,6 +223,8 @@ public static RandomScoreFunctionBuilder fromXContent(XContentParser parser)
throw new ParsingException(parser.getTokenLocation(), "random_score seed must be an int/long or string, not '"
+ token.toString() + "'");
}
} else if ("field".equals(currentFieldName)) {
randomScoreFunctionBuilder.setField(parser.text());
} else {
throw new ParsingException(parser.getTokenLocation(), NAME + " query does not support [" + currentFieldName + "]");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,8 @@ public static ScriptScoreFunctionBuilder scriptFunction(String script) {
return (new ScriptScoreFunctionBuilder(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, script, emptyMap())));
}

public static RandomScoreFunctionBuilder randomFunction(int seed) {
return (new RandomScoreFunctionBuilder()).seed(seed);
}

public static RandomScoreFunctionBuilder randomFunction(long seed) {
return (new RandomScoreFunctionBuilder()).seed(seed);
}

public static RandomScoreFunctionBuilder randomFunction(String seed) {
return (new RandomScoreFunctionBuilder()).seed(seed);
public static RandomScoreFunctionBuilder randomFunction() {
return new RandomScoreFunctionBuilder();
}

public static WeightBuilder weightFactorFunction(float weight) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.common.xcontent.XContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.SeqNoFieldMapper;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -191,6 +192,7 @@ private static ScoreFunctionBuilder<?> randomScoreFunction() {
} else {
randomScoreFunctionBuilder.seed(randomAlphaOfLengthBetween(1, 10));
}
randomScoreFunctionBuilder.setField(SeqNoFieldMapper.NAME); // guaranteed to exist
}
functionBuilder = randomScoreFunctionBuilder;
break;
Expand Down Expand Up @@ -270,14 +272,14 @@ public void testIllegalArguments() {
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder((QueryBuilder) null));
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder((ScoreFunctionBuilder<?>) null));
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder((FilterFunctionBuilder[]) null));
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder(null, randomFunction(123)));
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder(null, randomFunction()));
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder(matchAllQuery(), (ScoreFunctionBuilder<?>) null));
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder(matchAllQuery(), (FilterFunctionBuilder[]) null));
expectThrows(IllegalArgumentException.class, () -> new FunctionScoreQueryBuilder(null, new FilterFunctionBuilder[0]));
expectThrows(IllegalArgumentException.class,
() -> new FunctionScoreQueryBuilder(matchAllQuery(), new FilterFunctionBuilder[] { null }));
expectThrows(IllegalArgumentException.class, () -> new FilterFunctionBuilder((ScoreFunctionBuilder<?>) null));
expectThrows(IllegalArgumentException.class, () -> new FilterFunctionBuilder(null, randomFunction(123)));
expectThrows(IllegalArgumentException.class, () -> new FilterFunctionBuilder(null, randomFunction()));
expectThrows(IllegalArgumentException.class, () -> new FilterFunctionBuilder(matchAllQuery(), null));
FunctionScoreQueryBuilder builder = new FunctionScoreQueryBuilder(matchAllQuery());
expectThrows(IllegalArgumentException.class, () -> builder.scoreMode(null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ public void closeAllTheReaders() throws IOException {
public void testExplainFunctionScoreQuery() throws IOException {

Explanation functionExplanation = getFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION);
checkFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)");
checkFunctionScoreExplanation(functionExplanation, "random score function (seed: 0, field: test)");
assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0));

functionExplanation = getFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION);
Expand Down Expand Up @@ -331,7 +331,7 @@ public void checkFunctionScoreExplanation(Explanation randomExplanation, String

public void testExplainFiltersFunctionScoreQuery() throws IOException {
Explanation functionExplanation = getFiltersFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION);
checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)", 0);
checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0, field: test)", 0);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));

functionExplanation = getFiltersFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION);
Expand Down Expand Up @@ -366,7 +366,7 @@ public void testExplainFiltersFunctionScoreQuery() throws IOException {
, LIN_DECAY_FUNCTION
);

checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)", 0);
checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0, field: test)", 0);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));

checkFiltersFunctionScoreExplanation(functionExplanation, "field value function: ln(doc['test'].value?:1.0 * factor=1.0)", 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,18 @@

package org.elasticsearch.index.query.functionscore;

import org.elasticsearch.Version;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper.NumberType;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.script.Script;
import org.elasticsearch.test.ESTestCase;
import org.mockito.Mockito;

public class ScoreFunctionBuilderTests extends ESTestCase {

Expand All @@ -39,4 +49,23 @@ public void testIllegalArguments() {
expectThrows(IllegalArgumentException.class, () -> new ExponentialDecayFunctionBuilder("", "", null, ""));
expectThrows(IllegalArgumentException.class, () -> new ExponentialDecayFunctionBuilder("", "", null, "", randomDouble()));
}

public void testRandomScoreFunctionWithSeed() throws Exception {
RandomScoreFunctionBuilder builder = new RandomScoreFunctionBuilder();
builder.seed(42);
QueryShardContext context = Mockito.mock(QueryShardContext.class);
Settings indexSettings = Settings.builder().put(IndexMetaData.SETTING_VERSION_CREATED, Version.CURRENT)
.put(IndexMetaData.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetaData.SETTING_NUMBER_OF_REPLICAS, 1).build();
IndexSettings settings = new IndexSettings(IndexMetaData.builder("index").settings(indexSettings).build(), Settings.EMPTY);
Mockito.when(context.index()).thenReturn(settings.getIndex());
Mockito.when(context.getShardId()).thenReturn(0);
Mockito.when(context.getIndexSettings()).thenReturn(settings);
MapperService mapperService = Mockito.mock(MapperService.class);
MappedFieldType ft = new NumberFieldMapper.NumberFieldType(NumberType.LONG);
ft.setName("foo");
Mockito.when(mapperService.fullName(Mockito.anyString())).thenReturn(ft);
Mockito.when(context.getMapperService()).thenReturn(mapperService);
builder.toFunction(context);
assertWarnings("As of version 7.0 Elasticsearch will require that a [field] parameter is provided when a [seed] is set");
}
}
Loading

0 comments on commit f1ff7f2

Please sign in to comment.