-
Notifications
You must be signed in to change notification settings - Fork 19
/
SentenceTable.cs
98 lines (86 loc) · 3.43 KB
/
SentenceTable.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using BrightData;
using BrightData.Types;
using BrightWire;
using BrightWire.Models.Bayesian;
using BrightWire.TrainingData.Helper;
namespace ExampleCode.DataTableTrainers
{
internal class SentenceTable
{
readonly IDataTable _sentenceTable;
readonly Dictionary<string, uint> _stringIndex = [];
readonly List<string> _strings = [];
readonly uint _empty;
public SentenceTable(BrightDataContext context, IEnumerable<string[]> sentences)
{
// create an empty string to represent null
_empty = GetStringIndex("");
var builder = context.CreateTableBuilder();
builder.CreateColumn(BrightDataType.IndexList, "Sentences");
foreach(var sentence in sentences)
builder.AddRow(IndexList.Create(sentence.Select(GetStringIndex).ToArray()));
_sentenceTable = builder.BuildInMemory().Result;
}
public uint GetStringIndex(string str)
{
if (!_stringIndex.TryGetValue(str, out var ret)) {
_stringIndex.Add(str, ret = (uint)_strings.Count);
_strings.Add(str);
}
return ret;
}
(uint Index, string String) Append(uint index, StringBuilder sb)
{
var str = _strings[(int)index];
if (Char.IsLetterOrDigit(str[0]) && sb.Length > 0) {
var lastChar = sb[^1];
if (lastChar != '\'' && lastChar != '-')
sb.Append(' ');
}
sb.Append(str);
return (index, str);
}
public MarkovModel3<uint> TrainMarkovModel(bool writeResults = true)
{
// create a markov trainer that uses a window of size 3
var context = _sentenceTable.Context;
var trainer = context.CreateMarkovTrainer3(_empty);
var column = _sentenceTable.GetColumn<IndexList>(0);
foreach(var sentence in column.EnumerateAllTyped().ToBlockingEnumerable())
trainer.Add(sentence.Indices);
var ret = trainer.Build();
if (writeResults) {
foreach(var sentence in GenerateText(ret))
Console.WriteLine(sentence);
}
return ret;
}
public IEnumerable<string> GenerateText(MarkovModel3<uint> model, int count = 50)
{
var context = _sentenceTable.Context;
var table = model.AsDictionary;
for (var i = 0; i < count; i++) {
var sb = new StringBuilder();
uint prevPrev = default, prev = default, curr = default;
for (var j = 0; j < 1024; j++) {
var transitions = table.GetTransitions(prevPrev, prev, curr);
if (transitions == null)
break;
var distribution = context.CreateCategoricalDistribution(transitions.Select(d => d.Probability));
var next = Append(transitions[distribution.Sample()].NextState, sb);
if (SimpleTokeniser.IsEndOfSentence(next.String))
break;
prevPrev = prev;
prev = curr;
curr = next.Index;
}
yield return sb.ToString();
}
}
}
}