Skip to content

Commit ef75c8b

Browse files
committed
brain/braintest: simplify test brain
For #94.
1 parent a35e7cb commit ef75c8b

File tree

1 file changed

+35
-36
lines changed

1 file changed

+35
-36
lines changed

brain/braintest/braintest_test.go

+35-36
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,18 @@ import (
1111

1212
"github.com/zephyrtronium/robot/brain"
1313
"github.com/zephyrtronium/robot/brain/braintest"
14-
"github.com/zephyrtronium/robot/userhash"
1514
)
1615

1716
// membrain is an implementation of braintest.Interface using in-memory maps
1817
// to verify that the integration tests test the correct things.
1918
type membrain struct {
20-
mu sync.Mutex
21-
tups map[string]map[string][][2]string // map of tags to map of prefixes to id and suffix
22-
users map[userhash.Hash][][2]string // map of hashes to tag and id
23-
tms map[string]map[int64][]string // map of tags to map of timestamps to ids
19+
mu sync.Mutex
20+
tups map[string]*memtag
21+
}
22+
23+
type memtag struct {
24+
tups map[string][][2]string // map of prefixes to id and suffix
25+
forgort map[string]bool // set of forgorten ids
2426
}
2527

2628
var _ brain.Interface = (*membrain)(nil)
@@ -30,46 +32,25 @@ func (m *membrain) Learn(ctx context.Context, tag string, msg *brain.Message, tu
3032
defer m.mu.Unlock()
3133
if m.tups[tag] == nil {
3234
if m.tups == nil {
33-
m.tups = make(map[string]map[string][][2]string)
34-
m.users = make(map[userhash.Hash][][2]string)
35-
m.tms = make(map[string]map[int64][]string)
35+
m.tups = make(map[string]*memtag)
3636
}
37-
m.tups[tag] = make(map[string][][2]string)
38-
m.tms[tag] = make(map[int64][]string)
37+
m.tups[tag] = &memtag{tups: make(map[string][][2]string), forgort: make(map[string]bool)}
3938
}
40-
m.users[msg.Sender] = append(m.users[msg.Sender], [2]string{tag, msg.ID})
41-
tms := m.tms[tag]
42-
tms[msg.Timestamp] = append(tms[msg.Timestamp], msg.ID)
4339
r := m.tups[tag]
4440
for _, tup := range tuples {
4541
p := strings.Join(tup.Prefix, "\xff")
46-
r[p] = append(r[p], [2]string{msg.ID, tup.Suffix})
42+
r.tups[p] = append(r.tups[p], [2]string{msg.ID, tup.Suffix})
4743
}
4844
return nil
4945
}
5046

51-
func (m *membrain) forgetIDLocked(tag, id string) {
52-
for p, u := range m.tups[tag] {
53-
for len(u) > 0 {
54-
k := slices.IndexFunc(u, func(v [2]string) bool { return v[0] == id })
55-
if k < 0 {
56-
break
57-
}
58-
u[k], u[len(u)-1] = u[len(u)-1], u[k]
59-
u = u[:len(u)-1]
60-
}
61-
if len(u) != 0 {
62-
m.tups[tag][p] = u
63-
} else {
64-
delete(m.tups[tag], p)
65-
}
66-
}
67-
}
68-
6947
func (m *membrain) Forget(ctx context.Context, tag, id string) error {
7048
m.mu.Lock()
7149
defer m.mu.Unlock()
72-
m.forgetIDLocked(tag, id)
50+
if m.tups[tag] == nil {
51+
m.tups[tag] = &memtag{tups: make(map[string][][2]string), forgort: make(map[string]bool)}
52+
}
53+
m.tups[tag].forgort[id] = true
7354
return nil
7455
}
7556

@@ -78,7 +59,9 @@ func (m *membrain) Recall(ctx context.Context, tag string, page string, out []br
7859
}
7960

8061
// Think implements brain.Interface.
81-
func (m *membrain) Think(ctx context.Context, tag string, prefix []string) iter.Seq[func(id *[]byte, suf *[]byte) error] {
62+
func (m *membrain) Think(ctx context.Context, tag string, prompt []string) iter.Seq[func(id *[]byte, suf *[]byte) error] {
63+
m.mu.Lock()
64+
defer m.mu.Unlock()
8265
panic("unimplemented")
8366
}
8467

@@ -87,7 +70,15 @@ func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w *br
8770
defer m.mu.Unlock()
8871
var s string
8972
if len(prompt) == 0 {
90-
u := m.tups[tag][""]
73+
u := slices.Clone(m.tups[tag].tups[""])
74+
d := 0
75+
for k, v := range u {
76+
if m.tups[tag].forgort[v[0]] {
77+
u[d], u[k] = u[k], u[d]
78+
d++
79+
}
80+
}
81+
u = u[d:]
9182
if len(u) == 0 {
9283
return nil
9384
}
@@ -98,7 +89,15 @@ func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w *br
9889
s = brain.ReduceEntropy(prompt[len(prompt)-1])
9990
}
10091
for range 256 {
101-
u := m.tups[tag][s]
92+
u := slices.Clone(m.tups[tag].tups[s])
93+
d := 0
94+
for k, v := range u {
95+
if m.tups[tag].forgort[v[0]] {
96+
u[d], u[k] = u[k], u[d]
97+
d++
98+
}
99+
}
100+
u = u[d:]
102101
if len(u) == 0 {
103102
break
104103
}

0 commit comments

Comments
 (0)