{ "cells": [ { "cell_type": "markdown", "metadata": { "toc": "true" }, "source": [ "

Table of Contents

\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# quick start " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "No information on wall, this very nature case which needs the exploration, try to see its relation of generalization to performance " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Resources\n", "\n", "* [*The* Reinforcement learning book from Sutton & Barto](http://incompleteideas.net/sutton/book/the-book-2nd.html)\n", "* [The REINFORCE paper from Ronald J. Williams (1992)](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# FULL MODEL" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Populating the interactive namespace from numpy and matplotlib\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/cruiser/anaconda3/lib/python3.6/site-packages/IPython/core/magics/pylab.py:160: UserWarning: pylab import has clobbered these variables: ['random']\n", "`%matplotlib` prevents importing * from pylab and numpy\n", " \"\\n`%matplotlib` prevents importing * from pylab and numpy\"\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import numpy as np\n", "from itertools import count\n", "import random\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import torch.autograd as autograd\n", "from torch.autograd import Variable\n", "from torch.nn import init\n", "from torch.nn import DataParallel\n", "from torch.utils.data import DataLoader\n", "\n", "import matplotlib.mlab as mlab\n", "import matplotlib.pyplot as plt\n", "import matplotlib.animation\n", "import seaborn as sns\n", "from IPython.display import HTML\n", "\n", "import pretrain\n", "from pretrain import *\n", "\n", "import navigation2\n", "from navigation2 import *\n", "\n", "import Nets \n", "from Nets import * \n", "\n", "%pylab inline\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Qnetwork\n", "\n", "To select actions we take maximum of Q value, corresponding to certain move." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the liquid state approach to work, you need a lot of neurons as surplus or enough hidden to hidden connectivity to make it have an effect." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## POMDP RNN Game\n", "\n", "In this game , we use a new reward function determined by game, if the agent achieves the goal before 50, reward is 1. If time pass 50 reward is 0.5, once time pass 100 agent gets a reward of -0.5 . Practically, this is found to be easier to learn than the rewards as a continous function of time. Tf the agent learns to search in a efficient way, the largest possible way for search is to firstly arrive at corner then goes to the goal, which, takes about 50 steps, it is reasonble to make 50 and 100 as milestone thing. Also in principe as the game doesn't have a timer , it is not if it can use a reward as funtion of time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "3 condition for ending , when pass time limit, game over" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For weight update, it seems to be better do it after episode, as it makes non-sense evaluate strategy during episode, but a the end. Also, it is much quicker. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A programming of MDP here, hidden state is as state of enviroment" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "pregame = PretrainGame(grid_size = (15, 15), holes = 0, random_seed = 4 , set_reward = [(0.5, 0.25), (0.5, 0.75)])\n", "pregame.reset(set_agent=(2,2))\n", "# rls_q = RLS(1)\n", "# rls_sl = RLS(1)\n", "# for i in range(1):\n", "# pregame.fulltrain(trials = 4)\n", " " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "game = ValueMaxGame(pregame.net, grid_size = (15, 15), holes = 0, random_seed = 4 , set_reward = [(0.5, 0.25), (0.5, 0.75)])\n", "game.reset()\n", "# game.experiment(rls_q, rls_sl, 20, epsilon = 0.5, lr = 1e-3, train_hidden = False, train_q = False) " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQQAAAECCAYAAAAYUakXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAC39JREFUeJzt3X/sXXV9x/Hna7S0FiXANo1SMnAS\nNkPcMN9sqItbrCYMCfjHkmHG0k2TZsk2kZgohD/M/luicZpskTSAkklwCeIkRB0NasySQVZ+hBXK\nBH8MqtWymKnBDNr43h/f23dK09LunnvPuaTPR9J87709fN+f75f22XPuPd97UlVIEsAvTb0ASavD\nIEhqBkFSMwiSmkGQ1AyCpLYSQUhyWZL/TPJUkutHnn1ekq8n2ZvksSTXjjn/iHWcluThJPdMMPus\nJHcmeWL2fXjLyPOvm33v9yS5I8nmJc+7NcmBJHuOeOycJLuSPDn7ePbI8z82+/4/muSLSc5a1vyX\nMnkQkpwG/APwh8AbgfcmeeOISzgEfKiqfhO4FPjLkecfdi2wd4K5AJ8CvlpVvwH81pjrSHIu8AFg\nraouBk4Drl7y2M8Clx312PXAfVV1IXDf7P6Y83cBF1fVm4BvATcscf5xTR4E4HeAp6rqO1X1AvB5\n4KqxhlfV/qp6aHb7Z6z/ZTh3rPkASbYC7wZuHnPubPaZwNuBWwCq6oWq+p+Rl7EBeEWSDcAW4AfL\nHFZV3wR+fNTDVwG3zW7fBrxnzPlVdW9VHZrdvR/Yuqz5L2UVgnAu8MwR9/cx8l/Iw5KcD1wCPDDy\n6E8CHwZ+MfJcgNcDzwKfmR2y3JzkjLGGV9X3gY8DTwP7gZ9U1b1jzT/Ca6pq/2xN+4FXT7CGw94H\nfGWKwasQhBzjsdHPp07ySuALwAer6qcjzr0COFBVD4418ygbgDcDn66qS4DnWO7u8ovMjtWvAi4A\nXgeckeSaseavmiQ3sn4Ye/sU81chCPuA8464v5Ul7zIeLclG1mNwe1XdNeZs4G3AlUm+x/rh0juS\nfG7E+fuAfVV1eK/oTtYDMZZ3At+tqmer6iBwF/DWEecf9qMkrwWYfTww9gKSbAeuAP6kJvoho1UI\nwr8DFya5IMnprD+hdPdYw5OE9ePnvVX1ibHmHlZVN1TV1qo6n/Wv/WtVNdq/kFX1Q+CZJBfNHtoG\nPD7WfNYPFS5NsmX2/2Ib0zy5ejewfXZ7O/ClMYcnuQz4CHBlVf18zNkvUlWT/wIuZ/2Z1W8DN448\n+/dYP0R5FHhk9uvyib4PfwDcM8Hc3wZ2z74H/wycPfL8vwGeAPYA/whsWvK8O1h/vuIg63tI7wd+\nmfVXF56cfTxn5PlPsf5c2uE/gzeN/eegqshsgZK0EocMklaEQZDUDIKkZhAkNYMgqa1UEJLscP6p\nOf9U/tpXYf5hKxUEYOpvivNPzdnOn1m1IEia0KgnJp2eTbWZ4/8g3UGeZyObRluP81dn/qn8tY8x\n/395jhfq+WP9IOGLbFjaCo5hM2fwu9k25khJwAN130lt5yGDpGYQJLVBQZjyzVElLd7cQViBN0eV\ntGBD9hAmfXNUSYs3JAgr8+aokhZjyMuOJ/XmqLNTMncAbGbLgHGSlm3IHsJJvTlqVe2sqrWqWpvy\nxA9JJzYkCJO+OaqkxZv7kKGqDiX5K+BfWL/81q1V9djCViZpdINOXa6qLwNfXtBaJE3MMxUlNYMg\nqRkESc0gSGoGQVIzCJKaQZDUDIKkZhAkNYMgqRkESc0gSGoGQVIzCJKaQZDUDIKkZhAkNYMgqY16\n9eepPfV3l069BJ3C3nDd/VMv4YTcQ5DUDIKkZhAkNYMgqQ25HPx5Sb6eZG+Sx5Jcu8iFSRrfkFcZ\nDgEfqqqHkrwKeDDJrqp6fEFrkzSyufcQqmp/VT00u/0zYC9eDl56WVvIcwhJzgcuAR5YxOeTNI3B\nJyYleSXwBeCDVfXTY/z+DmAHwGa2DB0naYkG7SEk2ch6DG6vqruOtU1V7ayqtapa28imIeMkLdmQ\nVxkC3ALsrapPLG5JkqYyZA/hbcCfAu9I8sjs1+ULWpekCcz9HEJV/SuQBa5F0sQ8U1FSMwiSmkGQ\n1AyCpGYQJDWDIKkZBEnNIEhqBkFSMwiSmkGQ1AyCpGYQJDWDIKkZBEnNIEhqBkFSMwiSmkGQ1AyC\npGYQJDWDIKkZBEltcBCSnJbk4ST3LGJBkqaziD2Ea1m/FLykl7mhF3vdCrwbuHkxy5E0paF7CJ8E\nPgz8YgFrkTSxIVd/vgI4UFUPnmC7HUl2J9l9kOfnHSdpBEOv/nxlku8Bn2f9KtCfO3qjqtpZVWtV\ntbaRTQPGSVq2uYNQVTdU1daqOh+4GvhaVV2zsJVJGp3nIUhqGxbxSarqG8A3FvG5JE3HPQRJzSBI\nagZBUjMIkppBkNQMgqRmECQ1gyCpGQRJzSBIagZBUjMIkppBkNQMgqRmECQ1gyCpGQRJzSBIagZB\nUjMIkppBkNQMgqRmECS1oVd/PivJnUmeSLI3yVsWtTBJ4xt6oZZPAV+tqj9KcjqwZQFrkjSRuYOQ\n5Ezg7cCfAVTVC8ALi1mWpCkMOWR4PfAs8JkkDye5OckZC1qXpAkMCcIG4M3Ap6vqEuA54PqjN0qy\nI8nuJLsP8vyAcZKWbUgQ9gH7quqB2f07WQ/Ei1TVzqpaq6q1jWwaME7Sss0dhKr6IfBMkotmD20D\nHl/IqiRNYuirDH8N3D57heE7wJ8PX5KkqQwKQlU9AqwtaC2SJuaZipKaQZDUhj6HoP+Hb//xTYP+\n+1//p79Y0EqkY3MPQVIzCJKaQZDUDIKkZhAkNYMgqRkESc0gSGoGQVIzCJKaQZDUDIKkZhAkNYMg\nqRkESc33QxiR72egVecegqRmECQ1gyCpGQRJbVAQklyX5LEke5LckWTzohYmaXxzByHJucAHgLWq\nuhg4Dbh6UQuTNL6hhwwbgFck2QBsAX4wfEmSpjLkYq/fBz4OPA3sB35SVfcuamGSxjfkkOFs4Crg\nAuB1wBlJrjnGdjuS7E6y+yDPz79SSUs35JDhncB3q+rZqjoI3AW89eiNqmpnVa1V1dpGNg0YJ2nZ\nhgThaeDSJFuSBNgG7F3MsiRNYchzCA8AdwIPAf8x+1w7F7QuSRMY9MNNVfVR4KMLWoukiXmmoqRm\nECQ1gyCpGQRJzSBIagZBUjMIkppBkNQMgqRmECQ1gyCpGQRJzSBIagZBUjMIkppBkNQMgqRmECQ1\ngyCpGQRJzSBIagZBUjMIktoJg5Dk1iQHkuw54rFzkuxK8uTs49nLXaakMZzMHsJngcuOeux64L6q\nuhC4b3Zf0svcCYNQVd8EfnzUw1cBt81u3wa8Z8HrkjSBeZ9DeE1V7QeYfXz14pYkaSqDru14MpLs\nAHYAbGbLssdJGmDePYQfJXktwOzjgeNtWFU7q2qtqtY2smnOcZLGMG8Q7ga2z25vB760mOVImtLJ\nvOx4B/BvwEVJ9iV5P/C3wLuSPAm8a3Zf0svcCZ9DqKr3Hue3ti14LZIm5pmKkppBkNQMgqRmECQ1\ngyCpGQRJzSBIagZBUjMIkppBkNQMgqRmECQ1gyCpGQRJzSBIagZBUjMIkppBkNSW/jbsq+QN190/\n9RKkleYegqRmECQ1gyCpzXs5+I8leSLJo0m+mOSs5S5T0hjmvRz8LuDiqnoT8C3ghgWvS9IE5roc\nfFXdW1WHZnfvB7YuYW2SRraI5xDeB3xlAZ9H0sQGnYeQ5EbgEHD7S2zj5eCll4m5g5BkO3AFsK2q\n6njbVdVOYCfAmTnnuNtJmt5cQUhyGfAR4Per6ueLXZKkqcx7Ofi/B14F7ErySJKblrxOSSOY93Lw\ntyxhLZIm5pmKkppBkNQMgqRmECQ1gyCpGQRJzSBIagZBUjMIkppBkNQMgqRmECQ1gyCpGQRJzSBI\nagZBUjMIkppBkNQMgqSWl3gH9cUPS54F/uslNvkV4L9HWo7zV2v+qfy1jzH/16rqV0+00ahBOJEk\nu6tqzfmn3vxT+WtfhfmHecggqRkESW3VgrDT+afs/FP5a1+F+cCKPYcgaVqrtocgaUIGQVIzCJKa\nQZDUDIKk9n9qIVxyXTQ+QgAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.matshow(game.grid.grid)\n", "# plt.savefig('g16h3-map')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Tranining \n", "Pretranining is done with fixed size 15, training is between 10 to 15, test on 19 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training from zero seems to be better because it will allow the agent to explore from new" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "clear session data 49 8014520320\n", "0 rewards (0.70775816993464047, -0.58743632131773105)\n", "clear session data 49 8014782464\n", "1 rewards (0.84480729166666668, -0.47554287327577194)\n", "clear session data 49 8014856192\n", "2 rewards (0.85838815789473677, -0.57429482910252139)\n", "clear session data 49 8014856192\n", "3 rewards (0.94659763071895431, -0.46225049398009899)\n", "clear session data 49 8014856192\n", "4 rewards (0.88314413484692122, -0.46399090320433678)\n", "clear session data 49 8014856192\n", "5 rewards (0.94693850511695909, -0.42810030108906955)\n", "clear session data 49 8014856192\n", "6 rewards (0.92999999999999994, -0.52071005917159763)\n", "clear session data 49 8014856192\n", "7 rewards (0.90101231325863673, -0.43866298300845707)\n", "clear session data 49 8014856192\n", "8 rewards (0.98708639705882351, -0.6272189349112427)\n", "clear session data 49 8496320512\n", "9 rewards (0.9579791666666666, -0.32668934432246671)\n", "clear session data 49 8496418816\n", "0 rewards (0.47891466346153849, -0.61015073248711582)\n", "clear session data 49 8496422912\n", "1 rewards (0.56557783321662003, -0.69832100591715973)\n", "clear session data 49 8496422912\n", "2 rewards (0.75977941176470587, -0.73964497041420119)\n", "clear session data 49 8496422912\n", "3 rewards (0.79722314836865915, -0.64520679558348071)\n", "clear session data 49 8496685056\n", "4 rewards (0.91668969298245617, -0.47439434882582948)\n", "clear session data 49 8496685056\n", "5 rewards (0.8740410861713106, -0.69848901098901106)\n", "clear session data 49 8496685056\n", "6 rewards (0.92713304631062954, -0.72200786921246873)\n", "clear session data 49 8496685056\n", "7 rewards (0.92941666666666667, -0.71597633136094674)\n", "clear session data 49 8496685056\n", "8 rewards (0.92747395833333335, -0.71005917159763321)\n", "clear session data 49 8496685056\n", "9 rewards (0.92909926470588244, -0.70433665524758649)\n" ] } ], "source": [ "for iters, noise in enumerate(3 * [0.0]):\n", " for trial in [83]: \n", " Pretest = PretrainTest(holes = 0, weight_write = 'weights_cpu/rnn_1515tanh512_checkpoint{}'.format(trial))\n", " weight_read = Pretest.weight\n", " weight_write = 'weights2/rnn_1515tanh512_checkpoint{}_{}'.format(trial, iters)\n", " rewards = Pretest.qlearn(weight_read, weight_write, iterations = 10, noise = noise, size_train = np.arange(10, 51, 20), size_test=[10, 50])\n", " np.save('Rewards_l_{}_{}.npy'.format(iters, trial), rewards)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "clear session data 49 6494257152\n", "0 rewards (0.65636478758169936, -0.60744100493960507)\n", "clear session data 49 6494359552\n", "1 rewards (0.78438996388028892, -0.68233885533788063)\n", "clear session data 49 8837349376\n", "2 rewards (0.86496527777777776, -0.62422073352750762)\n", "clear session data 49 8837775360\n", "3 rewards (0.93840277777777781, -0.67579832437450782)\n", "clear session data 49 8837775360\n", "4 rewards (0.92827083333333338, -0.63412211313180222)\n", "clear session data 49 8837513216\n", "5 rewards (0.80653618421052631, -0.59335437887788767)\n", "clear session data 49 8837775360\n", "6 rewards (0.95729961622807025, -0.58738728005862639)\n", "clear session data 49 8837775360\n", "7 rewards (0.98776315789473679, -0.63512413957125502)\n", "clear session data 49 8837775360\n", "8 rewards (0.89785197368421055, -0.66437639333376675)\n", "clear session data 49 8837775360\n", "9 rewards (0.90398326775885796, -0.65900525222640605)\n", "clear session data 49 8837775360\n", "0 rewards (0.85949107142857151, -0.64164788252517924)\n", "clear session data 49 8837775360\n", "1 rewards (0.91624908088235291, -0.64006851237151419)\n", "clear session data 49 8837775360\n", "2 rewards (0.8892916666666667, -0.66272189349112431)\n", "clear session data 49 8837775360\n", "3 rewards (0.89720680147058829, -0.66272189349112431)\n", "clear session data 49 8837775360\n", "4 rewards (0.99577322146807445, -0.72781065088757391)\n", "clear session data 49 8837775360\n", "5 rewards (0.89789522058823534, -0.63905325443786976)\n", "clear session data 49 8837775360\n", "6 rewards (0.80626096491228072, -0.74556213017751483)\n", "clear session data 49 8837775360\n", "7 rewards (0.90360135432378064, -0.66863905325443795)\n", "clear session data 49 8837775360\n", "8 rewards (0.95651219040247681, -0.81065088757396442)\n", "clear session data 49 8837775360\n", "9 rewards (0.94561805555555556, -0.70414201183431957)\n", "clear session data 49 8837775360\n", "0 rewards (0.90650483911513313, -0.75153158807004949)\n" ] } ], "source": [ "for iters, noise in enumerate(3 * [0.0]):\n", " for trial in [300]: \n", " Pretest = PretrainTest(holes = 0, weight_write = 'weights_cpu/rnn_1515tanh512_checkpoint{}'.format(trial))\n", " weight_read = Pretest.weight\n", " weight_write = 'weights2/rnn_1515tanh512_checkpoint{}_{}'.format(trial, iters)\n", " rewards = Pretest.qlearn(weight_read, weight_write, iterations = 10, noise = noise, size_train = np.arange(10, 51, 20), size_test=[10, 50])\n", " np.save('Rewards_l_{}_{}.npy'.format(iters, trial), rewards)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "clear session data 49 3544981504\n", "0 rewards (0.67881578947368426, -0.42000000000000004)\n", "clear session data 49 3545014272\n", "1 rewards (0.85294534206226347, -0.42000000000000004)\n", "clear session data 49 3545014272\n", "2 rewards (0.8577552083333333, -0.35113839285714288)\n", "clear session data 49 4291399680\n", "3 rewards (0.98908333333333331, -0.062053146376970569)\n", "clear session data 49 4291399680\n", "4 rewards (0.97919117647058829, 0.020718528033158136)\n", "clear session data 49 4291399680\n", "5 rewards (0.98895833333333338, -0.22017156862745096)\n", "clear session data 49 4291399680\n", "6 rewards (0.94848684210526313, -0.30016022727272729)\n", "clear session data 49 4291399680\n", "7 rewards (0.93812499999999999, -0.17090342925963492)\n", "clear session data 49 4291399680\n", "8 rewards (0.93840525191183088, -0.37)\n", "clear session data 49 4291399680\n", "9 rewards (0.88824479166666659, -0.38)\n", "clear session data 49 4440875008\n", "0 rewards (0.87243165204678363, -0.32212756467439785)\n", "clear session data 49 4440875008\n", "1 rewards (0.99635661764705885, -0.13)\n", "clear session data 49 4702019584\n", "2 rewards (0.99948214285714287, -0.1654194093252325)\n", "clear session data 49 4702019584\n", "3 rewards (0.98895833333333338, -0.28148901845675911)\n", "clear session data 49 4702019584\n", "4 rewards (0.97999999999999998, -0.15492155480713213)\n", "clear session data 49 4702019584\n", "5 rewards (0.99948529411764708, -0.19046610169491526)\n", "clear session data 49 4702019584\n", "6 rewards (0.98725548245614037, 0.097114174020424021)\n", "clear session data 49 4702019584\n", "7 rewards (0.93890196078431365, -0.070250000000000007)\n", "clear session data 49 4702019584\n", "8 rewards (0.97948529411764707, 0.054120075727722178)\n", "clear session data 49 4702019584\n", "9 rewards (0.99776838235294119, -0.061593084802830551)\n", "clear session data 49 4702019584\n", "0 rewards (0.47702027740920927, -0.39000000000000001)\n", "clear session data 49 4702019584\n", "1 rewards (0.96579160216718263, -0.40020833333333328)\n", "clear session data 49 4702019584\n", "2 rewards (0.96895833333333337, -0.43283007948986718)\n", "clear session data 49 4702019584\n", "3 rewards (0.98841094771241833, -0.11334851740262515)\n", "clear session data 49 4702019584\n", "4 rewards (0.99738075657894742, -0.080546568627450979)\n", "clear session data 49 4702019584\n", "5 rewards (0.99869618055555553, -0.16632517350599613)\n", "clear session data 49 4702019584\n", "6 rewards (0.97603737745098051, -0.013745689320032589)\n", "clear session data 49 4702019584\n", "7 rewards (1.0, 0.028728448275862063)\n", "clear session data 49 4702019584\n", "8 rewards (1.0, -0.18198892537784345)\n", "clear session data 49 4702019584\n", "9 rewards (0.9784714052287582, -0.16189120926243566)\n" ] } ], "source": [ "for iters, noise in enumerate(3 * [0.0]):\n", " for trial in [300]: \n", " Pretest = PretrainTest(holes = 0, weight_write = 'weights_cpu/rnn_1515tanh512_checkpoint{}'.format(trial))\n", " weight_read = Pretest.weight\n", " weight_write = 'weights2/rnn_1515tanh512_checkpoint{}_{}'.format(trial, iters+5)\n", " rewards = Pretest.qlearn(weight_read, weight_write, iterations = 10, noise = noise, size_train = np.arange(10, 31, 20), size_test=[10, 30])\n", " np.save('Rewards_l_{}_{}.npy'.format(iters, trial), rewards)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Test" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "759px", "left": "0px", "right": "1228px", "top": "67px", "width": "212px" }, "toc_section_display": "block", "toc_window_display": true }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }