In many reinforcement learning tasks, the agent has to learn to interact with many objects of different types and generalize to unseen combinations and numbers of objects. Often a task is a composition of previously learned tasks (e.g. block stacking). These are examples of compositional generalization, in which we compose object-centric representations to solve complex tasks. Recent works have shown the benefits of object-factored representations and hierarchical abstractions for improving sample efficiency in these settings. On the other hand, these methods do not fully exploit the benefits of factorization in terms of object attributes. In this paper, we address this opportunity and introduce the Dynamic Attribute FacTored RL (DAFT-RL) framework. In DAFT-RL, we leverage object-centric representation learning to extract objects from visual inputs. We learn to classify them in classes and infer their latent parameters. For each class of object, we learn a class template graph that describes how the dynamics and reward of an object of this class factorize according to its attributes. We also learn an interaction pattern graph that describes how objects of different classes interact with each other at the attribute level. Through these graphs and a dynamic interaction graph that models the interactions between objects, we can learn a policy that can then be directly applied in a new environment by just estimating the interactions and latent parameters. We evaluate DAFT-RL in three benchmark datasets and show our framework outperforms the state-of-the-art in generalizing across unseen objects with varying attributes and latent parameters, as well as in the composition of previously learned tasks.
Figure 1 The graphical representation of DAFT-MDP. The colors denote the attributes for an object or a class, the red dashed lines denote edges that can be switched on or off at different timesteps.
Table 1 Average success rate over 3 random seeds for Push & Switch compositional generalization in terms of combination of skills (S), changing number of objects (O), and changing latent parameters (L) with respect to training. The numbers in bold highlight the top-performing method.
Table 2 Average success rate over 3 random seeds for Spriteworld with unseen object numbers, color and shape combinations. The numbers in bold highlight the top-performing method.
Table 3 Average success rate over 3 random seeds for Block-stacking with unseen object numbers 0.4 and mass combinations. The numbers in bold highlight the top-performing method.
Figure 2 A. The smoothed learning curve for 2-Push + 2-Switch (L+S) with different friction coefficients for each object (for clarity, we show only the top three methods in terms of the success rate); B. The smoothed learning curve for the object comparison task in Spriteworld with unseen object numbers, combinations of colors and shapes (for clarity, we show only the top three methods in terms of the success rate); C. Success rate versus number of blocks in the stacking task, where each block has distinct mass.
Figure 3 3-Push+3-Switch task.
Figure 4 Quality of learned graphs w.r.t. the number of samples for 3-Push + 3-Switch (L+O+S). We plot the success rate of the RL task, the R 2 coefficient for learned representation vs the true latent parameters, and normalized Structured Hamming Distance (nSHD) between the learned graph and true graph under different number of training samples (measured as a ratio with the ones in the main paper) for (a) data collected by random policy, and (b) data collected by pre-trained policy.
@inproceedings{
feng2023learning,
title={Learning Dynamic Attribute-factored World Models for Efficient Multi-object Reinforcement Learning},
author={Fan Feng and Sara Magliacane},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=bsNslV3Ahe}
}