Source code to the paper Deep Q-Networks for Accelerating the Training of Deep Neural Networks
Reproduce our results on MNIST
Dependencies
We are using Lua/Torch. The DQN component is mostly modified from DeepMind Atari DQN.
You might need to run install_dependencies.sh first.
Tuning learning rates on MNIST
cd mnist_lr/; cd mnist; th train-on-mnist.lua; #get regression filter, save in ../save/ ./run_gpu; #Start tune learning rate using dqn #To get the test curve, run following command cd mnist_lr/dqn/logs; python paint_lr_episode.py; python paint_lr_vs.py;
Tuning mini-batch selection on MNIST
cd mnist_minibatch; cd mnist; th train-on-mnist.lua; #get regression filter, save in ../save/ ./run_gpu; #Start select mini-batch using dqn #To get the test curve, run following command cd mnist_minibatch/dqn/logs; python paint_mini_episode.py; python paint_mini_vs.py;
Different Settings
- GPU device can be set in
run_gpuwheregpu=0 - Learning rate can be set in
/ataricifar/dqn/cnnGameEnv.lua, in thestepfunction. - When to stop doing regression is in
/ataricifar/dqn/cnnGameEnv/lua, in line 250
TODO
- Experiments on CIFAR-10
- Transfer learning: subset of CIFAR-10 to full CIFAR-10
- Visualization of the actions taken by the DQN. For example, show which categories have been used at every iteration.
Citation
@article{dqn-accelerate-dnn,
title={Deep Q-Networks for Accelerating the Training of Deep Neural Networks},
author={Fu, Jie and Lin, Zichuan and Liu, Miao and Leonard, Nicholas and Feng, Jiashi and Chua, Tat-Seng},
journal={arXiv preprint arXiv:1606.01467},
year={2016}
}
Contact
If you have any problems or suggestions, please contact me: jie.fu A~_~T u.nus.education