# Monte Carlo Tree Search in java

Complete implementation of Monte Carlo Tree Search in java

import java.util.Collections;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Random;

/**
* Monte Carlo Tree Search (MCTS) is a heuristic search algorithm used in
* decition taking problems especially games.
*

*/
public class MonteCarloTreeSearch {

public class Node {

Node parent;
ArrayList<Node> childNodes;
boolean isPlayersTurn; // True if it is the player's turn.
boolean playerWon; // True if the player won; false if the opponent won.
int score;
int visitCount;

public Node() {
}

public Node(Node parent, boolean isPlayersTurn) {
this.parent = parent;
childNodes = new ArrayList<>();
this.isPlayersTurn = isPlayersTurn;
playerWon = false;
score = 0;
visitCount = 0;
}
}

static final int WIN_SCORE = 10;
static final int TIME_LIMIT = 500; // Time the algorithm will be running for (in milliseconds).

public static void main(String[] args) {
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();

mcts.monteCarloTreeSearch(mcts.new Node(null, true));
}

/**
* Explores a game tree using Monte Carlo Tree Search (MCTS) and returns the
* most promising node.
*
* @param rootNode Root node of the game tree.
* @return The most promising child of the root node.
*/
public Node monteCarloTreeSearch(Node rootNode) {
Node winnerNode;
double timeLimit;

// Expand the root node.

timeLimit = System.currentTimeMillis() + TIME_LIMIT;

// Explore the tree until the time limit is reached.
while (System.currentTimeMillis() < timeLimit) {
Node promisingNode;

// Get a promising node using UCT.
promisingNode = getPromisingNode(rootNode);

// Expand the promising node.
if (promisingNode.childNodes.size() == 0) {
}

simulateRandomPlay(promisingNode);
}

winnerNode = getWinnerNode(rootNode);
printScores(rootNode);
System.out.format("\nThe optimal node is: %02d\n", rootNode.childNodes.indexOf(winnerNode) + 1);

return winnerNode;
}

public void addChildNodes(Node node, int childCount) {
for (int i = 0; i < childCount; i++) {
}
}

/**
* Uses UCT to find a promising child node to be explored.
*
* UCT: Upper Confidence bounds applied to Trees.
*
* @param rootNode Root node of the tree.
* @return The most promising node according to UCT.
*/
public Node getPromisingNode(Node rootNode) {
Node promisingNode = rootNode;

// Iterate until a node that hasn't been expanded is found.
while (promisingNode.childNodes.size() != 0) {
double uctIndex = Double.MIN_VALUE;
int nodeIndex = 0;

// Iterate through child nodes and pick the most promising one
// using UCT (Upper Confidence bounds applied to Trees).
for (int i = 0; i < promisingNode.childNodes.size(); i++) {
Node childNode = promisingNode.childNodes.get(i);
double uctTemp;

// If child node has never been visited
// it will have the highest uct value.
if (childNode.visitCount == 0) {
nodeIndex = i;
break;
}

uctTemp = ((double) childNode.score / childNode.visitCount)
+ 1.41 * Math.sqrt(Math.log(promisingNode.visitCount) / (double) childNode.visitCount);

if (uctTemp > uctIndex) {
uctIndex = uctTemp;
nodeIndex = i;
}
}

promisingNode = promisingNode.childNodes.get(nodeIndex);
}

return promisingNode;
}

/**
* Simulates a random play from a nodes current state and back propagates
* the result.
*
* @param promisingNode Node that will be simulated.
*/
public void simulateRandomPlay(Node promisingNode) {
Random rand = new Random();
Node tempNode = promisingNode;
boolean isPlayerWinner;

// The following line randomly determines whether the simulated play is a win or loss.
// To use the MCTS algorithm correctly this should be a simulation of the nodes' current
// state of the game until it finishes (if possible) and use an evaluation function to
// determine how good or bad the play was.
// e.g. Play tic tac toe choosing random squares until the game ends.
promisingNode.playerWon = (rand.nextInt(6) == 0);

isPlayerWinner = promisingNode.playerWon;

// Back propagation of the random play.
while (tempNode != null) {
tempNode.visitCount++;

// Add wining scores to bouth player and opponent depending on the turn.
if ((tempNode.isPlayersTurn && isPlayerWinner)
|| (!tempNode.isPlayersTurn && !isPlayerWinner)) {
tempNode.score += WIN_SCORE;
}

tempNode = tempNode.parent;
}
}

public Node getWinnerNode(Node rootNode) {
return Collections.max(rootNode.childNodes, Comparator.comparing(c -> c.score));
}

public void printScores(Node rootNode) {
System.out.println("N.\tScore\t\tVisits");

for (int i = 0; i < rootNode.childNodes.size(); i++) {
System.out.println(String.format("%02d\t%d\t\t%d", i + 1,
rootNode.childNodes.get(i).score, rootNode.childNodes.get(i).visitCount));
}
}
}