Monte Carlo tree search keeps giving the same result

109 views Asked by At

I wrote a Monte Carlo tree search algorithm (based on https://en.wikipedia.org/wiki/Monte_Carlo_tree_search), and connected it with the "python-chess" library.

Basically, the algorithm gets stuck somewhere, because it keeps printing as output "1/2-1/2" (draw).

There's probably an issue with the expansion function, but I don't really know where.

Here's the code:

class MonteCarloTreeSearch():
    def __init__(self):
        self.board = chess.Board()
        self.results = []
        self.total_matches = 0

def expansion(self, leaf):
    done = False
    if not (self.board.is_game_over()):
        if len(leaf.children != 0):
            possible = np.array([i for i in self.board.legal_moves])
            actual = []
            for i in range(len(possible)):
                test = possible[i]
                cnt = 0
                for j in range(len(leaf.children)):
                    if (not test == leaf.children[j].value) and cnt == 0:
                        actual.append(test)
                        cnt += 1
            move = random.choice(actual)
        else:
            move = random.choice([i for i in self.board.legal_moves])                
        
        self.board.push(move)            
        if self.board.is_game_over():
            done = True
        
        child = Node(move, player="white" if leaf.player=="black" else "black", parent=[leaf], score=0)
        
        leaf.children = np.append(leaf.children, child)
        
        if not done:
            return self.expansion(child)
    return leaf
            
def playout(self, starting_node):        
    childr = self.expansion(starting_node)  
    
    result = self.board.result()
    if result == "1-0":
        result = 1
    elif result == "0-1":
        result = -1
    elif result == "1/2-1/2":
        result = .5
    elif result == "*":
        raise Exception("ERROR: Game was not over, but playout stopped.")
    else:
        raise Exception("ERROR: Playout process error.")
                    
    return childr, result, starting_node

def expansion_playout_backpropagation(self, start):
    node, result, starting_node = self.playout(start)
            
    i = 0
    while i == 0: 
        node.matches += 1
        if node.player == "white":
            if result == 1:
                node.score += 1
                
            if node.matches > 0:
                node.winp = node.score / node.matches
            else:
                node.winp = 0
        elif node.player == "black":
            if result == -1:
                node.score += 1
                
            if node.matches > 0:
                node.winp = node.score / node.matches
            else:
                node.winp = 0
        else:
            raise Exception("ERROR: Invalid player selection.")
            
        if node.is_equal(starting_node):
            i += 1
        else:
            node = node.parent[0]
        
    self.results.append(self.board.result())
    print(self.board.result())
        
    self.board.reset_board()
    self.total_matches += 1

def backpropagation(self, node, result, starting_node):
        i = 0
        while i == 0:
            node.matches += 1
            if node.player == "white":
                if result == 1:
                    node.score += 1
                    
                if node.matches > 0:
                    node.winp = node.score / node.matches
                else:
                    node.winp = 0
            elif node.player == "black":
                if result == -1:
                    node.score += 1
                    
                if node.matches > 0:
                    node.winp = node.score / node.matches
                else:
                    node.winp = 0
            else:
                raise Exception("ERROR: Invalid player selection.")

            if node.is_equal(starting_node):
                i += 1
            else:
                node = node.parent[0]

        self.results.append(self.board.result())
        print(self.board.result())

        self.board.reset_board()
        self.total_matches += 1
        
            
def fitness(self, node):
    p = node.winp
    simulations = node.matches
    parent_simulations = node.parent[0].matches if node.parent[0] != None else self.matches
    c = math.sqrt(2)
    if simulations > 0 and parent_simulations > 0:
        formula = p + c * math.sqrt((np.log(parent_simulations)) / simulations)
    else:
        formula = p
    
    return formula
0

There are 0 answers