diff --git a/README.md b/README.md index aec962f..ad728ce 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ Since version 0.4, `GROOT` will also output the variation graphs which had reads align. These graphs are in [GFA format](https://github.com/GFA-spec/GFA-spec), allowing you to visualise graph alignments using [Bandage](https://github.com/rrwick/Bandage) and determine which variants of a given ARG type are dominant in your metagenomes. Read the [documentation](http://groot-documentation.readthedocs.io/en/latest/?badge=latest) for more info. +Since version 0.8.0, `GROOT` can now optionally use an [LSH Ensemble](https://ekzhu.github.io/datasketch/lshensemble.html) index to enable containment searching. This is thanks to the excellent [method](http://www.vldb.org/pvldb/vol9/p1185-zhu.pdf) and [implementation](https://github.com/ekzhu/lshensemble) of Erkang Zhu. This new index allows the reads of varying read length to be queried against **groot graphs**. + ## Installation Check out the [releases](https://github.com/will-rowe/groot/releases) to download a binary. Alternatively, install using Bioconda or compile the software from source. diff --git a/cmd/align.go b/cmd/align.go index 34e5403..70bc5a5 100644 --- a/cmd/align.go +++ b/cmd/align.go @@ -32,7 +32,7 @@ import ( "github.com/pkg/profile" "github.com/spf13/cobra" "github.com/will-rowe/groot/src/graph" - "github.com/will-rowe/groot/src/lshForest" + "github.com/will-rowe/groot/src/lshIndex" "github.com/will-rowe/groot/src/misc" "github.com/will-rowe/groot/src/stream" "github.com/will-rowe/groot/src/version" @@ -187,21 +187,29 @@ func runAlign() { log.Print("loading index information...") info := new(misc.IndexInfo) misc.ErrorCheck(info.Load(*indexDir + "/index.info")) + if info.Containment { + log.Printf("\tindex type: lshEnsemble") + log.Printf("\tcontainment search seeding: enabled") + } else { + log.Printf("\tindex type: lshForest") + log.Printf("\tcontainment search seeding: disabled") + } + log.Printf("\twindow sized used in indexing: %d\n", info.ReadLength) log.Printf("\tk-mer size: %d\n", info.Ksize) log.Printf("\tsignature size: %d\n", info.SigSize) log.Printf("\tJaccard similarity theshold: %0.2f\n", info.JSthresh) - log.Printf("\twindow sized used in indexing: %d\n", info.ReadLength) log.Print("loading the groot graphs...") graphStore := make(graph.GraphStore) misc.ErrorCheck(graphStore.Load(*indexDir + "/index.graph")) log.Printf("\tnumber of variation graphs: %d\n", len(graphStore)) log.Print("loading the MinHash signatures...") - database := lshForest.NewLSHforest(info.SigSize, info.JSthresh) + var database *lshIndex.LshEnsemble + if info.Containment { + database = lshIndex.NewLSHensemble(make([]lshIndex.Partition, lshIndex.PARTITIONS), info.SigSize, lshIndex.MAXK) + } else { + database = lshIndex.NewLSHforest(info.SigSize, info.JSthresh) + } misc.ErrorCheck(database.Load(*indexDir + "/index.sigs")) - database.Index() - numHF, numBucks := database.Settings() - log.Printf("\tnumber of hash functions per bucket: %d\n", numHF) - log.Printf("\tnumber of buckets: %d\n", numBucks) /////////////////////////////////////////////////////////////////////////////////////// // create SAM references from the sequences held in the graphs referenceMap, err := graphStore.GetRefs() @@ -222,10 +230,12 @@ func runAlign() { // add in the process parameters dataStream.InputFile = *fastq + fastqChecker.Containment = info.Containment fastqChecker.WindowSize = info.ReadLength dbQuerier.Db = database dbQuerier.CommandInfo = info dbQuerier.GraphStore = graphStore + dbQuerier.Threshold = info.JSthresh graphAligner.GraphStore = graphStore graphAligner.RefMap = referenceMap graphAligner.MaxClip = *clip diff --git a/cmd/index.go b/cmd/index.go index 5ac9fd8..ec9413a 100644 --- a/cmd/index.go +++ b/cmd/index.go @@ -35,7 +35,7 @@ import ( "github.com/spf13/cobra" "github.com/will-rowe/gfa" "github.com/will-rowe/groot/src/graph" - "github.com/will-rowe/groot/src/lshForest" + "github.com/will-rowe/groot/src/lshIndex" "github.com/will-rowe/groot/src/misc" "github.com/will-rowe/groot/src/seqio" "github.com/will-rowe/groot/src/version" @@ -51,6 +51,7 @@ var ( msaList []string // the collected MSA files outDir *string // directory to save index files and log to defaultOutDir = "./groot-index-" + string(time.Now().Format("20060102150405")) // a default dir to store the index files + containment *bool // use lshEnsemble instead of lshForest -- allows for variable read length ) // the index command (used by cobra) @@ -70,10 +71,11 @@ var indexCmd = &cobra.Command{ func init() { kSize = indexCmd.Flags().IntP("kmerSize", "k", 7, "size of k-mer") sigSize = indexCmd.Flags().IntP("sigSize", "s", 128, "size of MinHash signature") - readLength = indexCmd.Flags().IntP("readLength", "l", 100, "length of query reads (which will be aligned during the align subcommand)") - jsThresh = indexCmd.Flags().Float64P("jsThresh", "j", 0.99, "minimum Jaccard similarity for a seed to be recorded") + readLength = indexCmd.Flags().IntP("readLength", "l", 100, "max length of query reads (which will be aligned during the align subcommand)") + jsThresh = indexCmd.Flags().Float64P("jsThresh", "j", 0.99, "minimum Jaccard similarity for a seed to be recorded (note: this is used as a containment theshold when --containment set") msaDir = indexCmd.Flags().StringP("msaDir", "i", "", "directory containing the clustered references (MSA files) - required") outDir = indexCmd.PersistentFlags().StringP("outDir", "o", defaultOutDir, "directory to save index files to") + containment = indexCmd.Flags().BoolP("containment", "c", false, "use lshEnsemble instead of lshForest (allows for variable read length during alignment)") indexCmd.MarkFlagRequired("msaDir") RootCmd.AddCommand(indexCmd) } @@ -136,6 +138,11 @@ func runIndex() { // check the supplied files and then log some stuff log.Printf("checking parameters...") misc.ErrorCheck(indexParamCheck()) + if *containment { + log.Printf("\tindexing scheme: lshEnsemble (containment search)") + } else { + log.Printf("\tindexing scheme: lshForest") + } log.Printf("\tprocessors: %d", *proc) log.Printf("\tk-mer size: %d", *kSize) log.Printf("\tsignature size: %d", *sigSize) @@ -198,56 +205,57 @@ func runIndex() { }() /////////////////////////////////////////////////////////////////////////////////////// // collect and store the GrootGraph windows - var sigStore = make([]map[int]map[int][][]uint64, len(graphStore)) - for i := range sigStore { - sigStore[i] = make(map[int]map[int][][]uint64) - } + sigStore := []*lshIndex.GraphWindow{} + lookupMap := make(lshIndex.KeyLookupMap) // receive the signatures - var sigCount int = 0 for window := range windowChan { - // initialise the inner map of sigStore if graph has not been seen yet - if _, ok := sigStore[window.GraphID][window.Node]; !ok { - sigStore[window.GraphID][window.Node] = make(map[int][][]uint64) - } + // combine graphID, nodeID and offset to form a string key for signature + stringKey := fmt.Sprintf("g%dn%do%d", window.GraphID, window.Node, window.OffSet) + // convert to a graph window + gw := &lshIndex.GraphWindow{stringKey, *readLength, window.Sig} // store the signature for the graph:node:offset - sigStore[window.GraphID][window.Node][window.OffSet] = append(sigStore[window.GraphID][window.Node][window.OffSet], window.Sig) - sigCount++ + sigStore = append(sigStore, gw) + // add a key to the lookup map + lookupMap[stringKey] = seqio.Key{GraphID: window.GraphID, Node: window.Node, OffSet: window.OffSet} } - log.Printf("\tnumber of signatures generated: %d\n", sigCount) - /////////////////////////////////////////////////////////////////////////////////////// - // run LSH forest - log.Printf("running LSH forest...\n") - database := lshForest.NewLSHforest(*sigSize, *jsThresh) - // range over the nodes in each graph, each node will have one or more signature - for graphID, nodesMap := range sigStore { - // add each signature to the database - for nodeID, offsetMap := range nodesMap { - for offset, signatures := range offsetMap { - for _, signature := range signatures { - // combine graphID, nodeID and offset to form a string key for signature - stringKey := fmt.Sprintf("g%dn%do%d", graphID, nodeID, offset) - // add the key to a lookup map - key := seqio.Key{GraphID: graphID, Node: nodeID, OffSet: offset} - database.KeyLookup[stringKey] = key - // add the signature to the lshForest - database.Add(stringKey, signature) - } - } + numSigs := len(sigStore) + log.Printf("\tnumber of signatures generated: %d\n", numSigs) + var database *lshIndex.LshEnsemble + if *containment == false { + /////////////////////////////////////////////////////////////////////////////////////// + // run LSH forest + log.Printf("running LSH Forest...\n") + database = lshIndex.NewLSHforest(*sigSize, *jsThresh) + // range over the nodes in each graph, each node will have one or more signature + for window := range lshIndex.Windows2Chan(sigStore) { + // add the signature to the lshForest + database.Add(window.Key, window.Signature, 0) } + // print some stuff + numHF, numBucks := database.Lshes[0].Settings() + log.Printf("\tnumber of hash functions per bucket: %d\n", numHF) + log.Printf("\tnumber of buckets: %d\n", numBucks) + } else { + /////////////////////////////////////////////////////////////////////////////////////// + // run LSH ensemble (https://github.com/ekzhu/lshensemble) + log.Printf("running LSH Ensemble...\n") + database = lshIndex.BootstrapLshEnsemble(lshIndex.PARTITIONS, *sigSize, lshIndex.MAXK, numSigs, lshIndex.Windows2Chan(sigStore)) + // print some stuff + log.Printf("\tnumber of LSH Ensemble partitions: %d\n", lshIndex.PARTITIONS) + log.Printf("\tmax no. hash functions per bucket: %d\n", lshIndex.MAXK) } - numHF, numBucks := database.Settings() - log.Printf("\tnumber of hash functions per bucket: %d\n", numHF) - log.Printf("\tnumber of buckets: %d\n", numBucks) + // attach the key lookup map to the index + database.KeyLookup = lookupMap /////////////////////////////////////////////////////////////////////////////////////// // record runtime info - info := &misc.IndexInfo{Version: version.VERSION, Ksize: *kSize, SigSize: *sigSize, JSthresh: *jsThresh, ReadLength: *readLength} + info := &misc.IndexInfo{Version: version.VERSION, Ksize: *kSize, SigSize: *sigSize, JSthresh: *jsThresh, ReadLength: *readLength, Containment: *containment} // save the index files log.Printf("saving index files to \"%v\"...", *outDir) + misc.ErrorCheck(database.Dump(*outDir + "/index.sigs")) + log.Printf("\tsaved MinHash signatures") misc.ErrorCheck(info.Dump(*outDir + "/index.info")) log.Printf("\tsaved runtime info") misc.ErrorCheck(graphStore.Dump(*outDir + "/index.graph")) log.Printf("\tsaved groot graphs") - misc.ErrorCheck(database.Dump(*outDir + "/index.sigs")) - log.Printf("\tsaved MinHash signatures") log.Println("finished") } diff --git a/cmd/report.go b/cmd/report.go index 5e67086..0039687 100644 --- a/cmd/report.go +++ b/cmd/report.go @@ -98,7 +98,7 @@ func reportParamCheck() error { log.Printf("\tBAM file: %v", *bamFile) } if *covCutoff > 1.0 { - return fmt.Errorf("supplied coverage cutoff exceeds 1.0 (100%): %v", *covCutoff) + return fmt.Errorf("supplied coverage cutoff exceeds 1.0 (100%%): %.2f", *covCutoff) } return nil } diff --git a/db/full-ARG-databases/resfinder/aminoglycoside.fsa b/db/full-ARG-databases/resfinder/aminoglycoside.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/beta-lactam.fsa b/db/full-ARG-databases/resfinder/beta-lactam.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/colistin.fsa b/db/full-ARG-databases/resfinder/colistin.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/fosfomycin.fsa b/db/full-ARG-databases/resfinder/fosfomycin.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/fusidicacid.fsa b/db/full-ARG-databases/resfinder/fusidicacid.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/glycopeptide.fsa b/db/full-ARG-databases/resfinder/glycopeptide.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/macrolide.fsa b/db/full-ARG-databases/resfinder/macrolide.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/nitroimidazole.fsa b/db/full-ARG-databases/resfinder/nitroimidazole.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/oxazolidinone.fsa b/db/full-ARG-databases/resfinder/oxazolidinone.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/phenicol.fsa b/db/full-ARG-databases/resfinder/phenicol.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/quinolone.fsa b/db/full-ARG-databases/resfinder/quinolone.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/rifampicin.fsa b/db/full-ARG-databases/resfinder/rifampicin.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/sulphonamide.fsa b/db/full-ARG-databases/resfinder/sulphonamide.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/tetracycline.fsa b/db/full-ARG-databases/resfinder/tetracycline.fsa old mode 100755 new mode 100644 diff --git a/db/full-ARG-databases/resfinder/trimethoprim.fsa b/db/full-ARG-databases/resfinder/trimethoprim.fsa old mode 100755 new mode 100644 diff --git a/src/alignment/alignment_test.go b/src/alignment/alignment_test.go index 202842d..2c544fd 100644 --- a/src/alignment/alignment_test.go +++ b/src/alignment/alignment_test.go @@ -1,8 +1,8 @@ package alignment import ( + "fmt" "io" - "log" "os" "testing" @@ -18,12 +18,12 @@ var ( sigSize = 128 ) -func loadGFA() *gfa.GFA { +func loadGFA() (*gfa.GFA, error) { // load the GFA file fh, err := os.Open(inputFile) reader, err := gfa.NewReader(fh) if err != nil { - log.Fatal("can't read gfa file: %v", err) + return nil, fmt.Errorf("can't read gfa file: %v", err) } // collect the GFA instance myGFA := reader.CollectGFA() @@ -34,13 +34,13 @@ func loadGFA() *gfa.GFA { break } if err != nil { - log.Fatal("error reading line in gfa file: %v", err) + return nil, fmt.Errorf("error reading line in gfa file: %v", err) } if err := line.Add(myGFA); err != nil { - log.Fatal("error adding line to GFA instance: %v", err) + return nil, fmt.Errorf("error adding line to GFA instance: %v", err) } } - return myGFA + return myGFA, nil } func setupMultimapRead() (*seqio.FASTQread, error) { @@ -80,13 +80,16 @@ func TestExactMatchMultiMapper(t *testing.T) { // create the read testRead, err := setupMultimapRead() if err != nil { - log.Fatal(err) + t.Fatal(err) } // create the GrootGraph and graphStore - myGFA := loadGFA() + myGFA, err := loadGFA() + if err != nil { + t.Fatal(err) + } grootGraph, err := graph.CreateGrootGraph(myGFA, 1) if err != nil { - log.Fatal(err) + t.Fatal(err) } graphStore := make(graph.GraphStore) graphStore[grootGraph.GraphID] = grootGraph @@ -113,13 +116,16 @@ func TestExactMatchUniqMapper(t *testing.T) { // create the read testRead, err := setupUniqmapRead() if err != nil { - log.Fatal(err) + t.Fatal(err) } // create the GrootGraph and graphStore - myGFA := loadGFA() + myGFA, err := loadGFA() + if err != nil { + t.Fatal(err) + } grootGraph, err := graph.CreateGrootGraph(myGFA, 1) if err != nil { - log.Fatal(err) + t.Fatal(err) } graphStore := make(graph.GraphStore) graphStore[grootGraph.GraphID] = grootGraph diff --git a/src/graph/graph_test.go b/src/graph/graph_test.go index dc219a3..4eea74a 100644 --- a/src/graph/graph_test.go +++ b/src/graph/graph_test.go @@ -1,9 +1,9 @@ package graph import ( + "fmt" "github.com/will-rowe/gfa" "io" - "log" "os" "testing" ) @@ -17,12 +17,12 @@ var ( blaB10 = []byte("ATGAAAGGATTAAAAGGGCTATTGGTTCTGGCTTTAGGCTTTACAGGACTACAGGTTTTTGGGCAACAGAACCCTGATATTAAAATTGAAAAATTAAAAGATAATTTATACGTCTATACAACCTATAATACCTTCAAAGGAACTAAATATGCGGCTAATGCGGTATATATGGTAACCGATAAAGGAGTAGTGGTTATAGACTCTCCATGGGGAGAAGATAAATTTAAAAGTTTTACAGACGAGATTTATAAAAAGCACGGAAAGAAAGTTATCATGAACATTGCAACCCACTCTCATGATGATAGAGCCGGAGGTCTTGAATATTTTGGTAAACTAGGTGCAAAAACTTATTCTACTAAAATGACAGATTCTATTTTAGCAAAAGAGAATAAGCCAAGAGCAAAGTACACTTTTGATAATAATAAATCTTTTAAAGTAGGAAAGACTGAGTTTCAGGTTTATTATCCGGGAAAAGGTCATACAGCAGATAATGTGGTTGTGTGGTTTCCTAAAGACAAAGTATTAGTAGGAGGCTGCATTGTAAAAAGTGGTGATTCGAAAGACCTTGGGTTTATTGGGGAAGCTTATGTAAACGACTGGACACAGTCCATACACAACATTCAGCAGAAATTTCCCTATGTTCAGTATGTCGTTGCAGGTCATGACGACTGGAAAGATCAAACATCAATACAACATACACTGGATTTAATCAGTGAATATCAACAAAAACAAAAGGCTTCAAATTAA") ) -func loadGFA() *gfa.GFA { +func loadGFA() (*gfa.GFA, error) { // load the GFA file fh, err := os.Open(inputFile) reader, err := gfa.NewReader(fh) if err != nil { - log.Fatal("can't read gfa file: %v", err) + return nil, fmt.Errorf("can't read gfa file: %v", err) } // collect the GFA instance myGFA := reader.CollectGFA() @@ -33,30 +33,30 @@ func loadGFA() *gfa.GFA { break } if err != nil { - log.Fatal("error reading line in gfa file: %v", err) + return nil, fmt.Errorf("error reading line in gfa file: %v", err) } if err := line.Add(myGFA); err != nil { - log.Fatal("error adding line to GFA instance: %v", err) + return nil, fmt.Errorf("error adding line to GFA instance: %v", err) } } - return myGFA + return myGFA, nil } -func loadMSA() *gfa.GFA { +func loadMSA() (*gfa.GFA, error) { // load the MSA msa, _ := gfa.ReadMSA(inputFile2) // convert the MSA to a GFA instance myGFA, err := gfa.MSA2GFA(msa) - if err != nil { - log.Fatal(err) - } - return myGFA + return myGFA, err } // test CreateGrootGraph func TestCreateGrootGraph(t *testing.T) { - myGFA := loadGFA() - _, err := CreateGrootGraph(myGFA, 1) + myGFA, err := loadGFA() + if err != nil { + t.Fatal(err) + } + _, err = CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) } @@ -64,7 +64,10 @@ func TestCreateGrootGraph(t *testing.T) { // test Graph2Seq func TestGraph2Seq(t *testing.T) { - myGFA := loadGFA() + myGFA, err := loadGFA() + if err != nil { + t.Fatal(err) + } grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) @@ -82,8 +85,10 @@ func TestGraph2Seq(t *testing.T) { // test WindowGraph func TestWindowGraph(t *testing.T) { - myGFA := loadMSA() - //myGFA := loadGFA() + myGFA, err := loadMSA() + if err != nil { + t.Fatal(err) + } grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) @@ -99,7 +104,10 @@ func TestWindowGraph(t *testing.T) { // test GraphStore dump/load func TestGraphStore(t *testing.T) { - myGFA := loadGFA() + myGFA, err := loadGFA() + if err != nil { + t.Fatal(err) + } grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) @@ -123,7 +131,10 @@ func TestGraphStore(t *testing.T) { // test DumpGraph to save a gfa func TestGraphDump(t *testing.T) { - myGFA := loadGFA() + myGFA, err := loadGFA() + if err != nil { + t.Fatal(err) + } grootGraph, err := CreateGrootGraph(myGFA, 1) if err != nil { t.Fatal(err) diff --git a/src/lshForest/lshForest.go b/src/lshForest/lshForest.go deleted file mode 100644 index 0426c79..0000000 --- a/src/lshForest/lshForest.go +++ /dev/null @@ -1,318 +0,0 @@ -package lshForest - -import ( - "encoding/binary" - "encoding/gob" - "fmt" - "github.com/will-rowe/groot/src/seqio" - "math" - "os" - "sort" - "sync" -) - -/* - The LSH forest -*/ -type LSHforest struct { - numHashFuncs int - numBands int - initialHashTables []initialHashTable - hashTables []hashTables - hashedSignatureFunc hashedSignatureFunc - hashValueSize int - KeyLookup KeyLookupMap -} - -/* - A function to construct the LSH forest -*/ -func NewLSHforest(sigSize int, jsThresh float64) *LSHforest { - // calculate the optimal number of bands and hash functions to use based on the length of MinHash signature and a Jaccard Similarity theshhold - numHashFuncs, numBands, _, _ := optimise(sigSize, jsThresh) - hashValueSize := 4 - // create the initial hash tables - initialHashTables := make([]initialHashTable, numBands) - for i := range initialHashTables { - initialHashTables[i] = make(initialHashTable) - } - // create the hash tables that will be populated once the LSH forest indexing method has been run - indexedHashTables := make([]hashTables, numBands) - for i := range indexedHashTables { - indexedHashTables[i] = make(hashTables, 0) - } - // create the KeyLookup map to relate signatures to graph locations - KeyLookup := make(KeyLookupMap) - // return the address of the new LSH forest - newLSHforest := new(LSHforest) - newLSHforest.numHashFuncs = numHashFuncs - newLSHforest.numBands = numBands - newLSHforest.hashValueSize = hashValueSize - newLSHforest.initialHashTables = initialHashTables - newLSHforest.hashTables = indexedHashTables - newLSHforest.hashedSignatureFunc = hashedSignatureFuncGen(hashValueSize) - newLSHforest.KeyLookup = KeyLookup - return newLSHforest -} - -/* - The types needed by the LSH forest -*/ -// this map relates the stringified seqio.Key to the original, allowing LSHforest search results to easily be related to graph locations -type KeyLookupMap map[string]seqio.Key - -// graphKeys is a slice containing all the stringified graphKeys for a given hashed signature -type graphKeys []string - -// the initial hash table uses the hashed signature as a key - the values are the corresponding graphKeys -type initialHashTable map[string]graphKeys - -// a band is a single hash table that is stored in the indexedHashTables - it contains the band of a hash signature and the corresponding graphKeys -type band struct { - HashedSignature string - graphKeys graphKeys -} - -// this is populated during indexing -- it is a slice of bands and can be sorted -type hashTables []band - -//methods to satisfy the sort interface -func (h hashTables) Len() int { return len(h) } -func (h hashTables) Swap(i, j int) { h[i], h[j] = h[j], h[i] } -func (h hashTables) Less(i, j int) bool { return h[i].HashedSignature < h[j].HashedSignature } - -// the hashkey function type and the generator function -type hashedSignatureFunc func([]uint64) string - -func hashedSignatureFuncGen(hashValueSize int) hashedSignatureFunc { - return func(sig []uint64) string { - hashedSig := make([]byte, hashValueSize*len(sig)) - buf := make([]byte, 8) - for i, v := range sig { - // use the ByteOrder interface to write binary data - // use the LittleEndian implementation and call the Put method - binary.LittleEndian.PutUint64(buf, v) - copy(hashedSig[i*hashValueSize:(i+1)*hashValueSize], buf[:hashValueSize]) - } - return string(hashedSig) - } -} - -/* - A method to return the number of hash functions and number of bands set by the LSH forest -*/ -func (self *LSHforest) Settings() (numHashFuncs, numBands int) { - return self.numHashFuncs, self.numBands -} - -/* - A method to add a minhash signature and graph key to the LSH forest -*/ -func (self *LSHforest) Add(key string, sig []uint64) { - // split the signature into the right number of bands and then hash each one - hashedSignature := make([]string, self.numBands) - for i := 0; i < self.numBands; i++ { - hashedSignature[i] = self.hashedSignatureFunc(sig[i*self.numHashFuncs : (i+1)*self.numHashFuncs]) - } - // iterate over each band in the LSH forest - for i := 0; i < len(self.initialHashTables); i++ { - // if the current band in the signature isn't in the current band in the LSH forest, add it - if _, ok := self.initialHashTables[i][hashedSignature[i]]; !ok { - self.initialHashTables[i][hashedSignature[i]] = make(graphKeys, 1) - self.initialHashTables[i][hashedSignature[i]][0] = key - // if it is, append the current key (graph location) to this hashed signature band - } else { - self.initialHashTables[i][hashedSignature[i]] = append(self.initialHashTables[i][hashedSignature[i]], key) - } - } -} - -/* - A method to index the graph (transfers contents of each initialHashTable so they can be sorted and searched) -*/ -func (self *LSHforest) Index() { - // iterate over the empty indexed hash tables - for i := range self.hashTables { - // transfer contents from the corresponding band in the initial hash table - for HashedSignature, keys := range self.initialHashTables[i] { - self.hashTables[i] = append(self.hashTables[i], band{HashedSignature, keys}) - } - // sort the new hashtable and store it in the corresponding slot in the indexed hash tables - sort.Sort(self.hashTables[i]) - // clear the initial hashtable that has just been processed - self.initialHashTables[i] = make(initialHashTable) - } -} - -/* - Methods to dump the LSH forest to disk and then load it again -*/ -func (self *LSHforest) Dump(path string) error { - if len(self.hashTables[0]) != 0 { - return fmt.Errorf("cannot dump the LSH Forest after running the indexing method") - } - file, err := os.Create(path) - if err != nil { - return err - } - defer file.Close() - encoder := gob.NewEncoder(file) - for _, bandContents := range self.initialHashTables { - err := encoder.Encode(bandContents) - if err != nil { - return err - } - } - err = encoder.Encode(self.KeyLookup) - if err != nil { - return err - } - return nil -} -func (self *LSHforest) Load(path string) error { - file, err := os.Open(path) - if err != nil { - return err - } - defer file.Close() - decoder := gob.NewDecoder(file) - for _, bandContents := range self.initialHashTables { - err = decoder.Decode(&bandContents) - if err != nil { - return err - } - } - err = decoder.Decode(&self.KeyLookup) - if err != nil { - return err - } - return nil -} - -/* - A method to query a MinHash signature against the LSH forest -*/ -func (self *LSHforest) Query(sig []uint64) []string { - result := make([]string, 0) - // more info on done chans for explicit cancellation in concurrent pipelines: https://blog.golang.org/pipelines - done := make(chan struct{}) - defer close(done) - // collect query results and aggregate in a single array to send back - for key := range self.runQuery(sig, done) { - result = append(result, key) - } - return result -} - -func (self *LSHforest) runQuery(sig []uint64, done <-chan struct{}) <-chan string { - queryResultChan := make(chan string) - go func() { - defer close(queryResultChan) - // hash the query signature - hashedSignature := make([]string, self.numBands) - for i := 0; i < self.numBands; i++ { - hashedSignature[i] = self.hashedSignatureFunc(sig[i*self.numHashFuncs : (i+1)*self.numHashFuncs]) - } - // don't send back multiple copies of the same key - seens := make(map[string]bool) - // compress internal nodes using a prefix - prefixSize := self.hashValueSize * self.numHashFuncs - // run concurrent hashtable queries - keyChan := make(chan string) - var wg sync.WaitGroup - wg.Add(self.numBands) - for i := 0; i < self.numBands; i++ { - go func(band hashTables, queryChunk string) { - defer wg.Done() - // sort.Search uses binary search to find and return the smallest index i in [0, n) at which f(i) is true - index := sort.Search(len(band), func(x int) bool { return band[x].HashedSignature[:prefixSize] >= queryChunk }) - // k is the index returned by the search - if index < len(band) && band[index].HashedSignature[:prefixSize] == queryChunk { - for j := index; j < len(band) && band[j].HashedSignature[:prefixSize] == queryChunk; j++ { - // copies key values from this hashtable to the keyChan until all values from band[j] copied or done is closed - for _, key := range band[j].graphKeys { - select { - case keyChan <- key: - case <-done: - return - } - } - } - } - }(self.hashTables[i], hashedSignature[i]) - } - go func() { - wg.Wait() - close(keyChan) - }() - for key := range keyChan { - if _, seen := seens[key]; seen { - continue - } - queryResultChan <- key - seens[key] = true - } - }() - return queryResultChan -} - -// the following funcs are taken from https://github.com/ekzhu/minhash-lsh - -// optimise returns the optimal number of hash functions and the optimal number of bands for Jaccard similarity search, as well as the false positive and negative probabilities. -func optimise(sigSize int, jsThresh float64) (int, int, float64, float64) { - optimumNumHashFuncs, optimumNumBands := 0, 0 - fp, fn := 0.0, 0.0 - minError := math.MaxFloat64 - for l := 1; l <= sigSize; l++ { - for k := 1; k <= sigSize; k++ { - if l*k > sigSize { - break - } - currFp := probFalsePositive(l, k, jsThresh, 0.01) - currFn := probFalseNegative(l, k, jsThresh, 0.01) - currErr := currFn + currFp - if minError > currErr { - minError = currErr - optimumNumHashFuncs = k - optimumNumBands = l - fp = currFp - fn = currFn - } - } - } - return optimumNumHashFuncs, optimumNumBands, fp, fn -} - -// Compute the integral of function f, lower limit a, upper limit l, and -// precision defined as the quantize step -func integral(f func(float64) float64, a, b, precision float64) float64 { - var area float64 - for x := a; x < b; x += precision { - area += f(x+0.5*precision) * precision - } - return area -} - -// Probability density function for false positive -func falsePositive(l, k int) func(float64) float64 { - return func(j float64) float64 { - return 1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l)) - } -} - -// Probability density function for false negative -func falseNegative(l, k int) func(float64) float64 { - return func(j float64) float64 { - return 1.0 - (1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l))) - } -} - -// Compute the cummulative probability of false negative given threshold t -func probFalseNegative(l, k int, t, precision float64) float64 { - return integral(falseNegative(l, k), t, 1.0, precision) -} - -// Compute the cummulative probability of false positive given threshold t -func probFalsePositive(l, k int, t, precision float64) float64 { - return integral(falsePositive(l, k), 0, t, precision) -} diff --git a/src/lshIndex/README.md b/src/lshIndex/README.md new file mode 100644 index 0000000..88757b4 --- /dev/null +++ b/src/lshIndex/README.md @@ -0,0 +1,9 @@ +# lshIndex package + +Since version 0.8.0, GROOT has two options for indexing the variation graphs - lshForest or lshEnsemble. With the addition of lshEnsemble, GROOT can now receive variable read lengths and seed these against graphs using containment search. + +This has required a significant re-write of the lshForest code, which is all contained in this directory. The majority of this code now comes from the [lshEnsemble package](https://godoc.org/github.com/ekzhu/lshensemble) by ekzhu. I have just made a few changes: + +* removed methods that were unnecessary for GROOT +* added a method to write the index to disk +* added methods to generate a single LSH Forest index using a Jaccard Similarity threshold and signature length parameter diff --git a/src/lshIndex/lshEnsemble.go b/src/lshIndex/lshEnsemble.go new file mode 100644 index 0000000..e87126c --- /dev/null +++ b/src/lshIndex/lshEnsemble.go @@ -0,0 +1,189 @@ +package lshIndex + +import ( + "encoding/gob" + "fmt" + "os" + "sync" + + "github.com/orcaman/concurrent-map" + "github.com/will-rowe/groot/src/seqio" +) + +type param struct { + k int + l int +} + +// Partition represents a domain size partition in the LSH Ensemble index. +type Partition struct { + Lower int `json:"lower"` + Upper int `json:"upper"` +} + +// KeyLookupMap relates the stringified seqio.Key to the original, allowing LSH index search results to easily be related to graph locations +type KeyLookupMap map[string]seqio.Key + +// GraphWindow represents a region of a variation graph +type GraphWindow struct { + // The unique key of this window + Key interface{} + // The window size + Size int + // The MinHash signature of this window + Signature []uint64 +} + +// LshEnsemble represents an LSH Ensemble index. +type LshEnsemble struct { + Partitions []Partition + Lshes []*LshForest + MaxK int + NumHash int + paramCache cmap.ConcurrentMap + Indexed bool + SingleForest bool + KeyLookup KeyLookupMap +} + +// Add a new domain to the index given its partition ID - the index of the partition. +// The added domain won't be searchable until the Index() function is called. +func (e *LshEnsemble) Add(key interface{}, sig []uint64, partInd int) { + e.Lshes[partInd].Add(key, sig) +} + +// Index makes all added domains searchable. +func (e *LshEnsemble) Index() { + for i := range e.Lshes { + e.Lshes[i].Index() + } + e.Indexed = true +} + +// Query returns the candidate domain keys in a channel. +// This function is given the MinHash signature of the query domain, sig, the domain size, +// the containment threshold, and a cancellation channel. +// Closing channel done will cancel the query execution. +// The query signature must be generated using the same seed as the signatures of the indexed domains, +// and have the same number of hash functions. +func (e *LshEnsemble) Query(sig []uint64, size int, threshold float64, done <-chan struct{}) <-chan interface{} { + if e.SingleForest { + return e.queryForest(sig, done) + } + params := e.computeParams(size, threshold) + return e.queryWithParam(sig, params, done) +} + +// +func (e *LshEnsemble) queryWithParam(sig []uint64, params []param, done <-chan struct{}) <-chan interface{} { + // Collect candidates from all partitions + keyChan := make(chan interface{}) + var wg sync.WaitGroup + wg.Add(len(e.Lshes)) + for i := range e.Lshes { + go func(lsh *LshForest, k, l int) { + lsh.Query(sig, k, l, keyChan, done) + wg.Done() + }(e.Lshes[i], params[i].k, params[i].l) + } + go func() { + wg.Wait() + close(keyChan) + }() + return keyChan +} + +// +func (e *LshEnsemble) queryForest(sig []uint64, done <-chan struct{}) <-chan interface{} { + keyChan := make(chan interface{}) + var wg sync.WaitGroup + wg.Add(1) + go func(lsh *LshForest) { + lsh.Query(sig, -1, -1, keyChan, done) + wg.Done() + }(e.Lshes[0]) + go func() { + wg.Wait() + close(keyChan) + }() + return keyChan +} + +// Compute the optimal k and l for each partition +func (e *LshEnsemble) computeParams(size int, threshold float64) []param { + params := make([]param, len(e.Partitions)) + for i, p := range e.Partitions { + x := p.Upper + key := cacheKey(x, size, threshold) + if cached, exist := e.paramCache.Get(key); exist { + params[i] = cached.(param) + } else { + optK, optL, _, _ := e.Lshes[i].OptimalKL(x, size, threshold) + computed := param{optK, optL} + e.paramCache.Set(key, computed) + params[i] = computed + } + } + return params +} + +// Make a cache key with threshold precision to 2 decimal points +func cacheKey(x, q int, t float64) string { + return fmt.Sprintf("%.8x %.8x %.2f", x, q, t) +} + +// Dump an LSH index to disk +func (LshEnsemble *LshEnsemble) Dump(path string) error { + if LshEnsemble.Indexed == true { + return fmt.Errorf("cannot dump the LSH Index after running the indexing method") + } + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + encoder := gob.NewEncoder(file) + if err := encoder.Encode(LshEnsemble); err != nil { + return err + } + for _, lsh := range LshEnsemble.Lshes { + for _, bandContents := range lsh.initHashTables { + err := encoder.Encode(bandContents) + if err != nil { + return err + } + } + } + err = encoder.Encode(LshEnsemble.KeyLookup) + if err != nil { + return err + } + return nil +} + +// Load an LSH index from disk +func (LshEnsemble *LshEnsemble) Load(path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + decoder := gob.NewDecoder(file) + if err := decoder.Decode(&LshEnsemble); err != nil { + return(err) + } + for _, lsh := range LshEnsemble.Lshes { + for _, bandContents := range lsh.initHashTables { + err := decoder.Decode(&bandContents) + if err != nil { + return err + } + } + } + err = decoder.Decode(&LshEnsemble.KeyLookup) + if err != nil { + return err + } + LshEnsemble.Index() + return nil +} diff --git a/src/lshIndex/lshForest.go b/src/lshIndex/lshForest.go new file mode 100644 index 0000000..1c14ee8 --- /dev/null +++ b/src/lshIndex/lshForest.go @@ -0,0 +1,210 @@ +package lshIndex + +import ( + "encoding/binary" + "math" + "sort" +) + +// +type keys []interface{} + +// For initial bootstrapping +type initHashTable map[string]keys + +// +type bucket struct { + hashKey string + keys keys +} + +// +type hashTable []bucket +func (h hashTable) Len() int { return len(h) } +func (h hashTable) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h hashTable) Less(i, j int) bool { return h[i].hashKey < h[j].hashKey } + +// +type hashKeyFunc func([]uint64) string + +// +func hashKeyFuncGen(hashValueSize int) hashKeyFunc { + return func(sig []uint64) string { + s := make([]byte, hashValueSize*len(sig)) + buf := make([]byte, 8) + for i, v := range sig { + // use the ByteOrder interface to write binary data + // use the LittleEndian implementation and call the Put method + binary.LittleEndian.PutUint64(buf, v) + copy(s[i*hashValueSize:(i+1)*hashValueSize], buf[:hashValueSize]) + } + return string(s) + } +} + +// LshForest represents a MinHash LSH implemented using LSH Forest +// (http://ilpubs.stanford.edu:8090/678/1/2005-14.pdf). +// It supports query-time setting of the MinHash LSH parameters +// L (number of bands) and +// K (number of hash functions per band). +type LshForest struct { + k int + l int + initHashTables []initHashTable + hashTables []hashTable + hashKeyFunc hashKeyFunc + hashValueSize int + KeyLookup KeyLookupMap +} + +// +func newLshForest(k, l int) *LshForest { + if k < 0 || l < 0 { + panic("k and l must be positive") + } + hashTables := make([]hashTable, l) + initHashTables := make([]initHashTable, l) + for i := range initHashTables { + initHashTables[i] = make(initHashTable) + } + return &LshForest{ + k: k, + l: l, + hashValueSize: HASH_SIZE, + initHashTables: initHashTables, + hashTables: hashTables, + hashKeyFunc: hashKeyFuncGen(HASH_SIZE), + KeyLookup: make(KeyLookupMap), + } +} + +// Returns the number of hash functions per band and the number of bands +func (f *LshForest) Settings() (int, int) { + return f.k, f.l +} + +// Add a key with MinHash signature into the index. +// The key won't be searchable until Index() is called. +func (f *LshForest) Add(key interface{}, sig []uint64) { + // Generate hash keys + Hs := make([]string, f.l) + for i := 0; i < f.l; i++ { + Hs[i] = f.hashKeyFunc(sig[i*f.k : (i+1)*f.k]) + } + // Insert keys into the bootstrapping tables + for i := range f.initHashTables { + ht := f.initHashTables[i] + hk := Hs[i] + if _, exist := ht[hk]; exist { + ht[hk] = append(ht[hk], key) + } else { + ht[hk] = make(keys, 1) + ht[hk][0] = key + } + } +} + +// Index makes all the keys added searchable. +func (f *LshForest) Index() { + for i := range f.hashTables { + ht := make(hashTable, 0, len(f.initHashTables[i])) + // Build sorted hash table using buckets from init hash tables + for hashKey, keys := range f.initHashTables[i] { + ht = append(ht, bucket{ + hashKey: hashKey, + keys: keys, + }) + } + sort.Sort(ht) + f.hashTables[i] = ht + // Reset the init hash tables + f.initHashTables[i] = make(initHashTable) + } +} + +// Query returns candidate keys given the query signature and parameters. +func (f *LshForest) Query(sig []uint64, K, L int, out chan<- interface{}, done <-chan struct{}) { + if K == -1 { + K = f.k + } + if L == -1 { + L = f.l + } + prefixSize := f.hashValueSize * K + // Generate hash keys + Hs := make([]string, L) + for i := 0; i < L; i++ { + Hs[i] = f.hashKeyFunc(sig[i*f.k : i*f.k+K]) + } + seens := make(map[interface{}]bool) + for i := 0; i < L; i++ { + ht := f.hashTables[i] + hk := Hs[i] + k := sort.Search(len(ht), func(x int) bool { + return ht[x].hashKey[:prefixSize] >= hk + }) + if k < len(ht) && ht[k].hashKey[:prefixSize] == hk { + for j := k; j < len(ht) && ht[j].hashKey[:prefixSize] == hk; j++ { + for _, key := range ht[j].keys { + if _, seen := seens[key]; seen { + continue + } + seens[key] = true + select { + case out <- key: + case <-done: + return + } + } + } + } + } +} + +// OptimalKL returns the optimal K and L for containment search, +// and the false positive and negative probabilities. +// where x is the indexed domain size, q is the query domain size, +// and t is the containment threshold. +func (f *LshForest) OptimalKL(x, q int, t float64) (optK, optL int, fp, fn float64) { + minError := math.MaxFloat64 + for l := 1; l <= f.l; l++ { + for k := 1; k <= f.k; k++ { + currFp := probFalsePositiveC(x, q, l, k, t, PRECISION) + currFn := probFalseNegativeC(x, q, l, k, t, PRECISION) + currErr := currFn + currFp + if minError > currErr { + minError = currErr + optK = k + optL = l + fp = currFp + fn = currFn + } + } + } + return +} + +// optimise returns the optimal number of hash functions and the optimal number of bands for Jaccard similarity search, as well as the false positive and negative probabilities. +func optimise(sigSize int, jsThresh float64) (int, int, float64, float64) { + optimumNumHashFuncs, optimumNumBands := 0, 0 + fp, fn := 0.0, 0.0 + minError := math.MaxFloat64 + for l := 1; l <= sigSize; l++ { + for k := 1; k <= sigSize; k++ { + if l*k > sigSize { + break + } + currFp := probFalsePositive(l, k, jsThresh, PRECISION) + currFn := probFalseNegative(l, k, jsThresh, PRECISION) + currErr := currFn + currFp + if minError > currErr { + minError = currErr + optimumNumHashFuncs = k + optimumNumBands = l + fp = currFp + fn = currFn + } + } + } + return optimumNumHashFuncs, optimumNumBands, fp, fn +} diff --git a/src/lshIndex/lshIndex.go b/src/lshIndex/lshIndex.go new file mode 100644 index 0000000..736a4a3 --- /dev/null +++ b/src/lshIndex/lshIndex.go @@ -0,0 +1,96 @@ +package lshIndex + +import ( + //"errors" + "github.com/orcaman/concurrent-map" +) + +// set to 2/4/8 for 16bit/32bit/64bit hash values +const HASH_SIZE = 8 +// integration precision for optimising number of bands + hash functions in LSH Forest +const PRECISION = 0.01 +// number of partitions and maxK to use in LSH Ensemble (TODO: add these as customisable parameters for GROOT) +const PARTITIONS = 10 +const MAXK = 4 + +// error messages +//var ( + //querySizeError = errors.New("Query size is > +/- 10 bases of reference windows, re-index using --containment") +//) + +// NewLSHensemble initializes a new index consisting of MinHash LSH implemented using LshForest. +// numHash is the number of hash functions in MinHash. +// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band". +func NewLSHensemble(parts []Partition, numHash, maxK int) *LshEnsemble { + lshes := make([]*LshForest, len(parts)) + for i := range lshes { + lshes[i] = newLshForest(maxK, numHash/maxK) + } + return &LshEnsemble{ + Lshes: lshes, + Partitions: parts, + MaxK: maxK, + NumHash: numHash, + paramCache: cmap.New(), + } +} + +// NewLshForest initializes a new index consisting of MinHash LSH implemented using a single LshForest. +// sigSize is the number of hash functions in MinHash. +// jsThresh is the minimum Jaccard similarity needed for a query to return a match +func NewLSHforest(sigSize int, jsThresh float64) *LshEnsemble { + // calculate the optimal number of bands and hash functions to use + numHashFuncs, numBands, _, _ := optimise(sigSize, jsThresh) + lshes := make([]*LshForest, 1) + lshes[0] = newLshForest(numHashFuncs, numBands) + return &LshEnsemble{ + Lshes: lshes, + Partitions: make([]Partition, 1), + MaxK: numBands, + NumHash: numHashFuncs, + paramCache: cmap.New(), + SingleForest: true, + } +} + +// BoostrapLshEnsemble builds an index from a channel of domains. +// The returned index consists of MinHash LSH implemented using LshForest. +// numPart is the number of partitions to create. +// numHash is the number of hash functions in MinHash. +// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band". +// GraphWindow is a channel emitting windows (don't need to be sorted by their sizes as windows are constant) TODO: should probably add a check for this +func BootstrapLshEnsemble(numPart, numHash, maxK, totalNumWindows int, windows <-chan *GraphWindow) *LshEnsemble { + index := NewLSHensemble(make([]Partition, numPart), numHash, maxK) + bootstrap(index, totalNumWindows, windows) + return index +} + +// bootstrap +func bootstrap(index *LshEnsemble, totalNumWindows int, windows <-chan *GraphWindow) { + numPart := len(index.Partitions) + depth := totalNumWindows / numPart + var currDepth, currPart int + for rec := range windows { + index.Add(rec.Key, rec.Signature, currPart) + currDepth++ + index.Partitions[currPart].Upper = rec.Size + if currDepth >= depth && currPart < numPart-1 { + currPart++ + index.Partitions[currPart].Lower = rec.Size + currDepth = 0 + } + } + return +} + +// Windows2Chan is a utility function that converts a GraphWindow slice in memory to a GraphWindow channel. +func Windows2Chan(windows []*GraphWindow) <-chan *GraphWindow { + c := make(chan *GraphWindow, 1000) + go func() { + for _, w := range windows { + c <- w + } + close(c) + }() + return c +} \ No newline at end of file diff --git a/src/lshIndex/lshIndex_test.go b/src/lshIndex/lshIndex_test.go new file mode 100644 index 0000000..6d4e17d --- /dev/null +++ b/src/lshIndex/lshIndex_test.go @@ -0,0 +1,168 @@ +// testing is incomplete, more to be added... +package lshIndex + +import ( + "fmt" + "os" + "testing" +) + +var ( + // test graph windows + entry1 = &GraphWindow{ + Key : fmt.Sprintf("g%dn%do%d", 1, 2, 3), + Size : 100, + Signature : []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + } + entry2 = &GraphWindow{ + Key : fmt.Sprintf("g%dn%do%d", 1, 3, 1), + Size : 100, + Signature : []uint64{1, 4, 3, 4, 5, 5, 7, 4, 9, 10}, + } + entry3 = &GraphWindow{ + Key : fmt.Sprintf("g%dn%do%d", 3, 22, 2), + Size : 100, + Signature : []uint64{4, 4, 3, 4, 5, 6, 7, 4, 9, 4}, + } + entries = []*GraphWindow{entry1, entry2, entry3} + // LSH Forest parameters + jsThresh = 0.85 + // LSH Ensemble parameters + numPart = 4 + numHash = 10 + maxK = 4 + // query for LSH Forest + query1 = &GraphWindow{ + Key : fmt.Sprintf("g%dn%do%d", 1, 2, 3), + Size : 100, + Signature : []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + } + // query for LSH Ensemble + query2 = &GraphWindow{ + Key : fmt.Sprintf("g%dn%do%d", 1, 2, 3), + Size : 50, + Signature : []uint64{1, 1, 3, 4, 5, 6, 7, 8, 9, 10}, + } +) + +// test the lshForest constructor, add a record and query it +func Test_lshForestConstructor(t *testing.T) { + index := NewLSHforest(len(entry1.Signature), jsThresh) + numHF, numBucks := index.Lshes[0].Settings() + t.Logf("\tnumber of hash functions per bucket: %d\n", numHF) + t.Logf("\tnumber of buckets: %d\n", numBucks) + index.Add(entry1.Key, entry1.Signature, 0) + index.Index() + done := make(chan struct{}) + defer close(done) + var check string + for result := range index.Query(query1.Signature, query1.Size, jsThresh, done) { + check = result.(string) + if check != entry1.Key { + t.Fatal() + } + } + if check == "" { + t.Fatal("no result from LSH Forest") + } +} + +// test the lshForest constructor and add a set of records, then query +func Test_lshForestBootstrap(t *testing.T) { + index := NewLSHforest(len(entry1.Signature), jsThresh) + for _, i := range entries { + index.Add(i.Key, i.Signature, 0) + } + if len(index.Partitions) != 1 || index.SingleForest != true { + t.Fatal() + } + index.Index() + done := make(chan struct{}) + defer close(done) + var check string + for result := range index.Query(query1.Signature, query1.Size, jsThresh, done) { + check = result.(string) + if check != entry1.Key { + t.Fatal("incorrect result returned from LSH Forest") + } + } + if check == "" { + t.Fatal("no result from LSH Forest") + } +} + +// test the lshForest dump and load methods +func Test_lshForestDump(t *testing.T) { + index := NewLSHforest(len(entry1.Signature), jsThresh) + for _, i := range entries { + index.Add(i.Key, i.Signature, 0) + } + if err := index.Dump("./lsh.index"); err != nil { + t.Fatal(err) + } + index2 := NewLSHforest(len(entry1.Signature), jsThresh) + if err := index2.Load("./lsh.index"); err != nil { + t.Fatal(err) + } + if err := os.Remove("./lsh.index"); err != nil { + t.Fatal(err) + } + done := make(chan struct{}) + defer close(done) + var check string + for result := range index2.Query(query1.Signature, query1.Size, jsThresh, done) { + check = result.(string) + if check != entry1.Key { + t.Fatal(check) + } + } + if check == "" { + t.Fatal("no result from LSH Forest") + } +} + + +// test the lshEnsemble constructor, add the records and query it +func Test_lshEnsembleBootstrap(t *testing.T) { + index := BootstrapLshEnsemble(numPart, numHash, maxK, len(entries), Windows2Chan(entries)) + index.Index() + done := make(chan struct{}) + defer close(done) + var check string + for result := range index.Query(query2.Signature, query2.Size, jsThresh, done) { + check = result.(string) + if check != entry1.Key { + t.Fatal("incorrect result returned from LSH Ensemble") + } + } + if check == "" { + t.Fatal("no result from LSH ensemble") + } +} + +// test the lshEnsemble dump and load methods +func Test_lshEnsembleDump(t *testing.T) { + index := BootstrapLshEnsemble(numPart, numHash, maxK, len(entries), Windows2Chan(entries)) + if err := index.Dump("./lsh.index"); err != nil { + t.Fatal(err) + } + index2 := NewLSHensemble(make([]Partition, numPart), numHash, maxK) + if err := index2.Load("./lsh.index"); err != nil { + t.Fatal(err) + } + if err := os.Remove("./lsh.index"); err != nil { + t.Fatal(err) + } + done := make(chan struct{}) + defer close(done) + var check string + for result := range index2.Query(query2.Signature, query2.Size, jsThresh, done) { + check = result.(string) + if check != entry1.Key { + t.Fatal() + } + } + if check == "" { + t.Fatal("no result from LSH ensemble") + } +} diff --git a/src/lshIndex/probability.go b/src/lshIndex/probability.go new file mode 100644 index 0000000..a2c1e87 --- /dev/null +++ b/src/lshIndex/probability.go @@ -0,0 +1,86 @@ +// copy of https://github.com/ekzhu/lshensemble/blob/0322dae1f4d960f6fb3f9e6e2870786b9f4239ed/probability.go +package lshIndex + +import "math" + +// Compute the integral of function f, lower limit a, upper limit l, and +// precision defined as the quantize step +func integral(f func(float64) float64, a, b, precision float64) float64 { + var area float64 + for x := a; x < b; x += precision { + area += f(x+0.5*precision) * precision + } + return area +} + +/* + The following are using Jaccard similarity +*/ +// Probability density function for false positive +func falsePositive(l, k int) func(float64) float64 { + return func(j float64) float64 { + return 1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l)) + } +} + +// Probability density function for false negative +func falseNegative(l, k int) func(float64) float64 { + return func(j float64) float64 { + return 1.0 - (1.0 - math.Pow(1.0-math.Pow(j, float64(k)), float64(l))) + } +} + +// Compute the cummulative probability of false negative given threshold t +func probFalseNegative(l, k int, t, precision float64) float64 { + return integral(falseNegative(l, k), t, 1.0, precision) +} + +// Compute the cummulative probability of false positive given threshold t +func probFalsePositive(l, k int, t, precision float64) float64 { + return integral(falsePositive(l, k), 0, t, precision) +} + +/* + The following are using Jaccard containment TODO: consolidate these functions with the above +*/ +// Probability density function for false positive +func falsePositiveC(x, q, l, k int) func(float64) float64 { + return func(t float64) float64 { + return 1.0 - math.Pow(1.0-math.Pow(t/(1.0+float64(x)/float64(q)-t), float64(k)), float64(l)) + } +} + +// Probability density function for false negative +func falseNegativeC(x, q, l, k int) func(float64) float64 { + return func(t float64) float64 { + return 1.0 - (1.0 - math.Pow(1.0-math.Pow(t/(1.0+float64(x)/float64(q)-t), float64(k)), float64(l))) + } +} + +// Compute the cummulative probability of false negative +func probFalseNegativeC(x, q, l, k int, t, precision float64) float64 { + fn := falseNegativeC(x, q, l, k) + xq := float64(x) / float64(q) + if xq >= 1.0 { + return integral(fn, t, 1.0, precision) + } + if xq >= t { + return integral(fn, t, xq, precision) + } else { + return 0.0 + } +} + +// Compute the cummulative probability of false positive +func probFalsePositiveC(x, q, l, k int, t, precision float64) float64 { + fp := falsePositiveC(x, q, l, k) + xq := float64(x) / float64(q) + if xq >= 1.0 { + return integral(fp, 0.0, t, precision) + } + if xq >= t { + return integral(fp, 0.0, t, precision) + } else { + return integral(fp, 0.0, xq, precision) + } +} diff --git a/src/misc/misc.go b/src/misc/misc.go index 2a2a542..59c1417 100644 --- a/src/misc/misc.go +++ b/src/misc/misc.go @@ -85,6 +85,7 @@ type IndexInfo struct { SigSize int JSthresh float64 ReadLength int + Containment bool } // method to dump the info to file diff --git a/src/stream/stream.go b/src/stream/stream.go index 0f5e870..7eabd1e 100644 --- a/src/stream/stream.go +++ b/src/stream/stream.go @@ -12,7 +12,7 @@ import ( "github.com/biogo/hts/sam" "github.com/will-rowe/groot/src/alignment" "github.com/will-rowe/groot/src/graph" - "github.com/will-rowe/groot/src/lshForest" + "github.com/will-rowe/groot/src/lshIndex" "github.com/will-rowe/groot/src/misc" "github.com/will-rowe/groot/src/seqio" "github.com/will-rowe/groot/src/version" @@ -164,6 +164,7 @@ type FastqChecker struct { WindowSize int MinReadLength int MinQual int + Containment bool } func NewFastqChecker() *FastqChecker { @@ -206,8 +207,10 @@ func (proc *FastqChecker) Run() { meanRL := float64(lengthTotal) / float64(rawCount) log.Printf("\tmean read length: %.0f\n", meanRL) // check the length is within +/-10 bases of the graph window - if meanRL < float64(proc.WindowSize-10) || meanRL > float64(proc.WindowSize+10) { - misc.ErrorCheck(fmt.Errorf("mean read length is outside the graph window size (+/- 10 bases)\n")) + if proc.Containment == false { + if meanRL < float64(proc.WindowSize-10) || meanRL > float64(proc.WindowSize+10) { + misc.ErrorCheck(fmt.Errorf("read length is too variable (> +/- 10 bases of graph window size), try re-indexing using the --containment option\n")) + } } } @@ -218,9 +221,10 @@ type DbQuerier struct { process Input chan seqio.FASTQread Output chan seqio.FASTQread - Db *lshForest.LSHforest + Db *lshIndex.LshEnsemble CommandInfo *misc.IndexInfo GraphStore graph.GraphStore + Threshold float64 } func NewDbQuerier() *DbQuerier { @@ -247,9 +251,11 @@ func (proc *DbQuerier) Run() { } // get signature for read readMH := read.RunMinHash(proc.CommandInfo.Ksize, proc.CommandInfo.SigSize) - // query the LSH forest - for _, result := range proc.Db.Query(readMH.Signature()) { - seed := proc.Db.KeyLookup[result] + // query the LSH index + done := make(chan struct{}) + defer close(done) + for result := range proc.Db.Query(readMH.Signature(), len(read.Seq), proc.Threshold, done) { + seed := proc.Db.KeyLookup[result.(string)] seed.RC = read.RC seeds = append(seeds, seed) } diff --git a/src/version/version.go b/src/version/version.go index c2b3650..12900c6 100644 --- a/src/version/version.go +++ b/src/version/version.go @@ -1,3 +1,3 @@ package version -const VERSION = "0.7.1" +const VERSION = "0.8.0" diff --git a/testing/full-argannot-perfect-reads-small-variable-rl.fq.gz b/testing/full-argannot-perfect-reads-small-variable-rl.fq.gz new file mode 100644 index 0000000..524a8a0 Binary files /dev/null and b/testing/full-argannot-perfect-reads-small-variable-rl.fq.gz differ diff --git a/testing/run_travis_tests.sh b/testing/run_travis_tests.sh index 0e6a76a..c26094d 100644 --- a/testing/run_travis_tests.sh +++ b/testing/run_travis_tests.sh @@ -1,3 +1,5 @@ +## Due to Travis failing on runtime, I've had to change the db over from arg-annot to groot-core-db, consequently the test reads (derived from arg-annot) won't align as well + #!/bin/bash # install the software @@ -13,4 +15,4 @@ go build ./groot align -p 1 -i test-index -f testing/full-argannot-perfect-reads-small.fq.gz > out.bam # report -./groot report -i out.bam +./groot report -i out.bam \ No newline at end of file