Published on

Connect 4 with Monte Carlo Tree Search

Authors
  • avatar
    Name
    Qi Wang
    Twitter

Overview

Table of Contents

Introduction

A few years ago, DeepMind built a program that beat the world champion in the board game Go. Recently, there has been advancements reinforcement learning in complicated games such as Chess, Starcraft II, and many more. Monte Carlo Tree Search is an technique used in some of these games that takes a heuristic approach to selecting an optimal action based on the given state of the environment. Interestingly, the Monte Carlo Tree Search algorithm does not need any sort of evaluation function of the state because it learns from general experience and gameplay patterns. Specifically, it is used for games with a sequence of moves because each increment in depth represents a new move. Mostly, this algorithm has been used to learn two-player games (each depth alternates player turn), but is sometimes applied to one player games.

Connect 4

I'm pretty sure you all know of the game Connect 4. But, if you don't, Connect 4 is a game where two players drop pieces of their own color into a board 6 rows high with 7 columns. To win the game, the player must be the first to have 4 pieces of their color connected either horizontally, vertically, or diagonally. A few of the win conditions are shown below.

xxxx
x
x
x
x
x
 x
  x
   x

The code for the Connect 4 game class is at the bottom of this post.

Monte Carlo Tree Search (MCTS)

We normally learn to play games by learning moves that lead to us winning or a high score. In terms of a machine agent learning how to play games there are a few options to choose from.

  1. Learn from the reward function
  2. Learn from knowing which states are inherently better

With Monte Carlo Tree Search, we use the second option to used to choose optimal moves to beat the game.

Intuition of MCTS

One type of tree traversal algorithm that you may be familiar with is Depth First Search (DFS). For example, to beat a game using DFS, we would simulate every single possible outcome. If we can simulate every outcome, it is trivial to come up with the best action for each state. However, it is extremely computationally heavy to compute every single possible outcome. In fact, there are over 1012010^{120} different games of chess — more than the amount of atoms in the universe. With the world's fastest supercomputer being able to perform 101810^{18} floating-point operations per second, and reducing each dfs()\texttt{dfs}() call to be one operation, it would at least take 1010210^{102} seconds 1095\approx 10^{95} years. Below is just one part of a tree for the simple game tic-tac-toe.

Image

Image Source

Obviously, DFS will not work out for us :( This is where Monte Carlo Tree Search comes into play. Let's modify the DFS approach to search actions that are more likely to win. This will dramatically reduce time spent searching useless moves, and focus on differentiating between better options. For MCTS, the machine learns to play through the heuristic strategy from simulating many sequences of random gameplay against itself.

Building the Tree

There are four phases of implementing MCTS:

  1. Selection
  2. Expansion
  3. Rollout
  4. Backpropagation

Essentially, the algorithm is these four stages repeated under a user defined time_limit\texttt{time\_limit}. Since this algorithm does not allocate a lot of resources towards unoptimal actions, it will be able to fully explore good actions under a low time constraint. Below is an illustration demonstrating one of these cycles of MCTS.

MCTSDiagram

Image Source

Selection Phase

This phase is used to select a node in the tree to expand and simulate upon. We select the leaf node with the maximum value that we define. If the leaf node has not been explored upon, we will select that node; otherwise, we will add all possible actions as its children into the tree and randomly select one node.

To account for the balance of exploration vs exploitation, we will use the Upper Confidence bound (also used by AlphaGo) applied to Trees to determine the value of the node. Below is the equation:

UCT(ni)=QN+cilogNpN\texttt{UCT}(n_i) = \frac{Q}{N} + c_i \sqrt{\frac{\log N_p}{N}}

nin_i is the current node. There are QQ amount of wins the agent has out of the NN simulations in that given state. cic_i is a float from 00 to 11, representing the exploration rate of the agent. NpN_p is the number of simulations that the parent of nin_i has.

As shown in the equation above, it can be observed that the node is more likely to be explored if it yields a high probability of winning, and it has been relatively unexplored in comparison to the parent node.

Below is the node class:

class Node:
    def __init__(self, move, parent):
        self.move = move
        self.parent = parent
        self.N = 0
        self.Q = 0
        self.children = {}
        self.outcome = GameMeta.PLAYERS['none']

    def add_children(self, children: dict) -> None:
        for child in children:
            self.children[child.move] = child

    def value(self, explore: float = MCTSMeta.EXPLORATION):
        if self.N == 0:
            # we prioritize nodes that are not explored
            return 0 if explore == 0 else GameMeta.INF
        else:
            return self.Q / self.N + explore * math.sqrt(math.log(self.parent.N) / self.N)

The code for selecting a node is show below:

def select_node(self) -> tuple:
    node = self.root
    state = deepcopy(self.root_state)

    while len(node.children) != 0:
        children = node.children.values()
        max_value = max(children, key=lambda n: n.value()).value()
        # select nodes with the highest UCT value
        max_nodes = [n for n in children if n.value() == max_value]

        # randomly select on to expand upon
        node = random.choice(max_nodes)
        state.move(node.move)

        if node.N == 0:
            return node, state

    if self.expand(node, state): # determines if the state is a terminal state (game over)
        node = random.choice(list(node.children.values()))
        state.move(node.move)

    return node, state

Expansion Phase

After the given node is selected, we want to add all possible actions as children of the selected node. Below is the code:

def expand(self, parent: Node, state: ConnectState) -> bool:
    if state.game_over():
        return False

    children = [Node(move, parent) for move in state.get_legal_moves()]
    parent.add_children(children)

    return True

Rollout/Simulation Phase

In the rollout phase, we are simply simulating through a random game starting from the given state in the input. This function will return the winner of the simulated game to update the probability of each node. Below is the code:

def roll_out(self, state: ConnectState) -> int:
    while not state.game_over():
        state.move(random.choice(state.get_legal_moves()))

    return state.get_outcome() # function in the game class shown at the bottom of this blog

Backpropagation Phase

This step is used to propagate the winner of the simulated game through all of the ancestors of the selected node. We go through all of its ancestors because the selected node's state came from the parents and contributes to the overall "goodness" of the parent states. Below is the code:

def back_propagate(self, node: Node, turn: int, outcome: int) -> None:
    # For the current player, not the next player
    reward = 0 if outcome == turn else 1

    while node is not None:
        node.N += 1
        node.Q += reward
        node = node.parent
        if outcome == GameMeta.OUTCOMES['draw']: # we count it as a loss for every state
            reward = 0
        else:
            reward = 1 - reward # alternates between 0 and 1 because each alternate depth represents different player turns

Combining the Four Phases

Combining all the phases in order, we will select the node to simulate, then perform rollout on that node and then backpropagate the results onto its parent nodes. We will repeat these steps until time_limit\texttt{time\_limit} is reached. Below is the code for this step:

def search(self, time_limit: int):
    start_time = time.process_time()

    num_rollouts = 0
    while time.process_time() - start_time < time_limit:
        node, state = self.select_node()
        outcome = self.roll_out(state)
        self.back_propagate(node, state.to_play, outcome)
        num_rollouts += 1 # for calculating statistics

    run_time = time.process_time() - start_time
    self.run_time = run_time
    self.num_rollouts = num_rollouts

Depending on what kind of computational power your device has, the time_limit\texttt{time\_limit} will vary. I find that 10...1510...15 seconds (20,000\approx 20,000 rollouts) is sufficient for it to play at a relatively high level.

Choosing Best Action

With the game tree, choosing the best action is quite trivial. For each node, we choose the action that leads to the state with the most NN. It is important to note that we do not choose the highest QN\frac{Q}{N} because it could come from a relatively unexplored node. However, a node with high NN is bound to be a good action because we do not allocate a lot of time to exploring worse-off options. Below is the code to select the best move:

def best_move(self):
    if self.root_state.game_over():
        return -1

    max_value = max(self.root.children.values(), key=lambda n: n.N).N
    max_nodes = [n for n in self.root.children.values() if n.N == max_value]
    best_child = random.choice(max_nodes)

    return best_child.move

There are a few other functions that I did not include as I find them to be self-explanatory. Below is the code to playing Connect 4 with MCTS.

Conclusion

Hopefully, you were able to correctly implement the MCTS algorithm and reach relatively high performance on your game. I'll see you in the next post!

meta.py - Defines game and MCTS parameters

import math

class GameMeta:
    PLAYERS = {'none': 0, 'one': 1, 'two': 2}
    OUTCOMES = {'none': 0, 'one': 1, 'two': 2, 'draw': 3}
    INF = float('inf')
    ROWS = 6
    COLS = 7


class MCTSMeta:
    EXPLORATION = math.sqrt(2)

ConnectState.py - Game class for Connect 4

from copy import deepcopy
import numpy as np
from meta import GameMeta


class ConnectState:
    def __init__(self):
        self.board = [[0] * GameMeta.COLS for _ in range(GameMeta.ROWS)]
        self.to_play = GameMeta.PLAYERS['one']
        self.height = [GameMeta.ROWS - 1] * GameMeta.COLS
        self.last_played = []

    def get_board(self):
        return deepcopy(self.board)

    def move(self, col):
        self.board[self.height[col]][col] = self.to_play
        self.last_played = [self.height[col], col]
        self.height[col] -= 1
        self.to_play = GameMeta.PLAYERS['two'] if self.to_play == GameMeta.PLAYERS['one'] else GameMeta.PLAYERS['one']

    def get_legal_moves(self):
        return [col for col in range(GameMeta.COLS) if self.board[0][col] == 0]

    def check_win(self):
        if len(self.last_played) > 0 and self.check_win_from(self.last_played[0], self.last_played[1]):
            return self.board[self.last_played[0]][self.last_played[1]]
        return 0

    def check_win_from(self, row, col):
        player = self.board[row][col]
        """
        Last played action is at (row, col)
        Check surrounding 7x7 grid for a win
        """

        consecutive = 1
        # Check horizontal
        tmprow = row
        while tmprow + 1 < GameMeta.ROWS and self.board[tmprow + 1][col] == player:
            consecutive += 1
            tmprow += 1
        tmprow = row
        while tmprow - 1 >= 0 and self.board[tmprow - 1][col] == player:
            consecutive += 1
            tmprow -= 1

        if consecutive >= 4:
            return True

        # Check vertical
        consecutive = 1
        tmpcol = col
        while tmpcol + 1 < GameMeta.COLS and self.board[row][tmpcol + 1] == player:
            consecutive += 1
            tmpcol += 1
        tmpcol = col
        while tmpcol - 1 >= 0 and self.board[row][tmpcol - 1] == player:
            consecutive += 1
            tmpcol -= 1

        if consecutive >= 4:
            return True

        # Check diagonal
        consecutive = 1
        tmprow = row
        tmpcol = col
        while tmprow + 1 < GameMeta.ROWS and tmpcol + 1 < GameMeta.COLS and self.board[tmprow + 1][tmpcol + 1] == player:
            consecutive += 1
            tmprow += 1
            tmpcol += 1
        tmprow = row
        tmpcol = col
        while tmprow - 1 >= 0 and tmpcol - 1 >= 0 and self.board[tmprow - 1][tmpcol - 1] == player:
            consecutive += 1
            tmprow -= 1
            tmpcol -= 1

        if consecutive >= 4:
            return True

        # Check anti-diagonal
        consecutive = 1
        tmprow = row
        tmpcol = col
        while tmprow + 1 < GameMeta.ROWS and tmpcol - 1 >= 0 and self.board[tmprow + 1][tmpcol - 1] == player:
            consecutive += 1
            tmprow += 1
            tmpcol -= 1
        tmprow = row
        tmpcol = col
        while tmprow - 1 >= 0 and tmpcol + 1 < GameMeta.COLS and self.board[tmprow - 1][tmpcol + 1] == player:
            consecutive += 1
            tmprow -= 1
            tmpcol += 1

        if consecutive >= 4:
            return True

        return False

    def game_over(self):
        return self.check_win() or len(self.get_legal_moves()) == 0

    def get_outcome(self):
        if len(self.get_legal_moves()) == 0 and self.check_win() == 0:
            return GameMeta.OUTCOMES['draw']

        return GameMeta.OUTCOMES['one'] if self.check_win() == GameMeta.PLAYERS['one'] else GameMeta.OUTCOMES['two']

    def print(self):
        print('=============================')

        for row in range(GameMeta.ROWS):
            for col in range(GameMeta.COLS):
                print('| {} '.format('X' if self.board[row][col] == 1 else 'O' if self.board[row][col] == 2 else ' '), end='')
            print('|')

        print('=============================')

mcts.py - Monte Carlo Tree Search Agent

import random
import time
import math
from copy import deepcopy

from ConnectState import ConnectState
from meta import GameMeta, MCTSMeta


class Node:
    def __init__(self, move, parent):
        self.move = move
        self.parent = parent
        self.N = 0
        self.Q = 0
        self.children = {}
        self.outcome = GameMeta.PLAYERS['none']

    def add_children(self, children: dict) -> None:
        for child in children:
            self.children[child.move] = child

    def value(self, explore: float = MCTSMeta.EXPLORATION):
        if self.N == 0:
            return 0 if explore == 0 else GameMeta.INF
        else:
            return self.Q / self.N + explore * math.sqrt(math.log(self.parent.N) / self.N)


class MCTS:
    def __init__(self, state=ConnectState()):
        self.root_state = deepcopy(state)
        self.root = Node(None, None)
        self.run_time = 0
        self.node_count = 0
        self.num_rollouts = 0

    def select_node(self) -> tuple:
        node = self.root
        state = deepcopy(self.root_state)

        while len(node.children) != 0:
            children = node.children.values()
            max_value = max(children, key=lambda n: n.value()).value()
            max_nodes = [n for n in children if n.value() == max_value]

            node = random.choice(max_nodes)
            state.move(node.move)

            if node.N == 0:
                return node, state

        if self.expand(node, state):
            node = random.choice(list(node.children.values()))
            state.move(node.move)

        return node, state

    def expand(self, parent: Node, state: ConnectState) -> bool:
        if state.game_over():
            return False

        children = [Node(move, parent) for move in state.get_legal_moves()]
        parent.add_children(children)

        return True

    def roll_out(self, state: ConnectState) -> int:
        while not state.game_over():
            state.move(random.choice(state.get_legal_moves()))

        return state.get_outcome()

    def back_propagate(self, node: Node, turn: int, outcome: int) -> None:

        # For the current player, not the next player
        reward = 0 if outcome == turn else 1

        while node is not None:
            node.N += 1
            node.Q += reward
            node = node.parent
            if outcome == GameMeta.OUTCOMES['draw']:
                reward = 0
            else:
                reward = 1 - reward

    def search(self, time_limit: int):
        start_time = time.process_time()

        num_rollouts = 0
        while time.process_time() - start_time < time_limit:
            node, state = self.select_node()
            outcome = self.roll_out(state)
            self.back_propagate(node, state.to_play, outcome)
            num_rollouts += 1

        run_time = time.process_time() - start_time
        self.run_time = run_time
        self.num_rollouts = num_rollouts

    def best_move(self):
        if self.root_state.game_over():
            return -1

        max_value = max(self.root.children.values(), key=lambda n: n.N).N
        max_nodes = [n for n in self.root.children.values() if n.N == max_value]
        best_child = random.choice(max_nodes)

        return best_child.move

    def move(self, move):
        if move in self.root.children:
            self.root_state.move(move)
            self.root = self.root.children[move]
            return

        self.root_state.move(move)
        self.root = Node(None, None)

    def statistics(self) -> tuple:
        return self.num_rollouts, self.run_time

game.py - Main class to play with agent

from ConnectState import ConnectState
from mcts import MCTS


def play():
    state = ConnectState()
    mcts = MCTS(state)

    while not state.game_over():
        print("Current state:")
        state.print()

        user_move = int(input("Enter a move: "))
        while user_move not in state.get_legal_moves():
            print("Illegal move")
            user_move = int(input("Enter a move: "))

        state.move(user_move)
        mcts.move(user_move)

        state.print()

        if state.game_over():
            print("Player one won!")
            break

        print("Thinking...")

        mcts.search(8)
        num_rollouts, run_time = mcts.statistics()
        print("Statistics: ", num_rollouts, "rollouts in", run_time, "seconds")
        move = mcts.best_move()

        print("MCTS chose move: ", move)

        state.move(move)
        mcts.move(move)

        if state.game_over():
            print("Player two won!")
            break


if __name__ == "__main__":
    play()
Subscribe to the newsletter