Skip to content

Commit

Permalink
pool: support CountMany() using hyperloglog.
Browse files Browse the repository at this point in the history
  • Loading branch information
fiatjaf committed Nov 16, 2024
1 parent 99e4503 commit 0d40b40
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 19 deletions.
13 changes: 12 additions & 1 deletion envelopes.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nostr

import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"strconv"
Expand Down Expand Up @@ -142,7 +143,8 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) {
type CountEnvelope struct {
SubscriptionID string
Filters
Count *int64
Count *int64
HyperLogLog []byte
}

func (_ CountEnvelope) Label() string { return "COUNT" }
Expand All @@ -161,9 +163,11 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error {

var countResult struct {
Count *int64 `json:"count"`
HLL string `json:"hll"`
}
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
v.Count = countResult.Count
v.HyperLogLog, _ = hex.DecodeString(countResult.HLL)
return nil
}

Expand All @@ -189,6 +193,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) {
if v.Count != nil {
w.RawString(`,{"count":`)
w.RawString(strconv.FormatInt(*v.Count, 10))
if v.HyperLogLog != nil {
w.RawString(`,"hll":"`)
hllHex := make([]byte, 0, 512)
hex.Encode(hllHex, v.HyperLogLog)
w.Buffer.AppendBytes(hllHex)
w.RawString(`"`)
}
w.RawString(`}`)
} else {
for _, filter := range v.Filters {
Expand Down
2 changes: 1 addition & 1 deletion nip45/helpers.go → nip45/hyperloglog/helpers.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package nip45
package hyperloglog

import (
"math"
Expand Down
20 changes: 10 additions & 10 deletions nip45/hll.go → nip45/hyperloglog/hll.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package nip45
package hyperloglog

import (
"encoding/binary"
Expand All @@ -18,13 +18,14 @@ func New() *HyperLogLog {
return hll
}

func (hll *HyperLogLog) Encode() string {
return hex.EncodeToString(hll.registers)
}

func (hll *HyperLogLog) Decode(enc string) error {
_, err := hex.Decode(hll.registers, []byte(enc))
return err
func (hll *HyperLogLog) GetRegisters() []byte { return hll.registers }
func (hll *HyperLogLog) SetRegisters(enc []byte) { hll.registers = enc }
func (hll *HyperLogLog) MergeRegisters(other []byte) {
for i, v := range other {
if v > hll.registers[i] {
hll.registers[i] = v
}
}
}

func (hll *HyperLogLog) Clear() {
Expand All @@ -45,13 +46,12 @@ func (hll *HyperLogLog) Add(id string) {
}
}

func (hll *HyperLogLog) Merge(other *HyperLogLog) error {
func (hll *HyperLogLog) Merge(other *HyperLogLog) {
for i, v := range other.registers {
if v > hll.registers[i] {
hll.registers[i] = v
}
}
return nil
}

func (hll *HyperLogLog) Count() uint64 {
Expand Down
2 changes: 1 addition & 1 deletion nip45/hll_test.go → nip45/hyperloglog/hll_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package nip45
package hyperloglog

import (
"encoding/hex"
Expand Down
34 changes: 34 additions & 0 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"time"

"github.com/nbd-wtf/go-nostr/nip45/hyperloglog"
"github.com/puzpuzpuz/xsync/v3"
)

Expand Down Expand Up @@ -468,6 +469,39 @@ func (pool *SimplePool) subManyEose(
return events
}

// CountMany aggregates count results from multiple relays using HyperLogLog
func (pool *SimplePool) CountMany(
ctx context.Context,
urls []string,
filter Filter,
opts []SubscriptionOption,
) int {
hll := hyperloglog.New()

wg := sync.WaitGroup{}
wg.Add(len(urls))
for _, url := range urls {
go func(nm string) {
defer wg.Done()
relay, err := pool.EnsureRelay(url)
if err != nil {
return
}
ce, err := relay.countInternal(ctx, Filters{filter}, opts...)
if err != nil {
return
}
if len(ce.HyperLogLog) != 256 {
return
}
hll.MergeRegisters(ce.HyperLogLog)
}(NormalizeURL(url))
}

wg.Wait()
return int(hll.Count())
}

// QuerySingle returns the first event returned by the first relay, cancels everything else.
func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent {
ctx, cancel := context.WithCancel(ctx)
Expand Down
16 changes: 12 additions & 4 deletions relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
}
case *CountEnvelope:
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
subscription.countResult <- *env.Count
subscription.countResult <- *env
}
case *OKEnvelope:
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
Expand Down Expand Up @@ -478,11 +478,19 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error)
}

func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
v, err := r.countInternal(ctx, filters, opts...)
if err != nil {
return 0, err
}
return *v.Count, nil
}

func (r *Relay) countInternal(ctx context.Context, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
sub := r.PrepareSubscription(ctx, filters, opts...)
sub.countResult = make(chan int64)
sub.countResult = make(chan CountEnvelope)

if err := sub.Fire(); err != nil {
return 0, err
return CountEnvelope{}, err
}

defer sub.Unsub()
Expand All @@ -499,7 +507,7 @@ func (r *Relay) Count(ctx context.Context, filters Filters, opts ...Subscription
case count := <-sub.countResult:
return count, nil
case <-ctx.Done():
return 0, ctx.Err()
return CountEnvelope{}, ctx.Err()
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Subscription struct {
Filters Filters

// for this to be treated as a COUNT and not a REQ this must be set
countResult chan int64
countResult chan CountEnvelope

// the Events channel emits all EVENTs that come in a Subscription
// will be closed when the subscription ends
Expand Down Expand Up @@ -152,7 +152,7 @@ func (sub *Subscription) Fire() error {
if sub.countResult == nil {
reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON()
} else {
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil}.MarshalJSON()
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil, nil}.MarshalJSON()
}

sub.live.Store(true)
Expand Down

0 comments on commit 0d40b40

Please sign in to comment.