Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
add: knowledge askdir command for adhoc-querying directories (no crea…
Browse files Browse the repository at this point in the history
…te-dataset and separate ingest and retrieve required)
  • Loading branch information
iwilltry42 committed May 8, 2024
1 parent c6aaf67 commit 5c43b7f
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 14 deletions.
1 change: 1 addition & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Client interface {
ListDatasets(ctx context.Context) ([]types.Dataset, error)
Ingest(ctx context.Context, datasetID string, data []byte, opts datastore.IngestOpts) ([]string, error)
IngestPaths(ctx context.Context, datasetID string, opts *IngestPathsOpts, paths ...string) (int, error) // returns number of files ingested
AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *RetrieveOpts) ([]vectorstore.Document, error)
DeleteDocuments(ctx context.Context, datasetID string, documentIDs ...string) error
Retrieve(ctx context.Context, datasetID string, query string, opts RetrieveOpts) ([]vectorstore.Document, error)
}
75 changes: 63 additions & 12 deletions pkg/client/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@ package client

import (
"context"
"crypto/sha1"
"encoding/hex"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/vectorstore"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"log/slog"
"os"
"path/filepath"
"slices"
)

func checkIgnored(path string, ignoreExtensions []string) bool {
ext := filepath.Ext(path)
slog.Debug("checking path", "path", path, "ext", ext, "ignore", ignoreExtensions)
for _, ie := range ignoreExtensions {
if ext == ie {
return true
}
}
return false
return slices.Contains(ignoreExtensions, ext)
}

func ingestPaths(ctx context.Context, opts *IngestPathsOpts, ingestionFunc func(path string) error, paths ...string) (int, error) {

ingestedFilesCount := 0

if opts.Concurrency < 1 {
Expand All @@ -35,11 +33,6 @@ func ingestPaths(ctx context.Context, opts *IngestPathsOpts, ingestionFunc func(
for _, p := range paths {
path := p

if checkIgnored(path, opts.IgnoreExtensions) {
slog.Debug("Skipping ingestion of file", "path", path, "reason", "extension ignored")
continue
}

fileInfo, err := os.Stat(path)
if err != nil {
return ingestedFilesCount, fmt.Errorf("failed to get file info for %s: %w", path, err)
Expand Down Expand Up @@ -81,6 +74,10 @@ func ingestPaths(ctx context.Context, opts *IngestPathsOpts, ingestionFunc func(
return ingestedFilesCount, err
}
} else {
if checkIgnored(path, opts.IgnoreExtensions) {
slog.Debug("Skipping ingestion of file", "path", path, "reason", "extension ignored")
continue
}
// Process a file directly
g.Go(func() error {
if err := sem.Acquire(ctx, 1); err != nil {
Expand All @@ -97,3 +94,57 @@ func ingestPaths(ctx context.Context, opts *IngestPathsOpts, ingestionFunc func(
// Wait for all goroutines to finish
return ingestedFilesCount, g.Wait()
}

func hashPath(path string) string {
hasher := sha1.New()
hasher.Write([]byte(path))
hashBytes := hasher.Sum(nil)
return hex.EncodeToString(hashBytes)
}

func AskDir(ctx context.Context, c Client, path string, query string, opts *IngestPathsOpts, ropts *RetrieveOpts) ([]vectorstore.Document, error) {
abspath, err := filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("failed to get absolute path from %q: %w", path, err)
}

finfo, err := os.Stat(abspath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("path %q does not exist", abspath)
}
return nil, fmt.Errorf("failed to get file info for %q: %w", abspath, err)
}
if !finfo.IsDir() {
return nil, fmt.Errorf("path %q is not a directory", abspath)
}

datasetID := hashPath(abspath)
slog.Debug("Directory Dataset ID hashed", "path", abspath, "id", datasetID)

// check if dataset exists
dataset, err := c.GetDataset(ctx, datasetID)
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetID, err)
}
if dataset == nil {
// create dataset
_, err := c.CreateDataset(ctx, datasetID)
if err != nil {
return nil, fmt.Errorf("failed to create dataset %q: %w", datasetID, err)
}
}

// ingest files
if opts == nil {
opts = &IngestPathsOpts{}
}
ingested, err := c.IngestPaths(ctx, datasetID, opts, path)
if err != nil {
return nil, fmt.Errorf("failed to ingest files: %w", err)
}
slog.Debug("Ingested files", "count", ingested, "path", abspath)

// retrieve documents
return c.Retrieve(ctx, datasetID, query, *ropts)
}
4 changes: 4 additions & 0 deletions pkg/client/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ func (c *DefaultClient) Retrieve(_ context.Context, datasetID string, query stri
return docs, nil
}

func (c *DefaultClient) AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *RetrieveOpts) ([]vectorstore.Document, error) {
return AskDir(ctx, c, path, query, opts, ropts)
}

func (c *DefaultClient) request(method, path string, body io.Reader) ([]byte, error) {
url := c.ServerURL + path

Expand Down
4 changes: 4 additions & 0 deletions pkg/client/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ func (c *StandaloneClient) DeleteDocuments(ctx context.Context, datasetID string
func (c *StandaloneClient) Retrieve(ctx context.Context, datasetID string, query string, opts RetrieveOpts) ([]vectorstore.Document, error) {
return c.Datastore.Retrieve(ctx, datasetID, types.Query{Prompt: query, TopK: z.Pointer(opts.TopK)})
}

func (c *StandaloneClient) AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *RetrieveOpts) ([]vectorstore.Document, error) {
return AskDir(ctx, c, path, query, opts, ropts)
}
61 changes: 61 additions & 0 deletions pkg/cmd/askdir.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package cmd

import (
"encoding/json"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/client"
"github.com/spf13/cobra"
"strings"
)

type ClientAskDir struct {
Client
Path string `usage:"Path to the directory to query" short:"p" default:"./knowledge"`
ClientIngestOpts
ClientRetrieveOpts
}

func (s *ClientAskDir) Customize(cmd *cobra.Command) {
cmd.Use = "askdir [--path <path>] <query>"
cmd.Short = "Retrieve sources for a query from a dataset generated from a directory"
cmd.Args = cobra.ExactArgs(1)
}

func (s *ClientAskDir) Run(cmd *cobra.Command, args []string) error {
c, err := s.getClient()
if err != nil {
return err
}

path := s.Path
query := args[0]

ingestOpts := &client.IngestPathsOpts{
IgnoreExtensions: strings.Split(s.IgnoreExtensions, ","),
Concurrency: s.Concurrency,
Recursive: s.Recursive,
}

retrieveOpts := &client.RetrieveOpts{
TopK: s.TopK,
}

sources, err := c.AskDirectory(cmd.Context(), path, query, ingestOpts, retrieveOpts)
if err != nil {
return fmt.Errorf("failed to retrieve sources: %w", err)
}

if len(sources) == 0 {
fmt.Printf("No sources found for the query %q from path %q\n", query, path)
return nil
}

jsonSources, err := json.Marshal(sources)
if err != nil {
return err
}

fmt.Printf("Retrieved the following %d sources for the query %q from path %q: %s\n", len(sources), query, path, jsonSources)

return nil
}
6 changes: 5 additions & 1 deletion pkg/cmd/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ import (

type ClientIngest struct {
Client
Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"`
Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"`
ClientIngestOpts
}

type ClientIngestOpts struct {
IgnoreExtensions string `usage:"Comma-separated list of file extensions to ignore" env:"KNOW_INGEST_IGNORE_EXTENSIONS"`
Concurrency int `usage:"Number of concurrent ingestion processes" short:"c" default:"10" env:"KNOW_INGEST_CONCURRENCY"`
Recursive bool `usage:"Recursively ingest directories" short:"r" default:"false" env:"KNOW_INGEST_RECURSIVE"`
Expand Down
9 changes: 9 additions & 0 deletions pkg/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@ package cmd
import (
"github.com/acorn-io/cmd"
"github.com/spf13/cobra"
"log/slog"
"os"
)

func init() {
if os.Getenv("DEBUG") != "" {
_ = slog.SetLogLoggerLevel(slog.LevelDebug)
}
}

func New() *cobra.Command {
return cmd.Command(
&Knowledge{},
Expand All @@ -16,6 +24,7 @@ func New() *cobra.Command {
new(ClientDeleteDataset),
new(ClientRetrieve),
new(ClientResetDatastore),
new(ClientAskDir),
)
}

Expand Down
6 changes: 5 additions & 1 deletion pkg/cmd/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ import (
type ClientRetrieve struct {
Client
Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"`
TopK int `usage:"Number of sources to retrieve" short:"k" default:"5"`
ClientRetrieveOpts
}

type ClientRetrieveOpts struct {
TopK int `usage:"Number of sources to retrieve" short:"k" default:"5"`
}

func (s *ClientRetrieve) Customize(cmd *cobra.Command) {
Expand Down

0 comments on commit 5c43b7f

Please sign in to comment.