{ "cells": [ { "cell_type": "markdown", "id": "8966a060", "metadata": {}, "source": [ "# Policy iteration\n", "In policy iteration, you start with an arbitrary policy.\n", "Then, the the policy is improved at every iteration by first creating a DTMC for the previous policy, and then applying whichever choice would be best in that DTMC for the updated policy." ] }, { "cell_type": "code", "execution_count": 1, "id": "2b549cc1", "metadata": { "execution": { "iopub.execute_input": "2026-03-26T10:47:37.903974Z", "iopub.status.busy": "2026-03-26T10:47:37.903805Z", "iopub.status.idle": "2026-03-26T10:47:38.100973Z", "shell.execute_reply": "2026-03-26T10:47:38.100342Z" } }, "outputs": [], "source": [ "from stormvogel import *\n", "from stormvogel.visualization import JSVisualization\n", "from time import sleep\n", "\n", "\n", "def arg_max(funcs, args):\n", " \"\"\"Takes a list of callables and arguments and return the argument that yields the highest value.\"\"\"\n", " executed = [f(x) for f, x in zip(funcs, args)]\n", " index = executed.index(max(executed))\n", " return args[index]\n", "\n", "\n", "def policy_iteration(\n", " model: Model,\n", " prop: str,\n", " visualize: bool = True,\n", " layout: Layout = stormvogel.layout.DEFAULT(),\n", " delay: int = 2,\n", " clear: bool = False,\n", ") -> Result:\n", " \"\"\"Performs policy iteration on the given mdp.\n", " Args:\n", " model (Model): MDP.\n", " prop (str): PRISM property string to maximize. Rembember that this is a property on the induced DTMC, not the MDP.\n", " visualize (bool): Whether the intermediate and final results should be visualized. Defaults to True.\n", " layout (Layout): Layout to use to show the intermediate results.\n", " delay (int): Seconds to wait between each iteration.\n", " clear (bool): Whether to clear the visualization of each previous iteration.\n", " \"\"\"\n", " old = None\n", " new = random_scheduler(model)\n", "\n", " while not old == new:\n", " old = new\n", "\n", " dtmc = old.generate_induced_dtmc()\n", " dtmc_result = model_checking(dtmc, prop=prop)\n", "\n", " if visualize:\n", " vis = JSVisualization(\n", " model, layout=layout, scheduler=old, result=dtmc_result\n", " )\n", " vis.show()\n", " sleep(delay)\n", " if clear:\n", " vis.clear()\n", "\n", " choices = {\n", " i: arg_max(\n", " [\n", " lambda a: sum(\n", " [\n", " (p * dtmc_result.get_result_of_state(s2.id))\n", " for p, s2 in s1.get_outgoing_transitions(a)\n", " ]\n", " )\n", " for _ in s1.available_actions()\n", " ],\n", " s1.available_actions(),\n", " )\n", " for i, s1 in model.states.items()\n", " }\n", " new = Scheduler(model, choices)\n", " if visualize:\n", " print(\"Value iteration done:\")\n", " show(model, layout=layout, scheduler=new, result=dtmc_result)\n", " return dtmc_result" ] }, { "cell_type": "code", "execution_count": 2, "id": "3145cda1", "metadata": { "execution": { "iopub.execute_input": "2026-03-26T10:47:38.103135Z", "iopub.status.busy": "2026-03-26T10:47:38.102886Z", "iopub.status.idle": "2026-03-26T10:47:42.368891Z", "shell.execute_reply": "2026-03-26T10:47:42.368324Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ac073ccc56ef48bcada58d68ca36eb05", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "87ae253a10c7404e8e97722c351a6214", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Value iteration done:\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "
\n", "