diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java
index 8455f154c0f8..fc8dcb49894f 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetCell.java
@@ -22,7 +22,6 @@
import org.apache.beam.sdk.metrics.MetricName;
import org.apache.beam.sdk.metrics.MetricsContainer;
import org.apache.beam.sdk.metrics.StringSet;
-import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.checkerframework.checker.nullness.qual.Nullable;
/**
@@ -101,11 +100,15 @@ public void add(String value) {
if (this.setValue.get().stringSet().contains(value)) {
return;
}
- update(StringSetData.create(ImmutableSet.of(value)));
+ add(new String[] {value});
}
@Override
public void add(String... values) {
- update(StringSetData.create(ImmutableSet.copyOf(values)));
+ StringSetData original;
+ do {
+ original = setValue.get();
+ } while (!setValue.compareAndSet(original, original.addAll(values)));
+ dirty.afterModification();
}
}
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java
index 466d4ad46eb6..4fc5d3beca31 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java
@@ -19,25 +19,49 @@
import com.google.auto.value.AutoValue;
import java.io.Serializable;
+import java.util.Arrays;
+import java.util.HashSet;
import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.StreamSupport;
import org.apache.beam.sdk.metrics.StringSetResult;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
- * Data describing the StringSet. The {@link StringSetData} hold an immutable copy of the set from
- * which it was initially created. This should retain enough detail that it can be combined with
- * other {@link StringSetData}.
+ * Data describing the StringSet. The {@link StringSetData} hold a copy of the set from which it was
+ * initially created. This should retain enough detail that it can be combined with other {@link
+ * StringSetData}.
+ *
+ *
The underlying set is mutable for {@link #addAll} operation, otherwise a copy set will be
+ * generated.
+ *
+ *
The summation of all string length for a {@code StringSetData} cannot exceed 1 MB. Further
+ * addition of elements are dropped.
*/
@AutoValue
public abstract class StringSetData implements Serializable {
+ private static final Logger LOG = LoggerFactory.getLogger(StringSetData.class);
+ // 1 MB
+ @VisibleForTesting static final long STRING_SET_SIZE_LIMIT = 1_000_000L;
public abstract Set stringSet();
+ public abstract long stringSize();
+
/** Returns a {@link StringSetData} which is made from an immutable copy of the given set. */
public static StringSetData create(Set set) {
- return new AutoValue_StringSetData(ImmutableSet.copyOf(set));
+ if (set.isEmpty()) {
+ return empty();
+ }
+ HashSet combined = new HashSet<>();
+ long stringSize = addUntilCapacity(combined, 0L, set);
+ return new AutoValue_StringSetData(combined, stringSize);
+ }
+
+ /** Returns a {@link StringSetData} which is made from the given set in place. */
+ private static StringSetData createInPlace(HashSet set, long stringSize) {
+ return new AutoValue_StringSetData(set, stringSize);
}
/** Return a {@link EmptyStringSetData#INSTANCE} representing an empty {@link StringSetData}. */
@@ -45,6 +69,23 @@ public static StringSetData empty() {
return EmptyStringSetData.INSTANCE;
}
+ /**
+ * Add strings into this {@code StringSetData} and return the result {@code StringSetData}. Reuse
+ * the original StringSetData's set. As a result, current StringSetData will become invalid.
+ *
+ * >Should only be used by {@link StringSetCell#add}.
+ */
+ public StringSetData addAll(String... strings) {
+ HashSet combined;
+ if (this.stringSet() instanceof HashSet) {
+ combined = (HashSet) this.stringSet();
+ } else {
+ combined = new HashSet<>(this.stringSet());
+ }
+ long stringSize = addUntilCapacity(combined, this.stringSize(), Arrays.asList(strings));
+ return StringSetData.createInPlace(combined, stringSize);
+ }
+
/**
* Combines this {@link StringSetData} with other, both original StringSetData are left intact.
*/
@@ -54,10 +95,9 @@ public StringSetData combine(StringSetData other) {
} else if (other.stringSet().isEmpty()) {
return this;
} else {
- ImmutableSet.Builder combined = ImmutableSet.builder();
- combined.addAll(this.stringSet());
- combined.addAll(other.stringSet());
- return StringSetData.create(combined.build());
+ HashSet combined = new HashSet<>(this.stringSet());
+ long stringSize = addUntilCapacity(combined, this.stringSize(), other.stringSet());
+ return StringSetData.createInPlace(combined, stringSize);
}
}
@@ -65,12 +105,12 @@ public StringSetData combine(StringSetData other) {
* Combines this {@link StringSetData} with others, all original StringSetData are left intact.
*/
public StringSetData combine(Iterable others) {
- Set combined =
- StreamSupport.stream(others.spliterator(), true)
- .flatMap(other -> other.stringSet().stream())
- .collect(Collectors.toSet());
- combined.addAll(this.stringSet());
- return StringSetData.create(combined);
+ HashSet combined = new HashSet<>(this.stringSet());
+ long stringSize = this.stringSize();
+ for (StringSetData other : others) {
+ stringSize = addUntilCapacity(combined, stringSize, other.stringSet());
+ }
+ return StringSetData.createInPlace(combined, stringSize);
}
/** Returns a {@link StringSetResult} representing this {@link StringSetData}. */
@@ -78,6 +118,31 @@ public StringSetResult extractResult() {
return StringSetResult.create(stringSet());
}
+ /** Add strings into set until reach capacity. Return the all string size of added set. */
+ private static long addUntilCapacity(
+ HashSet combined, long currentSize, Iterable others) {
+ if (currentSize > STRING_SET_SIZE_LIMIT) {
+ // already at capacity
+ return currentSize;
+ }
+ for (String string : others) {
+ if (combined.add(string)) {
+ currentSize += string.length();
+
+ // check capacity both before insert and after insert one, so the warning only emit once.
+ if (currentSize > STRING_SET_SIZE_LIMIT) {
+ LOG.warn(
+ "StringSet metrics reaches capacity. Further incoming elements won't be recorded."
+ + " Current size: {}, last element size: {}.",
+ currentSize,
+ string.length());
+ break;
+ }
+ }
+ }
+ return currentSize;
+ }
+
/** Empty {@link StringSetData}, representing no values reported and is immutable. */
public static class EmptyStringSetData extends StringSetData {
@@ -91,6 +156,11 @@ public Set stringSet() {
return ImmutableSet.of();
}
+ @Override
+ public long stringSize() {
+ return 0L;
+ }
+
/** Return a {@link StringSetResult#empty()} which is immutable empty set. */
@Override
public StringSetResult extractResult() {
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java
index 665ce3743c51..534db203ff3c 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetDataTest.java
@@ -22,6 +22,7 @@
import static org.junit.Assert.assertTrue;
import java.util.Collections;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.junit.Rule;
import org.junit.Test;
@@ -81,6 +82,14 @@ public void testStringSetDataEmptyIsImmutable() {
assertThrows(UnsupportedOperationException.class, () -> empty.stringSet().add("aa"));
}
+ @Test
+ public void testStringSetDataEmptyCanAdd() {
+ ImmutableSet contents = ImmutableSet.of("ab", "cd");
+ StringSetData stringSetData = StringSetData.empty();
+ stringSetData = stringSetData.addAll(contents.toArray(new String[] {}));
+ assertEquals(stringSetData.stringSet(), contents);
+ }
+
@Test
public void testEmptyExtract() {
assertTrue(StringSetData.empty().extractResult().getStringSet().isEmpty());
@@ -94,9 +103,26 @@ public void testExtract() {
}
@Test
- public void testExtractReturnsImmutable() {
- StringSetData stringSetData = StringSetData.create(ImmutableSet.of("ab", "cd"));
- // check that immutable copy is returned
- assertThrows(UnsupportedOperationException.class, () -> stringSetData.stringSet().add("aa"));
+ public void testStringSetAddUntilCapacity() {
+ StringSetData combined = StringSetData.empty();
+ @SuppressWarnings("InlineMeInliner") // Inline representation is Java11+ only
+ String commonPrefix = Strings.repeat("*", 1000);
+ long stringSize = 0;
+ for (int i = 0; i < 1000; ++i) {
+ String s = commonPrefix + i;
+ stringSize += s.length();
+ combined = combined.addAll(s);
+ }
+ assertTrue(combined.stringSize() < stringSize);
+ assertTrue(combined.stringSize() > StringSetData.STRING_SET_SIZE_LIMIT);
+ }
+
+ @Test
+ public void testStringSetAddSizeTrackedCorrectly() {
+ StringSetData combined = StringSetData.empty();
+ combined = combined.addAll("a", "b", "c", "b");
+ assertEquals(3, combined.stringSize());
+ combined = combined.addAll("c", "d", "e");
+ assertEquals(5, combined.stringSize());
}
}
diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd
index a8f4003d8980..98bb5eff0977 100644
--- a/sdks/python/apache_beam/metrics/cells.pxd
+++ b/sdks/python/apache_beam/metrics/cells.pxd
@@ -45,7 +45,7 @@ cdef class GaugeCell(MetricCell):
cdef class StringSetCell(MetricCell):
- cdef readonly set data
+ cdef readonly object data
cdef inline bint _update(self, value) except -1
diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py
index 407106342fb8..63fc9f3f7cc9 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -23,11 +23,14 @@
# pytype: skip-file
+import logging
import threading
import time
from datetime import datetime
from typing import Any
+from typing import Iterable
from typing import Optional
+from typing import Set
from typing import SupportsInt
try:
@@ -47,6 +50,8 @@ class fake_cython:
'GaugeResult'
]
+_LOGGER = logging.getLogger(__name__)
+
class MetricCell(object):
"""For internal use only; no backwards-compatibility guarantees.
@@ -297,9 +302,9 @@ def _update(self, value):
self.data.add(value)
def get_cumulative(self):
- # type: () -> set
+ # type: () -> StringSetData
with self._lock:
- return set(self.data)
+ return self.data.get_cumulative()
def combine(self, other):
# type: (StringSetCell) -> StringSetCell
@@ -522,6 +527,98 @@ def singleton(value):
return DistributionData(value, 1, value, value)
+class StringSetData(object):
+ """For internal use only; no backwards-compatibility guarantees.
+
+ The data structure that holds data about a StringSet metric.
+
+ StringSet metrics are restricted to set of strings only.
+
+ This object is not thread safe, so it's not supposed to be modified
+ by other than the StringSetCell that contains it.
+
+ The summation of all string length for a StringSetData cannot exceed 1 MB.
+ Further addition of elements are dropped.
+ """
+
+ _STRING_SET_SIZE_LIMIT = 1_000_000
+
+ def __init__(self, string_set: Optional[Set] = None, string_size: int = 0):
+ self.string_set = string_set or set()
+ if not string_size:
+ string_size = 0
+ for s in self.string_set:
+ string_size += len(s)
+ self.string_size = string_size
+
+ def __eq__(self, other: object) -> bool:
+ if isinstance(other, StringSetData):
+ return (
+ self.string_size == other.string_size and
+ self.string_set == other.string_set)
+ else:
+ return False
+
+ def __hash__(self) -> int:
+ return hash(self.string_set)
+
+ def __repr__(self) -> str:
+ return 'StringSetData{}:{}'.format(self.string_set, self.string_size)
+
+ def get_cumulative(self) -> "StringSetData":
+ return StringSetData(set(self.string_set), self.string_size)
+
+ def add(self, *strings):
+ """
+ Add strings into this StringSetData and return the result StringSetData.
+ Reuse the original StringSetData's set.
+ """
+ self.string_size = self.add_until_capacity(
+ self.string_set, self.string_size, strings)
+ return self
+
+ def combine(self, other: "StringSetData") -> "StringSetData":
+ """
+ Combines this StringSetData with other, both original StringSetData are left
+ intact.
+ """
+ if other is None:
+ return self
+
+ combined = set(self.string_set)
+ string_size = self.add_until_capacity(
+ combined, self.string_size, other.string_set)
+ return StringSetData(combined, string_size)
+
+ @classmethod
+ def add_until_capacity(
+ cls, combined: set, current_size: int, others: Iterable[str]):
+ """
+ Add strings into set until reach capacity. Return the all string size of
+ added set.
+ """
+ if current_size > cls._STRING_SET_SIZE_LIMIT:
+ return current_size
+
+ for string in others:
+ if string not in combined:
+ combined.add(string)
+ current_size += len(string)
+ if current_size > cls._STRING_SET_SIZE_LIMIT:
+ _LOGGER.warning(
+ "StringSet metrics reaches capacity. Further incoming elements "
+ "won't be recorded. Current size: %d, last element size: %d.",
+ current_size,
+ len(string))
+ break
+ return current_size
+
+ @staticmethod
+ def singleton(value):
+ # type: (int) -> DistributionData
+ return DistributionData(value, 1, value, value)
+
+
class MetricAggregator(object):
"""For internal use only; no backwards-compatibility guarantees.
@@ -612,17 +709,18 @@ def result(self, x):
class StringSetAggregator(MetricAggregator):
@staticmethod
def identity_element():
- # type: () -> set
- return set()
+ # type: () -> StringSetData
+ return StringSetData()
def combine(self, x, y):
- # type: (set, set) -> set
- if len(x) == 0:
+ # type: (StringSetData, StringSetData) -> StringSetData
+ if len(x.string_set) == 0:
return y
- elif len(y) == 0:
+ elif len(y.string_set) == 0:
return x
else:
- return set.union(x, y)
+ return x.combine(y)
def result(self, x):
- return x
+ # type: (StringSetData) -> set
+ return set(x.string_set)
diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py
index 052ff051bf96..d1ee37b8ed82 100644
--- a/sdks/python/apache_beam/metrics/cells_test.py
+++ b/sdks/python/apache_beam/metrics/cells_test.py
@@ -26,6 +26,7 @@
from apache_beam.metrics.cells import GaugeCell
from apache_beam.metrics.cells import GaugeData
from apache_beam.metrics.cells import StringSetCell
+from apache_beam.metrics.cells import StringSetData
from apache_beam.metrics.metricbase import MetricName
@@ -176,9 +177,9 @@ def test_not_leak_mutable_set(self):
c.add('test')
c.add('another')
s = c.get_cumulative()
- self.assertEqual(s, set(('test', 'another')))
+ self.assertEqual(s, StringSetData({'test', 'another'}, 11))
s.add('yet another')
- self.assertEqual(c.get_cumulative(), set(('test', 'another')))
+ self.assertEqual(c.get_cumulative(), StringSetData({'test', 'another'}, 11))
def test_combine_appropriately(self):
s1 = StringSetCell()
@@ -190,7 +191,16 @@ def test_combine_appropriately(self):
s2.add('3')
result = s2.combine(s1)
- self.assertEqual(result.data, set(('1', '2', '3')))
+ self.assertEqual(result.data, StringSetData({'1', '2', '3'}))
+
+ def test_add_size_tracked_correctly(self):
+ s = StringSetCell()
+ s.add('1')
+ s.add('2')
+ self.assertEqual(s.data.string_size, 2)
+ s.add('2')
+ s.add('3')
+ self.assertEqual(s.data.string_size, 3)
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py
index 37007add9163..fa70d3a4d9c0 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -47,6 +47,7 @@
from apache_beam.metrics.cells import DistributionCell
from apache_beam.metrics.cells import GaugeCell
from apache_beam.metrics.cells import StringSetCell
+from apache_beam.metrics.cells import StringSetData
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.statesampler import get_current_tracker
@@ -356,7 +357,7 @@ def __init__(
counters=None, # type: Optional[Dict[MetricKey, int]]
distributions=None, # type: Optional[Dict[MetricKey, DistributionData]]
gauges=None, # type: Optional[Dict[MetricKey, GaugeData]]
- string_sets=None, # type: Optional[Dict[MetricKey, set]]
+ string_sets=None, # type: Optional[Dict[MetricKey, StringSetData]]
):
# type: (...) -> None
diff --git a/sdks/python/apache_beam/metrics/execution_test.py b/sdks/python/apache_beam/metrics/execution_test.py
index b157aeb20e9e..38e27f1f3d0c 100644
--- a/sdks/python/apache_beam/metrics/execution_test.py
+++ b/sdks/python/apache_beam/metrics/execution_test.py
@@ -110,11 +110,12 @@ def test_get_cumulative_or_updates(self):
self.assertEqual(
set(all_values), {v.value
for _, v in cumulative.gauges.items()})
- self.assertEqual({str(i % 7)
- for i in all_values},
- functools.reduce(
- set.union,
- (v for _, v in cumulative.string_sets.items())))
+ self.assertEqual(
+ {str(i % 7)
+ for i in all_values},
+ functools.reduce(
+ set.union,
+ (v.string_set for _, v in cumulative.string_sets.items())))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py b/sdks/python/apache_beam/metrics/monitoring_infos.py
index a9540f2846ad..09cb350b3826 100644
--- a/sdks/python/apache_beam/metrics/monitoring_infos.py
+++ b/sdks/python/apache_beam/metrics/monitoring_infos.py
@@ -31,6 +31,7 @@
from apache_beam.metrics.cells import DistributionResult
from apache_beam.metrics.cells import GaugeData
from apache_beam.metrics.cells import GaugeResult
+from apache_beam.metrics.cells import StringSetData
from apache_beam.portability import common_urns
from apache_beam.portability.api import metrics_pb2
@@ -305,10 +306,12 @@ def user_set_string(namespace, name, metric, ptransform=None):
Args:
namespace: User-defined namespace of StringSet.
name: Name of StringSet.
- metric: The set representing the metrics.
+ metric: The StringSetData representing the metrics.
ptransform: The ptransform id used as a label.
"""
labels = create_labels(ptransform=ptransform, namespace=namespace, name=name)
+ if isinstance(metric, StringSetData):
+ metric = metric.string_set
if isinstance(metric, set):
metric = list(metric)
if isinstance(metric, list):