Lecture 05: MDP Coding Activity
MDP Coding Activity
In this activity, you will implement a simple Monte Carlo method to estimate the value function of a given Markov Decision Process (MDP). Follow the steps below to complete the code.
Use the provided below to understand the set up of MDP.
MDP FrozenLake Activity
# %%
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
env = gym.make('FrozenLake-v1', render_mode="rgb_array", map_name="4x4", is_slippery=True) # or you can try '8x8'
env.reset()
plt.imshow(env.render())
plt.show()
n_state = env.env.observation_space.n
n_action = env.env.action_space.n
print("# of actions", n_action)
print("# of states", n_state)
P = env.env.env.env.P
# At state 14 apply action 2
for prob, next_state, reward, done in P[14][2]:
print("If apply action 2 under state 14, there is %3.2g probability it will transition to state %d and yield reward as %i" % (
prob, next_state, reward))
def compute_policy_v(env, policy, gamma=1.0):
# The goal of this function is to calculate the Q(s,a) for each action under each state
eps = 1e-10
V = np.zeros(n_state)
while True:
prev_V = np.copy(V)
# TODO: Part 2
# Calculate V until it converges
# Use Bellman Equation to update V(s) with V(s')
# Hint: Use following code
# for action, prob_act in enumerate(policy[state]):
# To iterate all actions and its probs output by the policy for a given state
# This is used to check if V is converged.
if (np.sum((np.fabs(prev_V - V))) <= eps):
break
return V
# TODO: Part 1
# Create a random policy that evenly assigns probabilities to all
# possible actions.
#
# As a result, policy[s] will yield a 4-entry array indicating the probability
# of going four directions, which is a quarter.
# Change the following line to create such a policy
policy = np.zeros((16, 4))
# The following code is used to verify your V
V = compute_policy_v(env, policy)
true_V = np.array([0.0139398, 0.01163093, 0.02095299, 0.01047649, 0.01624867, 0.,
0.04075154, 0., 0.0348062, 0.08816993, 0.14205316, 0.,
0., 0.17582037, 0.43929118, 0.])
assert ((V - true_V) ** 2).sum() < 1e-3, "Your V is not correct!"
print("Nice, your V calculation is perfect!")
Copy or download the code above into your WLS2 or local machine. Complete the TODO parts in the code.