BO-QSA
This repository contains the official implementation of the ICLR 2023 paper:
Improving Object-centric Learning With Query Optimization
Baoxiong Jia*, YuLiu*, Siyuan Huang
Environment Setup
We provide all environment configurations in requirements.txt. To install all packages, you can create a conda environment and install the packages as follows:
conda create -n BO-QSA python=3.8 conda activate BO-QSA conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge pip install -r requirements.txt
In our experiments, we used NVIDIA CUDA 11.3 on Ubuntu 20.04. Similar CUDA version should also be acceptable with corresponding version control for torch and torchvision.
Dataset
1. ShapeStacks, ObjectsRoom, CLEVRTex, Flowers
Download ShapeStacks, ObjectsRoom, CLEVRTex and Flowers datasets with
chmod +x scripts/downloads_data.sh ./downloads_data.sh
For ObjectsRoom dataset, you need to run objectsroom_process.py to save the tfrecords dataset as a png format.
Remember to change the DATA_ROOT in downloads_data.sh and objectsroom_process.py to your own paths.
2. PTR, Birds, Dogs, Cars
Download PTR dataset following instructions from http://ptr.csail.mit.edu. Download CUB-Birds, Stanford Dogs, and Cars datasets from here, provided by authors from DRC. We use the birds.zip, cars.tar and dogs.zip and then uncompress them.
4. YCB, ScanNet, COCO
YCB, ScanNet and COCO datasets are available from here, provided by authors from UnsupObjSeg.
5. Data preparation
Please organize the data following here before experiments.
Training
To train the model from scratch we provide the following model files:
train_trans_dec.py: transformer-based modeltrain_mixture_dec.py: mixture-based modeltrain_base_sa.py: original slot-attention We provide training scripts underscripts/train. Please use the following command and change.shfile to the model you want to experiment with. Take the transformer-based decoder experiment on Birds as an exmaple, you can run the following:
$ cd scripts $ cd train $ chmod +x trans_dec_birds.sh $ ./trans_dec_birds.sh
Remember to change the paths in path.json to your own paths.
Reloading checkpoints & Evaluation
To reload checkpoints and only run inference, we provide the following model files:
test_trans_dec.py: transformer-based modeltest_mixture_dec.py: mixture-based modeltest_base_sa.py: original slot-attention
Similarly, we provide testing scripts under scripts/test. We provide transformer-based model for real-world datasets (Birds, Dogs, Cars, Flowers, YCB, ScanNet, COCO)
and mixture-based model for synthetic datasets(ShapeStacks, ObjectsRoom, ClevrTex, PTR). We provide all checkpoints here. Please use the following command and change .sh file to the model you want to experiment with:
$ cd scripts $ cd test $ chmod +x trans_dec_birds.sh $ ./trans_dec_birds.sh
Citation
If you find our paper and/or code helpful, please consider citing:
@inproceedings{jia2023improving,
title={Improving Object-centric Learning with Query Optimization},
author={Jia, Baoxiong and Liu, Yu and Huang, Siyuan},
booktitle={The Eleventh International Conference on Learning Representations},
year={2023}
}
Acknowledgement
This code heavily used resources from SLATE, SlotAttention, GENESISv2, DRC, Multi-Object Datasets, shapestacks. We thank the authors for open-sourcing their awesome projects.
