1
- module SolomonoffInduction where
2
1
{-# LANGUAGE ConstraintKinds #-}
3
- {-# LANGUAGE GeneralizedNewtypeDeriving #-}
2
+ module SolomonoffInduction ( solomonoffInduction ) where
4
3
import Prelude hiding (Real )
5
4
import Control.Applicative
6
5
import System.Random (randomRIO )
7
6
7
+ -- Helper functions
8
8
ifM :: Monad m => m Bool -> m a -> m a -> m a
9
9
ifM p t e = p >>= \ x -> if x then t else e
10
10
11
- data Bit = Zero | One deriving (Eq , Ord , Show , Enum )
12
-
13
- toBool :: Bit -> Bool
14
- toBool One = True
15
- toBool Zero = False
16
-
17
- fromBool :: Bool -> Bit
18
- fromBool True = One
19
- fromBool False = Zero
20
-
21
- newtype Machine = M Integer
22
-
23
- len :: Machine -> Integer
24
- len (M i) = ceiling (logBase 2 (fromIntegral i) :: Double )
25
-
26
- allMachines :: Stream Machine
27
- allMachines = M <$> makeStream (+ 1 ) 0
28
-
29
- machineEncodingReal :: POM m => Real m -> Machine
30
- machineEncodingReal = undefined
31
-
11
+ -- An infinite list datatype.
12
+ -- We could just use lists, but it's nice to avoid the [] cases.
13
+ -- Streams of nested intervals will be used to represent real numbers.
32
14
data Stream a = a :! Stream a
15
+
33
16
instance Functor Stream where
34
17
fmap f (x :! xs) = f x :! fmap f xs
18
+
35
19
instance Applicative Stream where
36
20
pure x = x :! pure x
37
21
(f :! fs) <*> (x :! xs) = f x :! (fs <*> xs)
@@ -47,129 +31,182 @@ streamTake :: Int -> Stream a -> [a]
47
31
streamTake 0 _ = []
48
32
streamTake n (x:! xs) = x : streamTake (n- 1 ) xs
49
33
50
- streamFind :: (a -> Bool ) -> Stream a -> a
51
- streamFind f (x:! xs) = if f x then x else streamFind f xs
52
-
53
34
streamZipWith :: (a -> b -> c ) -> Stream a -> Stream b -> Stream c
54
35
streamZipWith f (x :! xs) (y :! ys) = f x y :! streamZipWith f xs ys
55
36
56
37
streamZip :: Stream a -> Stream b -> Stream (a , b )
57
38
streamZip = streamZipWith (,)
58
39
59
- first :: Stream (Maybe a ) -> a
60
- first (Nothing :! xs) = first xs
61
- first (Just x:! _) = x
40
+ -- A bit datatype.
41
+ -- We could just use Bool, but it's nice to leverage the type system where possible.
42
+ data Bit = Zero | One deriving (Eq , Ord , Show , Enum )
43
+ fromBool :: Bool -> Bit
44
+ fromBool True = One
45
+ fromBool False = Zero
62
46
63
- type Bounds = (Rational , Rational )
47
+ -- A Machine datatype.
48
+ -- This will be used to encode probabilistic oracle machines.
49
+ -- Assume any machine can be encoded as an integer.
50
+ newtype Machine = M Integer
64
51
65
- compareBounds :: Bounds -> Bounds -> Maybe Ordering
66
- compareBounds (a, b) (c, d)
67
- | b > c = Just GT
68
- | a < d = Just LT
69
- | otherwise = Nothing
52
+ -- This will be used to define a simplicity prior.
53
+ -- We should be careful to pick an encoding of machines such that the sum of
54
+ -- 2^(- len m) for all m sums to 1. Right now we won't worry too much about that.
55
+ len :: Machine -> Integer
56
+ len ( M i) = ceiling ( logBase 2 ( fromIntegral i) :: Double )
70
57
58
+ -- TODO: this violates the condition that sum [2^(- len m) | m <-
59
+ -- allMachines] == 1 assumption.
60
+ allMachines :: Stream Machine
61
+ allMachines = M <$> makeStream (+ 1 ) 0
62
+
63
+ -- Probabilistic oracle machines.
64
+ -- Remember, these are Turing machines that can flip coins and call oracles.
65
+ -- We will consider oracle than answer questions of the form
66
+ -- "Is the probability that M(bits) == 1 > p?", where M is a machine, bits is
67
+ -- a finite bitstring used as input, and p is a rational probability,
68
+
69
+ -- It may be somewhat difficult (read: uncomputable) to implement a reflective
70
+ -- oracle, but you can implement other "wrong" oracles if you want to test the
71
+ -- code, as seen below.
71
72
class OracleMachine m where
72
73
oracle :: Machine -> [Bit ] -> Rational -> m Bit
73
74
75
+ newtype OptimisticOracle a = OO a
76
+ instance OracleMachine OptimisticOracle where
77
+ oracle _ _ _ = OO One
78
+
79
+ newtype PessimisticOracle a = PO a
80
+ instance OracleMachine PessimisticOracle where
81
+ oracle _ _ _ = PO Zero
82
+
74
83
class ProbabilisticMachine m where
75
84
tossCoin :: m Bit
76
85
86
+ -- The IO monad can implement the probabilistic part of POMs.
77
87
instance ProbabilisticMachine IO where
78
88
tossCoin = fromBool <$> randomRIO (False , True )
79
89
90
+ -- A probabilistic oracle machine is a monad that lets you toss coins and call oracles.
80
91
type POM m = (Functor m , Applicative m , Monad m , OracleMachine m , ProbabilisticMachine m )
81
92
82
- genCoinSequence :: POM m => Stream (m Bit )
83
- genCoinSequence = tossCoin :! genCoinSequence
93
+ -- Reals will be represented by a series ofnested intervals.
94
+ type Interval = (Rational , Rational )
95
+ compareInterval :: Interval -> Interval -> Maybe Ordering
96
+ compareInterval (a, b) (c, d)
97
+ | b > c = Just GT
98
+ | a < d = Just LT
99
+ | otherwise = Nothing
100
+
101
+ -- Actually, just kidding: reals are represented by a process (read: Monad)
102
+ -- which generates successive nested intervals.
103
+ -- Well, because this is haskell, we don't actually require that the intervals
104
+ -- be nested. Everything will blow up if you generate a "real" with expanding
105
+ -- intervals. So don't do that.
106
+ type Real m = Stream (m Interval )
107
+
108
+ makeReal :: Applicative m => Rational -> Real m
109
+ makeReal r = pure (r, r) :! makeReal r
110
+
111
+ zeroR :: Applicative m => Real m
112
+ zeroR = makeReal 0
84
113
85
- type Real m = Stream (m Bounds )
114
+ oneR :: Applicative m => Real m
115
+ oneR = makeReal 1
86
116
87
- liftR2 :: POM m => (Rational -> Rational -> Rational ) -> Real m -> Real m -> Real m
88
- liftR2 op (x:! xs) (y:! ys) = newBounds :! liftR2 op xs ys where
89
- newBounds = do
117
+ liftR2 :: Monad m => (Rational -> Rational -> Rational ) -> Real m -> Real m -> Real m
118
+ liftR2 op (x:! xs) (y:! ys) = newInterval :! liftR2 op xs ys where
119
+ newInterval = do
90
120
(a, b) <- x
91
121
(c, d) <- y
92
122
let (e, f) = (op a c, op b d)
93
- pure (max e f, min e f)
123
+ return (max e f, min e f)
94
124
95
- liftR1 :: POM m => (Rational -> Rational ) -> Real m -> Real m
96
- liftR1 op (x:! xs) = newBounds :! liftR1 op xs where
97
- newBounds = do
125
+ liftR1 :: Monad m => (Rational -> Rational ) -> Real m -> Real m
126
+ liftR1 op (x:! xs) = newInterval :! liftR1 op xs where
127
+ newInterval = do
98
128
(a, b) <- x
99
129
let (c, d) = (op a, op b)
100
- pure (max c d, min c d)
130
+ return (max c d, min c d)
101
131
102
- zeroR :: POM m => Real m
103
- zeroR = pure ( 0 , 0 ) :! zeroR
132
+ realProduct :: POM m => [ Real m ] - > Real m
133
+ realProduct = foldr (liftR2 (*) ) oneR
104
134
105
- oneR :: POM m => Real m
106
- oneR = pure (1 , 1 ) :! oneR
135
+ oneMinus :: POM m => Real m - > Real m
136
+ oneMinus = liftR1 (1 - )
107
137
108
- ------------- Begin.
138
+ -- Drops intervals that have 0 as a lower bound.
139
+ -- This makes division work. (Without this, division would fail on reals that
140
+ -- have zero as a lower-bound at some point, even if they eventually move away
141
+ -- from that lower bound.)
142
+ -- However, this function loops forever if the real is zero.
143
+ dropZeroIntervals :: POM m => Real m -> m (Real m )
144
+ dropZeroIntervals r@ (x:! xs) = do
145
+ (_, lo) <- x
146
+ if lo == 0 then dropZeroIntervals xs else pure r
109
147
110
- refineR :: POM m => Bounds -> m (Real m )
148
+ realDiv :: POM m => Real m -> Real m -> Real m
149
+ realDiv = liftR2 (/)
150
+
151
+ compareR :: Monad m => Real m -> Real m -> m Ordering
152
+ compareR (x:! xs) (y:! ys) = do
153
+ bx <- x
154
+ by <- y
155
+ case compareInterval bx by of
156
+ Just LT -> return LT
157
+ Just GT -> return GT
158
+ _ -> compareR xs ys
159
+
160
+ refineR :: (Monad m , ProbabilisticMachine m ) => Interval -> m (Real m )
111
161
refineR (hi, lo) = do
112
162
bit <- tossCoin
113
163
let med = (hi + lo) / 2
114
164
let bounds = if bit == One then (hi, med) else (med, lo)
115
165
rest <- refineR bounds
116
- pure $ pure bounds :! rest
166
+ return $ return bounds :! rest
167
+
168
+ -- Probabilistic oracle machine functions for manipulating reals:
117
169
118
- genRandomReal :: POM m => m (Real m )
170
+ -- Generates a real using a sequence of coin flips.
171
+ -- Each coin toss halves the interval. On a 1, we keep the top half, on a 0, we
172
+ -- keep the bottom half.
173
+ genRandomReal :: (Monad m , ProbabilisticMachine m ) => m (Real m )
119
174
genRandomReal = refineR (1 , 0 )
120
175
121
- flipR :: POM m => Real m -> m Bit
176
+ -- This allows probabilistic oracle machines to create a branch that has some
177
+ -- real probability of going either way.
178
+ -- That is, flipR (real 0.8)
179
+ flipR :: (Monad m , ProbabilisticMachine m ) => Real m -> m Bit
122
180
flipR r = do
123
181
rand <- genRandomReal
124
182
comp <- compareR rand r
125
183
case comp of
126
- LT -> pure Zero
127
- GT -> pure One
184
+ LT -> return Zero
185
+ GT -> return One
128
186
EQ -> error " A real generated from coin tosses can never equal another real."
129
187
130
- compareR :: POM m => Real m -> Real m -> m Ordering
131
- compareR (x:! xs) (y:! ys) = do
132
- bx <- x
133
- by <- y
134
- case compareBounds bx by of
135
- Just LT -> pure LT
136
- Just GT -> pure GT
137
- _ -> compareR xs ys
138
-
139
- restrictBounds :: POM m => Machine -> [Bit ] -> m (Rational , Rational ) -> m (Rational , Rational )
140
- restrictBounds m bs pbs = do
141
- (hi, lo) <- pbs
142
- let mid = (hi + lo) / 2
143
- ans <- oracle m bs mid
144
- pure $ if ans == One then (hi, mid) else (mid, lo)
145
-
188
+ -- Finds the probability that a machine, run on a given input, outputs a given bit.
189
+ -- Basically does binary refinement using the oracle.
190
+ -- Generates a series of nested intervals.
146
191
getProb :: POM m => Machine -> [Bit ] -> Bit -> Real m
147
- getProb m bs b
148
- | b == One = prob1
149
- | otherwise = oneMinus prob1
150
- where prob1 = makeStream (restrictBounds m bs) (pure (1 , 0 ))
151
-
192
+ getProb m bits bit = if bit == One then prob1 else oneMinus prob1 where
193
+ prob1 = makeStream restrictInterval (pure (1 , 0 ))
194
+ restrictInterval pbounds = do
195
+ (hi, lo) <- pbounds
196
+ let mid = (hi + lo) / 2
197
+ ans <- oracle m bits mid
198
+ return $ if ans == One then (hi, mid) else (mid, lo)
199
+
200
+ -- Finds the probability that a machine would have output a given bit sequence.
152
201
getStringProb :: POM m => Machine -> [Bit ] -> Real m
153
202
getStringProb m bs = realProduct [getProb m bs' b' | (bs', b') <- observations bs]
154
203
where observations xs = [(take n xs, xs !! n) | n <- [0 .. length xs - 1 ]]
155
204
156
- realProduct :: POM m => [Real m ] -> Real m
157
- realProduct = foldr (liftR2 (*) ) oneR
158
-
159
- realSum :: POM m => [Real m ] -> Real m
160
- realSum = foldr (liftR2 (+) ) zeroR
161
-
162
- oneMinus :: POM m => Real m -> Real m
163
- oneMinus = liftR1 (1 - )
164
-
165
- dropZeroBounds :: POM m => Real m -> m (Real m )
166
- dropZeroBounds r@ (x:! xs) = do
167
- (_, lo) <- x
168
- if lo == 0 then dropZeroBounds xs else pure r
169
-
170
- realDiv :: POM m => Real m -> Real m -> Real m
171
- realDiv = liftR2 (/)
172
-
205
+ -- Given a measure of how likely each machine is to accept x (in some abstract
206
+ -- fashion) and x, this function generates the generic probability (over all
207
+ -- machines) of ``accepting x."
208
+ -- Translation: this can be used to figure out the probability of a given
209
+ -- string being generated *in general*.
173
210
pOverMachines :: POM m => (Machine -> x -> Real m ) -> x -> Real m
174
211
pOverMachines f x = nthApprox <$> makeStream (+ 1 ) 0 where
175
212
nthApprox n = do
@@ -180,10 +217,17 @@ pOverMachines f x = nthApprox <$> makeStream (+1) 0 where
180
217
let lower = sum [m * lo | (m, (_, lo)) <- zip measures bounds]
181
218
pure (1 - sum measures + upper, lower)
182
219
220
+ -- Finally, the definition of Solomonoff induction.
221
+ -- Basically, it selects a machine according to both its simplicity-weighted
222
+ -- probability and its probability of generating the bits seen so far, and then
223
+ -- acts as that machine acts.
224
+ -- Thus, this machine defines a probability distribution over bits that
225
+ -- predicts the behavior of each machine in proportion to its posterior
226
+ -- probability.
183
227
solomonoffInduction :: POM m => [Bit ] -> m Bit
184
228
solomonoffInduction bs = pickM >>= \ m -> flipR (getProb m bs One ) where
185
229
pickM = do
186
- normalizationFactor <- dropZeroBounds $ pOverMachines getStringProb bs
230
+ normalizationFactor <- dropZeroIntervals $ pOverMachines getStringProb bs
187
231
rand <- genRandomReal
188
232
let likelihood m = getStringProb m bs `realDiv` normalizationFactor
189
233
let machineProb m = liftR1 (2 ^ negate (len m) * ) (likelihood m)
@@ -194,11 +238,12 @@ solomonoffInduction bs = pickM >>= \m -> flipR (getProb m bs One) where
194
238
let findMachine ((m, isSelected):! xs) = ifM isSelected (pure m) (findMachine xs)
195
239
findMachine $ streamZip allMachines decisions
196
240
241
+
197
242
type Action = Bit
198
243
newtype Observation = O
199
- { sense :: Int
200
- , reward :: Int
201
- } deriving (Eq , Ord , Enum , Num , Read , Show )
244
+ { sense :: Int
245
+ , reward :: Int
246
+ } deriving (Eq , Ord , Enum , Num , Read , Show )
202
247
obsToBits :: Observation -> [Bit ]
203
248
obsToBit = undefined
204
249
@@ -215,4 +260,4 @@ getEnvProb m h o = undefined
215
260
216
261
getHistProb :: POM m => Machine -> History -> Real m
217
262
getHistProb m h = realProduct [getEnvProb m h' o | (h', o) <- observations h]
218
- where observations xs = [(take n xs, xs !! n) | n <- [0 .. length xs - 1 ]]
263
+ where observations xs = [(take n xs, xs !! n) | n <- [0 .. length xs - 1 ]]
0 commit comments