import Data.List import System.Random {- Data Types for the game -} data Player = X | O deriving (Eq, Show) type Board = [[Maybe Player]] data Node = Node { wins :: Int -- how many wins , plays :: Int -- how many plays , player :: Player -- which player is it , state :: Board -- the state of the board , children :: [Node] -- the children } | Empty -- if this is an unexplored node | Forbidden -- if this is a forbidden node deriving (Eq, Show) -- | Data type to zip the tree up data Choice = Place { winsP :: Int -- the info for the current node , playsP :: Int , playerP :: Player , stateP :: Board , siblingsP :: ([Node], [Node]) -- the paths not taken } deriving Show type Thread = [Choice] type Zipper = (Thread, Node) {- Tic-Tac-Toe functions -} -- | change an int to a coord on the board place2coord :: Int -> (Int, Int) place2coord x = (x `div` 3, x `mod` 3) -- | did anyone win? win :: Board -> Bool win b = any full $ b ++ (transpose b) ++ (diags b) where full [Nothing, _, _] = False full [x1, x2, x3] = x1 == x2 && x2 == x3 diags [[x1, _, x2], [_, x3, _], [x4, _, x5]] = [[x1,x3,x5], [x2,x3,x4]] -- | is it a draw? draw :: Board -> Bool draw b = (null . possibleMoves) b && (not.win) b -- | have we reached a goal state? goal :: Board -> Bool goal b = win b || draw b -- | list of possible moves to make possibleMoves :: Board -> [Int] possibleMoves b = map snd $ filter ((==Nothing) . fst) $ zip (concat b) [0..] -- | return the next player nextPlayer :: Player -> Player nextPlayer X = O nextPlayer O = X -- | perform a move move :: Int -> Player -> Board -> Board move pos p b = take x b ++ (newline : drop (x+1) b) where (x, y) = place2coord pos newline = take y bi ++ (Just p) : drop (y+1) bi bi = b !! x {- Utility functions -} -- | calculates the upper bound of the confidence interval confidence :: Int -> Int -> Int -> Double confidence wins ni n = hi where hi = mu + interval interval = sqrt $ (2 * log n') / ni' mu = wins' / ni' wins' = fromIntegral wins ni' = fromIntegral ni n' = fromIntegral n -- | returns the index of the node with maximum confidence maxConfidence :: [Node] -> Int maxConfidence ns = snd.head $ sort $ zip (map conf ns) [0..] where conf Forbidden = 1000 conf Empty = 1000 conf n = negate $ confidence (wins n) (plays n) total total = sum $ map plays $ filter valid ns valid ni = ni /= Forbidden && ni /= Empty -- | is there any unexplored children? anyEmpty :: Node -> Bool anyEmpty n = any (==Empty) $ children n -- | is this a dead end? allForbidden :: Node -> Bool allForbidden n = all (==Forbidden) $ children n -- | zip down the tree nextStep :: Int -> Node -> Choice nextStep pos n = Place (wins n) (plays n) (player n) (state n) (cs', cs'') where cs' = take pos $ children n cs'' = drop (pos + 1) $ children n -- | pick a list element at random and return the index pickRandom :: Eq a => StdGen -> [a] -> a -> (Int, StdGen) pickRandom g xs x | null choices = error "No choices to be made!" | otherwise = (pos, g') where pos = snd $ choices !! idx (idx, g') = randomR (0, length choices - 1) g choices = filter ((==x) . fst) $ zip xs [0..] -- | pick a child at random randomChild :: StdGen -> [Node] -> (Int, StdGen) randomChild g ns = pickRandom g ns Empty -- | choose a move at random randomMove :: StdGen -> Board -> (Int, StdGen) randomMove g b = pickRandom g (concat b) Nothing {- MCTS functions -} -- | select the next node to expand select :: Zipper -> Zipper -- if we reach an empty or forbidden, something went really wrong select z@(t, Empty) = error "Select reached empty node" select z@(t, Forbidden) = error "Select reached forbidden node" select z@(t, n) | anyEmpty n = z -- expand this node | allForbidden n = z -- nothing more to be done | otherwise = select (t', n') -- move forward where n' = children n !! idx t' = step : t step = nextStep idx n idx = maxConfidence $ children n -- | expand the selected node to a random empty children expansion :: StdGen -> Zipper -> (Zipper, StdGen) expansion g z@(t, n) = ((t', n'), g') where t' = (nextStep pos n) : t n' = Node 0 0 (nextPlayer $ player n) s nextChildren -- next state of the board s = move pos (player n) (state n) -- next children with forbidden positions marked nextChildren = map avail $ concat s avail Nothing = Empty avail _ = Forbidden -- the child of n to be expanded (pos, g') = randomChild g $ children n -- | simulate the remainder of the game at random, without expanding the tree simulation :: StdGen -> Player -> Board -> (Int, Board, StdGen) simulation g p b | win b = (score, b, g) | draw b = ( 0, b, g) | otherwise = simulation g' (nextPlayer p) b' where score = if p == X then -1 else 1 b' = move pos p b (pos, g') = randomMove g b -- | propagates the result backpropagation :: Int -> Zipper -> Node backpropagation score ([], n) = n backpropagation score (t:ts, n) = backpropagation score (ts, n'') where -- zip up n'' = Node (winsP t + wins'') (playsP t + 1) (playerP t) (stateP t) ns ns = s1 ++ n' : s2 (s1, s2) = siblingsP t n' = n {wins = wins n + wins', plays = plays n + 1} wins' = winner score $ player n wins'' = winner score $ playerP t -- | let's count a draw as a win winner :: Int -> Player -> Int winner 0 _ = 1 winner 1 X = 1 winner (-1) O = 1 winner _ _ = 0 -- | MCTS algorithm mcts :: StdGen -> Node -> (Node, StdGen) mcts g n | (not.anyEmpty) $ snd z = (n, g) -- there's no need to expand more | otherwise = (n'', g'') where n'' = backpropagation sc z' (sc, b, g'') = simulation g (player $ snd z') (state $ snd z') (z', g') = expansion g z z = select ([], n) -- | iterates the algorithm it times iter :: Int -> StdGen -> Node -> Node iter it g n | it <= 0 = n | otherwise = iter (it-1) g' n' where (n', g') = mcts g n -- | play the game with AI play :: Node -> Board play Empty = [[]] play Forbidden = [[]] play n | goal $ state n = state n | all (\ni -> ni==Empty || ni==Forbidden) $ children n = state n | otherwise = play n' where n' = children n !! idx idx = maxConfidence $ children n {- initial states -} s0 :: Board s0 = [[Nothing | _ <- [1..3]] | _ <- [1..3]] emptyChildren :: [Node] emptyChildren = replicate 9 Empty n0 :: Node n0 = Node 0 0 X s0 emptyChildren -- | the main function main = do g <- newStdGen let n = iter 10000 g n0 print $ play n