diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java index 8455f154c0f8..fc8dcb49894f 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java @@ -22,7 +22,6 @@ import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.metrics.StringSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -101,11 +100,15 @@ public void add(String value) { if (this.setValue.get().stringSet().contains(value)) { return; } - update(StringSetData.create(ImmutableSet.of(value))); + add(new String[] {value}); } @Override public void add(String... values) { - update(StringSetData.create(ImmutableSet.copyOf(values))); + StringSetData original; + do { + original = setValue.get(); + } while (!setValue.compareAndSet(original, original.addAll(values))); + dirty.afterModification(); } } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java index 466d4ad46eb6..4fc5d3beca31 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java @@ -19,25 +19,49 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; +import java.util.Arrays; +import java.util.HashSet; import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; import org.apache.beam.sdk.metrics.StringSetResult; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * Data describing the StringSet. The {@link StringSetData} hold an immutable copy of the set from - * which it was initially created. This should retain enough detail that it can be combined with - * other {@link StringSetData}. + * Data describing the StringSet. The {@link StringSetData} hold a copy of the set from which it was + * initially created. This should retain enough detail that it can be combined with other {@link + * StringSetData}. + * + *

The underlying set is mutable for {@link #addAll} operation, otherwise a copy set will be + * generated. + * + *

The summation of all string length for a {@code StringSetData} cannot exceed 1 MB. Further + * addition of elements are dropped. */ @AutoValue public abstract class StringSetData implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(StringSetData.class); + // 1 MB + @VisibleForTesting static final long STRING_SET_SIZE_LIMIT = 1_000_000L; public abstract Set stringSet(); + public abstract long stringSize(); + /** Returns a {@link StringSetData} which is made from an immutable copy of the given set. */ public static StringSetData create(Set set) { - return new AutoValue_StringSetData(ImmutableSet.copyOf(set)); + if (set.isEmpty()) { + return empty(); + } + HashSet combined = new HashSet<>(); + long stringSize = addUntilCapacity(combined, 0L, set); + return new AutoValue_StringSetData(combined, stringSize); + } + + /** Returns a {@link StringSetData} which is made from the given set in place. */ + private static StringSetData createInPlace(HashSet set, long stringSize) { + return new AutoValue_StringSetData(set, stringSize); } /** Return a {@link EmptyStringSetData#INSTANCE} representing an empty {@link StringSetData}. */ @@ -45,6 +69,23 @@ public static StringSetData empty() { return EmptyStringSetData.INSTANCE; } + /** + * Add strings into this {@code StringSetData} and return the result {@code StringSetData}. Reuse + * the original StringSetData's set. As a result, current StringSetData will become invalid. + * + *

>Should only be used by {@link StringSetCell#add}. + */ + public StringSetData addAll(String... strings) { + HashSet combined; + if (this.stringSet() instanceof HashSet) { + combined = (HashSet) this.stringSet(); + } else { + combined = new HashSet<>(this.stringSet()); + } + long stringSize = addUntilCapacity(combined, this.stringSize(), Arrays.asList(strings)); + return StringSetData.createInPlace(combined, stringSize); + } + /** * Combines this {@link StringSetData} with other, both original StringSetData are left intact. */ @@ -54,10 +95,9 @@ public StringSetData combine(StringSetData other) { } else if (other.stringSet().isEmpty()) { return this; } else { - ImmutableSet.Builder combined = ImmutableSet.builder(); - combined.addAll(this.stringSet()); - combined.addAll(other.stringSet()); - return StringSetData.create(combined.build()); + HashSet combined = new HashSet<>(this.stringSet()); + long stringSize = addUntilCapacity(combined, this.stringSize(), other.stringSet()); + return StringSetData.createInPlace(combined, stringSize); } } @@ -65,12 +105,12 @@ public StringSetData combine(StringSetData other) { * Combines this {@link StringSetData} with others, all original StringSetData are left intact. */ public StringSetData combine(Iterable others) { - Set combined = - StreamSupport.stream(others.spliterator(), true) - .flatMap(other -> other.stringSet().stream()) - .collect(Collectors.toSet()); - combined.addAll(this.stringSet()); - return StringSetData.create(combined); + HashSet combined = new HashSet<>(this.stringSet()); + long stringSize = this.stringSize(); + for (StringSetData other : others) { + stringSize = addUntilCapacity(combined, stringSize, other.stringSet()); + } + return StringSetData.createInPlace(combined, stringSize); } /** Returns a {@link StringSetResult} representing this {@link StringSetData}. */ @@ -78,6 +118,31 @@ public StringSetResult extractResult() { return StringSetResult.create(stringSet()); } + /** Add strings into set until reach capacity. Return the all string size of added set. */ + private static long addUntilCapacity( + HashSet combined, long currentSize, Iterable others) { + if (currentSize > STRING_SET_SIZE_LIMIT) { + // already at capacity + return currentSize; + } + for (String string : others) { + if (combined.add(string)) { + currentSize += string.length(); + + // check capacity both before insert and after insert one, so the warning only emit once. + if (currentSize > STRING_SET_SIZE_LIMIT) { + LOG.warn( + "StringSet metrics reaches capacity. Further incoming elements won't be recorded." + + " Current size: {}, last element size: {}.", + currentSize, + string.length()); + break; + } + } + } + return currentSize; + } + /** Empty {@link StringSetData}, representing no values reported and is immutable. */ public static class EmptyStringSetData extends StringSetData { @@ -91,6 +156,11 @@ public Set stringSet() { return ImmutableSet.of(); } + @Override + public long stringSize() { + return 0L; + } + /** Return a {@link StringSetResult#empty()} which is immutable empty set. */ @Override public StringSetResult extractResult() { diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java index 665ce3743c51..534db203ff3c 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.Collections; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Rule; import org.junit.Test; @@ -81,6 +82,14 @@ public void testStringSetDataEmptyIsImmutable() { assertThrows(UnsupportedOperationException.class, () -> empty.stringSet().add("aa")); } + @Test + public void testStringSetDataEmptyCanAdd() { + ImmutableSet contents = ImmutableSet.of("ab", "cd"); + StringSetData stringSetData = StringSetData.empty(); + stringSetData = stringSetData.addAll(contents.toArray(new String[] {})); + assertEquals(stringSetData.stringSet(), contents); + } + @Test public void testEmptyExtract() { assertTrue(StringSetData.empty().extractResult().getStringSet().isEmpty()); @@ -94,9 +103,26 @@ public void testExtract() { } @Test - public void testExtractReturnsImmutable() { - StringSetData stringSetData = StringSetData.create(ImmutableSet.of("ab", "cd")); - // check that immutable copy is returned - assertThrows(UnsupportedOperationException.class, () -> stringSetData.stringSet().add("aa")); + public void testStringSetAddUntilCapacity() { + StringSetData combined = StringSetData.empty(); + @SuppressWarnings("InlineMeInliner") // Inline representation is Java11+ only + String commonPrefix = Strings.repeat("*", 1000); + long stringSize = 0; + for (int i = 0; i < 1000; ++i) { + String s = commonPrefix + i; + stringSize += s.length(); + combined = combined.addAll(s); + } + assertTrue(combined.stringSize() < stringSize); + assertTrue(combined.stringSize() > StringSetData.STRING_SET_SIZE_LIMIT); + } + + @Test + public void testStringSetAddSizeTrackedCorrectly() { + StringSetData combined = StringSetData.empty(); + combined = combined.addAll("a", "b", "c", "b"); + assertEquals(3, combined.stringSize()); + combined = combined.addAll("c", "d", "e"); + assertEquals(5, combined.stringSize()); } } diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd index a8f4003d8980..98bb5eff0977 100644 --- a/sdks/python/apache_beam/metrics/cells.pxd +++ b/sdks/python/apache_beam/metrics/cells.pxd @@ -45,7 +45,7 @@ cdef class GaugeCell(MetricCell): cdef class StringSetCell(MetricCell): - cdef readonly set data + cdef readonly object data cdef inline bint _update(self, value) except -1 diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 407106342fb8..63fc9f3f7cc9 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -23,11 +23,14 @@ # pytype: skip-file +import logging import threading import time from datetime import datetime from typing import Any +from typing import Iterable from typing import Optional +from typing import Set from typing import SupportsInt try: @@ -47,6 +50,8 @@ class fake_cython: 'GaugeResult' ] +_LOGGER = logging.getLogger(__name__) + class MetricCell(object): """For internal use only; no backwards-compatibility guarantees. @@ -297,9 +302,9 @@ def _update(self, value): self.data.add(value) def get_cumulative(self): - # type: () -> set + # type: () -> StringSetData with self._lock: - return set(self.data) + return self.data.get_cumulative() def combine(self, other): # type: (StringSetCell) -> StringSetCell @@ -522,6 +527,98 @@ def singleton(value): return DistributionData(value, 1, value, value) +class StringSetData(object): + """For internal use only; no backwards-compatibility guarantees. + + The data structure that holds data about a StringSet metric. + + StringSet metrics are restricted to set of strings only. + + This object is not thread safe, so it's not supposed to be modified + by other than the StringSetCell that contains it. + + The summation of all string length for a StringSetData cannot exceed 1 MB. + Further addition of elements are dropped. + """ + + _STRING_SET_SIZE_LIMIT = 1_000_000 + + def __init__(self, string_set: Optional[Set] = None, string_size: int = 0): + self.string_set = string_set or set() + if not string_size: + string_size = 0 + for s in self.string_set: + string_size += len(s) + self.string_size = string_size + + def __eq__(self, other: object) -> bool: + if isinstance(other, StringSetData): + return ( + self.string_size == other.string_size and + self.string_set == other.string_set) + else: + return False + + def __hash__(self) -> int: + return hash(self.string_set) + + def __repr__(self) -> str: + return 'StringSetData{}:{}'.format(self.string_set, self.string_size) + + def get_cumulative(self) -> "StringSetData": + return StringSetData(set(self.string_set), self.string_size) + + def add(self, *strings): + """ + Add strings into this StringSetData and return the result StringSetData. + Reuse the original StringSetData's set. + """ + self.string_size = self.add_until_capacity( + self.string_set, self.string_size, strings) + return self + + def combine(self, other: "StringSetData") -> "StringSetData": + """ + Combines this StringSetData with other, both original StringSetData are left + intact. + """ + if other is None: + return self + + combined = set(self.string_set) + string_size = self.add_until_capacity( + combined, self.string_size, other.string_set) + return StringSetData(combined, string_size) + + @classmethod + def add_until_capacity( + cls, combined: set, current_size: int, others: Iterable[str]): + """ + Add strings into set until reach capacity. Return the all string size of + added set. + """ + if current_size > cls._STRING_SET_SIZE_LIMIT: + return current_size + + for string in others: + if string not in combined: + combined.add(string) + current_size += len(string) + if current_size > cls._STRING_SET_SIZE_LIMIT: + _LOGGER.warning( + "StringSet metrics reaches capacity. Further incoming elements " + "won't be recorded. Current size: %d, last element size: %d.", + current_size, + len(string)) + break + return current_size + + @staticmethod + def singleton(value): + # type: (int) -> DistributionData + return DistributionData(value, 1, value, value) + + class MetricAggregator(object): """For internal use only; no backwards-compatibility guarantees. @@ -612,17 +709,18 @@ def result(self, x): class StringSetAggregator(MetricAggregator): @staticmethod def identity_element(): - # type: () -> set - return set() + # type: () -> StringSetData + return StringSetData() def combine(self, x, y): - # type: (set, set) -> set - if len(x) == 0: + # type: (StringSetData, StringSetData) -> StringSetData + if len(x.string_set) == 0: return y - elif len(y) == 0: + elif len(y.string_set) == 0: return x else: - return set.union(x, y) + return x.combine(y) def result(self, x): - return x + # type: (StringSetData) -> set + return set(x.string_set) diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py index 052ff051bf96..d1ee37b8ed82 100644 --- a/sdks/python/apache_beam/metrics/cells_test.py +++ b/sdks/python/apache_beam/metrics/cells_test.py @@ -26,6 +26,7 @@ from apache_beam.metrics.cells import GaugeCell from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import StringSetCell +from apache_beam.metrics.cells import StringSetData from apache_beam.metrics.metricbase import MetricName @@ -176,9 +177,9 @@ def test_not_leak_mutable_set(self): c.add('test') c.add('another') s = c.get_cumulative() - self.assertEqual(s, set(('test', 'another'))) + self.assertEqual(s, StringSetData({'test', 'another'}, 11)) s.add('yet another') - self.assertEqual(c.get_cumulative(), set(('test', 'another'))) + self.assertEqual(c.get_cumulative(), StringSetData({'test', 'another'}, 11)) def test_combine_appropriately(self): s1 = StringSetCell() @@ -190,7 +191,16 @@ def test_combine_appropriately(self): s2.add('3') result = s2.combine(s1) - self.assertEqual(result.data, set(('1', '2', '3'))) + self.assertEqual(result.data, StringSetData({'1', '2', '3'})) + + def test_add_size_tracked_correctly(self): + s = StringSetCell() + s.add('1') + s.add('2') + self.assertEqual(s.data.string_size, 2) + s.add('2') + s.add('3') + self.assertEqual(s.data.string_size, 3) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index 37007add9163..fa70d3a4d9c0 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -47,6 +47,7 @@ from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import GaugeCell from apache_beam.metrics.cells import StringSetCell +from apache_beam.metrics.cells import StringSetData from apache_beam.runners.worker import statesampler from apache_beam.runners.worker.statesampler import get_current_tracker @@ -356,7 +357,7 @@ def __init__( counters=None, # type: Optional[Dict[MetricKey, int]] distributions=None, # type: Optional[Dict[MetricKey, DistributionData]] gauges=None, # type: Optional[Dict[MetricKey, GaugeData]] - string_sets=None, # type: Optional[Dict[MetricKey, set]] + string_sets=None, # type: Optional[Dict[MetricKey, StringSetData]] ): # type: (...) -> None diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py index b157aeb20e9e..38e27f1f3d0c 100644 --- a/sdks/python/apache_beam/metrics/execution_test.py +++ b/sdks/python/apache_beam/metrics/execution_test.py @@ -110,11 +110,12 @@ def test_get_cumulative_or_updates(self): self.assertEqual( set(all_values), {v.value for _, v in cumulative.gauges.items()}) - self.assertEqual({str(i % 7) - for i in all_values}, - functools.reduce( - set.union, - (v for _, v in cumulative.string_sets.items()))) + self.assertEqual( + {str(i % 7) + for i in all_values}, + functools.reduce( + set.union, + (v.string_set for _, v in cumulative.string_sets.items()))) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py index a9540f2846ad..09cb350b3826 100644 --- a/sdks/python/apache_beam/metrics/monitoring_infos.py +++ b/sdks/python/apache_beam/metrics/monitoring_infos.py @@ -31,6 +31,7 @@ from apache_beam.metrics.cells import DistributionResult from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import GaugeResult +from apache_beam.metrics.cells import StringSetData from apache_beam.portability import common_urns from apache_beam.portability.api import metrics_pb2 @@ -305,10 +306,12 @@ def user_set_string(namespace, name, metric, ptransform=None): Args: namespace: User-defined namespace of StringSet. name: Name of StringSet. - metric: The set representing the metrics. + metric: The StringSetData representing the metrics. ptransform: The ptransform id used as a label. """ labels = create_labels(ptransform=ptransform, namespace=namespace, name=name) + if isinstance(metric, StringSetData): + metric = metric.string_set if isinstance(metric, set): metric = list(metric) if isinstance(metric, list):