Policy iteration¶
In policy iteration, you start with an arbitrary policy. Then, the the policy is improved at every iteration by first creating a DTMC for the previous policy, and then applying whichever choice would be best in that DTMC for the updated policy.
[1]:
from stormvogel import *
from stormvogel.visualization import JSVisualization
from time import sleep
def arg_max(funcs, args):
"""Takes a list of callables and arguments and return the argument that yields the highest value."""
executed = [f(x) for f, x in zip(funcs, args)]
index = executed.index(max(executed))
return args[index]
def policy_iteration(
model: Model,
prop: str,
visualize: bool = True,
layout: Layout = stormvogel.layout.DEFAULT(),
delay: int = 2,
clear: bool = False,
) -> Result:
"""Performs policy iteration on the given mdp.
Args:
model (Model): MDP.
prop (str): PRISM property string to maximize. Rembember that this is a property on the induced DTMC, not the MDP.
visualize (bool): Whether the intermediate and final results should be visualized. Defaults to True.
layout (Layout): Layout to use to show the intermediate results.
delay (int): Seconds to wait between each iteration.
clear (bool): Whether to clear the visualization of each previous iteration.
"""
old = None
new = random_scheduler(model)
while not old == new:
old = new
dtmc = old.generate_induced_dtmc()
dtmc_result = model_checking(dtmc, prop=prop)
if visualize:
mapped_values = {
model.states[i]: dtmc_result.values.get(dtmc.states[i])
for i in range(len(model.states))
}
mapped_result = Result(model, mapped_values, old)
vis = JSVisualization(model, scheduler=old, result=mapped_result)
vis.show()
sleep(delay)
if clear:
vis.clear()
choices = {}
for i, s1 in enumerate(model.states):
def compute_val(a):
val = 0
for p, s2 in s1.get_outgoing_transitions(a):
# We get the state index in the original model, and look up in DTMC
s2_idx = model.get_state_index(s2)
dtmc_s2 = dtmc.states[s2_idx]
val += p * dtmc_result.get_result_of_state(dtmc_s2)
return val
# arg_max evaluates the functions over the arguments, so we pass a list of lambdas
lambdas = [
(lambda a, compute_val=compute_val: compute_val(a))
for _ in s1.available_actions()
]
best_action = arg_max(lambdas, s1.available_actions())
choices[s1] = best_action
new = Scheduler(model, choices)
if visualize:
print("Value iteration done:")
mapped_values = {
model.states[i]: dtmc_result.values.get(dtmc.states[i])
for i in range(len(model.states))
}
mapped_result = Result(model, mapped_values, new)
show(model, scheduler=new, result=mapped_result)
return dtmc_result
[2]:
lion = examples.create_lion_mdp()
prop = 'P=?[F "full"]'
res = policy_iteration(lion, prop)
Value iteration done:
Policy iteration is also available under stormvogel.extensions.visual_algos.