diff --git a/common/component/postgresql/postgresql_query.go b/common/component/postgresql/postgresql_query.go index 3a59e8ff55..8dd4dc4d15 100644 --- a/common/component/postgresql/postgresql_query.go +++ b/common/component/postgresql/postgresql_query.go @@ -40,6 +40,46 @@ func (q *Query) VisitEQ(f *query.EQ) (string, error) { return q.whereFieldEqual(f.Key, f.Val), nil } +func (q *Query) VisitNEQ(f *query.NEQ) (string, error) { + return q.whereFieldNotEqual(f.Key, f.Val), nil +} + +func (q *Query) VisitGT(f *query.GT) (string, error) { + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", v) + default: + return q.whereFieldGreaterThan(f.Key, v), nil + } +} + +func (q *Query) VisitGTE(f *query.GTE) (string, error) { + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", v) + default: + return q.whereFieldGreaterThanEqual(f.Key, v), nil + } +} + +func (q *Query) VisitLT(f *query.LT) (string, error) { + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", v) + default: + return q.whereFieldLessThan(f.Key, v), nil + } +} + +func (q *Query) VisitLTE(f *query.LTE) (string, error) { + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", v) + default: + return q.whereFieldLessThanEqual(f.Key, v), nil + } +} + func (q *Query) VisitIN(f *query.IN) (string, error) { if len(f.Vals) == 0 { return "", fmt.Errorf("empty IN operator for key %q", f.Key) @@ -70,6 +110,31 @@ func (q *Query) visitFilters(op string, filters []query.Filter) (string, error) return "", err } arr = append(arr, str) + case *query.NEQ: + if str, err = q.VisitNEQ(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.GT: + if str, err = q.VisitGT(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.GTE: + if str, err = q.VisitGTE(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.LT: + if str, err = q.VisitLT(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.LTE: + if str, err = q.VisitLTE(f); err != nil { + return "", err + } + arr = append(arr, str) case *query.IN: if str, err = q.VisitIN(f); err != nil { return "", err @@ -214,3 +279,38 @@ func (q *Query) whereFieldEqual(key string, value interface{}) string { query := filterField + "=$" + strconv.Itoa(position) return query } + +func (q *Query) whereFieldNotEqual(key string, value interface{}) string { + position := q.addParamValueAndReturnPosition(value) + filterField := translateFieldToFilter(key) + query := filterField + "!=$" + strconv.Itoa(position) + return query +} + +func (q *Query) whereFieldGreaterThan(key string, value interface{}) string { + position := q.addParamValueAndReturnPosition(value) + filterField := translateFieldToFilter(key) + query := filterField + ">$" + strconv.Itoa(position) + return query +} + +func (q *Query) whereFieldGreaterThanEqual(key string, value interface{}) string { + position := q.addParamValueAndReturnPosition(value) + filterField := translateFieldToFilter(key) + query := filterField + ">=$" + strconv.Itoa(position) + return query +} + +func (q *Query) whereFieldLessThan(key string, value interface{}) string { + position := q.addParamValueAndReturnPosition(value) + filterField := translateFieldToFilter(key) + query := filterField + "<$" + strconv.Itoa(position) + return query +} + +func (q *Query) whereFieldLessThanEqual(key string, value interface{}) string { + position := q.addParamValueAndReturnPosition(value) + filterField := translateFieldToFilter(key) + query := filterField + "<=$" + strconv.Itoa(position) + return query +} diff --git a/common/component/postgresql/postgresql_query_test.go b/common/component/postgresql/postgresql_query_test.go index ac88e211c5..1ddd78ac7e 100644 --- a/common/component/postgresql/postgresql_query_test.go +++ b/common/component/postgresql/postgresql_query_test.go @@ -49,10 +49,18 @@ func TestPostgresqlQueryBuildQuery(t *testing.T) { input: "../../../tests/state/query/q4.json", query: "SELECT key, value, xmin as etag FROM state WHERE (value->'person'->>'org'=$1 OR (value->'person'->>'org'=$2 AND (value->>'state'=$3 OR value->>'state'=$4))) ORDER BY value->>'state' DESC, value->'person'->>'name' LIMIT 2", }, + { + input: "../../../tests/state/query/q4-notequal.json", + query: "SELECT key, value, xmin as etag FROM state WHERE (value->'person'->>'org'=$1 OR (value->'person'->>'org'!=$2 AND (value->>'state'=$3 OR value->>'state'=$4))) ORDER BY value->>'state' DESC, value->'person'->>'name' LIMIT 2", + }, { input: "../../../tests/state/query/q5.json", query: "SELECT key, value, xmin as etag FROM state WHERE (value->'person'->>'org'=$1 AND (value->'person'->>'name'=$2 OR (value->>'state'=$3 OR value->>'state'=$4))) ORDER BY value->>'state' DESC, value->'person'->>'name' LIMIT 2", }, + { + input: "../../../tests/state/query/q8.json", + query: "SELECT key, value, xmin as etag FROM state WHERE (value->'person'->>'org'>=$1 OR (value->'person'->>'org'<$2 AND (value->>'state'=$3 OR value->>'state'=$4))) ORDER BY value->>'state' DESC, value->'person'->>'name' LIMIT 2", + }, } for _, test := range tests { data, err := os.ReadFile(test.input) diff --git a/pubsub/gcp/pubsub/pubsub.go b/pubsub/gcp/pubsub/pubsub.go index 41f6455bcd..38f65db5ab 100644 --- a/pubsub/gcp/pubsub/pubsub.go +++ b/pubsub/gcp/pubsub/pubsub.go @@ -54,9 +54,15 @@ type GCPPubSub struct { metadata *metadata logger logger.Logger - closed atomic.Bool - closeCh chan struct{} - wg sync.WaitGroup + closed atomic.Bool + closeCh chan struct{} + wg sync.WaitGroup + topicCache map[string]cacheEntry + lock *sync.RWMutex +} + +type cacheEntry struct { + LastSync time.Time } type GCPAuthJSON struct { @@ -76,9 +82,39 @@ type WhatNow struct { Type string `json:"type"` } +const topicCacheRefreshInterval = 5 * time.Hour + // NewGCPPubSub returns a new GCPPubSub instance. func NewGCPPubSub(logger logger.Logger) pubsub.PubSub { - return &GCPPubSub{logger: logger, closeCh: make(chan struct{})} + client := &GCPPubSub{ + logger: logger, + closeCh: make(chan struct{}), + topicCache: make(map[string]cacheEntry), + lock: &sync.RWMutex{}, + } + return client +} + +func (g *GCPPubSub) periodicCacheRefresh() { + // Run this loop 5 times every topicCacheRefreshInterval, to be able to delete items that are stale + ticker := time.NewTicker(topicCacheRefreshInterval / 5) + defer ticker.Stop() + + for { + select { + case <-g.closeCh: + return + case <-ticker.C: + g.lock.Lock() + for key, entry := range g.topicCache { + // Delete from the cache if the last sync was longer than topicCacheRefreshInterval + if time.Since(entry.LastSync) > topicCacheRefreshInterval { + delete(g.topicCache, key) + } + } + g.lock.Unlock() + } + } } func createMetadata(pubSubMetadata pubsub.Metadata) (*metadata, error) { @@ -110,6 +146,12 @@ func (g *GCPPubSub) Init(ctx context.Context, meta pubsub.Metadata) error { return err } + g.wg.Add(1) + go func() { + defer g.wg.Done() + g.periodicCacheRefresh() + }() + pubsubClient, err := g.getPubSubClient(ctx, metadata) if err != nil { return fmt.Errorf("%s error creating pubsub client: %w", errorMessagePrefix, err) @@ -174,12 +216,22 @@ func (g *GCPPubSub) Publish(ctx context.Context, req *pubsub.PublishRequest) err if g.closed.Load() { return errors.New("component is closed") } + g.lock.RLock() + _, topicExists := g.topicCache[req.Topic] + g.lock.RUnlock() - if !g.metadata.DisableEntityManagement { + // We are not acquiring a write lock before calling ensureTopic, so there's the chance that ensureTopic be called multiple time + // This is acceptable in our case, even is slightly wasteful, as ensureTopic is idempotent + if !g.metadata.DisableEntityManagement && !topicExists { err := g.ensureTopic(ctx, req.Topic) if err != nil { - return fmt.Errorf("%s could not get valid topic %s, %s", errorMessagePrefix, req.Topic, err) + return fmt.Errorf("%s could not get valid topic %s: %w", errorMessagePrefix, req.Topic, err) + } + g.lock.Lock() + g.topicCache[req.Topic] = cacheEntry{ + LastSync: time.Now(), } + g.lock.Unlock() } topic := g.getTopic(req.Topic) @@ -210,12 +262,22 @@ func (g *GCPPubSub) Subscribe(parentCtx context.Context, req pubsub.SubscribeReq if g.closed.Load() { return errors.New("component is closed") } + g.lock.RLock() + _, topicExists := g.topicCache[req.Topic] + g.lock.RUnlock() - if !g.metadata.DisableEntityManagement { + // We are not acquiring a write lock before calling ensureTopic, so there's the chance that ensureTopic be called multiple times + // This is acceptable in our case, even is slightly wasteful, as ensureTopic is idempotent + if !g.metadata.DisableEntityManagement && !topicExists { topicErr := g.ensureTopic(parentCtx, req.Topic) if topicErr != nil { - return fmt.Errorf("%s could not get valid topic - topic:%q, error: %v", errorMessagePrefix, req.Topic, topicErr) + return fmt.Errorf("%s could not get valid topic - topic:%q, error: %w", errorMessagePrefix, req.Topic, topicErr) + } + g.lock.Lock() + g.topicCache[req.Topic] = cacheEntry{ + LastSync: time.Now(), } + g.lock.Unlock() subError := g.ensureSubscription(parentCtx, g.metadata.ConsumerID, req.Topic) if subError != nil { @@ -354,9 +416,24 @@ func (g *GCPPubSub) getTopic(topic string) *gcppubsub.Topic { } func (g *GCPPubSub) ensureSubscription(parentCtx context.Context, subscription string, topic string) error { - err := g.ensureTopic(parentCtx, topic) - if err != nil { - return err + g.lock.RLock() + _, topicOK := g.topicCache[topic] + _, dlTopicOK := g.topicCache[g.metadata.DeadLetterTopic] + g.lock.RUnlock() + if !topicOK { + g.lock.Lock() + // Double-check if the topic still doesn't exist to avoid race condition + if _, ok := g.topicCache[topic]; !ok { + err := g.ensureTopic(parentCtx, topic) + if err != nil { + g.lock.Unlock() + return err + } + g.topicCache[topic] = cacheEntry{ + LastSync: time.Now(), + } + } + g.lock.Unlock() } managedSubscription := subscription + "-" + topic @@ -369,11 +446,20 @@ func (g *GCPPubSub) ensureSubscription(parentCtx context.Context, subscription s EnableMessageOrdering: g.metadata.EnableMessageOrdering, } - if g.metadata.DeadLetterTopic != "" { - subErr = g.ensureTopic(parentCtx, g.metadata.DeadLetterTopic) - if subErr != nil { - return subErr + if g.metadata.DeadLetterTopic != "" && !dlTopicOK { + g.lock.Lock() + // Double-check if the DeadLetterTopic still doesn't exist to avoid race condition + if _, ok := g.topicCache[g.metadata.DeadLetterTopic]; !ok { + subErr = g.ensureTopic(parentCtx, g.metadata.DeadLetterTopic) + if subErr != nil { + g.lock.Unlock() + return subErr + } + g.topicCache[g.metadata.DeadLetterTopic] = cacheEntry{ + LastSync: time.Now(), + } } + g.lock.Unlock() dlTopic := fmt.Sprintf("projects/%s/topics/%s", g.metadata.ProjectID, g.metadata.DeadLetterTopic) subConfig.DeadLetterPolicy = &gcppubsub.DeadLetterPolicy{ DeadLetterTopic: dlTopic, diff --git a/state/README.md b/state/README.md index 4000783edb..969d164c1f 100644 --- a/state/README.md +++ b/state/README.md @@ -80,6 +80,31 @@ type FilterEQ struct { Val interface{} } +type FilterNEQ struct { + Key string + Val interface{} +} + +type FilterGT struct { + Key string + Val interface{} +} + +type FilterGTE struct { + Key string + Val interface{} +} + +type FilterLT struct { + Key string + Val interface{} +} + +type FilterLTE struct { + Key string + Val interface{} +} + type FilterIN struct { Key string Vals []interface{} @@ -100,6 +125,16 @@ To simplify the process of query translation, we leveraged [visitor design patte type Visitor interface { // returns "equal" expression VisitEQ(*FilterEQ) (string, error) + // returns "not equal" expression + VisitNEQ(*FilterNEQ) (string, error) + // returns "greater than" expression + VisitGT(*FilterGT) (string, error) + // returns "greater than equal" expression + VisitGTE(*FilterGTE) (string, error) + // returns "less than" expression + VisitLT(*FilterLT) (string, error) + // returns "less than equal" expression + VisitLTE(*FilterLTE) (string, error) // returns "in" expression VisitIN(*FilterIN) (string, error) // returns "and" expression @@ -152,4 +187,4 @@ func (m *MyComponent) Query(req *state.QueryRequest) (*state.QueryResponse, erro } ``` -Some of the examples of State Query API implementation are [MongoDB](./mongodb/mongodb_query.go) and [CosmosDB](./azure/cosmosdb/cosmosdb_query.go) state store components. +Some of the examples of State Query API implementation are [Redis](./redis/redis_query.go), [MongoDB](./mongodb/mongodb_query.go) and [CosmosDB](./azure/cosmosdb/cosmosdb_query.go) state store components. diff --git a/state/azure/cosmosdb/cosmosdb_query.go b/state/azure/cosmosdb/cosmosdb_query.go index dfe8b1fa3e..fa8ed11d52 100644 --- a/state/azure/cosmosdb/cosmosdb_query.go +++ b/state/azure/cosmosdb/cosmosdb_query.go @@ -50,6 +50,76 @@ func (q *Query) VisitEQ(f *query.EQ) (string, error) { return replaceKeywords("c.value."+f.Key) + " = " + name, nil } +func (q *Query) VisitNEQ(f *query.NEQ) (string, error) { + // != + val, ok := f.Val.(string) + if !ok { + return "", fmt.Errorf("unsupported type of value %#v; expected string", f.Val) + } + name := q.setNextParameter(val) + + return replaceKeywords("c.value."+f.Key) + " != " + name, nil +} + +func (q *Query) VisitGT(f *query.GT) (string, error) { + // > + var name string + switch value := f.Val.(type) { + case int: + name = q.setNextParameterInt(value) + case float64: + name = q.setNextParameterFloat(value) + default: + return "", fmt.Errorf("unsupported type of value %#v; expected number", f.Val) + } + return replaceKeywords("c.value."+f.Key) + " > " + name, nil +} + +func (q *Query) VisitGTE(f *query.GTE) (string, error) { + // >= + var name string + switch value := f.Val.(type) { + case int: + name = q.setNextParameterInt(value) + case float64: + name = q.setNextParameterFloat(value) + default: + return "", fmt.Errorf("unsupported type of value %#v; expected number", f.Val) + } + + return replaceKeywords("c.value."+f.Key) + " >= " + name, nil +} + +func (q *Query) VisitLT(f *query.LT) (string, error) { + // < + var name string + switch value := f.Val.(type) { + case int: + name = q.setNextParameterInt(value) + case float64: + name = q.setNextParameterFloat(value) + default: + return "", fmt.Errorf("unsupported type of value %#v; expected number", f.Val) + } + + return replaceKeywords("c.value."+f.Key) + " < " + name, nil +} + +func (q *Query) VisitLTE(f *query.LTE) (string, error) { + // <= + var name string + switch value := f.Val.(type) { + case int: + name = q.setNextParameterInt(value) + case float64: + name = q.setNextParameterFloat(value) + default: + return "", fmt.Errorf("unsupported type of value %#v; expected number", f.Val) + } + + return replaceKeywords("c.value."+f.Key) + " <= " + name, nil +} + func (q *Query) VisitIN(f *query.IN) (string, error) { // IN ( , , ... , ) if len(f.Vals) == 0 { @@ -80,6 +150,31 @@ func (q *Query) visitFilters(op string, filters []query.Filter) (string, error) return "", err } arr = append(arr, str) + case *query.NEQ: + if str, err = q.VisitNEQ(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.GT: + if str, err = q.VisitGT(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.GTE: + if str, err = q.VisitGTE(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.LT: + if str, err = q.VisitLT(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.LTE: + if str, err = q.VisitLTE(f); err != nil { + return "", err + } + arr = append(arr, str) case *query.IN: if str, err = q.VisitIN(f); err != nil { return "", err @@ -144,6 +239,20 @@ func (q *Query) setNextParameter(val string) string { return pname } +func (q *Query) setNextParameterInt(val int) string { + pname := fmt.Sprintf("@__param__%d__", len(q.query.parameters)) + q.query.parameters = append(q.query.parameters, azcosmos.QueryParameter{Name: pname, Value: val}) + + return pname +} + +func (q *Query) setNextParameterFloat(val float64) string { + pname := fmt.Sprintf("@__param__%d__", len(q.query.parameters)) + q.query.parameters = append(q.query.parameters, azcosmos.QueryParameter{Name: pname, Value: val}) + + return pname +} + func (q *Query) execute(ctx context.Context, client *azcosmos.ContainerClient) ([]state.QueryItem, string, error) { opts := &azcosmos.QueryOptions{} diff --git a/state/azure/cosmosdb/cosmosdb_query_test.go b/state/azure/cosmosdb/cosmosdb_query_test.go index fc44c0c1f7..1649929118 100644 --- a/state/azure/cosmosdb/cosmosdb_query_test.go +++ b/state/azure/cosmosdb/cosmosdb_query_test.go @@ -126,6 +126,54 @@ func TestCosmosDbQuery(t *testing.T) { }, }, }, + { + input: "../../../tests/state/query/q4-notequal.json", + query: InternalQuery{ + query: "SELECT * FROM c WHERE c['value']['person']['org'] = @__param__0__ OR (c['value']['person']['org'] != @__param__1__ AND c['value']['state'] IN (@__param__2__, @__param__3__)) ORDER BY c['value']['state'] DESC, c['value']['person']['name'] ASC", + parameters: []azcosmos.QueryParameter{ + { + Name: "@__param__0__", + Value: "A", + }, + { + Name: "@__param__1__", + Value: "B", + }, + { + Name: "@__param__2__", + Value: "CA", + }, + { + Name: "@__param__3__", + Value: "WA", + }, + }, + }, + }, + { + input: "../../../tests/state/query/q8.json", + query: InternalQuery{ + query: "SELECT * FROM c WHERE c['value']['person']['org'] >= @__param__0__ OR (c['value']['person']['org'] < @__param__1__ AND c['value']['state'] IN (@__param__2__, @__param__3__)) ORDER BY c['value']['state'] DESC, c['value']['person']['name'] ASC", + parameters: []azcosmos.QueryParameter{ + { + Name: "@__param__0__", + Value: 123.0, + }, + { + Name: "@__param__1__", + Value: 10.0, + }, + { + Name: "@__param__2__", + Value: "CA", + }, + { + Name: "@__param__3__", + Value: "WA", + }, + }, + }, + }, } for _, test := range tests { data, err := os.ReadFile(test.input) diff --git a/state/mongodb/mongodb_query.go b/state/mongodb/mongodb_query.go index 66ac899134..27d132eed2 100644 --- a/state/mongodb/mongodb_query.go +++ b/state/mongodb/mongodb_query.go @@ -45,6 +45,56 @@ func (q *Query) VisitEQ(f *query.EQ) (string, error) { } } +func (q *Query) VisitNEQ(f *query.NEQ) (string, error) { + // { : } + switch v := f.Val.(type) { + case string: + return fmt.Sprintf(`{ "value.%s": {"$ne": %q} }`, f.Key, v), nil + default: + return fmt.Sprintf(`{ "value.%s": {"$ne": %v} }`, f.Key, v), nil + } +} + +func (q *Query) VisitGT(f *query.GT) (string, error) { + // { : } + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf(`{ "value.%s": {"$gt": %v} }`, f.Key, v), nil + } +} + +func (q *Query) VisitGTE(f *query.GTE) (string, error) { + // { : } + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf(`{ "value.%s": {"$gte": %v} }`, f.Key, v), nil + } +} + +func (q *Query) VisitLT(f *query.LT) (string, error) { + // { : } + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf(`{ "value.%s": {"$lt": %v} }`, f.Key, v), nil + } +} + +func (q *Query) VisitLTE(f *query.LTE) (string, error) { + // { : } + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf(`{ "value.%s": {"$lte": %v} }`, f.Key, v), nil + } +} + func (q *Query) VisitIN(f *query.IN) (string, error) { // { $in: [ , , ... , ] } if len(f.Vals) == 0 { @@ -81,6 +131,31 @@ func (q *Query) visitFilters(op string, filters []query.Filter) (string, error) return "", err } arr = append(arr, str) + case *query.NEQ: + if str, err = q.VisitNEQ(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.GT: + if str, err = q.VisitGT(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.GTE: + if str, err = q.VisitGTE(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.LT: + if str, err = q.VisitLT(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.LTE: + if str, err = q.VisitLTE(f); err != nil { + return "", err + } + arr = append(arr, str) case *query.IN: if str, err = q.VisitIN(f); err != nil { return "", err diff --git a/state/mongodb/mongodb_query_test.go b/state/mongodb/mongodb_query_test.go index 23de063ecd..814a45ede4 100644 --- a/state/mongodb/mongodb_query_test.go +++ b/state/mongodb/mongodb_query_test.go @@ -49,6 +49,14 @@ func TestMongoQuery(t *testing.T) { input: "../../tests/state/query/q6.json", query: `{ "$or": [ { "value.person.id": 123 }, { "$and": [ { "value.person.org": "B" }, { "value.person.id": { "$in": [ 567, 890 ] } } ] } ] }`, }, + { + input: "../../tests/state/query/q6-notequal.json", + query: `{ "$or": [ { "value.person.id": 123 }, { "$and": [ { "value.person.org": {"$ne": "B"} }, { "value.person.id": { "$in": [ 567, 890 ] } } ] } ] }`, + }, + { + input: "../../tests/state/query/q7.json", + query: `{ "$or": [ { "value.person.id": {"$lt": 123} }, { "$and": [ { "value.person.org": {"$gte": 2} }, { "value.person.id": { "$in": [ 567, 890 ] } } ] } ] }`, + }, } for _, test := range tests { data, err := os.ReadFile(test.input) diff --git a/state/query/filter.go b/state/query/filter.go index 4961f74830..2b0e64f8cb 100644 --- a/state/query/filter.go +++ b/state/query/filter.go @@ -36,6 +36,31 @@ func ParseFilter(obj interface{}) (Filter, error) { f := &EQ{} err := f.Parse(v) + return f, err + case "NEQ": + f := &NEQ{} + err := f.Parse(v) + + return f, err + case "GT": + f := >{} + err := f.Parse(v) + + return f, err + case "GTE": + f := >E{} + err := f.Parse(v) + + return f, err + case "LT": + f := <{} + err := f.Parse(v) + + return f, err + case "LTE": + f := <E{} + err := f.Parse(v) + return f, err case "IN": f := &IN{} @@ -81,6 +106,111 @@ func (f *EQ) Parse(obj interface{}) error { return nil } +type NEQ struct { + Key string + Val interface{} +} + +func (f *NEQ) Parse(obj interface{}) error { + m, ok := obj.(map[string]interface{}) + if !ok { + return fmt.Errorf("NEQ filter must be a map") + } + if len(m) != 1 { + return fmt.Errorf("NEQ filter must contain a single key/value pair") + } + for k, v := range m { + f.Key = k + f.Val = v + } + + return nil +} + +type GT struct { + Key string + Val interface{} +} + +func (f *GT) Parse(obj interface{}) error { + m, ok := obj.(map[string]interface{}) + if !ok { + return fmt.Errorf("GT filter must be a map") + } + if len(m) != 1 { + return fmt.Errorf("GT filter must contain a single key/value pair") + } + for k, v := range m { + f.Key = k + f.Val = v + } + + return nil +} + +type GTE struct { + Key string + Val interface{} +} + +func (f *GTE) Parse(obj interface{}) error { + m, ok := obj.(map[string]interface{}) + if !ok { + return fmt.Errorf("GTE filter must be a map") + } + if len(m) != 1 { + return fmt.Errorf("GTE filter must contain a single key/value pair") + } + for k, v := range m { + f.Key = k + f.Val = v + } + + return nil +} + +type LT struct { + Key string + Val interface{} +} + +func (f *LT) Parse(obj interface{}) error { + m, ok := obj.(map[string]interface{}) + if !ok { + return fmt.Errorf("LT filter must be a map") + } + if len(m) != 1 { + return fmt.Errorf("LT filter must contain a single key/value pair") + } + for k, v := range m { + f.Key = k + f.Val = v + } + + return nil +} + +type LTE struct { + Key string + Val interface{} +} + +func (f *LTE) Parse(obj interface{}) error { + m, ok := obj.(map[string]interface{}) + if !ok { + return fmt.Errorf("LTE filter must be a map") + } + if len(m) != 1 { + return fmt.Errorf("LTE filter must contain a single key/value pair") + } + for k, v := range m { + f.Key = k + f.Val = v + } + + return nil +} + type IN struct { Key string Vals []interface{} diff --git a/state/query/query.go b/state/query/query.go index ea3d874afe..663ad18a3f 100644 --- a/state/query/query.go +++ b/state/query/query.go @@ -53,6 +53,16 @@ type Query struct { type Visitor interface { // returns "equal" expression VisitEQ(*EQ) (string, error) + // returns "not equal" expression + VisitNEQ(*NEQ) (string, error) + // returns "greater than" expression + VisitGT(*GT) (string, error) + // returns "greater than equal" expression + VisitGTE(*GTE) (string, error) + // returns "less than" expression + VisitLT(*LT) (string, error) + // returns "less than equal" expression + VisitLTE(*LTE) (string, error) // returns "in" expression VisitIN(*IN) (string, error) // returns "and" expression @@ -89,6 +99,16 @@ func (h *Builder) buildFilter(filter Filter) (string, error) { switch f := filter.(type) { case *EQ: return h.visitor.VisitEQ(f) + case *NEQ: + return h.visitor.VisitNEQ(f) + case *GT: + return h.visitor.VisitGT(f) + case *GTE: + return h.visitor.VisitGTE(f) + case *LT: + return h.visitor.VisitLT(f) + case *LTE: + return h.visitor.VisitLTE(f) case *IN: return h.visitor.VisitIN(f) case *OR: diff --git a/state/redis/redis_query.go b/state/redis/redis_query.go index a61f77afa6..a5e679f167 100644 --- a/state/redis/redis_query.go +++ b/state/redis/redis_query.go @@ -66,6 +66,81 @@ func (q *Query) VisitEQ(f *query.EQ) (string, error) { } } +func (q *Query) VisitNEQ(f *query.NEQ) (string, error) { + // string: @:() + // numeric: @:[ ] + alias, err := q.getAlias(f.Key) + if err != nil { + return "", err + } + + switch v := f.Val.(type) { + case string: + return fmt.Sprintf("@%s:(%s)", alias, v), nil + default: + return fmt.Sprintf("@%s:[%v %v]", alias, v, v), nil + } +} + +func (q *Query) VisitGT(f *query.GT) (string, error) { + // numeric: @:[( +inf] + alias, err := q.getAlias(f.Key) + if err != nil { + return "", err + } + + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf("@%s:[(%v +inf]", alias, v), nil + } +} + +func (q *Query) VisitGTE(f *query.GTE) (string, error) { + // numeric: @:[ +inf] + alias, err := q.getAlias(f.Key) + if err != nil { + return "", err + } + + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf("@%s:[%v +inf]", alias, v), nil + } +} + +func (q *Query) VisitLT(f *query.LT) (string, error) { + // numeric: @:[-inf )] + alias, err := q.getAlias(f.Key) + if err != nil { + return "", err + } + + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf("@%s:[-inf (%v]", alias, v), nil + } +} + +func (q *Query) VisitLTE(f *query.LTE) (string, error) { + // numeric: @:[-inf ] + alias, err := q.getAlias(f.Key) + if err != nil { + return "", err + } + switch v := f.Val.(type) { + case string: + return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val) + default: + return fmt.Sprintf("@%s:[-inf %v]", alias, v), nil + } +} + func (q *Query) VisitIN(f *query.IN) (string, error) { // string: @:(|...) // numeric: replace with OR @@ -116,6 +191,31 @@ func (q *Query) visitFilters(op string, filters []query.Filter) (string, error) return "", err } arr = append(arr, fmt.Sprintf("(%s)", str)) + case *query.NEQ: + if str, err = q.VisitNEQ(f); err != nil { + return "", err + } + arr = append(arr, fmt.Sprintf("-(%s)", str)) + case *query.GT: + if str, err = q.VisitGT(f); err != nil { + return "", err + } + arr = append(arr, fmt.Sprintf("(%s)", str)) + case *query.GTE: + if str, err = q.VisitGTE(f); err != nil { + return "", err + } + arr = append(arr, fmt.Sprintf("(%s)", str)) + case *query.LT: + if str, err = q.VisitLT(f); err != nil { + return "", err + } + arr = append(arr, fmt.Sprintf("(%s)", str)) + case *query.LTE: + if str, err = q.VisitLTE(f); err != nil { + return "", err + } + arr = append(arr, fmt.Sprintf("(%s)", str)) case *query.IN: if str, err = q.VisitIN(f); err != nil { return "", err diff --git a/state/redis/redis_query_test.go b/state/redis/redis_query_test.go index cc503e62d7..baec62bbf3 100644 --- a/state/redis/redis_query_test.go +++ b/state/redis/redis_query_test.go @@ -46,6 +46,14 @@ func TestMongoQuery(t *testing.T) { input: "../../tests/state/query/q6.json", query: []interface{}{"((@id:[123 123])|((@org:(B)) (((@id:[567 567])|(@id:[890 890])))))", "SORTBY", "id", "LIMIT", "0", "2"}, }, + { + input: "../../tests/state/query/q6-notequal.json", + query: []interface{}{"((@id:[123 123])|(-(@org:(B)) (((@id:[567 567])|(@id:[890 890])))))", "SORTBY", "id", "LIMIT", "0", "2"}, + }, + { + input: "../../tests/state/query/q7.json", + query: []interface{}{"((@id:[-inf (123])|((@org:[2 +inf]) (((@id:[567 567])|(@id:[890 890])))))", "SORTBY", "id", "LIMIT", "0", "2"}, + }, } for _, test := range tests { data, err := os.ReadFile(test.input) diff --git a/state/sqlite/sqlite_dbaccess.go b/state/sqlite/sqlite_dbaccess.go index d01bb1de47..1e45f0b3cf 100644 --- a/state/sqlite/sqlite_dbaccess.go +++ b/state/sqlite/sqlite_dbaccess.go @@ -84,12 +84,15 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error { return err } - db, err := sql.Open("sqlite", connString) + a.db, err = sql.Open("sqlite", connString) if err != nil { return fmt.Errorf("failed to create connection: %w", err) } - a.db = db + // If the database is in-memory, we can't have more than 1 open connection + if a.metadata.IsInMemoryDB() { + a.db.SetMaxOpenConns(1) + } err = a.Ping(ctx) if err != nil { diff --git a/tests/config/state/redis/v6/statestore.yaml b/tests/config/state/redis/v6/statestore.yaml index f8e6afd837..a9e72a7906 100644 --- a/tests/config/state/redis/v6/statestore.yaml +++ b/tests/config/state/redis/v6/statestore.yaml @@ -18,6 +18,14 @@ spec: { "key": "message", "type": "TEXT" + }, + { + "key": "product.value", + "type": "NUMERIC" + }, + { + "key": "status", + "type": "TEXT" } ] } diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 8cc279e0ee..95aaf76474 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -40,6 +40,13 @@ type ValueType struct { Message string `json:"message"` } +type StructType struct { + Product struct { + Value int `json:"value"` + } `json:"product"` + Status string `json:"status"` +} + type intValueType struct { Message int32 `json:"message"` } @@ -119,6 +126,20 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St value: ValueType{Message: fmt.Sprintf("test%s", key)}, contentType: contenttype.JSONContentType, }, + { + key: fmt.Sprintf("%s-struct-operations", key), + value: StructType{Product: struct { + Value int `json:"value"` + }{Value: 15}, Status: "ACTIVE"}, + contentType: contenttype.JSONContentType, + }, + { + key: fmt.Sprintf("%s-struct-operations-inactive", key), + value: StructType{Product: struct { + Value int `json:"value"` + }{Value: 12}, Status: "INACTIVE"}, + contentType: contenttype.JSONContentType, + }, { key: fmt.Sprintf("%s-struct-with-int", key), value: intValueType{Message: 42}, @@ -235,6 +256,67 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St }, }, }, + { + query: ` + { + "filter": { + "AND": [ + { + "GTE": {"product.value": 10} + }, + { + "LT": {"product.value": 20} + }, + { + "NEQ": {"status": "INACTIVE"} + } + ] + } + } + `, + results: []state.QueryItem{ + { + Key: fmt.Sprintf("%s-struct-operations", key), + Data: []byte(fmt.Sprintf(`{"product":{"value":15}, "status":"ACTIVE"}`)), + }, + }, + }, + { + query: ` + { + "filter": { + "OR": [ + { + "AND": [ + { + "GT": {"product.value": 11.1} + }, + { + "EQ": {"status": "INACTIVE"} + } + ] + }, + { + "AND": [ + { + "LTE": {"product.value": 0.5} + }, + { + "EQ": {"status": "ACTIVE"} + } + ] + } + ] + } + } + `, + results: []state.QueryItem{ + { + Key: fmt.Sprintf("%s-struct-operations-inactive", key), + Data: []byte(fmt.Sprintf(`{"product":{"value":12}, "status":"INACTIVE"}`)), + }, + }, + }, } t.Run("init", func(t *testing.T) { @@ -312,6 +394,7 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St metadata.ContentType: contenttype.JSONContentType, metadata.QueryIndexName: "qIndx", } + resp, err := querier.Query(context.Background(), &req) require.NoError(t, err) assert.Equal(t, len(scenario.results), len(resp.Results)) @@ -1241,6 +1324,12 @@ func assertDataEquals(t *testing.T, expect any, actual []byte) { assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual)) } assert.Equal(t, expect, v) + case StructType: + // Custom type requires case mapping + if err := json.Unmarshal(actual, &v); err != nil { + assert.Failf(t, "unmarshal error", "error: %v, json: %s", err, string(actual)) + } + assert.Equal(t, expect, v) case int: // json.Unmarshal to float64 by default, case mapping to int coerces to int type if err := json.Unmarshal(actual, &v); err != nil { diff --git a/tests/state/query/q4-notequal.json b/tests/state/query/q4-notequal.json new file mode 100644 index 0000000000..bd55c29e15 --- /dev/null +++ b/tests/state/query/q4-notequal.json @@ -0,0 +1,37 @@ +{ + "filter": { + "OR": [ + { + "EQ": { + "person.org": "A" + } + }, + { + "AND": [ + { + "NEQ": { + "person.org": "B" + } + }, + { + "IN": { + "state": ["CA", "WA"] + } + } + ] + } + ] + }, + "sort": [ + { + "key": "state", + "order": "DESC" + }, + { + "key": "person.name" + } + ], + "page": { + "limit": 2 + } +} diff --git a/tests/state/query/q6-notequal.json b/tests/state/query/q6-notequal.json new file mode 100644 index 0000000000..f601a7f675 --- /dev/null +++ b/tests/state/query/q6-notequal.json @@ -0,0 +1,33 @@ +{ + "filter": { + "OR": [ + { + "EQ": { + "person.id": 123 + } + }, + { + "AND": [ + { + "NEQ": { + "person.org": "B" + } + }, + { + "IN": { + "person.id": [567, 890] + } + } + ] + } + ] + }, + "sort": [ + { + "key": "person.id" + } + ], + "page": { + "limit": 2 + } +} diff --git a/tests/state/query/q7.json b/tests/state/query/q7.json new file mode 100644 index 0000000000..86f87c6d57 --- /dev/null +++ b/tests/state/query/q7.json @@ -0,0 +1,36 @@ +{ + "filter": { + "OR": [ + { + "LT": { + "person.id": 123 + } + }, + { + "AND": [ + { + "GTE": { + "person.org": 2 + } + }, + { + "IN": { + "person.id": [ + 567, + 890 + ] + } + } + ] + } + ] + }, + "sort": [ + { + "key": "person.id" + } + ], + "page": { + "limit": 2 + } +} diff --git a/tests/state/query/q8.json b/tests/state/query/q8.json new file mode 100644 index 0000000000..e9fb2dab0f --- /dev/null +++ b/tests/state/query/q8.json @@ -0,0 +1,40 @@ +{ + "filter": { + "OR": [ + { + "GTE": { + "person.org": 123 + } + }, + { + "AND": [ + { + "LT": { + "person.org": 10 + } + }, + { + "IN": { + "state": [ + "CA", + "WA" + ] + } + } + ] + } + ] + }, + "sort": [ + { + "key": "state", + "order": "DESC" + }, + { + "key": "person.name" + } + ], + "page": { + "limit": 2 + } +}