I am working on a temporal difference learning example (https://www.youtube.com/watch?v=XrxgdpduWOU), and I'm having some trouble with the following equation in my python implementation as I seem to be double counting rewards and Q.
If I coded the grid below as a 2d array, my current location is (2, 2) and the goal is (2, 3), assuming max reward is 1. Let Q(t) be the average mean of my current location, then r(t+1) is 1 and I assume max Q(t+1) is also 1, which results in my Q(t) becoming close to 2 (assuming gamma of 1). Is this correct, or should I assume that Q(n), where n is the end point is 0?
Edited to include code - I modified the get_max_q function to return 0 if it is the end point and the values are all now below 1 (which I assume is more correct since reward is just 1) but not sure if this is the right approach (previously I set it to return 1 when it was the end point).
#not sure if this is correct
def get_max_q(q, pos):
#end point
#not sure if I should set this to 0 or 1
if pos == (MAX_ROWS - 1, MAX_COLS - 1):
return 0
return max([q[pos, am] for am in available_moves(pos)])
def learn(q, old_pos, action, reward):
new_pos = get_new_pos(old_pos, action)
max_q_next_move = get_max_q(q, new_pos)
q[(old_pos, action)] = q[old_pos, action] + alpha * (reward + max_q_next_move - q[old_pos, action]) -0.04
def move(q, curr_pos):
moves = available_moves(curr_pos)
if random.random() < epsilon:
action = random.choice(moves)
else:
index = np.argmax([q[m] for m in moves])
action = moves[index]
new_pos = get_new_pos(curr_pos, action)
#end point
if new_pos == (MAX_ROWS - 1, MAX_COLS - 1):
reward = 1
else:
reward = 0
learn(q, curr_pos, action, reward)
return get_new_pos(curr_pos, action)
=======================
OUTPUT
Average value (after I set Q(end point) to 0)
defaultdict(float,
{((0, 0), 'DOWN'): 0.5999999999999996,
((0, 0), 'RIGHT'): 0.5999999999999996,
...
((2, 2), 'UP'): 0.7599999999999998})
Average value (after I set Q(end point) to 1)
defaultdict(float,
{((0, 0), 'DOWN'): 1.5999999999999996,
((0, 0), 'RIGHT'): 1.5999999999999996,
....
((2, 2), 'LEFT'): 1.7599999999999998,
((2, 2), 'RIGHT'): 1.92,
((2, 2), 'UP'): 1.7599999999999998})

