1. Introduction
In recent years, due to the rapid development of artificial intelligence, model compression has received a great deal of attention from researchers, especially regarding deep neural networks [1]. Knowledge distillation is one of the most commonly used methods in model compression, which is to transfer the dark knowledge in the complex model to the simple model. Hinton [2] introduced the idea of knowledge distillation in neural networks, which involves using a teacher model’s high-level features as supervision for a smaller, more efficient student model. The teacher model is highly capable, while the student model is designed to achieve comparable prediction results with less complexity by metric learning, ideally approaching or even surpassing the teacher’s performance. In the context of knowledge distillation, the metric learning is comprised of a linear combination of two distinct losses: the cross-entropy (CE) loss with “hard” targets, and the distillation loss with divergence of “soft” distributions. In multi-task learning scenarios involving local features such as image segmentation and object detection, the predicted probability distribution captures more informative and intricate geometric features. As a result, the Kullback-Leibler divergence and L2 loss as knowledge distillation loss [3] may not effectively convey significant geometric information in such contexts.
The robust sliced Wasserstein distance [4], [5] with geometric median [6] projection (GMSW) for knowledge distillation is proposed in this paper, a new knowledge distillation loss that reduces the generalization gap between teacher and student to approach better knowledge transfer. Firstly, the sliced Wasserstein distance possesses a geometric interpretation, rendering it a suitable metric for comparing distributions with structural properties. Secondly, the sliced Wasserstein distance with geometric median exhibits superior resistance to outliers and noise compared to alternative distance metrics, including the Euclidean distance and KL divergence. The GMSW maps geometric feature in highly capable distribution of the teacher model to the student model through robust geometric median projections, which improves the performance of transfer learning. The method is validated on multiple models, achieving good results.
The remainder of the paper is organized as follows. A brief description of knowledge distillation and Wasserstein distance in Sect. 2. In Sect. 3, we describe the proposed method. The performance of the proposed method is presented in Sect. 4. Finally, Sect. 5 provides the conclusion.
2. Knowledge Distillation and Wasserstein Distance
2.1 Knowledge Distillation
Knowledge distillation is to transfer the dark knowledge in the complex model to the simple model, a student network is trained by leveraging additional supervision from a trained teacher network. Given an input sample \((\mathrm{x}, \mathrm{y})\), where \(\mathrm{x}\) is the network input and y is the one-hot label.
\[\begin{align} &P_t= \mathit{softmax} (Z_t(x)),\ P_s=\mathit{softmax}(Z_s(x)) \tag{1} \end{align}\] |
Assume \(Z_{t}\) and \(Z_{s}\) are the logit representations (before the SoftMax layer) of the teacher and student network and P is the output distribution in Eq. (1), respectively. The distillation objective encourages the output probability distribution over predictions from the student and teacher networks to be similar by minimizing the cross-entropy loss and knowledge distillation loss between predictions of teacher and student as follows:
\[\begin{align} &\mathit{Loss}=H(P_s,y)+\lambda KD(P^\tau_s,P^\tau_t) \tag{2} \end{align}\] |
Where \(\tau\) is a relaxation hyperparameter (referred as Temperature) for softening the output of teacher network, and \(\lambda\) is a hyper-parameter for balancing cross-entropy and knowledge distillation loss Eq. (2). The idea of knowledge distillation is to let the student mimic the teacher’s behavior by adding a strong congruent constraint on predictions using knowledge distillation loss.
2.2 Wasserstein Distance
In this section, we review the initial concepts and optimality conditions for computing the p-Wasserstein distance between two discrete probability measures in Monge and Kantorovich formulation for Wasserstein distance [7]. Let \(P(R^{d})\) be the set of Borel probability measures in \(R^{d}\), and let \(P_{2}(R^{d})\) be the subset of \(P(R^{d})\) consisting of probability measures that have finite second moments. The p-Wasserstein distance in Eq. (3), \(p\in [1,\infty)\), between \(\mu\) and \(\nu\) is defined as the solution of the optimal mass transportation problem. Let \(\mu\) and \(\nu\) be two probability measures on measurable spaces \(R^{d}\). \(\mu,\ \nu \in P_{2}(R^{d})\) and \(\Pi (\mu,\nu)\) be the set of couplings between \(\mu\), \(\nu\). For \(\mu,\ \nu \in P_{2}(R^{d})\), we write \(\Pi (\mu,\nu)\), that satisfies the following.
\[\begin{align} &\Pi(\mu,\nu)= \begin{cases} \pi(A\times R_d)=\mu(A)\quad\mathrm{any\ Borel}\ A\subseteq R_d\\ \pi(R_d\times B)=\nu(B)\quad\mathrm{any\ Borel}\ B\subseteq R_d\\ \end{cases} \tag{3} \end{align}\] |
In model compression applications one often deals with compact d-dimensional Euclidean spaces, hence \(\mathrm{X}=\mathrm{Y}= [0,1]^{d}\). The p-Wasserstein distance int Eq. (4) for \(p\in [1,\infty)\) is defined as
\[\begin{align} W(\mu,\nu):=\left(\inf\limits_{\pi\in\Pi(\mu,\nu)}\int_{X\times Y}(x-y)^p d\pi(x,y)\right)^{\frac{1}{p}} \tag{4} \end{align}\] |
Because we only consider low level costs in the rest of this paper, we will only use \(W_{2}\) to denote the 2-Wasserstein distance. For \(X,\ Y\subseteq R^{d}\) and \(\mathrm{T}\colon X\to\) Y, the push-forward of \(\mu \in P({\rm X})\) by T is defined by \(T\# \mu \to p(Y)\). In other words, \(T\# \mu\) is the measure satisfying \(T\# \mu (A)=\mu (T^{-1}(A))\) for any Borel set in Y. The 2-Wasserstein distance in Eq. (5) is defined by
\[\begin{align} W_2(\mu,\nu):=\left(\inf\limits_{\pi\in\Pi(\mu,\nu)}\int||x-y||^2 d\pi(x,y)\right)^{\frac{1}{2}} \tag{5} \end{align}\] |
3. The GMSW Loss
The sliced Wasserstein distance with geometric median projection is designed as knowledge distillation loss. The sliced Wasserstein distance is calculated via linear slicing of the probability distributions. In order to ensure an efficient evaluation of the sliced Wasserstein distance, a more informative projection is extracted by selectively linearizing these projections through geometric median. The geometric median is a statistical measure of central tendency that is determined by calculating the point that minimizes the sum of distances to all other points. It is also referred to as the geometric center or the median of a geometric distribution. The geometric median is less influenced by outliers or extreme values than the arithmetic or other means. Figure 1 pictorially illustrates our overall knowledge distillation loss in the sliced Wasserstein distance with geometric median.
3.1 Sliced Wasserstein Distance for Knowledge Distillation
The sliced Wasserstein distance maps a high-dimensional probability distribution into a one-dimensional representation through projections, and then calculates the distance between two prediction distributions of teacher model and students’ model as a functional on the p-Wasserstein distance of their one-dimensional presentation. The p-Wasserstein distance has a closed-form solution for the case of one-dimensional continuous probability measures. The slice process is related to the field of Integral Geometry and specifically the Radon transform. The relevant result to our discussion is that a d-dimensional probability density can be uniquely represented as the set of its one-dimensional marginal distributions following the Radon transform and the Fourier slice theorem. \(\delta (\mu)\) denotes the one dimensional Dirac delta function, and \(({}\cdot{}, {}\cdot{})\) denotes the Euclidean inner-product.
Definition 1\(\quad\)For any \(\mu,\ \nu \in P_{2}(R^{d})\), the SW distance of order 1 between them is defined as
\[\begin{align} SW_1(\mu,\nu):=\int_{S^{d-1}}W_1(\mu^\ast_\#\mu,\mu^\ast_\#\nu)d\delta(\mu) \tag{6} \end{align}\] |
For any \(\mu \in S_{\mathrm{d}-1}\), let \(\mu^{\ast}\) be the linear form with respect to \(u\) under the projection on \(\theta\), such that for \(\theta \in R^{d}\), \(\mu^{\ast}(\theta)= \mathopen{<}\mu,\theta \mathclose{>}\), \(\delta\) represents the uniform distribution on \(S_{\mathrm{d}-1}\). In the knowledge distillation application, the sliced Wasserstein distance need to be used for discrete measures. Since the expectation in Definition 1 is intractable, the Monte Carlo estimation is used projecting directions of length L.
\[\begin{align} \begin{cases} \mu^\ast_l(\theta)=(\theta_1,\theta_2,\theta_l,\ldots,\theta_L) \\ \widehat{SW}_1(\mu,\nu)\approx\frac{1}{L}\displaystyle\sum\nolimits^L_{l=1}W_1(\mu^\ast_l \mu,\mu^\ast_l \nu)\\ \end{cases} \tag{7} \end{align}\] |
\(\mathrm{W}_{1}(\mu_{\mathrm{l}}^{\ast}\mu,\mu_{\mathrm{l}}^{\ast}\nu)\) in Eq. (7) indicates that the empirical Wasserstein distance between \(\mu_{\mathrm{l}}^{\ast}\mu\) and \(\mu_{\mathrm{l}}^{\ast}\nu\) can be simply calculated by first sorting both samples and then culating the distance between the sorted samples.
3.2 Sliced Wasserstein Loss with Geometric Median
In this section, we discuss our approach that uses the Sliced Wasserstein Distance with Geometric Median (GMSW) to train a knowledge distillation model. In knowledge distillation, the role of distillation loss is to aggregate information about soft features, and its main contribution is to obtain important sources of geometric knowledge projection. we can capture the major discrepancy between two measures by considering a relatively small number of “important” slices. This problem can be alleviated by using Geometric Median, which considers as most representative projection in L projection directions. In the GMSW, the calculation can be simplified to pick the “best direction” along the projected distance which is the geometric median instead of using mean of the random projection directions generated from d dimensions for knowledge distillation. Based on this, the Sliced Wasserstein Distance with Geometric Median can be used as a more robust distance metric by replacing the Sliced Wasserstein Distance in its calculation.
Definition 2\(\quad\)Given a set of n positive real numbers \(\{x_{1},x_{2},x_{i}, \ldots, x_{n}\}\), the geometric median defined as
\[\begin{align} &\mathrm{Geometric\ Median}=\mathop{\rm argmin}\limits_\mathbf{x}\sum\limits^n_{i=1}||\mathbf{x}-\mathbf{x_i}||_2 \tag{8} \end{align}\] |
Here, argmin means the value of the argument x which minimizes the sum. In this case, it is the point x in n dimensional Euclidean space from where the sum of all Euclidean distances to the \(\mathrm{x}_{\mathrm{i}}\)’ is minimum. In GMSW, the distribution X in the Wasserstein distance with geometric median can be as the union vectors of all projected distances. the geometric median of all slices is the projection distance that minimizes the sum of its Euclidean distances to the other projection distance. More formally, following the notations in Eq. (9), the GMSW distance of order 1 is defined as
\[\begin{align} &\widehat{GMSW}(\mu,\nu)\approx\mathop{\rm argmin}\limits_{\mu^{\ast} \in \mathbb{R}^\mathrm{d}}=\sum\limits^{L}_{l=1}W_1(\mu^\ast_l \mu,\mu^\ast_l \nu) \tag{9} \end{align}\] |
The typical approach for computing geometric median is the Weiszfeld algorithm. The Weiszfeld algorithm is an iterative algorithm used to solve the geometric median problem. Its fundamental idea is to iteratively calculate the distance of each point to the mean point in order to approximate the geometric median value. The pseudo code of GMSW Loss is shown in Algorithm 1.
4. Experiments Results
We performed experiments on large-scale datasets: namely the CityScapes dataset [8]. We select three types of targets with thousands of images per class, which are pedestrians, vehicles, and motorbikes, and use the bounding rectangle of instance annotation as the bounding box of the target. In experiments, the performance of the GMSW loss with KL divergence and L2 loss are evaluated comparatively in object detection.
Common Settings. The backbone network for teacher model in all experiments is ResNet-50. For the student structure, we use more compact ResNet18 amd MobileNet as well as its variants with different FLOPs, since ResNet18 and MobileNet have been proved to be highly effective in keeping high accuracy while maintaining low FLOPs in many tasks. We conduct all the experiments on a computer with 1 NVIDIA V100 GPUs, and the object detection framework is based on CenterNet in our experiments. The same number of slices (\(L=100\)) is set for all comparable scores in GMSW loss. In Table 1, the term “Type” denotes the training approach, where “Normal” indicates the model achieved through standard training, “Distillation” refers to the student model trained with ResNet50 as the teacher model, and “Loss” specifically denotes the method of supervision in knowledge distillation. In our study, we adopt the evaluation metric of Average Precision, which is commonly referred to as mAP.
Our results show that the GMSW Loss outperformed the L2 and KL divergence in terms of accuracy. Specifically, the GMSW loss achieved an average mAP of 27.2%, which significantly surpassed the L2 loss’s 3.1% and the KL loss’s 1.8% for the student model of ResNet18. Figure 2 shows the visualization results obtained from the student model of ResNet18. Our observations indicate that the geometric attributes of the foreground features in third row is more align closely with the object regions in the first row than KL loss. Furthermore, the GMSW mechanism facilitates the detector to emphasize the geometric properties of the interested objects, such as angle, scale, and size.
To achieve the better robustness performance, we analyzed the impact of the number of projections by adjusting L. As shown in Fig. 3, increasing the number of projections L did not significantly improve mAP and led to training instability and longer training time. This is due to the fact that, in the context of object detection distillation, the foreground typically constitutes a relatively small proportion in comparison to the background. It is necessary to limit the number of projections to avoid noise interference and improve robustness.
5. Conclusions
We study the focus on key learning aspect of the Geometric Median Sliced Wasserstein (GMSW) loss for knowledge distillation in this paper. Our work provides an enhanced understanding of sliced Wasserstein distances with geometric median and the associated minimal distance estimators under knowledge distillation. Our results suggest that GMSW loss can significantly improve the robustness and accuracy for knowledge distillation. Further research is needed to investigate the applicability of SW to high dimensional data and to explore the optimal parameters for SW.
Acknowledgments
This work is supported by MOE Planned Project of Humanities and Social Sciences (No.20YJA870014) and Beijing Natural Science Foundation (L223022).
References
[1] J. Gou, B. Yu, S.J. Maybank, and D. Tao, “Knowledge distillation: A survey,” International Journal of Computer Vision, vol.129, pp.1789-1819, 2021.
CrossRef
[2] G. Hinton, O. Vinyals, and J. Dean, “Distilling the knowledge in a neural network,” arXiv preprint arXiv:1503.02531, 2015.
CrossRef
[3] T. Kim, J. Oh, N. Kim, S. Cho, and S.-Y. Yun, “Comparing Kullback-Leibler divergence and mean squared error loss in knowledge distillation,” arXiv preprint arXiv:2105.08919, 2021.
CrossRef
[4] I. Deshpande, Y.-T. Hu, R. Sun, A. Pyrros, N. Siddiqui, S. Koyejo, Z. Zhao, D. Forsyth, and A.G. Schwing, “Max-sliced Wasserstein distance and its use for gans,” Proc. IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.10648-10656, 2019.
CrossRef
[5] S. Kolouri, K. Nadjahi, U. Simsekli, R. Badeau, and G. Rohde, “Generalized sliced Wasserstein distances,” Advances in Neural Information Processing Systems, 2019.
[6] Y. He, P. Liu, Z. Wang, Z. Hu, and Y. Yang, “Filter pruning via geometric median for deep convolutional neural networks acceleration,” Proc. IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.4340-4349, 2019.
CrossRef
[7] C. Villani, Optimal transport: Old and new, Springer, Berlin, 2009.
CrossRef
[8] M. Cordts, M. Omran, S. Ramos, T. Scharwächter, M. Enzweiler, R. Benenson, U. Franke, S. Roth, and B. Schiele, “The cityscapes dataset,” CVPR Workshop on the Future of Datasets in Vision, vol.2, 2015.