I am coding a AI project that involves a MCTS search. After each iteration the search is supposed to backpropagate the data back and increase the visit_count of each parent also by 1, which is in this function here:
def backpropagate(self, value):
self.value_sum += value
self.visit_count += 1
value = self.game.get_opponent_value(value)
if self.parent is not None:
self.parent.backpropagate(value)
The function is being called here in the search():
@torch.no_grad()
def search(self, state):
print(type(state))
root = Node(self.game, self.args, state, visit_count=1)
policy, _ = self.model(
torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
)
policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
* np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
valid_moves = self.game.get_valid_moves(state)
policy *= valid_moves
policy /= np.sum(policy)
root.expand(policy)
for search in range(self.args['num_searches']):
node = root
while node.is_fully_expanded():
node = node.select()
value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
value = self.game.get_opponent_value(value)
if not is_terminal:
policy, value = self.model(
torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
)
policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
valid_moves = self.game.get_valid_moves(node.state)
# print(policy) if random.randint(1, 100) == 100 else print(None)
policy *= valid_moves
policy /= np.sum(policy)
value = value.item()
node.expand(policy)
node.backpropagate(value)
action_probs = np.zeros(self.game.action_size)
for child in root.children:
action_probs[child.action_taken] = child.visit_count
action_probs /= np.sum(action_probs)
return action_probs
However, when the code runs the visit counts for all but one child of the root node is always zero.
Can anyone help? This problem has been persisting for a long time and I can't seem to find a solution.
I tried checking the visit_count in the backpropagate function when the grandparent == None (parent is root node), and it returned some valid values for visit_count, but it always disappears after
def backpropagate(self, value):
self.value_sum += value
self.visit_count += 1
# print(self.parent.visit_count) if self.parent is not None and self.parent.parent is not None else print(None)
value = -value
if self.parent is not None:
if self.parent.parent is None:
print(self.visit_count)
self.parent.backpropagate(value)