Learning a World Model With Multitimescale Memory Augmentation

Research output: Contribution to journalArticlepeer-review

4 Scopus citations

Abstract

Model-based reinforcement learning (RL) is regarded as a promising approach to tackle the challenges that hinder model-free RL. The success of model-based RL hinges critically on the quality of the predicted dynamic models. However, for many real-world tasks involving high-dimensional state spaces, current dynamics prediction models show poor performance in long-term prediction. To that end, we propose a novel two-branch neural network architecture with multi-timescale memory augmentation to handle long-term and short-term memory differently. Specifically, we follow previous works to introduce a recurrent neural network architecture to encode history observation sequences into latent space, characterizing the long-term memory of agents. Different from previous works, we view the most recent observations as the short-term memory of agents and employ them to directly reconstruct the next frame to avoid compounding error. This is achieved by introducing a self-supervised optical flow prediction structure to model the action-conditional feature transformation at pixel level. The reconstructed observation is finally augmented by the long-term memory to ensure semantic consistency. Experimental results show that our approach is able to generate visually-realistic long-term predictions in DeepMind maze navigation games, and outperforms the prevalent state-of-the-art methods in prediction accuracy by a large margin. Furthermore, we also evaluate the usefulness of our world model by using the predicted frames to drive an imagination-augmented exploration strategy to improve the model-free RL controller.

Original languageEnglish
Pages (from-to)8493-8502
Number of pages10
JournalIEEE Transactions on Neural Networks and Learning Systems
Volume34
Issue number11
DOIs
StatePublished - 1 Nov 2023

Keywords

  • Model-based exploration
  • multitimescale memory augmentation
  • reinforcement learning (RL)
  • world model

Fingerprint

Dive into the research topics of 'Learning a World Model With Multitimescale Memory Augmentation'. Together they form a unique fingerprint.

Cite this