{ "cells": [ { "cell_type": "markdown", "id": "de94f659", "metadata": {}, "source": [ "# Policy iteration\n", "In policy iteration, you start with an arbitrary policy.\n", "Then, the the policy is improved at every iteration by first creating a DTMC for the previous policy, and then applying whichever choice would be best in that DTMC for the updated policy." ] }, { "cell_type": "code", "execution_count": 1, "id": "232685bd", "metadata": { "execution": { "iopub.execute_input": "2026-03-26T10:42:03.361929Z", "iopub.status.busy": "2026-03-26T10:42:03.361678Z", "iopub.status.idle": "2026-03-26T10:42:03.596106Z", "shell.execute_reply": "2026-03-26T10:42:03.595584Z" } }, "outputs": [], "source": [ "from stormvogel import *\n", "from stormvogel.visualization import JSVisualization\n", "from time import sleep\n", "\n", "\n", "def arg_max(funcs, args):\n", " \"\"\"Takes a list of callables and arguments and return the argument that yields the highest value.\"\"\"\n", " executed = [f(x) for f, x in zip(funcs, args)]\n", " index = executed.index(max(executed))\n", " return args[index]\n", "\n", "\n", "def policy_iteration(\n", " model: Model,\n", " prop: str,\n", " visualize: bool = True,\n", " layout: Layout = stormvogel.layout.DEFAULT(),\n", " delay: int = 2,\n", " clear: bool = False,\n", ") -> Result:\n", " \"\"\"Performs policy iteration on the given mdp.\n", " Args:\n", " model (Model): MDP.\n", " prop (str): PRISM property string to maximize. Rembember that this is a property on the induced DTMC, not the MDP.\n", " visualize (bool): Whether the intermediate and final results should be visualized. Defaults to True.\n", " layout (Layout): Layout to use to show the intermediate results.\n", " delay (int): Seconds to wait between each iteration.\n", " clear (bool): Whether to clear the visualization of each previous iteration.\n", " \"\"\"\n", " old = None\n", " new = random_scheduler(model)\n", "\n", " while not old == new:\n", " old = new\n", "\n", " dtmc = old.generate_induced_dtmc()\n", " dtmc_result = model_checking(dtmc, prop=prop)\n", "\n", " if visualize:\n", " mapped_values = {\n", " model.states[i]: dtmc_result.values.get(dtmc.states[i])\n", " for i in range(len(model.states))\n", " }\n", " mapped_result = Result(model, mapped_values, old)\n", " vis = JSVisualization(model, scheduler=old, result=mapped_result)\n", " vis.show()\n", " sleep(delay)\n", " if clear:\n", " vis.clear()\n", "\n", " choices = {}\n", " for i, s1 in enumerate(model.states):\n", "\n", " def compute_val(a):\n", " val = 0\n", " for p, s2 in s1.get_outgoing_transitions(a):\n", " # We get the state index in the original model, and look up in DTMC\n", " s2_idx = model.get_state_index(s2)\n", " dtmc_s2 = dtmc.states[s2_idx]\n", " val += p * dtmc_result.get_result_of_state(dtmc_s2)\n", " return val\n", "\n", " # arg_max evaluates the functions over the arguments, so we pass a list of lambdas\n", " lambdas = [\n", " (lambda a, compute_val=compute_val: compute_val(a))\n", " for _ in s1.available_actions()\n", " ]\n", " best_action = arg_max(lambdas, s1.available_actions())\n", " choices[s1] = best_action\n", "\n", " new = Scheduler(model, choices)\n", " if visualize:\n", " print(\"Value iteration done:\")\n", " mapped_values = {\n", " model.states[i]: dtmc_result.values.get(dtmc.states[i])\n", " for i in range(len(model.states))\n", " }\n", " mapped_result = Result(model, mapped_values, new)\n", " show(model, scheduler=new, result=mapped_result)\n", " return dtmc_result" ] }, { "cell_type": "code", "execution_count": 2, "id": "5f94eba4", "metadata": { "execution": { "iopub.execute_input": "2026-03-26T10:42:03.598031Z", "iopub.status.busy": "2026-03-26T10:42:03.597795Z", "iopub.status.idle": "2026-03-26T10:42:07.884831Z", "shell.execute_reply": "2026-03-26T10:42:07.884212Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9b152e0ddea94aecb74349e651615ee0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "02674914776a43dcab267e0c463c76e4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Value iteration done:\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "
\n", "