Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results

Main Reference: Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results []

Abstract

This paper explores the use of self-ensembling for visual domain adaptation problems. Our technique is derived from the mean teacher variant [29] of temporal ensembling [14], a technique that achieved state of the art results in the area of semi-supervised learning.

Introduction

  • We have developed the approach proposed by Tarvainen et al. [29] to work in a domain adaptation scenario. (mean teacher semi-supervised learning model)
  • To this end, we developed confidence thresholding and class balancing that allowed us to achieve state of the art results in a variety of benchmarks, with some of our results coming close to those achieved by traditional supervised learning.
  • Our approach is sufficiently flexble to be applicable to a variety of network architectures, both randomly initialized and pre-trained.

Methodolegy

The structure of the mean teacher model:

  • The student network is trained using gradient descent, while the weights of the teacher network are an exponential moving average of those of the student.
  • During training each input sample $x_{i}$ is passed through both the student and teacher networks, generating predicted class probability vectors $z_{i}$ (student) and $\tilde{z_{i}}$ (teacher). Different dropout, noise and image translation parameters are used for the student and teacher pathways.

The training loss is the sum of a supervised and an unsupervised component:

  • The supervised loss is cross-entropy loss computed using $z_{i}$ (student prediction). It is masked to 0 for unlabeled samples for which no ground truth is available.
  • The unsupervised component is the self-ensembling loss. It penalises the difference in class predictions between student ($z_{i}$) and teacher ($\tilde{z_{i}}$) networks for the same input sample. It is computed using the mean squared difference between the class probability predictions $z_{i}$ and $\tilde{z_{i}}$.

Laine et al. and Tarvainen et al. found that it was necessary to apply a time-dependent weighting to the unsupervised loss during training in order to prevent the network from getting stuck in a degenerate solution that gives poor classification performance. They used a function that follows a Gaussian curve from 0 to 1 during the first 80 epochs.

Adapting to domain adaptation

Our variant of the mean teacher model – shown in Figure. 2b – has separate source (XSi) and target (XT i) paths.

  • Batch normalization uses different normalization statistics for each domain during training.
  • Our approach must train using both simultaneously
  • We also do not maintain separate exponential moving averages of the means and variances for each dataset for use at test time.

Confidence thresholding :

For each unlabeled sample xT i the teacher network produces the predicted class probabilty vector $\tilde{z}{T{ij}}$ – where j is the class index drawn from the set of classes $C$ – from which we compute the confidence $\tilde{f_{T_{i}}} = \max j∈C \tilde{z}{T{ij}}$; the predicted probability of the predicted class of the sample.

Data augmentation :

Class balance loss :

Experiments

Our implementation was developed using PyTorch ([3]) and is publically available at http://github.com/Britefury/self-ensemble-visual-domain-adapt.