diff --git a/.gitignore b/.gitignore index a512ec7..082836e 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ dmypy.json .pyre/ .vscode/settings.json jpl.code-workspace +tmp/ diff --git a/examples/pong-RL.ipynb b/examples/pong-RL.ipynb index b2c80bb..d888cf0 100644 --- a/examples/pong-RL.ipynb +++ b/examples/pong-RL.ipynb @@ -87,8 +87,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 3.95 ms, sys: 0 ns, total: 3.95 ms\n", - "Wall time: 3.52 ms\n" + "CPU times: user 3.26 ms, sys: 0 ns, total: 3.26 ms\n", + "Wall time: 2.43 ms\n" ] } ], @@ -105,8 +105,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 4.71 ms, sys: 8.33 ms, total: 13 ms\n", - "Wall time: 722 ms\n" + "CPU times: user 5.71 ms, sys: 5.79 ms, total: 11.5 ms\n", + "Wall time: 834 ms\n" ] } ], @@ -123,25 +123,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1.3 ms, sys: 1.19 ms, total: 2.49 ms\n", - "Wall time: 117 ms\n" - ] - } - ], - "source": [ - "%time _ = pong.step()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.58 ms ± 216 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "2.41 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -151,17 +133,17 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAIAAACVT/22AAAD3ElEQVR4nO3dv25TVwDAYbu+GVBKn4CtE1KFVGBoqUQeokMnHoKpS16A52DqwENQqWqH/lk6MfMGgBjiKB0sWZWN4mDZNz8n37fFOopOcn46J76+saenz99MoOqL654AXEagpAmUNIGSJlDSBEqaQEkTKGkCJU2gpA0bR/zy6sXxnXsjTIXb5sPHtz/9+PPlYzYHenzn3lfHX+9oSvB5HPGkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNI2f5gsh+L1X8+uOPLk0cu9zmSH7KCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESpq7mW6OA7pH6ersoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0ra5WeSTb7FyI+9UYDs7LMQOSppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNJGene7pyffbRzzz9//vnv3foTJcLm7d7/89uE3G4f9+vqPESZjByVtvzvocLT6/c/m89Uxs9l0Op1MJrNhthg/P1sdwzgWv//ZMFt8eXFxMT8/XxlzNAz/HzzZ83rtN9AnTx6vPPL7b3+uPPL9D48XP/ODB/cXj4xzdrBuZb3m5+fr67X8a205eK/r5YgnTaCk7feId1gfluB62UFJEyhpY38MzSVX7F2oL1tfOBfqQaC0jX3Er58Lywv1y9d/g88lb5VPvhZ/XYtiByVNoKTt94i/rqd+bGdlvc7m8/XX4kdmByVNoKSN9Fr88uxwof6AHA3DVf4Vwu123F4CJW2kC/WevB+E4DLZQUkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEnb5o76k0cvdz4PbpIdFmIHJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkbNo748PHtCPPgFrpKWtPT529GmApsxxFPmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJ+w8thXYc3bPVNQAAAABJRU5ErkJggg==\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAIAAACVT/22AAAD7UlEQVR4nO3dv25TVwDAYbt2BhToE7B1qlQhURhaKjUP0aFTH4KpCy/Q52Bi4CGohNqhf5ZOzLwBRQxxlA6WLHQNcRrim5/t79tiHUUnOT+dE19fx9Mnj19NoOqzm54AXESgpAmUNIGSJlDSBEqaQEkTKGkCJU2gpM03jnj2/JfjW3dHmAqH5u271z/+8PPFYzYHenzr7ufHX1zTlOD/ccSTJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGStvmT5tgbL/786f0vTx48vamZXJ5A982gwl3niCdNoKQJlDSBkiZQ0gRKmkBJE+i+ufzl9524YipQ0gRKmkBJEyhpbhY5IDtx+9KAQPfQLob4MY540gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKRd5Xa7D77Zap9u8eITXWMhdlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSNtL/B/3+5JuNY/7+6583b/4dYTJc7M6d2/e//mrjsF9f/D7CZOygpG13B50fDb//6WIxHDObTafTyWQym8+W4xenwzGMY/n7n81nyy/Pz88XZ2eDMUfz+fuDJ1ter+0G+ujRw8Ejv738Y/DIt989XP7M9+59uXxknLODdYP1Wpydra/X6q+11eCtrpcjnjSBkrbdI95hvVuC62UHJU2gpI39QV4XXLF3ob5sfeFcqAeB0jb2Eb9+Lqwu1K9e/w0+lzwoH3wt/qYWxQ5KmkBJ2+4Rf1NP/biawXqdLhbrr8WPzA5KmkBJG+m1+NXZ4UL9Djmazy/zVgi323G4BEraSBfqPXnfCcFlsoOSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSdpU76k8ePL32ebBPrrEQOyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRK2nzjiLfvXo8wDw7QZdKaPnn8aoSpwNU44kkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGn/AdnJfCCSO2EvAAAAAElFTkSuQmCC\n", "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -172,10 +154,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "68.7 µs ± 2.79 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit -n100 pong.get('WIDTH')" + ] }, { "cell_type": "code", @@ -190,7 +182,7 @@ "metadata": {}, "outputs": [], "source": [ - "gl = jupylet.rl.Games(['pong'] * 8)" + "games = jupylet.rl.Games(['pong'] * 4)" ] }, { @@ -202,13 +194,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 5.66 ms, sys: 30.2 ms, total: 35.9 ms\n", - "Wall time: 799 ms\n" + "CPU times: user 1.02 ms, sys: 13.7 ms, total: 14.7 ms\n", + "Wall time: 805 ms\n" ] } ], "source": [ - "%time gl.start()" + "%time games.start()" ] }, { @@ -220,12 +212,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "5.17 ms ± 561 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "3.14 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ - "%timeit -n100 gl.step()" + "%timeit -n100 games.step()" ] }, { @@ -237,13 +229,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2.2 ms, sys: 1.7 ms, total: 3.9 ms\n", - "Wall time: 5.49 ms\n" + "CPU times: user 839 µs, sys: 840 µs, total: 1.68 ms\n", + "Wall time: 2.99 ms\n" ] } ], "source": [ - "%time al = gl.step()" + "%time al = games.step()" ] }, { @@ -253,9 +245,9 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAIAAACVT/22AAAD3ElEQVR4nO3dv25TVwDAYbu+GVBKn4CtE1KFVGBoqUQeokMnHoKpS16A52DqwENQqWqH/lk6MfMGgBjiKB0sWZWN4mDZNz8n37fFOopOcn46J76+saenz99MoOqL654AXEagpAmUNIGSJlDSBEqaQEkTKGkCJU2gpA0bR/zy6sXxnXsjTIXb5sPHtz/9+PPlYzYHenzn3lfHX+9oSvB5HPGkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZK2+bM6OSCv/3p2xZEnj17udSa7YgclTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNLczXSjHMo9SldnByVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0ra5m+mTbwB08+6jYWs7LMQOSppASRMoaQIlTaCkCZQ0gZImUNIESppASRMoaQIlTaCkCZQ0gZImUNJGevvFpyffbRzzz9//vnv3foTJcLm7d7/89uE3G4f9+vqPESZjByVtvzvocLT6/c/m89Uxs9l0Op1MJrNhthg/P1sdwzgWv//ZMFt8eXFxMT8/XxlzNAz/HzzZ83rtN9AnTx6vPPL7b3+uPPL9D48XP/ODB/cXj4xzdrBuZb3m5+fr67X8a205eK/r5YgnTaCk7feId1gfluB62UFJEyhpY39O0iVX7F2oL1tfOBfqQaC0jX3Er58Lywv1y9d/g88lb5VPvhZ/XYtiByVNoKTt94i/rqd+bGdlvc7m8/XX4kdmByVNoKSN9Fr88uxwof6AHA3DVf4Vwu123F4CJW2kC/WevB+E4DLZQUkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEnb5o76k0cvdz4PbpIdFmIHJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkbNo748PHtCPPgFrpKWtPT529GmApsxxFPmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJ+w/nlXYcY/+SLgAAAABJRU5ErkJggg==\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOAAAADgCAIAAACVT/22AAAD7UlEQVR4nO3dv25TVwDAYbt2BhToE7B1qlQhURhaKjUP0aFTH4KpCy/Q52Bi4CGohNqhf5ZOzLwBRQxxlA6WLHQNcRrim5/t79tiHUUnOT+dE19fx9Mnj19NoOqzm54AXESgpAmUNIGSJlDSBEqaQEkTKGkCJU2gpM03jnj2/JfjW3dHmAqH5u271z/+8PPFYzYHenzr7ufHX1zTlOD/ccSTJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGStvmT5tgbL/786f0vTx48vamZXJ5A982gwl3niCdNoKQJlDSBkiZQ0gRKmkBJE+i+ufzl9524YipQ0gRKmkBJEyhpbhY5IDtx+9KAQPfQLob4MY540gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKRd5Xa7D77Zap9u8eITXWMhdlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSNtL/B/3+5JuNY/7+6583b/4dYTJc7M6d2/e//mrjsF9f/D7CZOygpG13B50fDb//6WIxHDObTafTyWQym8+W4xenwzGMY/n7n81nyy/Pz88XZ2eDMUfz+fuDJ1ter+0G+ujRw8Ejv738Y/DIt989XP7M9+59uXxknLODdYP1Wpydra/X6q+11eCtrpcjnjSBkrbdI95hvVuC62UHJU2gpI39QV4XXLF3ob5sfeFcqAeB0jb2Eb9+Lqwu1K9e/w0+lzwoH3wt/qYWxQ5KmkBJ2+4Rf1NP/biawXqdLhbrr8WPzA5KmkBJG+m1+NXZ4UL9Djmazy/zVgi323G4BEraSBfqPXnfCcFlsoOSJlDSBEqaQEkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGkCJU2gpAmUNIGSdpU76k8ePL32ebBPrrEQOyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRKmkBJEyhpAiVNoKQJlDSBkiZQ0gRK2nzjiLfvXo8wDw7QZdKaPnn8aoSpwNU44kkTKGkCJU2gpAmUNIGSJlDSBEqaQEkTKGn/AdnJfCCSO2EvAAAAAElFTkSuQmCC\n", "text/plain": [ - "" + "" ] }, "execution_count": 14, @@ -282,7 +274,7 @@ { "data": { "text/plain": [ - "(8, 3, 224, 224)" + "(4, 3, 224, 224)" ] }, "execution_count": 15, @@ -295,6 +287,13 @@ "batch.shape" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/pong-start.state b/examples/pong-start.state new file mode 100644 index 0000000..1445e7c Binary files /dev/null and b/examples/pong-start.state differ diff --git a/examples/pong.ipynb b/examples/pong.ipynb index 3404703..450adce 100644 --- a/examples/pong.ipynb +++ b/examples/pong.ipynb @@ -60,7 +60,8 @@ "\n", "from jupylet.app import App\n", "from jupylet.label import Label\n", - "from jupylet.sprite import Sprite" + "from jupylet.sprite import Sprite\n", + "from jupylet.state import State, load_state, save_state" ] }, { @@ -97,6 +98,13 @@ "window = app.window" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 7, @@ -219,16 +227,6 @@ "execution_count": 15, "metadata": {}, "outputs": [], - "source": [ - "sl = 0\n", - "sr = 0" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], "source": [ "@app.event\n", "def on_draw():\n", @@ -254,90 +252,100 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "vyl = 0\n", - "pyl = HEIGHT/2\n", + "state = State(\n", + " \n", + " sl = 0,\n", + " sr = 0,\n", + " \n", + " bvx = 192,\n", + " bvy = 192,\n", + " \n", + " vyl = 0,\n", + " pyl = HEIGHT/2,\n", "\n", - "vyr = 0\n", - "pyr = HEIGHT/2\n", + " vyr = 0,\n", + " pyr = HEIGHT/2,\n", "\n", - "left = False\n", - "right = False\n", + " left = False,\n", + " right = False,\n", "\n", - "key_a = False\n", - "key_d = False" + " key_a = False,\n", + " key_d = False,\n", + ")" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "@app.event\n", "def on_key_press(symbol, modifiers):\n", - " \n", - " global left, right, key_a, key_d\n", - " \n", + " \n", " if symbol == key.LEFT:\n", - " left = True\n", + " state.left = True\n", " \n", " if symbol == key.RIGHT:\n", - " right = True\n", + " state.right = True\n", " \n", " if symbol == key.A:\n", - " key_a = True\n", + " state.key_a = True\n", " \n", " if symbol == key.D:\n", - " key_d = True\n", + " state.key_d = True\n", " \n", "\n", "@app.event\n", "def on_key_release(symbol, modifiers):\n", " \n", - " global left, right, key_a, key_d\n", - " \n", " if symbol == key.LEFT:\n", - " left = False\n", + " state.left = False\n", " \n", " if symbol == key.RIGHT:\n", - " right = False\n", + " state.right = False\n", "\n", " if symbol == key.A:\n", - " key_a = False\n", + " state.key_a = False\n", " \n", " if symbol == key.D:\n", - " key_d = False\n", + " state.key_d = False\n", " \n", "\n", "@app.run_me_again_and_again(1/120)\n", "def update_pads(dt):\n", - " \n", - " global vyl, vyr, pyl, pyr\n", - " \n", - " if right:\n", - " pyr = min(HEIGHT, pyr + dt * 512)\n", " \n", - " if left:\n", - " pyr = max(0, pyr - dt * 512)\n", + " if state.right:\n", + " state.pyr = min(HEIGHT, state.pyr + dt * 512)\n", " \n", - " if key_a:\n", - " pyl = min(HEIGHT, pyl + dt * 512)\n", + " if state.left:\n", + " state.pyr = max(0, state.pyr - dt * 512)\n", " \n", - " if key_d:\n", - " pyl = max(0, pyl - dt * 512)\n", + " if state.key_a:\n", + " state.pyl = min(HEIGHT, state.pyl + dt * 512)\n", " \n", - " ayl = 200 * (pyl - padl.y)\n", - " vyl = vyl * 0.9 + (ayl * dt)\n", + " if state.key_d:\n", + " state.pyl = max(0, state.pyl - dt * 512)\n", + " \n", + " ayl = 200 * (state.pyl - padl.y)\n", + " ayr = 200 * (state.pyr - padr.y)\n", + "\n", + " state.vyl = state.vyl * 0.9 + (ayl * dt)\n", + " state.vyr = state.vyr * 0.9 + (ayr * dt)\n", " \n", - " ayr = 200 * (pyr - padr.y)\n", - " vyr = vyr * 0.9 + (ayr * dt)\n", - " \n", - " padl.y += vyl * dt\n", - " padr.y += vyr * dt\n", + " padl.y += state.vyl * dt\n", + " padr.y += state.vyr * dt\n", " \n", " padr.clip_position(WIDTH, HEIGHT)\n", " padl.clip_position(WIDTH, HEIGHT)" @@ -352,97 +360,85 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "bvx = 192\n", - "bvy = 192" - ] - }, - { - "cell_type": "code", - "execution_count": 20, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "@app.run_me_again_and_again(1/120)\n", "def update_ball(dt):\n", " \n", - " global bvx, bvy, sl, sr\n", - "\n", - " bs0 = bvx ** 2 + bvy ** 2\n", + " bs0 = state.bvx ** 2 + state.bvy ** 2\n", " \n", " ball.rotation += 200 * dt\n", " \n", - " ball.x += bvx * dt\n", - " ball.y += bvy * dt\n", + " ball.x += state.bvx * dt\n", + " ball.y += state.bvy * dt\n", " \n", " if ball.top >= HEIGHT:\n", " app.play_once(sound)\n", " ball.y -= ball.top - HEIGHT\n", - " bvy = -bvy\n", + " state.bvy = -state.bvy\n", " \n", " if ball.bottom <= 0:\n", " app.play_once(sound)\n", " ball.y -= ball.bottom\n", - " bvy = -bvy\n", + " state.bvy = -state.bvy\n", " \n", " if ball.right >= WIDTH:\n", " app.play_once(sound)\n", " ball.x -= ball.right - WIDTH\n", " \n", - " bvx = -192\n", - " bvy = 192 * np.sign(bvy)\n", + " state.bvx = -192\n", + " state.bvy = 192 * np.sign(state.bvy)\n", " bs0 = 0\n", " \n", - " sl += 1\n", - " scorel.text = str(sl)\n", + " state.sl += 1\n", + " scorel.text = str(state.sl)\n", " \n", " if ball.left <= 0:\n", " app.play_once(sound)\n", " ball.x -= ball.left\n", " \n", - " bvx = 192\n", - " bvy = 192 * np.sign(bvy)\n", + " state.bvx = 192\n", + " state.bvy = 192 * np.sign(state.bvy)\n", " bs0 = 0\n", " \n", - " sr += 1\n", - " scorer.text = str(sr)\n", + " state.sr += 1\n", + " scorer.text = str(state.sr)\n", " \n", - " if bvx > 0 and ball.top >= padr.bottom and padr.top >= ball.bottom: \n", + " if state.bvx > 0 and ball.top >= padr.bottom and padr.top >= ball.bottom: \n", " if 0 < ball.right - padr.left < 10:\n", " app.play_once(sound)\n", " ball.x -= ball.right - padr.left\n", - " bvx = -bvx\n", - " bvy += vyr / 2\n", + " state.bvx = -state.bvx\n", + " state.bvy += state.vyr / 2\n", " \n", - " if bvx < 0 and ball.top >= padl.bottom and padl.top >= ball.bottom: \n", + " if state.bvx < 0 and ball.top >= padl.bottom and padl.top >= ball.bottom: \n", " if 0 < padl.right - ball.left < 10:\n", " app.play_once(sound)\n", " ball.x += ball.left - padl.right\n", - " bvx = -bvx\n", - " bvy += vyl / 2\n", + " state.bvx = -state.bvx\n", + " state.bvy += state.vyl / 2\n", " \n", - " bs1 = bvx ** 2 + bvy ** 2\n", + " bs1 = state.bvx ** 2 + state.bvy ** 2\n", " \n", " if bs1 < 0.9 * bs0:\n", - " bvx = (bs0 - bvy ** 2) ** 0.5 * np.sign(bvx)\n", + " state.bvx = (bs0 - state.bvy ** 2) ** 0.5 * np.sign(state.bvx)\n", "\n", " ball.wrap_position(WIDTH, HEIGHT)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "@app.run_me_now()\n", "def highlights(dt):\n", " \n", - " sl0 = sl\n", - " sr0 = sr\n", + " sl0 = state.sl\n", + " sr0 = state.sr\n", " \n", " slc = np.array(scorel.color)\n", " src = np.array(scorer.color)\n", @@ -456,38 +452,23 @@ " scorel.color = np.array(scorel.color) * r0 + (1 - r0) * slc\n", " scorer.color = np.array(scorer.color) * r0 + (1 - r0) * src\n", " \n", - " if sl0 != sl:\n", - " sl0 = sl\n", + " if sl0 != state.sl:\n", + " sl0 = state.sl\n", " scorel.color = 'white'\n", "\n", - " if sr0 != sr:\n", - " sr0 = sr\n", + " if sr0 != state.sr:\n", + " sr0 = state.sr\n", " scorer.color = 'white'\n", " " ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 20, "metadata": { "scrolled": false }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c30db9c85e194e978e03ce07d211339d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Canvas(layout=Layout(height='512px', width='512px'), size=(512, 512))" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "app.run()" ] @@ -499,11 +480,81 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def step(player0=[0, 0], player1=[0, 0], n=1):\n", + " \n", + " state.key_a, state.key_d = player0\n", + " \n", + " state.left, state.right = player1\n", + " \n", + " sl0 = state.sl\n", + " sr0 = state.sr\n", + " \n", + " if app.mode == 'hidden': \n", + " app.step(n)\n", + " \n", + " a = app.array0\n", + " \n", + " return {\n", + " 'screen0': a,\n", + " 'player0': {'reward': state.sl - sl0},\n", + " 'player1': {'reward': state.sr - sr0},\n", + " }" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "START = 'pong-start.state'\n", + "\n", + "\n", + "def reset():\n", + " load(START)\n", + " \n", + " \n", + "def load(path):\n", + " load_state(path, state, ball, padl, padr, scorel, scorer)\n", + " \n", + "\n", + "def save(path=None):\n", + " return save_state('pong', path, state, ball, padl, padr, scorel, scorer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Ignoring call to stop() since it appears to have been done accidentally." + ] + } + ], "source": [ "app.stop()" ] diff --git a/examples/pong.py b/examples/pong.py index 3074e88..798f84c 100644 --- a/examples/pong.py +++ b/examples/pong.py @@ -52,6 +52,7 @@ import jupylet.color from jupylet.app import App from jupylet.label import Label from jupylet.sprite import Sprite +from jupylet.state import State, load_state, save_state if __name__ == '__main__': @@ -100,10 +101,6 @@ scorer = Label( ) -sl = 0 -sr = 0 - - @app.event def on_draw(): @@ -119,151 +116,147 @@ def on_draw(): padr.draw() -vyl = 0 -pyl = HEIGHT/2 +state = State( + + sl = 0, + sr = 0, + + bvx = 192, + bvy = 192, + + vyl = 0, + pyl = HEIGHT/2, -vyr = 0 -pyr = HEIGHT/2 + vyr = 0, + pyr = HEIGHT/2, -left = False -right = False - -key_a = False -key_d = False + left = False, + right = False, + key_a = False, + key_d = False, +) @app.event def on_key_press(symbol, modifiers): - - global left, right, key_a, key_d - + if symbol == key.LEFT: - left = True + state.left = True if symbol == key.RIGHT: - right = True + state.right = True if symbol == key.A: - key_a = True + state.key_a = True if symbol == key.D: - key_d = True + state.key_d = True @app.event def on_key_release(symbol, modifiers): - global left, right, key_a, key_d - if symbol == key.LEFT: - left = False + state.left = False if symbol == key.RIGHT: - right = False + state.right = False if symbol == key.A: - key_a = False + state.key_a = False if symbol == key.D: - key_d = False + state.key_d = False @app.run_me_again_and_again(1/120) def update_pads(dt): - - global vyl, vyr, pyl, pyr - - if right: - pyr = min(HEIGHT, pyr + dt * 512) - if left: - pyr = max(0, pyr - dt * 512) + if state.right: + state.pyr = min(HEIGHT, state.pyr + dt * 512) - if key_a: - pyl = min(HEIGHT, pyl + dt * 512) + if state.left: + state.pyr = max(0, state.pyr - dt * 512) - if key_d: - pyl = max(0, pyl - dt * 512) + if state.key_a: + state.pyl = min(HEIGHT, state.pyl + dt * 512) - ayl = 200 * (pyl - padl.y) - vyl = vyl * 0.9 + (ayl * dt) + if state.key_d: + state.pyl = max(0, state.pyl - dt * 512) + + ayl = 200 * (state.pyl - padl.y) + ayr = 200 * (state.pyr - padr.y) + + state.vyl = state.vyl * 0.9 + (ayl * dt) + state.vyr = state.vyr * 0.9 + (ayr * dt) - ayr = 200 * (pyr - padr.y) - vyr = vyr * 0.9 + (ayr * dt) - - padl.y += vyl * dt - padr.y += vyr * dt + padl.y += state.vyl * dt + padr.y += state.vyr * dt padr.clip_position(WIDTH, HEIGHT) padl.clip_position(WIDTH, HEIGHT) -bvx = 192 -bvy = 192 - - @app.run_me_again_and_again(1/120) def update_ball(dt): - global bvx, bvy, sl, sr - - bs0 = bvx ** 2 + bvy ** 2 + bs0 = state.bvx ** 2 + state.bvy ** 2 ball.rotation += 200 * dt - ball.x += bvx * dt - ball.y += bvy * dt + ball.x += state.bvx * dt + ball.y += state.bvy * dt if ball.top >= HEIGHT: app.play_once(sound) ball.y -= ball.top - HEIGHT - bvy = -bvy + state.bvy = -state.bvy if ball.bottom <= 0: app.play_once(sound) ball.y -= ball.bottom - bvy = -bvy + state.bvy = -state.bvy if ball.right >= WIDTH: app.play_once(sound) ball.x -= ball.right - WIDTH - bvx = -192 - bvy = 192 * np.sign(bvy) + state.bvx = -192 + state.bvy = 192 * np.sign(state.bvy) bs0 = 0 - sl += 1 - scorel.text = str(sl) + state.sl += 1 + scorel.text = str(state.sl) if ball.left <= 0: app.play_once(sound) ball.x -= ball.left - bvx = 192 - bvy = 192 * np.sign(bvy) + state.bvx = 192 + state.bvy = 192 * np.sign(state.bvy) bs0 = 0 - sr += 1 - scorer.text = str(sr) + state.sr += 1 + scorer.text = str(state.sr) - if bvx > 0 and ball.top >= padr.bottom and padr.top >= ball.bottom: + if state.bvx > 0 and ball.top >= padr.bottom and padr.top >= ball.bottom: if 0 < ball.right - padr.left < 10: app.play_once(sound) ball.x -= ball.right - padr.left - bvx = -bvx - bvy += vyr / 2 + state.bvx = -state.bvx + state.bvy += state.vyr / 2 - if bvx < 0 and ball.top >= padl.bottom and padl.top >= ball.bottom: + if state.bvx < 0 and ball.top >= padl.bottom and padl.top >= ball.bottom: if 0 < padl.right - ball.left < 10: app.play_once(sound) ball.x += ball.left - padl.right - bvx = -bvx - bvy += vyl / 2 + state.bvx = -state.bvx + state.bvy += state.vyl / 2 - bs1 = bvx ** 2 + bvy ** 2 + bs1 = state.bvx ** 2 + state.bvy ** 2 if bs1 < 0.9 * bs0: - bvx = (bs0 - bvy ** 2) ** 0.5 * np.sign(bvx) + state.bvx = (bs0 - state.bvy ** 2) ** 0.5 * np.sign(state.bvx) ball.wrap_position(WIDTH, HEIGHT) @@ -271,8 +264,8 @@ def update_ball(dt): @app.run_me_now() def highlights(dt): - sl0 = sl - sr0 = sr + sl0 = state.sl + sr0 = state.sr slc = np.array(scorel.color) src = np.array(scorer.color) @@ -286,14 +279,50 @@ def highlights(dt): scorel.color = np.array(scorel.color) * r0 + (1 - r0) * slc scorer.color = np.array(scorer.color) * r0 + (1 - r0) * src - if sl0 != sl: - sl0 = sl + if sl0 != state.sl: + sl0 = state.sl scorel.color = 'white' - if sr0 != sr: - sr0 = sr + if sr0 != state.sr: + sr0 = state.sr scorer.color = 'white' - + + +def step(player0=[0, 0], player1=[0, 0], n=1): + + state.key_a, state.key_d = player0 + + state.left, state.right = player1 + + sl0 = state.sl + sr0 = state.sr + + if app.mode == 'hidden': + app.step(n) + + a = app.array0 + + return { + 'screen0': a, + 'player0': {'reward': state.sl - sl0}, + 'player1': {'reward': state.sr - sr0}, + } + + +START = 'pong-start.state' + + +def reset(): + load(START) + + +def load(path): + load_state(path, state, ball, padl, padr, scorel, scorer) + + +def save(path=None): + return save_state('pong', path, state, ball, padl, padr, scorel, scorer) + if __name__ == '__main__': app.run() diff --git a/examples/spaceship.ipynb b/examples/spaceship.ipynb index 5b165fb..5c2e6a6 100644 --- a/examples/spaceship.ipynb +++ b/examples/spaceship.ipynb @@ -64,16 +64,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "app = App(mode='window')" + "app = App(mode='jupyter')" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -111,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -138,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -180,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -194,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -226,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -264,11 +264,26 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5c9da688cc944d5793808e0860b034ca", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Canvas(layout=Layout(height='512px', width='512px'), size=(512, 512))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "app.run()" ] diff --git a/jupylet/label.py b/jupylet/label.py index 7dc7d11..562873c 100644 --- a/jupylet/label.py +++ b/jupylet/label.py @@ -29,6 +29,7 @@ import webcolors import pyglet from .color import color2rgb +from .state import State class Label(pyglet.text.Label): @@ -82,3 +83,19 @@ class Label(pyglet.text.Label): self.document.set_style(0, len(self.document.text), {'color': color}) + def get_state(self): + + return State( + x = self.x, + y = self.y, + text = self.text, + color = self.color, + ) + + def set_state(self, s): + + self.x = s.x + self.y = s.y + self.text = s.text + self.color = s.color + diff --git a/jupylet/rl.py b/jupylet/rl.py index 74d277b..72cfe8a 100644 --- a/jupylet/rl.py +++ b/jupylet/rl.py @@ -269,11 +269,11 @@ class ModuleProcess(object): class GameProcess(ModuleProcess): - def start(self, size=224): + def start(self, interval=1/30, size=224): super(GameProcess, self).start() - self.call('app.start') + self.call('app.start', interval) self.call('app.scale_window_to', size) self.call('app.step') @@ -290,7 +290,7 @@ class Games(object): else: self.games = games - def start(self, size=224): + def start(self, interval=1/30, size=224): for g in self.games: if type(g) is GameProcess: @@ -298,7 +298,7 @@ class Games(object): else: g.start() - self.call('app.start') + self.call('app.start', interval) self.call('app.scale_window_to', size) self.call('app.step') diff --git a/jupylet/sprite.py b/jupylet/sprite.py index e5bedac..f748597 100644 --- a/jupylet/sprite.py +++ b/jupylet/sprite.py @@ -34,6 +34,7 @@ import numpy as np from .collision import trbl, hitmap_and_outline_from_alpha, compute_collisions from .resource import image_from, pil_open +from .state import State _empty_array = np.array([]) @@ -270,6 +271,30 @@ class Sprite(pyglet.sprite.Sprite): im0 = PIL.Image.frombytes('RGBA', (id0.width, id0.height), id1) return im0 + def get_state(self): + + return State( + x = self.x, + y = self.y, + scale = self.scale, + opacity = self.opacity, + rotation = self.rotation, + anchor_x = self.anchor_x, + anchor_y = self.anchor_y, + ) + + def set_state(self, s): + + self.x = s.x + self.y = s.y + self.scale = s.scale + self.opacity = s.opacity + self.rotation = s.rotation + self.anchor_x = s.anchor_x + self.anchor_y = s.anchor_y + + self._update_position() + def canvas2sprite(c): diff --git a/jupylet/state.py b/jupylet/state.py new file mode 100644 index 0000000..bcf3844 --- /dev/null +++ b/jupylet/state.py @@ -0,0 +1,78 @@ +""" + jupylet/state.py + + Copyright (c) 2020, Nir Aides - nir@winpdb.org + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import hashlib +import random +import pickle + + +def o2h(o, n=12): + return hashlib.sha256(pickle.dumps(o)).hexdigest()[:n] + + +def save_state(name, path, *args): + + if not path: + path = '%s-%s.state' % (name, o2h(random.random())) + + with open(path, 'wb') as f: + sl = [o.get_state() for o in args] + pickle.dump(sl, f) + + return path + + +def load_state(path, *args): + + with open(path, 'rb') as f: + sl = pickle.load(f) + for o, s in zip(args, sl): + o.set_state(s) + + +class State(object): + + def __init__(self, **kwargs): + + for k, v in kwargs.items(): + setattr(self, k, v) + + def __repr__(self): + return repr(self.__dict__) + + def __setitem__(self, key, item): + self.__dict__[key] = item + + def __getitem__(self, key): + return self.__dict__[key] + + def get_state(self): + return self + + def set_state(self, s): + self.__dict__ = vars(s) +