Differentiable Network Pruning via Polarization of Probabilistic Channelwise Soft Masks

Channel pruning has been demonstrated as a highly effective approach to compress large convolutional neural networks. Existing differentiable channel pruning methods usually use deterministic soft masks to scale the channelwise outputs and explore an appropriate threshold on the masks to remove unimportant channels, which sometimes causes unexpected damage to the network accuracy when there are no sweet spots that clearly separate important channels from redundant ones. In this article, we introduce a new differentiable channel pruning method based on polarization of probabilistic channelwise soft masks (PPSMs). We use variational inference to approximate the posterior distributions of the masks and simultaneously exploit a polarization regularization to push the probabilistic masks towards either 0 or 1; thus, the channels with near-zero masks can be safely eliminated with little hurt on network accuracy. Our method significantly relieves the difficulty faced by the existing methods to find an appropriate threshold on the masks. The joint inference and polarization of probabilistic soft masks enable PPSM to yield better pruning results than the state of the arts. For instance, our method prunes 65.91% FLOPs of ResNet50 on the ImageNet dataset with only 0.7% model accuracy degradation.


Introduction
Convolutional neural network (CNN) yields unprecedented success in computer vision tasks [1,2], due to its intrinsic ability of automatically learning meaningful features. To achieve better results on these tasks, the structure of CNN is expanding wider and deeper. However, the performance elevation is often accompanied with extensive consumption of memory and computation footprint, which inhibits the deployment of complex CNNs on resource-constrained devices. Network compression techniques [3][4][5] have relieved the issue through condensing a large CNN into a compact subnetwork (subnet), and channel pruning is deemed as one of the most effective methods for network compression.
Channel pruning aims at removing semantically redundant channels from a pretrained or a baseline network with little damage to the model accuracy. Early works on channel pruning mainly employ an iterative search-evaluation scheme comprising of generation of subnets and evaluation of the subnets on a validation set. For instance, He et al. [6] propose to sequentially prune each layer of the network based on LASSO regression, and AMC [7] generates the reserved channels in each layer via reinforcement learning. e layerwise pruning may result in suboptimal solution due to deficient representation of global structural information of the network. To improve subnet search accuracy and efficiency, Cao et al. [8] leverage Bayesian optimization to generate the priority ordering of subnets for evaluation. Some methods exploit channel importance scores to guide the selection of subnets. For instance, Liebenwein et al. [9] use an important sampling distribution to yield subnets by giving higher sampling probability to important filters. LeGR [10] generates subnets with different trade-offs between model accuracy and efficiency by inferring global ranking of the filters. Similarly, HRank [11] sorts filters by the rank of feature maps and removes low-ranked filters to yield a subnet. ese methods all explicitly or implicitly depend on empirically defined metrics to assess the importance of filters and usually need to explore a threshold of the scores to remove unimportant filters or channels.
To learn a well-performing subnet under an end-to-end manner, differentiable pruning has become a popular approach for channel pruning. e basic concept is attaching learnable masks or gate functions behind channels to scale the original outputs and utilizing certain regularizations on the masks or gates to get an importance ordering of all channels. Lin et al. [12] use binary masks to remove redundant filters by setting corresponding masks as 0. Such hard pruning gives rise to the difficulty of mask optimization. To relieve the training complexity, many methods based on soft pruning have been proposed in recent years [13][14][15]. For instance, GAL [16] employs soft masks to remove structural redundancy with adversarial learning, GBN [17] estimates filter importance ranking through exploring the effects of setting channelwise gate to zero on the loss function, and DMCP [18] formulates channel pruning of each layer as a Markov process that defines the retaining probability of each channel. Kim et al. [19] also introduce the concept of gates for differentiable channel pruning. While these soft pruning methods perform acceptably well, they still need to explore an appropriate threshold on the masks or gates to define unimportant channels, which sometimes causes unexpected damage to the network accuracy when there are no sweet spots that clearly separate the channels into two parts. In addition, the learning of deterministic masks or gates may suffer from low stability and deficient convergence in large networks.
In this article, we introduce a new differentiable channel pruning method based on polarization of probabilistic soft masks (PPSMs). Figure 1 provides an intuitive illustration of the idea of mask polarization. PPSM is built on the assumption that the global channel ranking may vary with the input (similar concept is adopted in dynamic pruning); therefore, deterministic masks used by existing methods cannot capture this input-aware property. We use probabilistic channelwise soft masks to implicitly represent the uncertainty on channel ranking. To enable stable learning of the masks, variational inference is exploited to approximate the posterior distributions of the masks given the output features of a baseline network. Meanwhile, a new polarization regularization is utilized to push masks of redundant and important channels towards 0 and 1, respectively; thus, the channels with near-zero masks can be safely eliminated with little hurt on network accuracy. Our method significantly relieves the difficulty faced by the existing methods to find an appropriate threshold on the masks. To the best of our knowledge, PPSM is the first method to make joint inference and polarization of probabilistic soft masks.
Our main contributions are summarized as follows: (1) We propose a differentiable channel pruning method to remove redundant channels from a baseline network through learning input-aware probabilistic channelwise soft masks.
(2) Variational inference and polarization regularization are introduced to learn and push the probabilistic masks towards two ends and therefore clearly separate important channels from redundant ones. (3) Extensive evaluations of PPSM on popular network architectures and datasets show our method outperforms the state of the arts, and it prunes more FLOPs with less loss of model accuracy.

Related Works
2.1. Network Pruning. Network pruning eliminates the unnecessary weights or structured units such as filters and neurons of a pretrained neural network. Fine-grained pruning directly removes redundant weights within a filter or neuron and produces a highly sparse weight matrix. Many works [20,21] mainly apply sparsity-induced penalty on the weights to remove insignificant weights. While nonstructured pruning greatly reduces the parameters of the network, it is not hardware-friendly and requires a specially designed sparsity matrix multiplication library for acceleration. By comparison, coarse-grained or structured pruning [22][23][24][25] aims at removing structured units such as filters, channels, or layers. e widely used strategy of structured pruning is to attach a learnable scaling factor or mask after each structure is pruned with sparsity regularization [26][27][28]. Jung et al. [29] propose a new real-time target tracking meta-learning framework with efficient model adaptation and channel pruning. He et al. [22] propose meta-attribute-based filter pruning (MFP), which adaptively selects the most appropriate pruning standard through an attribute (meta-attribute) of the current state of the neural network. Li et al. [23] propose a new fusion catalytic pruning method called FuPruner to simultaneously optimize parametric and nonparametric operators to accelerate neural networks. Some recent efforts use two or more different techniques for joint optimization. is provides another flexible option for network compression because the two technologies complement each other. e joint optimization of pruning and other model compression algorithms (such as quantization, knowledge distillation, and matrix decomposition) [30][31][32] can deal with a larger search space and obtain a more compact network. Recent works like joint-DetNAS [33] and NPAS [34] perform joint optimization of neural architecture search (NAS) and pruning.

Neural Architecture Search
. NAS aims at automatically finding a compact neural architecture from a large search space. Early works use either reinforcement learning [35] or genetic algorithm [36] to update model responses for generating architectures with better performance. However, the search space of these methods is very large and significant computational overhead is required to search and select the best model from thousands of models. To address this problem, Differentiable Architecture Search (DARTS) [37] continuously processes the search space, which facilitates optimization algorithms such as gradient descent to find the optimal network structure. Our method can also be 2 Computational Intelligence and Neuroscience seen as a NAS process. Compared with conventional NAS, our method obtains the posterior distribution of the mask given the baseline output features and uses variational inference to learn the soft mask. en, the polarization regularization of the soft mask is employed to remove the channels with soft masks close to zero, resulting in a compact network.

Notations and Preliminaries.
Given a batch of input images x i N i�1 , the baseline network outputs corresponding feature maps y i N i�1 from the last layer. e input and output pairs (x i , y i ) N i�1 constitute a training dataset for supervised channel pruning. We use F i ∈ R C i ×W i ×H i to denote the feature map derived from the i-th layer, where i is the layer index, C i is the number of channels, and W i and H i are the height and width of the feature map, respectively. Suppose the number of filters across the network is n, we use a n-dimensional variable m � (m (1) , . . . , m (n) ) to represent the soft masks, where each element m (i) ∈ [0, 1]. By multiplying the channelwise outputs of the baseline network by the soft masks, we can get a pruned network through setting certain masks to 0. For each input x i , the corresponding soft masks are denoted as m i , and the output feature map of the pruned network is optimized to approximate the baseline y i .

Probabilistic Soft Masks.
e usage of probabilistic masks is motivated by the instance-aware channel ranking used in dynamic pruning. A single deterministic mask cannot capture such dynamics, while a distribution is more effective to characterize the variance of channel importance in static network pruning. In addition, learning a distribution tends to have better stability than learning a single deterministic value. erefore, probabilistic soft mask is used to capture the variance of channel importance. Given an input sample x i , we assume the output y i can be well approximated by cancelling out certain filters. Based on these conceptions, we formulate the dependence of output feature map y i on input x i and soft masks m i with a deep conditional generative model (CGM): for given input x i , sampled m i from the prior distribution p θ (m i |x i ) and generated output y i from the distribution p θ (y i |x i , m i ). Direct training of the deep CGM to maximize the conditional log-likelihood is intractable; therefore, we employ stochastic gradient variational Bayes (SGVB) [38] where ϕ and θ are variational and generative parameters, respectively. e KL divergence measures the similarity between the approximate and true posteriors.
To simplify the computation, we further assume the posterior distribution of m i is only conditioned on y i , that is, We adopt a simplified form of the conditional probability based on two reasons: (1) although the baseline network may give same outputs for different inputs, this will rarely happen given the complex nonlinear property of the network; (2) even if the output features of two images are same, the images are most likely from the same class and have little semantic difference. erefore, we remove x i from conditional probability q ϕ (m i |x i , y i ) to simplify the model training. As the element value of m i is constrained to the interval [0, 1], directly approximating the posterior distribution of m i is computationally inconvenient; therefore, we introduce an auxiliary n-dimensional real-valued variable z i , to calculate m i by applying sigmoid function to each element of z i , which we denote as m i � S(z i ). Based on these definitions, the lower bound can now be formulated as We then leverage conditional variational auto-encoder (CVAE) to optimize the lower bound with respect to both ϕ and θ. Figure 1 shows the proposed PPSM framework that is built on the CVAE to reason the probabilistic soft masks. e encoder consists of 5 fully connected layers of which the last two layers output the mean and variance of each z i , and the decoder is the pruned network that has same structure to the baseline network. Specifically, we use a centered isotropic multivariate Gaussian for the conditional prior on z i with p θ (z i |x i ) � N(z i ; 0, I) and also assume the variational posterior is a multivariate normal distribution with diagonal covariance matrix: where μ i and σ i are the outputs of the encoder and represents the mean and s.d. of the posterior, respectively. We sample z i from the posterior q ϕ (z i |y i ) using where ⊙ denotes element-wise product. e soft masks m i are then calculated using m i � S(z i ).
We use the mean soft masks m � 1/N i m i to scale the channelwise feature maps for each input x i . Here, m represents an average contribution of the inputs to channel importance and shows less variance than m i , 1 ≤ i ≤ N, and therefore is easier to be optimized. Given the soft masks m and input x i , the decoder yields the reconstructed feature map f(x i , m; θ) to approximate the baseline y i . Specifically, we use MSE loss to align the outputs of the pruned and baseline networks: e optimization objective now becomes minimizing the following loss function: e CVAE loss function makes it convenient to differentially approximate the posteriors of the soft masks and effectively recovering the baseline features.

Polarization Regularization.
Optimization of L CVAE does not provide a guarantee of clear separation of important filters from redundant ones; therefore, appropriate regularization on the soft masks is essential for harmless channel pruning. e conventional strategies that use either L1 or L2 regularization [39] aim to minify the masks of unimportant filters and need to carefully explore a threshold on the masks to prune filters with masks below the threshold. Inspired by the work in [28], we introduce a polarization regularizer on the probabilistic soft masks to push the posteriors of the masks towards 0 or 1, such that sweet spots that clearly separate the channels into two parts can be easily found. e adopted polarization regularizer is defined as follows: where m denotes the mean of m (1) , . . . , m (n) . e effect of the second RHS term − ‖m − m1 n ‖ 1 is to keep m (i) , 1 ≤ i ≤ n as far away from the mean as possible. e term − ‖m − m1 n ‖ 1 gets its extremums at vertices of the n-dimensional cube [0, 1] n , and the minimum is reached if half elements of m are 0. e hyperparameter t is introduced to control the weight of L1 regularization and also determine the sparsity of the soft masks.

Optimization.
By combining the loss functions associated with the CVAE, polarization regularizer, and regularizations on parameters ϕ and θ, we derive the following objective function: where R(ϕ) is L2 regularization on variational parameters ϕ, R(θ) is L2 regularization on generative parameters θ, and the weights λ ϕ and λ θ are fixed to 5e − 4. We can optimize the objective function with respect to ϕ and θ using a differentiable algorithm such as stochastic gradient descent (SGD).

Pruning
Strategy. After the model converges, the distribution of soft masks is analyzed to identify unimportant filters. Given a batch of input images, we measure the expected value of the mean soft masks m: 4 Computational Intelligence and Neuroscience As q ϕ (z i |y i )S(z i ) forms a complex function with respect to z i , the integral is intractable to calculate; therefore, we get Monte Carlo estimate of the expectation of m i as follows: where L is the number of samples. Each element of m denotes the soft mask attached to one of the filters. By utilizing the polarization effect, we do not need to explore a threshold on soft masks and can directly set the threshold as 0.5 to prune filters. When investigating the distribution histogram of m, a bimodal distribution is always observed and two peaks are clearly separated: one locates close to 0, and the other locates close to 1 (illustrated in Figure 2). In addition, the filters are completely separated into two parts with a large margin. We also observe that the batch of inputs has little effect on the distribution of soft masks after the model converges (demonstrated in Figure 3); therefore, only one batch of inputs is required when pruning the filters.

Implementation Details.
All networks were trained from scratch. e same data augmentation strategies were used as done in PyTorch official examples [45]. e training was conducted to run 200 and 100 epochs on CIFAR-10 and ImageNet datasets, respectively, with an initial learning rate of 0.1 and a mini-batch size of 128. e learning rate was multiplied by 0.1 at 50% and 75% of the training epochs on CIFAR-10, and multiplied by 0.1 at 30, 60, and 90 epochs on ImageNet. We utilized an SGD optimizer with a weight decay of 0.0005 and a momentum of 0.9. For MobileNet v2 on ImageNet, we used cosine annealing to automatically reduce the learning rate. All experiments were implemented on two NVIDIA RTX 3090 GPUs and Intel(R) Xeon(R) Gold 5218 CPU by PyTorch.

Hyperparameter.
During the pruning process, we need to set two hyperparameters λ and t to achieve desired FLOPs reduction. e hyperparameter λ controls the weight of polarization regularizer. With a larger λ, the soft mask will move more obviously to 0 and 1. e hyperparameter t controls the ratio of FLOPs to be reduced. A larger t will result in more FLOP reduction. In our experiments, we empirically set λ � 0.0004 on CIFAR-10 and λ � 0.00005 on ImageNet. To obtain the desired FLOPs reduction, different t values need to be tested for different network architectures (as shown in Table 1), and we set the range of t to [− 2, 2]. For example, when pruning ResNet56 on CIFAR-10, we obtained FLOPs reduction by 54.6% at t � 0.2. In addition, the initial learning rate during fine-tuning was set to 0.01.

Results on CIFAR-10.
We first compared our method to the state of the arts on a small-scale CIFAR-10 dataset. Channel pruning was performed on four popular neural networks including VGG16, ResNet32, ResNet56, and ResNet110, and the results are shown in Table 2.
When pruning VGG16, PPSM elevates the accuracy by 0.06% with 66.20% FLOPs pruned and performs better than HRank [11] and SCP [14] by yielding similar FLOP reduction. For ResNet32, when compared to LFPC [46] and Wang et al. [47], our method achieves the best accuracy at similar pruning rates of ∼53%, with an increase of 0.12% over baseline accuracy. In addition, PPSM outperforms LRF [27] and MainDP [15] by pruning 64.35% FLOPs and improves model accuracy by 0.09%. For ResNet56, PPSM was compared to 9 state-of-the-art methods in terms of high pruning rate (∼75% drop in FLOPs) and low pruning rate (∼50% drop in FLOPs), and our method performs better than or comparably to the competitors. For instance, with more FLOPs removed (75.62% vs 73.90%), PPSM exhibits lower accuracy loss (0.22% vs 0.26%) than LRF. Our method also increases the accuracy by 0.13% with 54.6% FLOPs compression, which is better than the results of DPFPS [49] and Wang et al. [47]. Figure 4(a) depicts the change in the test accuracy of different methods with respect to the percent of reduced FLOPs, and the results suggest PPSM achieves a higher accuracy than the competitors across different FLOP reduction rates. For ResNet110, with a similar FLOP reduction rate of ∼68.5%, our method performs much better than HRank in preserving network accuracy (− 0.23% vs 0.85% accuracy loss). In addition, LRF improves model accuracy by 0.58% at 62.6% FLOP reduction, and PPSM prunes more FLOPs (68.7%) with 0.23% increase in model accuracy.
We utilize one batch of inputs to reason the distribution of the soft masks after the model converges and use the inferred distribution to determine the filters to prune. To investigate the effect of different batches of inputs on the distribution of the soft masks, we compared the inferred m of first 100 filters across 100 batches on VGG16, ResNet32, ResNet56, and ResNet110. e results in Figure 3 imply our method is robust to the change in batches and outputs highly consistent soft mask for each filter across different batches, suggesting the CVAE framework and polarization regularizer adopted in PPSM are beneficial to stabilizing the learning of probabilistic masks, therefore making PPSM well adaptive to different network architectures.

Results on ImageNet.
We further evaluated the performance of PPSM on large-scale ImageNet dataset, and also made comparisons to the state-of-the-art methods.

Computational Intelligence and Neuroscience
We evaluated the top-1 and top-5 accuracy and FLOP reduction rate of PPSM on ResNet50 and MobileNet V2 networks, and the results are shown in Table 3. To better verify the effectiveness of our method, the competitive methods we choose are from recently published works, such as GAL [16], HRank [11], Zhuang et al. [28], GBN [17], DMC [13], DMCP [18], SCP [14], SCOP [50], LRF [27], DPFPS [49], CHIP [51], and SRR-GR [48]. For ResNet50, we conducted experiments at pruning rates of 50%, 60%, and 70%. e results show PPSM surpasses other methods in top-1 and top-5 accuracy when FLOPs are reduced by 60% and 70%. Specifically, compared with GAL, HRank, and CHIP, PPSM has the maximum reduction rate of 65.91% in FLOPs, while its top-1 accuracy only decreases by 0.7% and top-5 accuracy only decreases by 0.32, which is significantly better than the results of other three methods. Similarly, when FLOPs are reduced by ∼ 70%, our method delivers higher top-1 and top-5 accuracies than other methods. With ∼ 55% FLOP reduction, LRF better recovers model accuracy than DMC, SCP, and SRR-GR. e pruning rate of LRF is slightly higher than that of PPSM (56.40% vs 53.07%), but the top-1 accuracy of PPSM decreases less than that of LRF (0.35% vs 0.50%). As shown in Figure 4(b), PPSM's accuracy is less sensitive to the FLOP reduction rate, whereas the accuracy of the existing most advanced methods decreases significantly as the pruning rate increases. For the lightweight network MobileNet v2, PPSM has the lowest decrease in accuracy after pruning. Our method removes ∼ 28% FLOPs with only 0.45% accuracy loss, while Metapruning [52] causes 0.80% drop in accuracy when 27% FLOPs are pruned, and DPFPS prunes ∼ 25% FLOPs with a cost of 0.9% accuracy loss.
Taken together, the superior performance of PPSM is attributed to the effective polarization of probabilistic soft masks in a CVAE framework, where the uncertainty on channel importance is well characterized by approximating posterior distribution of the soft masks.

e Effectiveness of the Probabilistic Mask.
To verify the effectiveness of our adopted probabilistic masks, we      ese results demonstrate probabilistic masks can effectively capture the uncertainty on channel importance and thus deliver more accurate identifications of important channels than the deterministic masks.

e Effect of Batch Size on Learning the Probabilistic
Masks. Polarization regularization encourages the masks to move towards both ends and results in a clear boundary between the two parts of separated filters and thus makes it easier to select a threshold to remove the less important filters. As PPSM gathers statistics of the soft masks from the images within a batch to prune the filters, we further examined the effect of batch size on channel pruning. Specifically, the results based on batch sizes of 64, 128, and 256 were compared. e results in Figure 2 suggest batch size has little effect on learning the distribution of the masks, and filters are clearly divided into two parts with soft masks close to either 0 or 1 across different batch sizes. In addition, when the batch size increases, the distance between two peaks of the distribution also increases, suggesting the enhanced statistical strength of PPSM gained by the combination of CVAE with the polarization regularization.

Conclusions
In this article, we propose a novel differentiable channel pruning method called polarization of probabilistic soft mask (PPSM). To capture the statistical behavior of the channel importance that is modeled in dynamic pruning under an input-aware manner, PPSM exploits variational inference to learn the posterior distributions of the masks and simultaneously classifies the filters into two clearly separated parts by leveraging a new polarization regularization, and thus, the channels with masks close to zero can be safely removed with little effect on network accuracy. We evaluated the performance of PPSM on several popular network architectures using CIFAR-10 and ImageNet datasets, and the results demonstrate our method performs competitive to the state of the arts. One of the limitations of PPSM lies in its low efficiency in learning the soft masks via the CVAE framework, and we plan to improve this in near future.

Conflicts of Interest
e authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this article.