diff --git a/examples/gibberish/gibberish.go b/examples/gibberish/gibberish.go index b4e3d54..90a2270 100644 --- a/examples/gibberish/gibberish.go +++ b/examples/gibberish/gibberish.go @@ -40,7 +40,7 @@ func main() { return } score := sequenceProbablity(model.Chain, *username) - normalizedScore := (score - model.Mean) / model.StdDev + normalizedScore := (score - model.Mean) / model.StdDev isGibberish := normalizedScore < 0 fmt.Printf("Score: %f | Gibberish: %t\n", normalizedScore, isGibberish) } diff --git a/go.mod b/go.mod index 874731f..5703706 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/mb-14/gomarkov go 1.14 -require github.com/montanaflynn/stats v0.6.3 +require ( + github.com/montanaflynn/stats v0.6.3 + github.com/stretchr/testify v1.8.4 +) diff --git a/go.sum b/go.sum index bad6a1b..b012c43 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,18 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/montanaflynn/stats v0.6.3 h1:F8446DrvIF5V5smZfZ8K9nrmmix0AFgevPdLruGOmzk= github.com/montanaflynn/stats v0.6.3/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gomarkov.go b/gomarkov.go index 5384949..0787c62 100644 --- a/gomarkov.go +++ b/gomarkov.go @@ -9,38 +9,57 @@ import ( "time" ) -//Tokens are wrapped around a sequence of words to maintain the -//start and end transition counts +// Tokens are wrapped around a sequence of words to maintain the +// start and end transition counts const ( StartToken = "$" EndToken = "^" ) -//Chain is a markov chain instance +type preprocessedArray struct { + sparseArray + + sum int + orderedKeys []int +} + +// Chain is a markov chain instance type Chain struct { Order int statePool *spool - frequencyMat map[int]sparseArray + frequencyMat map[int]preprocessedArray lock *sync.RWMutex } +// PRNG is a pseudo-random number generator compatible with math/rand interfaces. +type PRNG interface { + // Intn returns a number number in the half-open interval [0,n) + Intn(int) int +} + type chainJSON struct { Order int `json:"int"` SpoolMap map[string]int `json:"spool_map"` FreqMat map[int]sparseArray `json:"freq_mat"` } -//MarshalJSON ... +var defaultPrng = rand.New(rand.NewSource(time.Now().UnixNano())) + +// MarshalJSON ... func (chain Chain) MarshalJSON() ([]byte, error) { + frequencyMat := make(map[int]sparseArray, len(chain.frequencyMat)) + for k, v := range chain.frequencyMat { + frequencyMat[k] = v.sparseArray + } obj := chainJSON{ chain.Order, chain.statePool.stringMap, - chain.frequencyMat, + frequencyMat, } return json.Marshal(obj) } -//UnmarshalJSON ... +// UnmarshalJSON ... func (chain *Chain) UnmarshalJSON(b []byte) error { var obj chainJSON err := json.Unmarshal(b, &obj) @@ -56,24 +75,31 @@ func (chain *Chain) UnmarshalJSON(b []byte) error { stringMap: obj.SpoolMap, intMap: intMap, } - chain.frequencyMat = obj.FreqMat + chain.frequencyMat = make(map[int]preprocessedArray, len(obj.FreqMat)) + for k, v := range obj.FreqMat { + chain.frequencyMat[k] = preprocessedArray{ + sparseArray: v, + sum: v.sum(), + orderedKeys: v.orderedKeys(), + } + } chain.lock = new(sync.RWMutex) return nil } -//NewChain creates an instance of Chain +// NewChain creates an instance of Chain func NewChain(order int) *Chain { chain := Chain{Order: order} chain.statePool = &spool{ stringMap: make(map[string]int), intMap: make(map[int]string), } - chain.frequencyMat = make(map[int]sparseArray, 0) + chain.frequencyMat = make(map[int]preprocessedArray) chain.lock = new(sync.RWMutex) return &chain } -//Add adds the transition counts to the chain for a given sequence of words +// Add adds the transition counts to the chain for a given sequence of words func (chain *Chain) Add(input []string) { startTokens := array(StartToken, chain.Order) endTokens := array(EndToken, chain.Order) @@ -87,15 +113,23 @@ func (chain *Chain) Add(input []string) { currentIndex := chain.statePool.add(pair.CurrentState.key()) nextIndex := chain.statePool.add(pair.NextState) chain.lock.Lock() - if chain.frequencyMat[currentIndex] == nil { - chain.frequencyMat[currentIndex] = make(sparseArray, 0) + pa, has := chain.frequencyMat[currentIndex] + if !has { + pa = preprocessedArray{ + sparseArray: make(sparseArray), + } } - chain.frequencyMat[currentIndex][nextIndex]++ + pa.sparseArray[nextIndex]++ + pa.sum++ + if len(pa.orderedKeys) != len(pa.sparseArray) { + pa.orderedKeys = pa.sparseArray.orderedKeys() + } + chain.frequencyMat[currentIndex] = pa chain.lock.Unlock() } } -//TransitionProbability returns the transition probability between two states +// TransitionProbability returns the transition probability between two states func (chain *Chain) TransitionProbability(next string, current NGram) (float64, error) { if len(current) != chain.Order { return 0, errors.New("N-gram length does not match chain order") @@ -106,13 +140,19 @@ func (chain *Chain) TransitionProbability(next string, current NGram) (float64, return 0, nil } arr := chain.frequencyMat[currentIndex] - sum := float64(arr.sum()) - freq := float64(arr[nextIndex]) + sum := float64(arr.sum) + freq := float64(arr.sparseArray[nextIndex]) return freq / sum, nil } -//Generate generates new text based on an initial seed of words +// Generate generates new text based on an initial seed of words func (chain *Chain) Generate(current NGram) (string, error) { + return chain.GenerateDeterministic(current, defaultPrng) +} + +// GenerateDeterministic generates new text deterministically, based on an initial seed of words and using a specified PRNG. +// Use it for reproducibly pseudo-random results (i.e. pass the same PRNG and same state every time). +func (chain *Chain) GenerateDeterministic(current NGram, prng PRNG) (string, error) { if len(current) != chain.Order { return "", errors.New("N-gram length does not match chain order") } @@ -125,17 +165,36 @@ func (chain *Chain) Generate(current NGram) (string, error) { return "", fmt.Errorf("Unknown ngram %v", current) } arr := chain.frequencyMat[currentIndex] - sum := arr.sum() - randN := rand.Intn(sum) - for i, freq := range arr { + randN := prng.Intn(arr.sum) + for _, key := range arr.orderedKeys { + freq := arr.sparseArray[key] randN -= freq if randN <= 0 { - return chain.statePool.intMap[i], nil + return chain.statePool.intMap[key], nil } } return "", nil } -func init() { - rand.Seed(time.Now().UnixNano()) +// GenerateAll generates whole chain of text from scratch. +func (chain *Chain) GenerateAll() ([]string, error) { + generatedText := []string{} + current := make(NGram, 0) + for i := 0; i < chain.Order; i++ { + current = append(current, StartToken) + } + + for { + next, err := chain.Generate(current) + if err != nil { + return []string{}, err + } + if next == EndToken { + break + } + + current = append(current, next)[1:] + generatedText = append(generatedText, next) + } + return generatedText, nil } diff --git a/gomarkov_test.go b/gomarkov_test.go index 35b94a6..396ec14 100644 --- a/gomarkov_test.go +++ b/gomarkov_test.go @@ -1,8 +1,13 @@ package gomarkov import ( + "encoding/json" + "io/ioutil" + "math/rand" "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestChain_MarshalJSON(t *testing.T) { @@ -240,3 +245,53 @@ func TestChain_Generate(t *testing.T) { }) } } + +func TestChain_GenerateDeterministic(t *testing.T) { + chain := NewChain(2) + chain.Add(NGram{"i", "like", "bees"}) + chain.Add(NGram{"i", "like", "cake"}) + chain.Add(NGram{"i", "like", "pizza"}) + chain.Add(NGram{"i", "like", "tacos"}) + + pairs := map[int64]string{ + 0: "cake", + 1: "bees", + 10: "cake", + 100: "pizza", + 1000: "bees", + } + for seed, expected := range pairs { + for i := 0; i < 16; i++ { + prng := rand.New(rand.NewSource(seed)) + got, err := chain.GenerateDeterministic(NGram{"i", "like"}, prng) + if err != nil { + panic(err) // you wrote a bad test and should feel bad + } + if got != expected { + t.Errorf("Chain.GenerateDeterministic() is not deterministic; seed = %d, got = %q, want %q", seed, got, expected) + break + } + } + } +} + +func BenchmarkChain_GenerateDeterministic(b *testing.B) { + data, err := ioutil.ReadFile("test_model.json") + require.NoError(b, err) + var chain Chain + require.NoError(b, json.Unmarshal(data, &chain)) + b.ResetTimer() + const seed = 100 + for i := 0; i < b.N; i++ { + prng := rand.New(rand.NewSource(seed)) + tokens := []string{StartToken} + for count := 0; count <= 100; count++ { + next, err := chain.GenerateDeterministic(tokens, prng) + require.NoError(b, err) + if next == EndToken { + next = StartToken + } + tokens = []string{next} + } + } +} diff --git a/helpers.go b/helpers.go index 644faac..7c18840 100644 --- a/helpers.go +++ b/helpers.go @@ -1,14 +1,17 @@ package gomarkov -import "strings" +import ( + "sort" + "strings" +) -//Pair is a pair of consecutive states in a sequece +// Pair is a pair of consecutive states in a sequece type Pair struct { CurrentState NGram // n = order of the chain NextState string // n = 1 } -//NGram is a array of words +// NGram is a array of words type NGram []string type sparseArray map[int]int @@ -17,6 +20,15 @@ func (ngram NGram) key() string { return strings.Join(ngram, "_") } +func (s sparseArray) orderedKeys() []int { + keys := make([]int, 0, len(s)) + for k := range s { + keys = append(keys, k) + } + sort.Ints(keys) + return keys +} + func (s sparseArray) sum() int { sum := 0 for _, count := range s { @@ -40,7 +52,7 @@ func array(value string, count int) []string { return arr } -//MakePairs generates n-gram pairs of consecutive states in a sequence +// MakePairs generates n-gram pairs of consecutive states in a sequence func MakePairs(tokens []string, order int) []Pair { var pairs []Pair for i := 0; i < len(tokens)-order; i++ {