Generalization on Unseen Domains via Inference-time Label-Preserving Target Projections



Original Source Here

Generalization on Unseen Domains via Inference-time Label-Preserving Target Projections

This article is regarding our recent CVPR 2021 Oral paper on Domain Generalization. As we know, Domain shift refers to the existence of significant divergence between the distributions of the training and the test data. This causes the machine learning models trained only on the training or the source data to perform poorly on the test or target data. A naive way of handling this problem is to fine-tune the model with new data which is often infeasible because of the difficulty in acquiring labelled data for every new target domain. The class of Domain Adaptation (DA) methods tackle this problem by utilizing the (unlabeled) target data to minimize the domain shift; however they cannot be used when unlabeled target data is unavailable.

Domain generalization (DG), on the other hand, views the problem from the following perspective: how to make a model trained on single or multiple source domains generalize on completely unseen target domains. These methods do so via (i) learning feature representations that are invariant to the data domains using methods such as adversarial learning, (ii) simulating the domain shift while learning through meta-learning approaches, and (iii) augmenting the source dataset with synthesized data from fictitious target domains. These methods have been shown to be effective in dealing with the problem of domain shift. However, most of the existing methods do not utilize the test sample from the target distribution available at the time of inference beyond mere classification. On the other hand, it is a common experience that when humans encounter an unseen object, they often relate it to a previously perceived similar object.

Motivated by this intuition, in this paper, we make the following contributions towards addressing the problem of DG:

(a) Given samples from multiple source distributions, we propose to learn a source domain invariant representation that also preserves the class labels.

(b) We propose to ‘project’ the target samples to the manifold of the source-data features before classification through an inference-time label-preserving optimization procedure over the input of a generative model (learned during training) that maps an arbitrary distribution (Normal) to the source-feature manifold.

(c) We demonstrate through extensive experimentation that our method achieves new state-of-the-art performance on standard DG tasks while also outperforming other methods in terms of robustness and data efficiency.

Fig. 1

As shown in Fig. 1, A) We begin by designing a function f (neural network f_\theta) to learn a label-preserving metric that produces a similarity score of 1 when the ground truth labels (given by function g) between a pair of images match and -1 otherwise. The function ‘sim’ refers to the cosine similarity function. B) f is implemented using a neural network f_\theta. During training, the examples from the source domains are utilized to create a source manifold Z_s using loss L_A such that the features on the manifold are implicitly clustered to preserve the labels of examples. C) A classifier C_\psi and a generative model G_\phi are trained on the label-preserving features from manifold Z_s such that G_\phi learns to map a Gaussian vector ‘u’ to a point on the manifold Z_s. D) During inference, f_\theta* projects target x_t to a point z_t on the label-preserving feature space. We propose an inference-time procedure to project the target feature to a point z_t* on the source manifold which is finally classified to predict its label \hat{y}_t.

Interested readers can watch the youtube video for mare clarity.

The motivation for our method comes from the following observation: DG methods that learn domain invariant representations do so only using the source data. Therefore, classifiers trained on such representations are not guaranteed to perform well on target data that is outside the source data manifolds. Hence, performance on the target data can be improved if the target sample is projected on to the manifold of the source features such that the ground-truth label is preserved, before classification. To this end, we propose a three-part procedure for domain generalization:

1. Learn a label-preserving domain invariant representation using source data. We first transform the data from multiple source domains into a space where they are clustered according to class labels, irrespective of the domains and build a classifier on these features.

2. Learn to generate features from the domain invariant feature manifold created from the source data by constructing a generative model on it.

3. Given a test target sample, project it on to the source-feature manifold in a label-preserving manner. This is done by solving an inference-time optimization problem on the input space of the aforementioned generative model. Finally, classify the projected target feature.

Fig. 2
Fig. 3

Fig.2 depicts the training procedure to learn f, G (generative model) and classifier C whereas Fig.3 shows the proposed inference time optimization method to project the target data points.

In conclusion, we propose a novel Domain Generalization technique where the source domains are utilized to learn a domain invariant label-preserving metric space. During inference, every target sample is projected onto this space so that the classifier trained on the source features can generalize well on the projected target sample. We have demonstrated that this method yields SOTA results on Multi Source, Single Source and Robust Domain Generalization settings. In addition, the data-efficiency of the method makes it suitable to work well in Low Resource settings. Future iterations of work could attempt to extend this method to Domain Generalization for Segmentation and Zero-Shot Learning.

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: