Rich Representation Learning 

What is Rich Representation?

Creating features, either hand-crafted or learned from data, has long been treated as a crucial and mysterious step toward intelligence for decades. The purpose of the created features set (representation) is to help solve seen (in-distribution) and unseen (out-of-distribution) problems easily. Out of these two problems, the out-of-distribution (OOD) generalization problem is, of course, the more charming, useful, and difficult one. Since the test distributions are unseen in OOD during the creation of representation, one question is what is the most desirable property of the representation for OOD? One answer is Rich Representation,  a representation with rich & diverse features that are ready to solve problems in a simple way (e.g. linear). 


In recent years, various out-of-distribution (OOD) methods are proposed to avoid spurious correlations, and find the invariant correlation, by introducing additional penalty terms. 

Unfortunately, [1] shows none of these methods work without a proper representation initialization, that is a representation with rich&diverse features (Rich Representation). [1] concludes the reason as an "optimization-generalization dilemma" in OOD generalization: on one hand, the added penalty terms are too strong to optimize reliably; on the other hand, the added penalty terms are too weak to enforce the invariance constraints. 

How to construct a Rich Representation? 

How to create such a rich representation with training data to potentially help many unseen shifted distributions? One principle is to:

 "seek features that are useful for at least some subset of the training examples"[1]. 

 [1] first performs multiple training episodes using adversarially reweighted training data in a manner that ensures that these training episodes yield substantially different representations. Then a multiple-head distillation process constructs a single feature extractor that combines all these distinct representations into a single representation vector of equivalent size.  

There are also other simpler ways to construct rich representations, such as simply concatenating mutiple independently trained representations [2]. 

Why the Dominant Representation Learning Approach Doesn't Work? 

The phenomenon above implies that the dominant representation learning approach (as a side effect of optimizing an expected cost for a single training distribution) refuses to learn rich&diverse signals. The lack of representation richness is potentially harmful to OOD (or so-called shifted tasks). [2] investigates the reason behind it, and gives a simple yet general approach to construct rich representations in various areas, such as transfer learning, self-supervised learning, few-shot learning, and OOD robustness. 

[2] trains the same architecture on the same dataset with the same hyperparameters multiple times with only different random seeds to discover different representations. Finally concatenates these representations to construct a rich one.  The rich representation is far better than the representation learned by the dominant approach on OOD tasks, no matter whether the number of parameters or the dimension of representation is aligned during the comparison. 

One supervised transfer learning example is shown below.  10 resnet50 are pre-trained on Imagenet(1k) with different random seeds. A resnet50_wide2 and resnet50_wide4, with 2x or 4x channels, are also pre-trained on Imagenet for comparison.  [2] concatenates k  pre-trained resnet50 representations ("catk") and trains a linear classifier on top of it (linear probing). For the IID Imagenet task, catk is slightly better than wide networks (left plot). For OOD tasks, however, the constructed rich representation is far better (top row).  Now we know the dominant representation learning approach refuses to learn rich&diverse signals. 

What if starting the dominant approach from an already constructed rich representation? The answer is negative again. The gray curve (bottom row) uses the Imagenet pre-trained representation  (catk) as initialization to fine-tune OOD tasks. But it still performs poorly. Because the dominant approach destroys the representation richness. On the other hand, the blue curve (bottom row) fine-tunes each resnet50 on the target task before representation concatenation. Then trains a linear classifier on top. This time it performs supervising well.  (please check [2] for more examples on other areas).

For the reason why the dominant approach refuses to learn rich&diverse signals,  [2] concludes it as:

 "Once the representation contains enough information to fulfill the training task, the optimization process has no reason to create and accumulate features that no longer help improve the training objective but might yet become useful when the data distribution changes."

Implicit Rich Representation Learning

Rich representation learning is not limited to the explicit representation manipulation above. Under mild conditions, averaging (nonlinear) model weights is shown to be beneficial to OOD generalization. [3] treats weight-averaging as an implicit way to enrich representation. And works well on various OOD tasks. 

Conclusion

Researchers had tried various approaches to create "System 2 AI" for decades. Some belief in "big model" and "large data".  The Rich Representation Learning here implies our (Jan 2023) position: we still cannot fully utilize the power of big model, large data, and invariance principle. And thus we are still far from System 2 AI.  Rich Representation Learning could be a bridge between System 1 AI to System 2 AI, a toolkit to fully utilize the power of big model, large data, and invariance principle.   

Jan, 2023

Reference