Skip to content

Commit

Permalink
Merge pull request dpkp#268 from se7entyse7en/keyed_message
Browse files Browse the repository at this point in the history
Pass key to message sent by `KeyedProducer`
  • Loading branch information
wizzat committed Nov 26, 2014
2 parents 52ec078 + a9e77bd commit 3689529
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 27 deletions.
18 changes: 13 additions & 5 deletions kafka/producer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
# timeout is reached
while count > 0 and timeout >= 0:
try:
topic_partition, msg = queue.get(timeout=timeout)
topic_partition, msg, key = queue.get(timeout=timeout)

except Empty:
break
Expand All @@ -67,7 +67,7 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
# Send collected requests upstream
reqs = []
for topic_partition, msg in msgset.items():
messages = create_message_set(msg, codec)
messages = create_message_set(msg, codec, key)
req = ProduceRequest(topic_partition.topic,
topic_partition.partition,
messages)
Expand Down Expand Up @@ -169,6 +169,10 @@ def send_messages(self, topic, partition, *msg):
All messages produced via this method will set the message 'key' to Null
"""
return self._send_messages(topic, partition, *msg)

def _send_messages(self, topic, partition, *msg, **kwargs):
key = kwargs.pop('key', None)

# Guarantee that msg is actually a list or tuple (should always be true)
if not isinstance(msg, (list, tuple)):
Expand All @@ -178,12 +182,16 @@ def send_messages(self, topic, partition, *msg):
if any(not isinstance(m, six.binary_type) for m in msg):
raise TypeError("all produce message payloads must be type bytes")

# Raise TypeError if the key is not encoded as bytes
if key is not None and not isinstance(key, six.binary_type):
raise TypeError("the key must be type bytes")

if self.async:
for m in msg:
self.queue.put((TopicAndPartition(topic, partition), m))
self.queue.put((TopicAndPartition(topic, partition), m, key))
resp = []
else:
messages = create_message_set(msg, self.codec)
messages = create_message_set(msg, self.codec, key)
req = ProduceRequest(topic, partition, messages)
try:
resp = self.client.send_produce_request([req], acks=self.req_acks,
Expand All @@ -199,7 +207,7 @@ def stop(self, timeout=1):
forcefully cleaning up.
"""
if self.async:
self.queue.put((STOP_ASYNC_PRODUCER, None))
self.queue.put((STOP_ASYNC_PRODUCER, None, None))
self.proc.join(timeout)

if self.proc.is_alive():
Expand Down
2 changes: 1 addition & 1 deletion kafka/producer/keyed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _next_partition(self, topic, key):

def send(self, topic, key, msg):
partition = self._next_partition(topic, key)
return self.send_messages(topic, partition, msg)
return self._send_messages(topic, partition, msg, key=key)

def __repr__(self):
return '<KeyedProducer batch=%s>' % self.async
8 changes: 4 additions & 4 deletions kafka/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,17 +597,17 @@ def create_snappy_message(payloads, key=None):
return Message(0, 0x00 | codec, key, snapped)


def create_message_set(messages, codec=CODEC_NONE):
def create_message_set(messages, codec=CODEC_NONE, key=None):
"""Create a message set using the given codec.
If codec is CODEC_NONE, return a list of raw Kafka messages. Otherwise,
return a list containing a single codec-encoded message.
"""
if codec == CODEC_NONE:
return [create_message(m) for m in messages]
return [create_message(m, key) for m in messages]
elif codec == CODEC_GZIP:
return [create_gzip_message(messages)]
return [create_gzip_message(messages, key)]
elif codec == CODEC_SNAPPY:
return [create_snappy_message(messages)]
return [create_snappy_message(messages, key)]
else:
raise UnsupportedCodecError("Codec 0x%02x unsupported" % codec)
42 changes: 25 additions & 17 deletions test/test_producer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def test_round_robin_partitioner(self):
start_offset1 = self.current_offset(self.topic, 1)

producer = KeyedProducer(self.client, partitioner=RoundRobinPartitioner)
resp1 = producer.send(self.topic, "key1", self.msg("one"))
resp2 = producer.send(self.topic, "key2", self.msg("two"))
resp3 = producer.send(self.topic, "key3", self.msg("three"))
resp4 = producer.send(self.topic, "key4", self.msg("four"))
resp1 = producer.send(self.topic, self.key("key1"), self.msg("one"))
resp2 = producer.send(self.topic, self.key("key2"), self.msg("two"))
resp3 = producer.send(self.topic, self.key("key3"), self.msg("three"))
resp4 = producer.send(self.topic, self.key("key4"), self.msg("four"))

self.assert_produce_response(resp1, start_offset0+0)
self.assert_produce_response(resp2, start_offset1+0)
Expand All @@ -220,20 +220,28 @@ def test_hashed_partitioner(self):
start_offset1 = self.current_offset(self.topic, 1)

producer = KeyedProducer(self.client, partitioner=HashedPartitioner)
resp1 = producer.send(self.topic, 1, self.msg("one"))
resp2 = producer.send(self.topic, 2, self.msg("two"))
resp3 = producer.send(self.topic, 3, self.msg("three"))
resp4 = producer.send(self.topic, 3, self.msg("four"))
resp5 = producer.send(self.topic, 4, self.msg("five"))
resp1 = producer.send(self.topic, self.key("1"), self.msg("one"))
resp2 = producer.send(self.topic, self.key("2"), self.msg("two"))
resp3 = producer.send(self.topic, self.key("3"), self.msg("three"))
resp4 = producer.send(self.topic, self.key("3"), self.msg("four"))
resp5 = producer.send(self.topic, self.key("4"), self.msg("five"))

self.assert_produce_response(resp1, start_offset1+0)
self.assert_produce_response(resp2, start_offset0+0)
self.assert_produce_response(resp3, start_offset1+1)
self.assert_produce_response(resp4, start_offset1+2)
self.assert_produce_response(resp5, start_offset0+1)
offsets = {0: start_offset0, 1: start_offset1}
messages = {0: [], 1: []}

self.assert_fetch_offset(0, start_offset0, [ self.msg("two"), self.msg("five") ])
self.assert_fetch_offset(1, start_offset1, [ self.msg("one"), self.msg("three"), self.msg("four") ])
keys = [self.key(k) for k in ["1", "2", "3", "3", "4"]]
resps = [resp1, resp2, resp3, resp4, resp5]
msgs = [self.msg(m) for m in ["one", "two", "three", "four", "five"]]

for key, resp, msg in zip(keys, resps, msgs):
k = hash(key) % 2
offset = offsets[k]
self.assert_produce_response(resp, offset)
offsets[k] += 1
messages[k].append(msg)

self.assert_fetch_offset(0, start_offset0, messages[0])
self.assert_fetch_offset(1, start_offset1, messages[1])

producer.stop()

Expand Down Expand Up @@ -393,7 +401,7 @@ def test_async_keyed_producer(self):

producer = KeyedProducer(self.client, partitioner = RoundRobinPartitioner, async=True)

resp = producer.send(self.topic, "key1", self.msg("one"))
resp = producer.send(self.topic, self.key("key1"), self.msg("one"))
self.assertEquals(len(resp), 0)

self.assert_fetch_offset(0, start_offset0, [ self.msg("one") ])
Expand Down
4 changes: 4 additions & 0 deletions test/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def msg(self, s):

return self._messages[s].encode('utf-8')

def key(self, k):
return k.encode('utf-8')


class Timer(object):
def __enter__(self):
self.start = time.time()
Expand Down

0 comments on commit 3689529

Please sign in to comment.