Learnable Tree Filter for Structure-preserving Feature Transform

09/27/2019 ∙ by Lin Song, et al. ∙ Megvii Technology Limited Xi'an Jiaotong University 0

Learning discriminative global features plays a vital role in semantic segmentation. And most of the existing methods adopt stacks of local convolutions or non-local blocks to capture long-range context. However, due to the absence of spatial structure preservation, these operators ignore the object details when enlarging receptive fields. In this paper, we propose the learnable tree filter to form a generic tree filtering module that leverages the structural property of minimal spanning tree to model long-range dependencies while preserving the details. Furthermore, we propose a highly efficient linear-time algorithm to reduce resource consumption. Thus, the designed modules can be plugged into existing deep neural networks conveniently. To this end, tree filtering modules are embedded to formulate a unified framework for semantic segmentation. We conduct extensive ablation studies to elaborate on the effectiveness and efficiency of the proposed method. Specifically, it attains better performance with much less overhead compared with the classic PSP block and Non-local operation under the same backbone. Our approach is proved to achieve consistent improvements on several benchmarks without bells-and-whistles. Code and models are available at https://github.com/StevenGrove/TreeFilter-Torch.



There are no comments yet.


page 6

page 7

Code Repositories


Learnable Tree Filter for Structure-preserving Feature Transform

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Scene perception, based on semantic segmentation, is a fundamental yet challenging topic in the vision field. The goal is to assign each pixel in the image with one of several predefined categories. With the developments of convolutional neural networks (CNN), it has achieved promising results using improved feature representations. Recently, numerous approaches have been proposed to capture larger receptive fields for global context aggregation 

long2015fully; chen2018deeplab; zhao2017pyramid; zhang2018context; li2019attention, which can be divided into local and non-local solutions according to their pipelines.

Traditional local approaches enlarge receptive fields by stacking conventional convolutions simonyan2014very; szegedy2015going; he2016deep or their variants (e.g., atrous convolutions chen2014semantic; chen2018deeplab). Moreover, the distribution of impact within a receptive field in deep stacks of convolutions converges to Gaussian luo2016understanding, without detailed structure preservation (the pertinent details, which are proved to be effective in feature representation dai2017deformable; zhu2018deformable). Considering the limitation of local operations, several non-local solutions have been proposed to model long-range feature dependencies directly, such as convolutional methods (e.g., non-local operations wang2018non, PSP zhao2017pyramid and ASPP modules chen2018deeplab; chen2017rethinking; chen2018encoder) and graph-based neural networks liang2017interpretable; li2018beyond; landrieu2018large. However, due to the absence of structure-preserving property, which considers both spatial distance and feature dissimilarity, the object details are still neglected. Going one step further, the abovementioned operations can be viewed as coarse feature aggregation methods, which means they fail to explicitly preserve the original structures when capturing long-range context cues.

In this work, we aim to fix this issue by introducing a novel network component that enables efficient structure-preserving feature transform, called learnable tree filter. Motivated by traditional tree filter yang2015stereo, a widely used image denoising operator, we utilize tree-structured graphs to model long-range dependencies while preserving the object structure. To this end, we first build the low-level guided minimum spanning trees (MST), as illustrated in Fig. 1

. Then the distance between vertices in MST are calculated based on the high-level semantics, which can be optimized in backpropagation. For example, the dissimilarity

between vertex and in Fig. 1 is calculated from semantic-rich feature embeddings. Thus, combined with the structural property of MST, the spatial distance and feature dissimilarity have been modeled into tree-structured graph simultaneously (e.g., the distance between vertex k and its spatially adjacent one n has been enlarged in Fig. 1, for that more edges with dissimilarities are calculated when approaching n). To enable the potential for practical application, we further propose an efficient algorithm which reduces the complexity of brute force implementation to linear-time consumption. Consequently, different from conditional random fields (CRF) chandra2017dense; harley2017segmentation; liu2017learning, the formulated modules can be embedded into several neural network layers for differentiable optimization.

In principle, the proposed tree filtering module is fundamentally different from most CNN based methods. The approach exploits a new dimension: tree-structure graph is utilized for structure-preserving feature transform, bring detailed object structure as well as long-range dependencies. With the designed efficient implementation, the proposed approach can be applied for multi-scale feature aggregation with much less resource consumption. Moreover, extensive ablation studies have been conducted to elaborate on its superiority in both performance and efficiency even comparing with PSP block zhao2017pyramid and Non-local block wang2018non. Experiments on two well-known datasets (PASCAL VOC 2012 everingham2010pascal and Cityscapes cordts2016cityscapes) also prove the effectiveness of the proposed method.

Figure 1: Toy illustration of the tree filtering module. Given a detail-rich feature map from low-level stage, we first measure the dissimilarity between each pixel and its’ quad neighbours. Then, the MST is built upon the 4-connected planar graph to formulate a learnable tree filter. The edge between two vertices denotes the distance calculated from high-level semantics. Red edges indicate the close relation with vertex . The intra-class inconsistency could be alleviated after feature transform.

2 Learnable Tree Filter

To preserve object structures when capturing long-range dependencies, we formulate the proposed learnable tree filter into a generic feature extractor, called tree filtering module. Thus, it can be easily embedded in deep neural networks for end-to-end optimization. In this section, we firstly introduce the learnable tree filtering operator. And then the efficient implementation is presented for practical applications. The constructed framework for semantic segmentation is elaborated at last.

2.1 Formulation

First, we represent the low-level feature as a undirected graph , with the dissimilarity weight for edges. The vertices are the semantic features, and the interconnections of them can be denoted as . Low-level stage feature map, which contains abundant object details, is adopted as the guidance for 4-connected planar graph construction, as illustrated in Fig.1. Thus, a spanning tree can be generated from the graph by performing a pruning algorithm to remove the edges with the substantial dissimilarity. From this perspective, the graph turns out to be the minimum spanning tree (MST) whose sum of dissimilarity weights is minimum out of all spanning trees. The property of MST ensures preferential interaction among similar vertices. Motivated by traditional tree filter yang2015stereo, a generic tree filtering module in the deep neural network can be formulated as:


Where and indicate the index of vertices, denotes the set of all vertices in the tree , represents the input encoded features and means the output signal sharing the same shape with . is a hyperedge which contains a set of vertices traced from vertex to in . The similarity function projects the features of the hyperedge into a positive scalar value, as described in Eq. 2. The unary function represents the feature embedding transformation. is the summation of similarity alone with to normalize the response.


According to the formula in Eq. 1, the tree filtering operation can be considered as one kind of weighted-average filter. The variable indicates the dissimilarity between adjacent vertices ( and ) that can be computed by a pairwise function (here we adopt Euclidean distance). The distance between two vertices ( and ) is defined as summation of dissimilarity along the path in hyperedge . Note that degenerates into spatial distance when is set to a constant matrix. Since actually measures pairwise distance in the embedded space, the aggregation along the pre-generated tree considers spatial distance and feature difference simultaneously.


The tree filtering module can be reformulated to Eq. 3. Obviously, the input feature and dissimilarity take responsibility for the output response . Therefore, the derivative of output with respect to input variables can be derived as Eq. 4 and Eq. 5. in Eq. 5 is defined with the children of vertex in the tree whose root node is the vertex .


In this way, the proposed tree filtering operator can be formulated as a differentiable module, which can be optimized by the backpropagation algorithm in an end-to-end manner.

2.2 Efficient Implementation

Let denotes the number of vertices in the tree . The tree filtering module needs to be accumulated times for each output vertex. For each channel, the computational complexity of brute force implementation is that is prohibitive for practical applications. Definition of the tree determines the absence of loop among the connections of vertices. According to this property, a well-designed dynamic programming algorithm can be used to speed up the optimization and inference process.

We introduce two sequential passes, namely aggregation and propagation, which are performed by traversing the tree structure. Let one vertex to be the root node. In the aggregation pass, the process is traced from the leaf nodes to the root node in the tree. For a vertex, its features do not update until all the children have been visited. In the propagation pass, the features will propagate from the updated vertex to their children recursively.


The sequential passes can be formulated into the recursive operators for the input : the aggregation pass and the propagation pass are respectively illustrated in Eq. 6 and Eq. 7, where par() indicates the parent of vertex in the tree whose root is vertex . is initialized from the updated value of root vertex .

Input: Tree ; Input feature ; Pairwise distance ; Gradient of loss w.r.t. output feature ; channel , vertex ; Set of vertices . Output: Output feature ; Gradient of loss w.r.t. input feature , w.r.t. pairwise distance . Preparation:   

Root vertex sampled with uniform distribution

   Breadth-first topological order for Aggr and Prop    All-ones matrix for normalization coefficient Forward: Aggregation from leaves to root Propagation from root to leaves Normalized output feature Backward: Aggregation from leaves to root Propagation from root to leaves Gradient of loss w.r.t. input feature for  do
           Parent of vertex      Gradient of unnormalized output feature      Gradient of normalization coefficient      Gradient of loss w.r.t. pairwise distance       end for
Algorithm 1 Linear time algorithm for Learnable Tree Filter

As shown in the algorithm 1, we propose a linear-time algorithm for the tree filtering module, whose proofs are provided in the appendix. In the preparation stage, we uniformly sample a vertex as the root and perform breadth-first sorting (BFS) algorithm to obtain the topological order of tree . The BFS algorithm can be accelerated by the parallel version on GPUs and ensure the efficiency of the following operations.

To compute the normalization coefficient, we construct an all-ones matrix as . Since the embedded feature is independent of the matrix

, the forward computation can be factorized into two dynamic programming processes. Furthermore, we propose an efficient implementation for the backward process. To reduce the unnecessary intermediate process, we combine the gradient of module and loss function. Note that the output

and normalization coefficient have been already computed in the inference phase. Thus the key of the backward process is to compute the intermediate variables and . The computation of these variables can be accelerated by the proposed linear time algorithm. And we adopt another iteration process for the gradient of pairwise distance.

Computational complexity. Since the number of batch and channel is much smaller than the vertices in the input feature, we only consider the influence of the vertices. For each channel, the computational complexity of all the processes in algorithm 1, including the construction process of MST and the computation of pairwise distance, is , which is linearly dependent on the number of vertices. It is necessary to point out that MST can be built in linear time using Contractive Borvka algorithm if given a planar graph, as is designed in this paper. Note that the batches and channels are independent of each other. For the practical implementation on GPUs, we can naturally perform the algorithm for batches and channels in parallel. Also, we adopt an effective scheduling scheme to compute vertices of the same depth on the tree parallelly. Consequently, the proposed algorithm reduces the computational complexity and time consumption dramatically.

2.3 Network Architecture for Segmentation

Figure 2: Overview of the proposed framework for semantic segmentation. The network is composed of a backbone encoder and a naive decoder. GAP denotes the extra global average pooling block. Multi-groups means using different feature splits to generate multiple groups of tree weights. The right diagram elaborates on the process details of a single stage tree filtering module, denoted as the green node in the decoder. Red arrows represent upsample operations. Best viewed in color.

Based on the efficient implementation algorithm, the proposed tree filtering module can be easily embedded into deep neural networks for resource-friendly feature aggregation. To illustrate the effectiveness of the proposed module, here we employ ResNet he2016deep

as our encoder to build up a unified network. The encoded features from ResNet are usually computed with output stride 32. To remedy for the resolution damage, we design a naive decoder module following previous works 

chen2017rethinking; chen2018encoder. In details, the features in the decoder are gradually upsampled by a factor of 2 and summed with corresponding low-level features in the encoder, similar to that in FPN lin2017feature. After that, the bottom-up embedding functions in the decoder are replaced by tree filter modules for multi-scale feature transform, as intuitively illustrated in Fig. 2.

To be more specific, given a low-level feature map in top-down pathway, which riches in instance details, a 4-connected planar graph can be constructed easily with the guidance of . Then, the edges with substantial dissimilarity are removed to formulate MST using the Borvka gallager1983distributed algorithm. High level semantic cues contained in are extracted using a simplified embedding function (Conv ). To measure the pairwise dissimilarity in Eq. 2 ( in Fig. 2), the widely used Euclidean distance wang2017normface is adopted. Furthermore, different groups of tree weights are generated to capture component dependent features, which will be analyzed in Sec 3.2. To highlight the effectiveness of the proposed method, the feature transformation is simplified to identity mapping, where . Thus, the learnable tree filter can be formulated by the algorithm elaborated on Sec. 2.1. Finally, the low-level feature map is fused with the operation output using pixel-wise summation. For multi-stage feature aggregation, the building blocks (green nodes in Fig. 1) are applied to different resolutions (Stage 1 to 3 in Fig. 1). An extra global average pooling operation is added to capture global context and construct another tree filtering module (Stage 4). The promotion brought by the extra components will be detailed discussed in ablation studies.

3 Experiments

In this section, we firstly describe the implementation details. Then the proposed approach will be decomposed step-by-step to reveal the effect of each component. Comparisons with several state-of-the-art benchmarks on PASCAL VOC 2012 everingham2010pascal and Cityscapes cordts2016cityscapes are presented at last.

3.1 Implementation Details

Following traditional protocols zhao2017pyramid; chen2018encoder; yu2018learning, ResNet he2016deep

is adopted as our backbone for the following experiments. Specifically, we employ the “poly” schedule with an initial learning rate 0.004 and power 0.9. The networks are optimized for 40K iterations using mini-batch stochastic gradient descent (SGD) with a weight decay of 1e-4 and a momentum of 0.9. We construct each mini-batch for training from 32 random crops (

for PASCAL VOC 2012 everingham2010pascal and for Cityscapes cordts2016cityscapes) after randomly flipping and scaling each image by 0.5 to 2.0.

Figure 3: Visualization of affinity maps in the specific position (marked by the red cross in each input image). TF and NL denotes using the proposed Tree Filtering module and Non-local block wang2018non, respectively. Different positions, resolution stages, and selected groups are explored in (a), (b), and (c), respectively. Our approach preserves more detailed structures than Non-local block. All of the input images are sampled from PASCAL VOC 2012 val set.
Figure 4: Qualitative results on PASCAL VOC 2012 val set. Given an input image from the top row, the structural cues are preserved in the corresponding prediction (the middle row). The generated results contains rich details even compared with its ground truth in the bottom row.

3.2 Ablation Studies

To elaborate on the effectiveness of the proposed approach, we conduct extensive ablation studies. First, we give detailed structure-preserving relation analysis as well as the visualization, as presented in Fig. 3. Next, the equipped stage and group number of the tree filtering module is explored. Different building blocks are compared to illustrate the effectiveness and efficiency of the tree filtering module.

Structure-preserving relations. As intuitively presented in the first row of Fig. 3, given different positions, the corresponding instance details are fully activated with the high response, which means that the proposed module has learned object structures and the long-range intra-class dependencies. Specifically, the object details (e.g., boundaries of the train rather than the coarse regions in Non-local wang2018non blocks) have been highlighted in the affinity maps. Qualitative results are also given to illustrate the preserved structural details, as presented in Fig. 4.

Which stage to equip the module? Tab. 2 presents the results when applying the tree filtering module to different stages (group number is fixed to 1). The convolutional operations are replaced by the building block (green nodes in Fig. 2) to form the equipped stage. As can be seen from Tab. 2, the network performance consistently improves with more tree filtering modules equipped. This can also be concluded from the qualitative results (the second row in Fig. 3), where the higher stage (Stage 3) contains more semantic cues and lower stages (Stage 1 and 2) focus more on complementary details. The maximum gap between the multi-stage equipped network and the raw one even up to 4.5% based on ResNet-50 backbone. Even with the powerful ResNet-101, the proposed approach still attains a 2.1% absolute gain, reaching 77.1% mIoU on PASCAL VOC 2012 val set.

Different group numbers. Different group settings are used to generate weights for the single-stage tree filtering module when fixing the channel number to 256. As shown in Tab. 2, the network reaches top performance (Group Num=16) when the group number approaching the category number (21 in PASCAL VOC 2012), and additional groups afford no more contribution. We guess the reason is that different kinds of tree weights are needed to deal with similar but different components, as shown in the third row of Fig. 3.

Different building blocks. We further compare the proposed tree filtering module with classic context modeling blocks (e.g., PSP block zhao2017pyramid and Non-local block wang2018non) and prove the superiority both in accuracy and efficiency. As illustrated in Tab. 3, the proposed module (TF) achieves better performance than others (PSP and NL) with much less resource consumption. What’s more, the tree filtering module brings consistent improvements in different backbones (5.2% for ResNet-50 with stride 32 and 2.6% for ResNet-101) with additional 0.7M parameters and 1.3G FLOPs overheads. Due to the structural-preserving property, the proposed module achieves additional 1.1% improvement over the PSP block with neglectable consumption, as presented in Tab. 3. And the extra Non-local block contributes no more gain over the tree filtering module, which could be attributed to the already modeled feature dependencies in the tree filter.

Extra components. Same with other works chen2018encoder; yu2018learning, we adopt some simple components for further improvements, including an extra global average pooling operation and additional ResBlocks in the decoder. In details, the global average pooling block combined with the Stage 4 module (see Fig. 2) is added after the backbone to capture global context, and the “Conv11” operations in the decoder (refer to the detailed diagram in Fig. 2) are replaced by ResBlocks (with “Conv3x3”). As presented in Tab. 5, the proposed method achieves consistent improvements and attains 79.4% on PASCAL VOC 2012 val set without applying data augmentation strategies.

Backbone Stage mIoU (%)
ResNet-50 None 70.2
+ Stage 1 72.1
+ Stage 1-2 73.1
+ Stage 1-3 74.7
ResNet-101 None 75.0
+ Stage 1-3 77.1
Table 2: Comparisons among different group settings of tree filtering module on PASCAL VOC 2012 val set when using the proposed decoder structure. Multi-scale and flipping strategy are adopted for testing.
Backbone Group Num mIoU (%)
ResNet-50 0 70.2
1 72.1
4 73.2
8 74.0
16 74.4
32 74.4
Table 1: Comparisons among different stages to equip tree filtering module on PASCAL VOC 2012 val set when using the proposed decoder architecture. Multi-scale and flipping strategy are adopted for testing.
Backbone Block Decoder OS Params (M) FLOPs (G) FLOPs (G) mIoU (%)
ResNet-50 None 8 129.3 162.1 0.0 69.2
NL 8 158.3 199.1 +37.0 74.2
PSP 8 178.3 171.1 +9.0 74.3
TF 8 133.3 163.1 +1.0 74.9
NL+TF 8 162.3 200.1 +38.0 74.9
PSP+TF 8 182.3 172.1 +10.0 75.4
ResNet-50 None 32 102.0 39.2 0.0 67.3
TF 32 102.7 40.5 +1.3 72.5
ResNet-101 None 32 175.0 65.0 0.0 72.8
TF 32 175.7 66.3 +1.3 75.4
Table 3: Comparisons among different building blocks on PASCAL VOC 2012 val set when using ResNet-50 as feature extractor without decoder. TF, PSP, and NL denotes using the proposed Tree Filtering module, PSP block zhao2017pyramid, and Non-local block wang2018non as the building block, respectively. OS represents the output stride used in the backbone. We calculate FLOPs when given a single scale input. All of the data augmentation strategies are dropped.
Backbone TF(multi) Extra mIoU (%)
ResNet-50 67.3
ResNet-101 78.3
Table 5: Comparisons with state-of-the-arts results on Cityscapes test set trained on fine annotation. We adopt vanilla ResNet-101 as our backbone.
Method Backbone mIoU (%)
RefineNet lin2017refinenet ResNet-101 73.6
DSSPN liang2018dynamic ResNet-101 77.8
PSPNet zhao2017pyramid ResNet-101 78.4
BiSeNet yu2018bisenet ResNet-101 78.9
DFN yu2018learning ResNet-101 79.3
PSANet zhao2018psanet ResNet-101 80.1
DenseASPP yang2018denseaspp DenseNet-161 80.6
Ours ResNet-101 80.8
Table 4: Comparisons among different components on PASCAL VOC 2012 val set. TF(multi) denotes multi-stage Tree Filtering modules with decoder. Extra represents extra components. All of the data augmentation strategies are dropped.
without MS-COCO pretrain Method Backbone mIoU (%) FCN long2015fully VGG 16 62.2 Deeplab v2 chen2018deeplab VGG 16 71.6 DPN liu2015semantic VGG 16 74.1 Piecewise lin2016efficient VGG 16 75.3 PSPNet zhao2017pyramid ResNet-101 82.6 DFN yu2018learning ResNet-101 82.7 EncNet zhang2018context ResNet-101 82.9 Ours ResNet-101 84.2 with MS-COCO pretrain Method Backbone mIoU (%) GCN peng2017large ResNet-152 82.2 RefineNet lin2017refinenet ResNet-101 84.2 PSPNet zhao2017pyramid ResNet-101 85.4 Deeplab v3 chen2017rethinking ResNet-101 85.7 EncNet zhang2018context ResNet-101 85.9 DFN yu2018learning ResNet-101 86.2 ExFuse zhang2018exfuse ResNet-101 86.2 Ours ResNet-101 86.3
Table 6: Comparisons with state-of-the-arts results on PASCAL VOC 2012 test set. We adopt vanilla ResNet-101 without atrous convolutions as our backbone.

3.3 Experiments on Cityscapes

To further illustrate the effectiveness of the proposed method, we evaluate the Cityscapes cordts2016cityscapes dataset. Our experiments involve the 2975, 500, 1525 images in train, val, and test set, respectively. With multi-scale and flipping strategy, the proposed method achieves 80.8% mIoU on Cityscapes test set when trained on fine annotation data only. Compared with previous leading algorithms, our method achieves superior performance using ResNet-101 without bells-and-whistles.

3.4 Experiments on PASCAL VOC

We carry out experiments on PASCAL VOC 2012 everingham2010pascal that contains 20 object categories and one background class. Following the procedure of yu2018learning, we use the augmented data hariharan2011semantic with 10,582 images for training and raw train-val set for further fine-tuning. In inference stage, multi-scale and horizontally flipping strategy are adopted for data augmentation. As shown in Tab. 6, the proposed method achieves the state-of-the-art performance on PASCAL VOC 2012 everingham2010pascal test set. In details, our approach reaches 84.2% mIoU without MS-COCO lin2014microsoft pre-train when adopting vanilla ResNet-101 as our backbone. If MS-COCO lin2014microsoft is added for pre-training, our approach achieves the leading performance with 86.3% mIoU.

4 Conclusion

In this work, we propose the learnable tree filter for structure-preserving feature transform. Different from most existing methods, the proposed approach leverages tree-structured graph for long-range dependencies modeling while preserving detailed object structures. We formulate the tree filtering module and give an efficient implementation with linear-time source consumption. Extensive ablation studies have been conducted to elaborate on the effectiveness and efficiency of the proposed method, which is proved to bring consistent improvements on different backbones with little computational overhead. Experiments on PASCAL VOC 2012 and Cityscapes prove the superiority of the proposed approach on semantic segmentation. More potential domains with structure relations (e.g., detection and instance segmentation) remain to be explored in the future.

5 Acknowledgment

We would like to thank Lingxi Xie for his valuable suggestions. This research was supported by the National Key R&D Program of China (No. 2017YFA0700800).


Appendix A Algorithm Proof

In this section, we present the detailed proofs for the formula in Algorithm 1. Note that the symbols follow the definition in the main paper.

Proof 1

Given one vertex as , the Eq. 8 proves the backward process of in Algorithm 1.

Proof 2

Given one edge with a pair of connected vertices and , the Eq. 9-11 proves the backward process of in Algorithm 1, where indicates the set of children of vertex in the tree whose root is and indicates the set of children of vertex in the tree whose root is .


The computational complexity of the component can be reduce to linear by the dynamic programming procedure below.


The same procedure can be easily adapted to obtain the component .