Skip to content

Commit

Permalink
add import for samples and dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Dec 29, 2024
1 parent 44e0d72 commit 4d6a232
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 15 deletions.
65 changes: 65 additions & 0 deletions app/storage/dictionary.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package storage

import (
"bufio"
"context"
"fmt"
"io"
"iter"
"sync"

Expand Down Expand Up @@ -161,6 +163,69 @@ func (d *Dictionary) Iterator(ctx context.Context, t DictionaryType) (iter.Seq[s
}, nil
}

// Import reads phrases from the reader and imports them into the storage.
// If withCleanup is true removes all entries with the same type before import.
func (d *Dictionary) Import(ctx context.Context, t DictionaryType, r io.Reader, withCleanup bool) (*DictionaryStats, error) {
if err := t.Validate(); err != nil {
return nil, err
}
if r == nil {
return nil, fmt.Errorf("reader cannot be nil")
}

d.lock.Lock()

// start transaction
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
d.lock.Unlock()
return nil, fmt.Errorf("failed to start transaction: %w", err)
}
defer tx.Rollback()

// remove all entries with the same type if requested
if withCleanup {
if _, err = tx.ExecContext(ctx, `DELETE FROM dictionary WHERE type = ?`, t); err != nil {
d.lock.Unlock()
return nil, fmt.Errorf("failed to remove old entries: %w", err)
}
}

// add entries, using INSERT OR REPLACE to handle duplicates
insertStmt, err := tx.PrepareContext(ctx, `INSERT OR REPLACE INTO dictionary (type, data) VALUES (?, ?)`)
if err != nil {
d.lock.Unlock()
return nil, fmt.Errorf("failed to prepare insert statement: %w", err)
}
defer insertStmt.Close()
scanner := bufio.NewScanner(r)
for scanner.Scan() {
data := scanner.Text()
if data == "" { // skip empty lines
continue
}
if _, err = insertStmt.ExecContext(ctx, t, data); err != nil {
d.lock.Unlock()
return nil, fmt.Errorf("failed to add entry: %w", err)
}
}

// check for scanner errors after the scan is complete
if err = scanner.Err(); err != nil {
d.lock.Unlock()
return nil, fmt.Errorf("error reading input: %w", err)
}

if err = tx.Commit(); err != nil {
d.lock.Unlock()
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}

d.lock.Unlock() // release the lock before getting stats

return d.Stats(ctx)
}

// String implements Stringer interface
func (t DictionaryType) String() string { return string(t) }

Expand Down
119 changes: 119 additions & 0 deletions app/storage/dictionary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package storage
import (
"context"
"fmt"
"strings"
"sync"
"testing"

Expand Down Expand Up @@ -417,3 +418,121 @@ func TestDictionary_Concurrent(t *testing.T) {
actualTotal := stats.TotalStopPhrases + stats.TotalIgnoredWords
assert.Equal(t, expectedTotal, actualTotal, "expected %d total phrases, got %d", expectedTotal, actualTotal)
}

func TestDictionary_Import(t *testing.T) {
db, teardown := setupTestDB(t)
defer teardown()
d, err := NewDictionary(db)
require.NoError(t, err)
ctx := context.Background()

t.Run("basic import with cleanup", func(t *testing.T) {
input := strings.NewReader("phrase1\nphrase2\nphrase3")
stats, err := d.Import(ctx, DictionaryTypeStopPhrase, input, true)
require.NoError(t, err)
require.NotNil(t, stats)

phrases, err := d.Read(ctx, DictionaryTypeStopPhrase)
require.NoError(t, err)
assert.Equal(t, 3, len(phrases))
assert.Equal(t, 3, stats.TotalStopPhrases)
})

t.Run("import without cleanup should append", func(t *testing.T) {
// first import
input1 := strings.NewReader("existing1\nexisting2")
_, err := d.Import(ctx, DictionaryTypeIgnoredWord, input1, true)
require.NoError(t, err)

// second import without cleanup should append
input2 := strings.NewReader("new1\nnew2")
stats, err := d.Import(ctx, DictionaryTypeIgnoredWord, input2, false)
require.NoError(t, err)
require.NotNil(t, stats)

phrases, err := d.Read(ctx, DictionaryTypeIgnoredWord)
require.NoError(t, err)
assert.Equal(t, 4, len(phrases))
assert.Equal(t, 4, stats.TotalIgnoredWords)
})

t.Run("import with cleanup should replace", func(t *testing.T) {
// first import
input1 := strings.NewReader("old1\nold2\nold3")
_, err := d.Import(ctx, DictionaryTypeStopPhrase, input1, true)
require.NoError(t, err)

// second import with cleanup should replace
input2 := strings.NewReader("new1\nnew2")
stats, err := d.Import(ctx, DictionaryTypeStopPhrase, input2, true)
require.NoError(t, err)
require.NotNil(t, stats)

phrases, err := d.Read(ctx, DictionaryTypeStopPhrase)
require.NoError(t, err)
assert.Equal(t, 2, len(phrases))
assert.Equal(t, 2, stats.TotalStopPhrases)
assert.ElementsMatch(t, []string{"new1", "new2"}, phrases)
})

t.Run("different types preserve independence", func(t *testing.T) {
// import stop phrases
inputStop := strings.NewReader("stop1\nstop2")
_, err := d.Import(ctx, DictionaryTypeStopPhrase, inputStop, true)
require.NoError(t, err)

// import ignored words
inputIgnored := strings.NewReader("ignored1\nignored2\nignored3")
stats, err := d.Import(ctx, DictionaryTypeIgnoredWord, inputIgnored, true)
require.NoError(t, err)
require.NotNil(t, stats)

stopPhrases, err := d.Read(ctx, DictionaryTypeStopPhrase)
require.NoError(t, err)
assert.Equal(t, 2, len(stopPhrases))

ignoredWords, err := d.Read(ctx, DictionaryTypeIgnoredWord)
require.NoError(t, err)
assert.Equal(t, 3, len(ignoredWords))
})

t.Run("invalid type", func(t *testing.T) {
input := strings.NewReader("phrase")
_, err := d.Import(ctx, "invalid", input, true)
assert.Error(t, err)
})

t.Run("empty input", func(t *testing.T) {
input := strings.NewReader("")
stats, err := d.Import(ctx, DictionaryTypeStopPhrase, input, true)
require.NoError(t, err)
require.NotNil(t, stats)

phrases, err := d.Read(ctx, DictionaryTypeStopPhrase)
require.NoError(t, err)
assert.Empty(t, phrases)
})

t.Run("input with empty lines", func(t *testing.T) {
input := strings.NewReader("phrase1\n\n\nphrase2\n\n")
stats, err := d.Import(ctx, DictionaryTypeStopPhrase, input, true)
require.NoError(t, err)
require.NotNil(t, stats)

phrases, err := d.Read(ctx, DictionaryTypeStopPhrase)
require.NoError(t, err)
assert.Equal(t, 2, len(phrases))
assert.ElementsMatch(t, []string{"phrase1", "phrase2"}, phrases)
})

t.Run("nil reader", func(t *testing.T) {
_, err := d.Import(ctx, DictionaryTypeStopPhrase, nil, true)
assert.Error(t, err)
})

t.Run("reader error", func(t *testing.T) {
errReader := &errorReader{err: fmt.Errorf("read error")}
_, err := d.Import(ctx, DictionaryTypeStopPhrase, errReader, true)
assert.Error(t, err)
})
}
84 changes: 71 additions & 13 deletions app/storage/samples.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package storage

import (
"bufio"
"context"
"fmt"
"io"
"iter"
"sync"

Expand Down Expand Up @@ -58,7 +60,7 @@ func NewSamples(db *sqlx.DB) (*Samples, error) {
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
type TEXT CHECK (type IN ('ham', 'spam')),
origin TEXT CHECK (origin IN ('preset', 'user')),
message TEXT NOT NULL
message NOT NULL UNIQUE
);
CREATE INDEX IF NOT EXISTS idx_samples_timestamp ON samples(timestamp);
CREATE INDEX IF NOT EXISTS idx_samples_type ON samples(type);
Expand Down Expand Up @@ -101,18 +103,8 @@ func (s *Samples) Add(ctx context.Context, t SampleType, o SampleOrigin, message
}
defer tx.Rollback()

// check if sample already exists
var count int
query := `SELECT COUNT(*) FROM samples WHERE type = ? AND origin = ? AND message = ?`
if err = tx.QueryRowContext(ctx, query, t, o, message).Scan(&count); err != nil {
return fmt.Errorf("failed to check if sample exists: %w", err)
}
if count > 0 {
return nil // sample already exists, silently return
}

// add new sample
query = `INSERT INTO samples (type, origin, message) VALUES (?, ?, ?)`
// add new sample, replace if exists
query := `INSERT OR REPLACE INTO samples (type, origin, message) VALUES (?, ?, ?)`
if _, err = tx.ExecContext(ctx, query, t, o, message); err != nil {
return fmt.Errorf("failed to add sample: %w", err)
}
Expand Down Expand Up @@ -217,6 +209,72 @@ func (s *Samples) Iterator(ctx context.Context, t SampleType, o SampleOrigin) (i
}, nil
}

// Import reads samples from the reader and imports them into the storage.
// Returns statistics about imported samples.
// If withCleanup is true removes all samples with the same type and origin before import.
func (s *Samples) Import(ctx context.Context, t SampleType, o SampleOrigin, r io.Reader, withCleanup bool) (*SamplesStats, error) {
if err := t.Validate(); err != nil {
return nil, err
}
if err := o.Validate(); err != nil {
return nil, err
}
if o == SampleOriginAny {
return nil, fmt.Errorf("can't import samples with origin 'any'")
}
if r == nil {
return nil, fmt.Errorf("reader cannot be nil")
}

s.lock.Lock()

// start transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
s.lock.Unlock()
return nil, fmt.Errorf("failed to start transaction: %w", err)
}
defer tx.Rollback()

// remove all samples with the same type and origin if requested
if withCleanup {
query := `DELETE FROM samples WHERE type = ? AND origin = ?`
if _, err = tx.ExecContext(ctx, query, t, o); err != nil {
s.lock.Unlock()
return nil, fmt.Errorf("failed to remove old samples: %w", err)
}
}

// add samples
query := `INSERT OR REPLACE INTO samples (type, origin, message) VALUES (?, ?, ?)`
scanner := bufio.NewScanner(r)
for scanner.Scan() {
message := scanner.Text()
if message == "" { // skip empty lines
continue
}
if _, err = tx.ExecContext(ctx, query, t, o, message); err != nil {
s.lock.Unlock()
return nil, fmt.Errorf("failed to add sample: %w", err)
}
}

// check for scanner errors after the scan is complete
if err = scanner.Err(); err != nil {
s.lock.Unlock()
return nil, fmt.Errorf("error reading input: %w", err)
}

if err = tx.Commit(); err != nil {
s.lock.Unlock()
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}

s.lock.Unlock() // release the lock before getting stats

return s.Stats(ctx)
}

// String implements Stringer interface
func (t SampleType) String() string { return string(t) }

Expand Down
Loading

0 comments on commit 4d6a232

Please sign in to comment.