Skip to content

Commit

Permalink
feat:trie tree merge
Browse files Browse the repository at this point in the history
  • Loading branch information
CocaineCong committed Aug 27, 2023
1 parent f25ebb4 commit efe26e0
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 63 deletions.
7 changes: 2 additions & 5 deletions app/search_engine/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"sync"
"sync/atomic"

"github.com/spf13/cast"

"github.com/CocaineCong/tangseng/app/search_engine/query"
"github.com/CocaineCong/tangseng/app/search_engine/segment"
"github.com/CocaineCong/tangseng/app/search_engine/types"
Expand Down Expand Up @@ -65,7 +63,7 @@ func (e *Engine) Text2PostingsLists(text string, docId int64) (err error) {
}

bufInvertedHash := make(segment.InvertedIndexHash)
trieTree := new(trie.Trie)
trieTree := trie.NewTrie()
for _, token := range tokens {
err = segment.Token2PostingsLists(bufInvertedHash, token, docId)
if err != nil {
Expand Down Expand Up @@ -162,8 +160,7 @@ func (e *Engine) FlushInvertedIndex(isEnd ...bool) (err error) {

// FlushDict 刷新dict
func (e *Engine) FlushDict(trieTree *trie.Trie, isEnd ...bool) (err error) {
currSegId := cast.ToInt64(e.CurrSegId)
err = e.Seg[e.CurrSegId].FlushTokenDict(currSegId, trieTree)
err = e.Seg[e.CurrSegId].FlushTokenDict(trieTree)
if err != nil {
log.LogrusObj.Errorln("Flush", err)
return
Expand Down
4 changes: 0 additions & 4 deletions app/search_engine/index/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ func AddDoc(in *Index) {
}

err = in.AddDocument(doc)
if err != nil {
log.LogrusObj.Errorf("index addDoc AddDocument: %v", err)
}
wg.Done()
}(item)
}
wg.Wait()
Expand Down
8 changes: 4 additions & 4 deletions app/search_engine/inputdata/inputdata.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package inputData

import (
"fmt"
"strings"

"github.com/spf13/cast"
Expand All @@ -14,10 +13,11 @@ import (
func Doc2Struct(docStr string) (*types.Document, error) {
docStr = strings.Replace(docStr, "\"", "", -1)
d := strings.Split(docStr, ",")
if len(d) < 3 {
return nil, fmt.Errorf("doc2Struct err: %v", "docStr is not right")
if len(d) <= 5 { // just fix the stupid data
for i := 0; i < 6-len(d); i++ {
d = append(d, "a")
}
}

doc := &types.Document{
DocId: cast.ToInt64(d[0]),
Title: d[1],
Expand Down
15 changes: 15 additions & 0 deletions app/search_engine/inputdata/inputdata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package inputData

import (
"fmt"
"testing"
)

func TestInputDataDoc2Struct(t *testing.T) {
a := "2,[安乐乡]导演利桑德罗·阿隆索导演将打造下一部影片[尤里卡](Eureka,暂译)。据悉该片探讨美国文化问题,故事发生在1870年到2019年期间,涉及地区包括美国、墨西哥以及亚马逊雨林。故事主角是一个经历波折,辗转各地的女性。本片今年7月已在达科他开拍,预计将在2020年上映。"
r, err := Doc2Struct(a)
if err != nil {
fmt.Println(err)
}
fmt.Println("r", r)
}
4 changes: 2 additions & 2 deletions app/search_engine/segment/segment.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ func (e *Segment) FlushInvertedIndex(PostingsHashBuf InvertedIndexHash) (err err
}

// FlushTokenDict 刷新写入 token dict
func (e *Segment) FlushTokenDict(currSegId int64, trieTree *trie.Trie) (err error) {
err = e.StorageDict(currSegId, trieTree)
func (e *Segment) FlushTokenDict(trieTree *trie.Trie) (err error) {
err = e.StorageDict(trieTree)

return
}
Expand Down
28 changes: 21 additions & 7 deletions app/search_engine/storage/dict_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package storage
import (
"bytes"

"github.com/spf13/cast"
bolt "go.etcd.io/bbolt"

"github.com/CocaineCong/tangseng/consts"
Expand All @@ -27,25 +26,40 @@ func NewDictDB(dbName string) (*DictDB, error) {
return &DictDB{db}, nil
}

func (d *DictDB) StorageDict(segId int64, trieTree *trie.Trie) (err error) {
func (d *DictDB) StorageDict(trieTree *trie.Trie) (err error) {
buf := bytes.NewBuffer(nil)
err = codec.BinaryEncoding(buf, trieTree)
if err != nil {
return
}

err = d.PutTrimTreeByKV([]byte(cast.ToString(segId)), buf.Bytes())
err = d.PutTrieTree([]byte(consts.DictBucket), buf.Bytes())

return
}

// PutTrimTreeByKV 通过kv进行存储
func (d *DictDB) PutTrimTreeByKV(key, value []byte) error {
// GetTrieTreeDict 获取 trie tree
func (d *DictDB) GetTrieTreeDict(buf *bytes.Buffer, trieTree *trie.Trie) (err error) {
v, err := d.GetTrieTree([]byte(consts.DictBucket))
if err != nil {
return
}
buf = bytes.NewBuffer(v)
err = codec.BinaryDecoding(buf, trieTree)
if err != nil {
return
}

return
}

// PutTrieTree 存储
func (d *DictDB) PutTrieTree(key, value []byte) error {
return Put(d.db, consts.DictBucket, key, value)
}

// GetTrimTree 通过term获取value
func (d *DictDB) GetTrimTree(key []byte) (value []byte, err error) {
// GetTrieTree 通过term获取value
func (d *DictDB) GetTrieTree(key []byte) (value []byte, err error) {
return Get(d.db, consts.DictBucket, key)
}

Expand Down
21 changes: 21 additions & 0 deletions app/search_engine/storage/dict_db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package storage

import (
"bytes"
"fmt"
"testing"

"github.com/CocaineCong/tangseng/config"
"github.com/CocaineCong/tangseng/pkg/trie"
)

func TestDictDB_GetTrimTree(t *testing.T) {
aConfig := config.Conf.SeConfig.StoragePath + "0.dict"
d, _ := NewDictDB(aConfig)
buf := bytes.NewBuffer(nil)
trieTree := trie.NewTrie()
err := d.GetTrieTreeDict(buf, trieTree)
fmt.Println(err)
a := trieTree.Find("导")
fmt.Println(a)
}
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ services:
SeConfig:
Version: "1.0.0"
StoragePath: "/Users/mac/GolandProjects/Go-SearchEngine/app/search_engine/data/db/"
SourceWuKoFile: "/Users/mac/GolandProjects/Go-SearchEngine/app/search_engine/data/movies.csv"
SourceWuKoFile: "/Users/mac/GolandProjects/Go-SearchEngine/app/search_engine/data/movies_data.csv"
MetaPath: "/Users/mac/GolandProjects/Go-SearchEngine/app/search_engine/data/db/segments.json"
SourceFiles:
- "./source"
Expand Down
8 changes: 8 additions & 0 deletions idl/search_engine.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ message SearchEngineResponse{
repeated string data = 4;
}

message WordAssociationResponse{
int64 code=1;
string msg=2;
repeated string word_association_list = 3;
string data = 4;
}


service SearchEngineService{
rpc SearchEngineSearch(SearchEngineRequest) returns(SearchEngineResponse);
rpc WordAssociation(SearchEngineRequest) returns(WordAssociationResponse);
}
134 changes: 94 additions & 40 deletions pkg/trie/trie.go
Original file line number Diff line number Diff line change
@@ -1,64 +1,118 @@
package trie

// TrieNode 树节点
import (
"fmt"
)

type TrieNode struct {
Char string // Unicode 字符
IsEnding bool // 是否是单词结尾
Children map[rune]*TrieNode // 该节点的子节点字典
IsEnd bool `json:"is_end"` // 标记该节点是否为一个单词的末尾
Children map[byte]*TrieNode `json:"children"` // 存储子节点的指针
}

// NewTrieNode 初始化 Trie 树节点
func NewTrieNode(char string) *TrieNode {
func NewTrieNode() *TrieNode {
return &TrieNode{
Char: char,
IsEnding: false,
Children: make(map[rune]*TrieNode),
IsEnd: false,
Children: make(map[byte]*TrieNode),
}
}

// Trie 树结构
type Trie struct {
Root *TrieNode // 根节点指针
Root *TrieNode // 存储 Trie 树的根节点
}

// NewTrie 初始化 Trie 树
func NewTrie() *Trie {
// 初始化根节点
trieNode := NewTrieNode("/")
return &Trie{trieNode}
return &Trie{Root: NewTrieNode()}
}

// Insert 往 Trie 树中插入一个单词
func (t *Trie) Insert(word string) {
node := t.Root // 获取根节点
for _, code := range word { // 以 Unicode 字符遍历该单词
value, ok := node.Children[code] // 获取 code 编码对应子节点
if !ok {
// 不存在则初始化该节点
value = NewTrieNode(string(code))
// 然后将其添加到子节点字典
node.Children[code] = value
func (trie *Trie) Insert(word string) {
node := trie.Root
for i := 0; i < len(word); i++ {
c := word[i]
if _, ok := node.Children[c]; !ok {
node.Children[c] = NewTrieNode()
}
// 当前节点指针指向当前子节点
node = value
node = node.Children[c]
}
node.IsEnding = true // 一个单词遍历完所有字符后将结尾字符打上标记
node.IsEnd = true
}

// Find 在 Trie 树中查找一个单词
func (t *Trie) Find(word string) bool {
node := t.Root
for _, code := range word {
value, ok := node.Children[code] // 获取对应子节点
if !ok {
// 不存在则直接返回
func (trie *Trie) Search(word string) bool {
node := trie.Root
for i := 0; i < len(word); i++ {
c := word[i]
if _, ok := node.Children[c]; !ok {
return false
}
// 否则继续往后遍历
node = value
node = node.Children[c]
}
return node.IsEnd
}

func (trie *Trie) StartsWith(prefix string) bool {
node := trie.Root
for i := 0; i < len(prefix); i++ {
c := prefix[i]
if _, ok := node.Children[c]; !ok {
return false
}
node = node.Children[c]
}
return true
}

func (trie *Trie) FindAllByPrefix(prefix string) []string {
node := trie.Root
for i := 0; i < len(prefix); i++ {
c := prefix[i]
if _, ok := node.Children[c]; !ok {
return nil
}
node = node.Children[c]
}
if node.IsEnding == false {
return false // 不能完全匹配,只是前缀
words := make([]string, 0)
trie.dfs(node, prefix, &words)
return words
}

func (trie *Trie) dfs(node *TrieNode, word string, words *[]string) {
if node.IsEnd {
*words = append(*words, word)
}
for c, child := range node.Children {
trie.dfs(child, word+string(c), words)
}
return true // 找到对应单词
}

func (trie *Trie) Merge(other *Trie) {
if other == nil {
return
}

var mergeNodes func(n1, n2 *TrieNode)
mergeNodes = func(n1, n2 *TrieNode) {
for c, child := range n2.Children {
if _, ok := n1.Children[c]; ok {
mergeNodes(n1.Children[c], child)
} else {
n1.Children[c] = child
}
}
n1.IsEnd = n1.IsEnd || n2.IsEnd
}

mergeNodes(trie.Root, other.Root)
}

func traverse(node *TrieNode, prefix string) {
if node.IsEnd {
fmt.Println(prefix)
}

for c, child := range node.Children {
traverse(child, prefix+string(c))
}
}

func (trie *Trie) Traverse() {
traverse(trie.Root, "")
}
45 changes: 45 additions & 0 deletions pkg/trie/trie_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package trie

import (
"bytes"
"encoding/gob"
"fmt"
"testing"
)

func TestTrieTree(t *testing.T) {
// example
t1 := NewTrie()
t1.Insert("hello")
t1.Insert("world")
fmt.Println("t1")
t1.Traverse()
t2 := NewTrie()
t2.Insert("hello")
t2.Insert("golang")
t2.Insert("programming")
fmt.Println("t2")
t2.Traverse()

t1.Merge(t2)
fmt.Println("t1 merge")
t1.Traverse()

r := t1.FindAllByPrefix("he")
fmt.Println(r)
}

func TestBinaryTree(t *testing.T) {
t2 := NewTrie()
t3 := NewTrie()
t2.Insert("hello")
t2.Insert("golang")
t2.Insert("programming")
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(t2)
fmt.Println(err)
err = gob.NewDecoder(buf).Decode(t3)
fmt.Println(err)

t3.Traverse()
}

0 comments on commit efe26e0

Please sign in to comment.