Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gibberish/gibberish.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func main() {
return
}
score := sequenceProbablity(model.Chain, *username)
normalizedScore := (score - model.Mean) / model.StdDev
normalizedScore := (score - model.Mean) / model.StdDev
isGibberish := normalizedScore < 0
fmt.Printf("Score: %f | Gibberish: %t\n", normalizedScore, isGibberish)
}
Expand Down
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ module github.com/mb-14/gomarkov

go 1.14

require github.com/montanaflynn/stats v0.6.3
require (
github.com/montanaflynn/stats v0.6.3
github.com/stretchr/testify v1.8.4
)
16 changes: 16 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/montanaflynn/stats v0.6.3 h1:F8446DrvIF5V5smZfZ8K9nrmmix0AFgevPdLruGOmzk=
github.com/montanaflynn/stats v0.6.3/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
107 changes: 83 additions & 24 deletions gomarkov.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,57 @@ import (
"time"
)

//Tokens are wrapped around a sequence of words to maintain the
//start and end transition counts
// Tokens are wrapped around a sequence of words to maintain the
// start and end transition counts
const (
StartToken = "$"
EndToken = "^"
)

//Chain is a markov chain instance
type preprocessedArray struct {
sparseArray

sum int
orderedKeys []int
}

// Chain is a markov chain instance
type Chain struct {
Order int
statePool *spool
frequencyMat map[int]sparseArray
frequencyMat map[int]preprocessedArray
lock *sync.RWMutex
}

// PRNG is a pseudo-random number generator compatible with math/rand interfaces.
type PRNG interface {
// Intn returns a number number in the half-open interval [0,n)
Intn(int) int
}

type chainJSON struct {
Order int `json:"int"`
SpoolMap map[string]int `json:"spool_map"`
FreqMat map[int]sparseArray `json:"freq_mat"`
}

//MarshalJSON ...
var defaultPrng = rand.New(rand.NewSource(time.Now().UnixNano()))

// MarshalJSON ...
func (chain Chain) MarshalJSON() ([]byte, error) {
frequencyMat := make(map[int]sparseArray, len(chain.frequencyMat))
for k, v := range chain.frequencyMat {
frequencyMat[k] = v.sparseArray
}
obj := chainJSON{
chain.Order,
chain.statePool.stringMap,
chain.frequencyMat,
frequencyMat,
}
return json.Marshal(obj)
}

//UnmarshalJSON ...
// UnmarshalJSON ...
func (chain *Chain) UnmarshalJSON(b []byte) error {
var obj chainJSON
err := json.Unmarshal(b, &obj)
Expand All @@ -56,24 +75,31 @@ func (chain *Chain) UnmarshalJSON(b []byte) error {
stringMap: obj.SpoolMap,
intMap: intMap,
}
chain.frequencyMat = obj.FreqMat
chain.frequencyMat = make(map[int]preprocessedArray, len(obj.FreqMat))
for k, v := range obj.FreqMat {
chain.frequencyMat[k] = preprocessedArray{
sparseArray: v,
sum: v.sum(),
orderedKeys: v.orderedKeys(),
}
}
chain.lock = new(sync.RWMutex)
return nil
}

//NewChain creates an instance of Chain
// NewChain creates an instance of Chain
func NewChain(order int) *Chain {
chain := Chain{Order: order}
chain.statePool = &spool{
stringMap: make(map[string]int),
intMap: make(map[int]string),
}
chain.frequencyMat = make(map[int]sparseArray, 0)
chain.frequencyMat = make(map[int]preprocessedArray)
chain.lock = new(sync.RWMutex)
return &chain
}

//Add adds the transition counts to the chain for a given sequence of words
// Add adds the transition counts to the chain for a given sequence of words
func (chain *Chain) Add(input []string) {
startTokens := array(StartToken, chain.Order)
endTokens := array(EndToken, chain.Order)
Expand All @@ -87,15 +113,23 @@ func (chain *Chain) Add(input []string) {
currentIndex := chain.statePool.add(pair.CurrentState.key())
nextIndex := chain.statePool.add(pair.NextState)
chain.lock.Lock()
if chain.frequencyMat[currentIndex] == nil {
chain.frequencyMat[currentIndex] = make(sparseArray, 0)
pa, has := chain.frequencyMat[currentIndex]
if !has {
pa = preprocessedArray{
sparseArray: make(sparseArray),
}
}
chain.frequencyMat[currentIndex][nextIndex]++
pa.sparseArray[nextIndex]++
pa.sum++
if len(pa.orderedKeys) != len(pa.sparseArray) {
pa.orderedKeys = pa.sparseArray.orderedKeys()
}
chain.frequencyMat[currentIndex] = pa
chain.lock.Unlock()
}
}

//TransitionProbability returns the transition probability between two states
// TransitionProbability returns the transition probability between two states
func (chain *Chain) TransitionProbability(next string, current NGram) (float64, error) {
if len(current) != chain.Order {
return 0, errors.New("N-gram length does not match chain order")
Expand All @@ -106,13 +140,19 @@ func (chain *Chain) TransitionProbability(next string, current NGram) (float64,
return 0, nil
}
arr := chain.frequencyMat[currentIndex]
sum := float64(arr.sum())
freq := float64(arr[nextIndex])
sum := float64(arr.sum)
freq := float64(arr.sparseArray[nextIndex])
return freq / sum, nil
}

//Generate generates new text based on an initial seed of words
// Generate generates new text based on an initial seed of words
func (chain *Chain) Generate(current NGram) (string, error) {
return chain.GenerateDeterministic(current, defaultPrng)
}

// GenerateDeterministic generates new text deterministically, based on an initial seed of words and using a specified PRNG.
// Use it for reproducibly pseudo-random results (i.e. pass the same PRNG and same state every time).
func (chain *Chain) GenerateDeterministic(current NGram, prng PRNG) (string, error) {
if len(current) != chain.Order {
return "", errors.New("N-gram length does not match chain order")
}
Expand All @@ -125,17 +165,36 @@ func (chain *Chain) Generate(current NGram) (string, error) {
return "", fmt.Errorf("Unknown ngram %v", current)
}
arr := chain.frequencyMat[currentIndex]
sum := arr.sum()
randN := rand.Intn(sum)
for i, freq := range arr {
randN := prng.Intn(arr.sum)
for _, key := range arr.orderedKeys {
freq := arr.sparseArray[key]
randN -= freq
if randN <= 0 {
return chain.statePool.intMap[i], nil
return chain.statePool.intMap[key], nil
}
}
return "", nil
}

func init() {
rand.Seed(time.Now().UnixNano())
// GenerateAll generates whole chain of text from scratch.
func (chain *Chain) GenerateAll() ([]string, error) {
generatedText := []string{}
current := make(NGram, 0)
for i := 0; i < chain.Order; i++ {
current = append(current, StartToken)
}

for {
next, err := chain.Generate(current)
if err != nil {
return []string{}, err
}
if next == EndToken {
break
}

current = append(current, next)[1:]
generatedText = append(generatedText, next)
}
return generatedText, nil
}
55 changes: 55 additions & 0 deletions gomarkov_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package gomarkov

import (
"encoding/json"
"io/ioutil"
"math/rand"
"reflect"
"testing"

"github.com/stretchr/testify/require"
)

func TestChain_MarshalJSON(t *testing.T) {
Expand Down Expand Up @@ -240,3 +245,53 @@ func TestChain_Generate(t *testing.T) {
})
}
}

func TestChain_GenerateDeterministic(t *testing.T) {
chain := NewChain(2)
chain.Add(NGram{"i", "like", "bees"})
chain.Add(NGram{"i", "like", "cake"})
chain.Add(NGram{"i", "like", "pizza"})
chain.Add(NGram{"i", "like", "tacos"})

pairs := map[int64]string{
0: "cake",
1: "bees",
10: "cake",
100: "pizza",
1000: "bees",
}
for seed, expected := range pairs {
for i := 0; i < 16; i++ {
prng := rand.New(rand.NewSource(seed))
got, err := chain.GenerateDeterministic(NGram{"i", "like"}, prng)
if err != nil {
panic(err) // you wrote a bad test and should feel bad
}
if got != expected {
t.Errorf("Chain.GenerateDeterministic() is not deterministic; seed = %d, got = %q, want %q", seed, got, expected)
break
}
}
}
}

func BenchmarkChain_GenerateDeterministic(b *testing.B) {
data, err := ioutil.ReadFile("test_model.json")
require.NoError(b, err)
var chain Chain
require.NoError(b, json.Unmarshal(data, &chain))
b.ResetTimer()
const seed = 100
for i := 0; i < b.N; i++ {
prng := rand.New(rand.NewSource(seed))
tokens := []string{StartToken}
for count := 0; count <= 100; count++ {
next, err := chain.GenerateDeterministic(tokens, prng)
require.NoError(b, err)
if next == EndToken {
next = StartToken
}
tokens = []string{next}
}
}
}
20 changes: 16 additions & 4 deletions helpers.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package gomarkov

import "strings"
import (
"sort"
"strings"
)

//Pair is a pair of consecutive states in a sequece
// Pair is a pair of consecutive states in a sequece
type Pair struct {
CurrentState NGram // n = order of the chain
NextState string // n = 1
}

//NGram is a array of words
// NGram is a array of words
type NGram []string

type sparseArray map[int]int
Expand All @@ -17,6 +20,15 @@ func (ngram NGram) key() string {
return strings.Join(ngram, "_")
}

func (s sparseArray) orderedKeys() []int {
keys := make([]int, 0, len(s))
for k := range s {
keys = append(keys, k)
}
sort.Ints(keys)
return keys
}

func (s sparseArray) sum() int {
sum := 0
for _, count := range s {
Expand All @@ -40,7 +52,7 @@ func array(value string, count int) []string {
return arr
}

//MakePairs generates n-gram pairs of consecutive states in a sequence
// MakePairs generates n-gram pairs of consecutive states in a sequence
func MakePairs(tokens []string, order int) []Pair {
var pairs []Pair
for i := 0; i < len(tokens)-order; i++ {
Expand Down