Skip to content

Commit 809bba0

Browse files
author
nate
committed
Merged master
2 parents 16e0fc2 + 30577d0 commit 809bba0

File tree

1 file changed

+145
-100
lines changed

1 file changed

+145
-100
lines changed

SolomonoffInduction.hs

Lines changed: 145 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,21 @@
1-
module SolomonoffInduction where
21
{-# LANGUAGE ConstraintKinds #-}
3-
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
2+
module SolomonoffInduction (solomonoffInduction) where
43
import Prelude hiding (Real)
54
import Control.Applicative
65
import System.Random (randomRIO)
76

7+
-- Helper functions
88
ifM :: Monad m => m Bool -> m a -> m a -> m a
99
ifM p t e = p >>= \x -> if x then t else e
1010

11-
data Bit = Zero | One deriving (Eq, Ord, Show, Enum)
12-
13-
toBool :: Bit -> Bool
14-
toBool One = True
15-
toBool Zero = False
16-
17-
fromBool :: Bool -> Bit
18-
fromBool True = One
19-
fromBool False = Zero
20-
21-
newtype Machine = M Integer
22-
23-
len :: Machine -> Integer
24-
len (M i) = ceiling (logBase 2 (fromIntegral i) :: Double)
25-
26-
allMachines :: Stream Machine
27-
allMachines = M <$> makeStream (+1) 0
28-
29-
machineEncodingReal :: POM m => Real m -> Machine
30-
machineEncodingReal = undefined
31-
11+
-- An infinite list datatype.
12+
-- We could just use lists, but it's nice to avoid the [] cases.
13+
-- Streams of nested intervals will be used to represent real numbers.
3214
data Stream a = a :! Stream a
15+
3316
instance Functor Stream where
3417
fmap f (x :! xs) = f x :! fmap f xs
18+
3519
instance Applicative Stream where
3620
pure x = x :! pure x
3721
(f :! fs) <*> (x :! xs) = f x :! (fs <*> xs)
@@ -47,129 +31,182 @@ streamTake :: Int -> Stream a -> [a]
4731
streamTake 0 _ = []
4832
streamTake n (x:!xs) = x : streamTake (n-1) xs
4933

50-
streamFind :: (a -> Bool) -> Stream a -> a
51-
streamFind f (x:!xs) = if f x then x else streamFind f xs
52-
5334
streamZipWith :: (a -> b -> c) -> Stream a -> Stream b -> Stream c
5435
streamZipWith f (x :! xs) (y :! ys) = f x y :! streamZipWith f xs ys
5536

5637
streamZip :: Stream a -> Stream b -> Stream (a, b)
5738
streamZip = streamZipWith (,)
5839

59-
first :: Stream (Maybe a) -> a
60-
first (Nothing:!xs) = first xs
61-
first (Just x:!_) = x
40+
-- A bit datatype.
41+
-- We could just use Bool, but it's nice to leverage the type system where possible.
42+
data Bit = Zero | One deriving (Eq, Ord, Show, Enum)
43+
fromBool :: Bool -> Bit
44+
fromBool True = One
45+
fromBool False = Zero
6246

63-
type Bounds = (Rational, Rational)
47+
-- A Machine datatype.
48+
-- This will be used to encode probabilistic oracle machines.
49+
-- Assume any machine can be encoded as an integer.
50+
newtype Machine = M Integer
6451

65-
compareBounds :: Bounds -> Bounds -> Maybe Ordering
66-
compareBounds (a, b) (c, d)
67-
| b > c = Just GT
68-
| a < d = Just LT
69-
| otherwise = Nothing
52+
-- This will be used to define a simplicity prior.
53+
-- We should be careful to pick an encoding of machines such that the sum of
54+
-- 2^(- len m) for all m sums to 1. Right now we won't worry too much about that.
55+
len :: Machine -> Integer
56+
len (M i) = ceiling (logBase 2 (fromIntegral i) :: Double)
7057

58+
-- TODO: this violates the condition that sum [2^(- len m) | m <-
59+
-- allMachines] == 1 assumption.
60+
allMachines :: Stream Machine
61+
allMachines = M <$> makeStream (+1) 0
62+
63+
-- Probabilistic oracle machines.
64+
-- Remember, these are Turing machines that can flip coins and call oracles.
65+
-- We will consider oracle than answer questions of the form
66+
-- "Is the probability that M(bits) == 1 > p?", where M is a machine, bits is
67+
-- a finite bitstring used as input, and p is a rational probability,
68+
69+
-- It may be somewhat difficult (read: uncomputable) to implement a reflective
70+
-- oracle, but you can implement other "wrong" oracles if you want to test the
71+
-- code, as seen below.
7172
class OracleMachine m where
7273
oracle :: Machine -> [Bit] -> Rational -> m Bit
7374

75+
newtype OptimisticOracle a = OO a
76+
instance OracleMachine OptimisticOracle where
77+
oracle _ _ _ = OO One
78+
79+
newtype PessimisticOracle a = PO a
80+
instance OracleMachine PessimisticOracle where
81+
oracle _ _ _ = PO Zero
82+
7483
class ProbabilisticMachine m where
7584
tossCoin :: m Bit
7685

86+
-- The IO monad can implement the probabilistic part of POMs.
7787
instance ProbabilisticMachine IO where
7888
tossCoin = fromBool <$> randomRIO (False, True)
7989

90+
-- A probabilistic oracle machine is a monad that lets you toss coins and call oracles.
8091
type POM m = (Functor m, Applicative m, Monad m, OracleMachine m, ProbabilisticMachine m)
8192

82-
genCoinSequence :: POM m => Stream (m Bit)
83-
genCoinSequence = tossCoin :! genCoinSequence
93+
-- Reals will be represented by a series ofnested intervals.
94+
type Interval = (Rational, Rational)
95+
compareInterval :: Interval -> Interval -> Maybe Ordering
96+
compareInterval (a, b) (c, d)
97+
| b > c = Just GT
98+
| a < d = Just LT
99+
| otherwise = Nothing
100+
101+
-- Actually, just kidding: reals are represented by a process (read: Monad)
102+
-- which generates successive nested intervals.
103+
-- Well, because this is haskell, we don't actually require that the intervals
104+
-- be nested. Everything will blow up if you generate a "real" with expanding
105+
-- intervals. So don't do that.
106+
type Real m = Stream (m Interval)
107+
108+
makeReal :: Applicative m => Rational -> Real m
109+
makeReal r = pure (r, r) :! makeReal r
110+
111+
zeroR :: Applicative m => Real m
112+
zeroR = makeReal 0
84113

85-
type Real m = Stream (m Bounds)
114+
oneR :: Applicative m => Real m
115+
oneR = makeReal 1
86116

87-
liftR2 :: POM m => (Rational -> Rational -> Rational) -> Real m -> Real m -> Real m
88-
liftR2 op (x:!xs) (y:!ys) = newBounds :! liftR2 op xs ys where
89-
newBounds = do
117+
liftR2 :: Monad m => (Rational -> Rational -> Rational) -> Real m -> Real m -> Real m
118+
liftR2 op (x:!xs) (y:!ys) = newInterval :! liftR2 op xs ys where
119+
newInterval = do
90120
(a, b) <- x
91121
(c, d) <- y
92122
let (e, f) = (op a c, op b d)
93-
pure (max e f, min e f)
123+
return (max e f, min e f)
94124

95-
liftR1 :: POM m => (Rational -> Rational) -> Real m -> Real m
96-
liftR1 op (x:!xs) = newBounds :! liftR1 op xs where
97-
newBounds = do
125+
liftR1 :: Monad m => (Rational -> Rational) -> Real m -> Real m
126+
liftR1 op (x:!xs) = newInterval :! liftR1 op xs where
127+
newInterval = do
98128
(a, b) <- x
99129
let (c, d) = (op a, op b)
100-
pure (max c d, min c d)
130+
return (max c d, min c d)
101131

102-
zeroR :: POM m => Real m
103-
zeroR = pure (0, 0) :! zeroR
132+
realProduct :: POM m => [Real m] -> Real m
133+
realProduct = foldr (liftR2 (*)) oneR
104134

105-
oneR :: POM m => Real m
106-
oneR = pure (1, 1) :! oneR
135+
oneMinus :: POM m => Real m -> Real m
136+
oneMinus = liftR1 (1-)
107137

108-
------------- Begin.
138+
-- Drops intervals that have 0 as a lower bound.
139+
-- This makes division work. (Without this, division would fail on reals that
140+
-- have zero as a lower-bound at some point, even if they eventually move away
141+
-- from that lower bound.)
142+
-- However, this function loops forever if the real is zero.
143+
dropZeroIntervals :: POM m => Real m -> m (Real m)
144+
dropZeroIntervals r@(x:!xs) = do
145+
(_, lo) <- x
146+
if lo == 0 then dropZeroIntervals xs else pure r
109147

110-
refineR :: POM m => Bounds -> m (Real m)
148+
realDiv :: POM m => Real m -> Real m -> Real m
149+
realDiv = liftR2 (/)
150+
151+
compareR :: Monad m => Real m -> Real m -> m Ordering
152+
compareR (x:!xs) (y:!ys) = do
153+
bx <- x
154+
by <- y
155+
case compareInterval bx by of
156+
Just LT -> return LT
157+
Just GT -> return GT
158+
_ -> compareR xs ys
159+
160+
refineR :: (Monad m, ProbabilisticMachine m) => Interval -> m (Real m)
111161
refineR (hi, lo) = do
112162
bit <- tossCoin
113163
let med = (hi + lo) / 2
114164
let bounds = if bit == One then (hi, med) else (med, lo)
115165
rest <- refineR bounds
116-
pure $ pure bounds :! rest
166+
return $ return bounds :! rest
167+
168+
-- Probabilistic oracle machine functions for manipulating reals:
117169

118-
genRandomReal :: POM m => m (Real m)
170+
-- Generates a real using a sequence of coin flips.
171+
-- Each coin toss halves the interval. On a 1, we keep the top half, on a 0, we
172+
-- keep the bottom half.
173+
genRandomReal :: (Monad m, ProbabilisticMachine m) => m (Real m)
119174
genRandomReal = refineR (1, 0)
120175

121-
flipR :: POM m => Real m -> m Bit
176+
-- This allows probabilistic oracle machines to create a branch that has some
177+
-- real probability of going either way.
178+
-- That is, flipR (real 0.8)
179+
flipR :: (Monad m, ProbabilisticMachine m) => Real m -> m Bit
122180
flipR r = do
123181
rand <- genRandomReal
124182
comp <- compareR rand r
125183
case comp of
126-
LT -> pure Zero
127-
GT -> pure One
184+
LT -> return Zero
185+
GT -> return One
128186
EQ -> error "A real generated from coin tosses can never equal another real."
129187

130-
compareR :: POM m => Real m -> Real m -> m Ordering
131-
compareR (x:!xs) (y:!ys) = do
132-
bx <- x
133-
by <- y
134-
case compareBounds bx by of
135-
Just LT -> pure LT
136-
Just GT -> pure GT
137-
_ -> compareR xs ys
138-
139-
restrictBounds :: POM m => Machine -> [Bit] -> m (Rational, Rational) -> m (Rational, Rational)
140-
restrictBounds m bs pbs = do
141-
(hi, lo) <- pbs
142-
let mid = (hi + lo) / 2
143-
ans <- oracle m bs mid
144-
pure $ if ans == One then (hi, mid) else (mid, lo)
145-
188+
-- Finds the probability that a machine, run on a given input, outputs a given bit.
189+
-- Basically does binary refinement using the oracle.
190+
-- Generates a series of nested intervals.
146191
getProb :: POM m => Machine -> [Bit] -> Bit -> Real m
147-
getProb m bs b
148-
| b == One = prob1
149-
| otherwise = oneMinus prob1
150-
where prob1 = makeStream (restrictBounds m bs) (pure (1, 0))
151-
192+
getProb m bits bit = if bit == One then prob1 else oneMinus prob1 where
193+
prob1 = makeStream restrictInterval (pure (1, 0))
194+
restrictInterval pbounds = do
195+
(hi, lo) <- pbounds
196+
let mid = (hi + lo) / 2
197+
ans <- oracle m bits mid
198+
return $ if ans == One then (hi, mid) else (mid, lo)
199+
200+
-- Finds the probability that a machine would have output a given bit sequence.
152201
getStringProb :: POM m => Machine -> [Bit] -> Real m
153202
getStringProb m bs = realProduct [getProb m bs' b' | (bs', b') <- observations bs]
154203
where observations xs = [(take n xs, xs !! n) | n <- [0 .. length xs - 1]]
155204

156-
realProduct :: POM m => [Real m] -> Real m
157-
realProduct = foldr (liftR2 (*)) oneR
158-
159-
realSum :: POM m => [Real m] -> Real m
160-
realSum = foldr (liftR2 (+)) zeroR
161-
162-
oneMinus :: POM m => Real m -> Real m
163-
oneMinus = liftR1 (1-)
164-
165-
dropZeroBounds :: POM m => Real m -> m (Real m)
166-
dropZeroBounds r@(x:!xs) = do
167-
(_, lo) <- x
168-
if lo == 0 then dropZeroBounds xs else pure r
169-
170-
realDiv :: POM m => Real m -> Real m -> Real m
171-
realDiv = liftR2 (/)
172-
205+
-- Given a measure of how likely each machine is to accept x (in some abstract
206+
-- fashion) and x, this function generates the generic probability (over all
207+
-- machines) of ``accepting x."
208+
-- Translation: this can be used to figure out the probability of a given
209+
-- string being generated *in general*.
173210
pOverMachines :: POM m => (Machine -> x -> Real m) -> x -> Real m
174211
pOverMachines f x = nthApprox <$> makeStream (+1) 0 where
175212
nthApprox n = do
@@ -180,10 +217,17 @@ pOverMachines f x = nthApprox <$> makeStream (+1) 0 where
180217
let lower = sum [m * lo | (m, (_, lo)) <- zip measures bounds]
181218
pure (1 - sum measures + upper, lower)
182219

220+
-- Finally, the definition of Solomonoff induction.
221+
-- Basically, it selects a machine according to both its simplicity-weighted
222+
-- probability and its probability of generating the bits seen so far, and then
223+
-- acts as that machine acts.
224+
-- Thus, this machine defines a probability distribution over bits that
225+
-- predicts the behavior of each machine in proportion to its posterior
226+
-- probability.
183227
solomonoffInduction :: POM m => [Bit] -> m Bit
184228
solomonoffInduction bs = pickM >>= \m -> flipR (getProb m bs One) where
185229
pickM = do
186-
normalizationFactor <- dropZeroBounds $ pOverMachines getStringProb bs
230+
normalizationFactor <- dropZeroIntervals $ pOverMachines getStringProb bs
187231
rand <- genRandomReal
188232
let likelihood m = getStringProb m bs `realDiv` normalizationFactor
189233
let machineProb m = liftR1 (2 ^ negate (len m) *) (likelihood m)
@@ -194,11 +238,12 @@ solomonoffInduction bs = pickM >>= \m -> flipR (getProb m bs One) where
194238
let findMachine ((m, isSelected):!xs) = ifM isSelected (pure m) (findMachine xs)
195239
findMachine $ streamZip allMachines decisions
196240

241+
197242
type Action = Bit
198243
newtype Observation = O
199-
{ sense :: Int
200-
, reward :: Int
201-
} deriving (Eq, Ord, Enum, Num, Read, Show)
244+
{ sense :: Int
245+
, reward :: Int
246+
} deriving (Eq, Ord, Enum, Num, Read, Show)
202247
obsToBits :: Observation -> [Bit]
203248
obsToBit = undefined
204249

@@ -215,4 +260,4 @@ getEnvProb m h o = undefined
215260

216261
getHistProb :: POM m => Machine -> History -> Real m
217262
getHistProb m h = realProduct [getEnvProb m h' o | (h', o) <- observations h]
218-
where observations xs = [(take n xs, xs !! n) | n <- [0 .. length xs - 1]]
263+
where observations xs = [(take n xs, xs !! n) | n <- [0 .. length xs - 1]]

0 commit comments

Comments
 (0)