Keywords

1 Introduction

Convolutional Neural Networks (CNNs) have, during the last few years, been used with great success on a variety of computer vision problems such as image classification [12] and object detection [8]. The capability of CNNs to learn high-level abstraction of data makes them well suited for the task of image classification. Following this development, there have been several successful attempts to extend CNN based methods to tasks done on the pixel level such as semantic segmentation [9, 16, 17].

A drawback of CNNs is that they do not have the ability to directly model statistical dependencies of output variables. Hence they cannot explicitly enforce smoothness constraints or encourage spatial consistency of the output, something that arguably is important for the task of semantic segmentation. To deal with this a Markov Random Field (MRF), or its variant Conditional Random Field (CRF), can be used as a refinement step. This was done by Chen et al. in [4] where they used CNNs to form the unary potential of the dense CRF model presented by Krähenbühl et al. in [11]. However, the CNN and the CRF models are trained separately in [4] meaning that the parameters of the CRF are learnt while holding the CNN weights fixed. In other words, the deep features are learnt disregarding statistical dependencies of the output variables. In reaction to this, several approaches for jointly training deep structured models, combining CNNs and CRFs, have recently been proposed [5, 14, 15, 24, 28]. In these approaches, as well as the one presented in this paper, the parameters of the CRF and the weights of the CNN can be trained jointly, enabling the possibility to learn deep image features taking dependencies of the output variables into account.

1.1 Contributions

What differentiates this paper from previous work done on learning deep structured models is mainly the joint learning algorithm. We apply a max-margin based learning framework inspired by [23]. This removes the need to calculate, or approximate, the partition function present in learning algorithms that try to maximize the log-likelihood. For instance, in [14, 28], the inference step is approximately solved using a few iterations of the mean-field algorithm or gradient descent, respectively. Similarly, in [5], sampling techniques are used to approximate the partition function. In our learning framework, we can use standard graph cut methods to perform optimal inference in the CRF model. We also show how the CNN weights can be trained to optimize the max-margin criterion via standard back-propagation. To our knowledge, we are the first to present a method for jointly training deep structured models with a max-margin objective for semantic segmentation.

Our experiments show that training deep structured models using our method gives better results than piecewise training where the CNN and CRF models are trained separately. This proves that training deep structured models jointly enables the model to learn deep features that take output dependencies into account which in turn gives better segmentations. We tested our method on the Weizmann Horse dataset [2] for proof of concept. In addition we applied it to two medical datasets, one for heart ventricle segmentation in ultrasound images and one for pericardium segmentation in CTA slices.

1.2 Related Work

The concept of deep structured models has been examined extensively in recent work. In [20] Ning et al. combine a CNN with an energy based model, similar to a MRF, for segmentation of cell nuclei and in [4] a dense CRF with unary potentials from a CNN is used to achieve state-of-the-art results on several semantic segmentation benchmarks.

Methods for jointly training these deep structured models have also received a lot of attention lately. In [24] Tompson et al. present a single learning framework unifying a novel ConvNet Part-Detector and an MRF inspired Spatial-Model achieving state-of-the-art performance on the task of human body pose recognition. Further, Chen et al. present a more general framework for joint learning of deep structured models that they apply to image tagging and word from image problems in [5]. Zheng et al. [28] show that the mean-field inference algorithm with Gaussian pairwise potentials from [11] can be modeled as a Recurrent Neural Networks. This enabled them to train their model within a standard deep learning framework using a log-likelihood loss. In [15], they formulated a CRF model with CNNs for estimating the unary and pairwise potentials via piecewise training.

In the field of medical image analysis, methods based on CNNs have also received an increased interest during the last few years with promising results [6, 19, 22]. Recently, more intricate deep learning approaches have been proposed. Ronneberger et al. [21] proposed the U-Net, a network based on the idea of “fully convolutional networks” [16]. A similar network structure was also proposed by Brosch et al. in [3]. However, to our knowledge, methods utilizing end-to-end training of deep structured models have yet to be presented for medical image segmentation tasks.

2 A Deep Conditional Random Field Model

The deep structured model proposed in this paper consists of a CNN coupled with a CRF. This setup allows the model to learn deep features while still taking dependencies in the output data into account. Denote the set of input instances by \(X = \{\varvec{x}^{(n)}\}_n\) and their corresponding labelings by \(Y = \{\varvec{y}^{(n)}\}_n\). The input and output instances are images indexed for each pixel by \(\varvec{x}^{(n)} = (x^{(n)}_1, \ldots , x^{(n)}_N)\) and \(\varvec{y}^{(n)} = (y^{(n)}_1, \ldots , y^{(n)}_N)\) respectively. We only consider the binary labeling case, hence \(y^{(n)}_i = \{0,1\}\). Our deep structured model is described by a CRF of the form

$$\begin{aligned} P(Y|X;\varvec{w},\varvec{\theta }) = \frac{1}{Z}e^{-\sum _n E(\varvec{y}^{(n)},\varvec{x}^{(n)};\varvec{w},\varvec{\theta })}, \end{aligned}$$
(1)

where \(\varvec{w}\) are the weights of the CRF, \(\varvec{\theta }\) are the weights of the CNN and Z is the partition function. The energy E considered decomposes over unary and pairwise terms according to the following form

$$\begin{aligned} E(\varvec{y},\varvec{x};\varvec{w},\varvec{\theta }) = \sum _{i \in \mathcal {V}} \varvec{\phi }_i(y_i,\varvec{x};\varvec{w},\varvec{\theta }) + \sum _{(i,j) \in \mathcal {E}} \varvec{\phi }_{ij}(y_i,y_j,\varvec{x};\varvec{w}), \end{aligned}$$
(2)

where \(\mathcal {V}\) is the set of nodes (i.e. pixels) and \(\mathcal {E}\) is the set of edges connecting neighbouring pixels.

The unary term of the energy E has the following form

$$\begin{aligned} \varvec{\phi }_i(y_i,\varvec{x};\varvec{w},\varvec{\theta }) = w_1 \log (\varPhi _i(y_i,\varvec{x};\varvec{\theta })), \end{aligned}$$
(3)

where \(\varPhi _i(y_i,\varvec{x};\varvec{\theta })\) denotes the output of the neural network for pixel i. There are no explicit requirements for the CNN except that it should output an estimate of the probability for each pixel being either foreground or background.

The pairwise term consists of two parts both penalizing two neighbouring pixels being labeled differently. The first part adds a constant cost while the other one adds a cost based on the contrast of the neighbouring pixels. If \(\mathbbm {1}_{y_i \ne y_j}\) denotes the indicator function equaling one if \(y_i \ne y_j\), the pairwise term has the following form

$$\begin{aligned} \varvec{\phi }_{ij}(y_i,y_j,\varvec{x};\varvec{w}) = \mathbbm {1}_{y_i \ne y_j} \left( w_2 + w_3 \; e^{- \frac{(x_i-x_j)^2}{2}}\right) . \end{aligned}$$
(4)

Note that, given these unary and pairwise terms, the energy is linear with respect to the weights \(\varvec{w}\).

2.1 Inference

Given an input instance \(\varvec{x}\), the inference problem equates to finding the maximum a posteriori labeling \(\varvec{y}^*\) given the model in (1). This is equivalent to finding a minimizer of the energy E in (2):

$$\begin{aligned} \varvec{y}^* = \mathop {\text {arg min}}\limits _{\varvec{y}}~E(\varvec{y},\varvec{x};\varvec{w},\varvec{\theta }). \end{aligned}$$
(5)

For our deep structured model the inference is done in two steps. Firstly, an estimation of the probability of each pixel being either foreground or background is computed by a forward pass of the CNN. Secondly, problem (5) is solved. We add the constraints \(w_i \ge 0\), \(i=2,3\) when learning the weights to make the energy E submodular. This means that graph cut algorithm can be used to efficiently find a global optimum [10].

2.2 Max-Margin Learning

There are two sets of learnable parameters, the weights of the CRF \(\varvec{w}\) and the weights of the CNN \(\varvec{\theta }\). The method of learning is based on an algorithm proposed by Szummer et al. [23] where the goal is to find a set of parameters \(\varvec{w},\varvec{\theta }\) such that

$$\begin{aligned} E(\varvec{y}^{(n)},\varvec{x}^{(n)};\varvec{w},\varvec{\theta }) \le E(\varvec{y},\varvec{x}^{(n)};\varvec{w},\varvec{\theta }) \quad \forall \varvec{y} \ne \varvec{y}^{(n)}, \end{aligned}$$
(6)

i.e. we want to learn a set of weights that assign the ground truth labeling an equal or lower energy than any other labeling. Since this problem might have multiple or no solutions we introduce a margin \(\zeta \) and try to maximize it according to

$$\begin{aligned} \begin{aligned}&\max _{\varvec{w}:|\varvec{w}| = 1} \; \; \; \; \zeta \\ s.t. \; \; \; \; E(\varvec{y},\varvec{x}^{(n)};\varvec{w},\varvec{\theta })&- E(\varvec{y}^{(n)},\varvec{x}^{(n)};\varvec{w},\varvec{\theta }) \ge \zeta \quad \forall \varvec{y} \ne \varvec{y}^{(n)}. \\ \end{aligned} \end{aligned}$$
(7)

Finding the set of parameters that provides the largest margin regularizes the problem and tends to give good generalization to unseen data. However, for the final objective we make a few changes suggested by Szummer et al. [23]. To start of, a slack variable for each training sample \(\xi _n\) is introduced to make the method more robust to noisy data. In addition, we use a rescaled margin, demanding a larger energy margin for labelings that differ a lot from the ground truth. Also, the program described in (7) includes an exponential amount of constraints which makes solving it intractable, we therefore perform the optimization over a much smaller set \(S^{(n)}\). These changes, given the variable transformation \(||\varvec{w}|| \leftarrow 1/\zeta \), give rise to the following problem

(8)

where N is the number of training samples, C is a hyperparameter regulating the slack penalty and \(\varDelta (\varvec{y}^{(n)},\varvec{y})\) is the Hamming loss \(\varDelta (\varvec{y}^{(n)},\varvec{y}) = \sum _i \delta (y_i^ {(n)},y_i)\).

The constraint set \(S^{(n)}\) is iteratively grown by adding labelings that violate the constraints in (8) the most. For each iteration, the weights are then updated to satisfy the new, larger constraint set. This weight update is repeated until the weights no longer change. The complete learning algorithm is summarized in Algorithm 1.

figure a

2.3 Back-Propagation of Error Derivatives

In this section, we show how the max-margin objective from the previous section can be optimized for our coupled CNN and CRF model. Our main goal during learning is to maximize the margin, or equivalently, minimize the objective \(\gamma \) as defined in (8). To be able to perform a gradient based weight update we need to calculate the derivative of this objective with respect to the weights of the network

$$\begin{aligned} \frac{\partial \gamma }{\partial \theta _j} = \sum _n \sum _i \frac{\partial \gamma }{\partial \varPhi _i} \frac{\partial \varPhi _i}{\partial \theta _j}, \end{aligned}$$
(9)

where the two sums are over the training instances, n, and the pixels, i. As previously, \(\varPhi _i\) is the output of the network. Given a well-defined network structure the term \(\frac{\partial \varPhi _i}{\partial \theta _j}\) can be easily calculated using standard back-propagation. Henceforth we will focus on calculating the term \(\frac{\partial \gamma }{\partial \varPhi _i}\). To simplify notation we will introduce \(z_i\) as the output of the network of pixel i being foreground, \(z_i = \varPhi _i(y_i = 1,\varvec{x};\varvec{\theta })\). We start of by expressing (8) on the following compact form

$$\begin{aligned} \begin{aligned} \gamma (\varvec{z})&= \min \limits _{\varvec{w},\varvec{\xi }} f (\varvec{w},\varvec{\xi }), \\ \text{ s.t. } \quad h_k(\varvec{w},\varvec{\xi },\varvec{z})&\le 0, \quad k=1,\ldots , M , \end{aligned} \end{aligned}$$
(10)

where f is the objective function, \(h_k\) characterize the constraints and M is the total number of constraints. We will treat \(\gamma \) as a function depending on the network output, \(\gamma (\varvec{z})\).

In addition, the minimizers \(\varvec{w}^*\) and \(\varvec{\xi }^*\) can also be seen as functions of \(\varvec{z}\), that is, \(\varvec{w}^* = \varvec{w}^*(\varvec{z})\) and \(\varvec{\xi }^* = \varvec{\xi }^*(\varvec{z})\), which gives that

$$\begin{aligned} \begin{aligned} \gamma (\varvec{z})&= f(\varvec{w}^*(\varvec{z}),\varvec{\xi }^*(\varvec{z})) = \frac{1}{2} \Vert \varvec{w}^*\Vert ^2 + \frac{C}{N} \sum _{n=1}^N \xi ^*_n, \\ \frac{\partial \gamma }{\partial z_i}&= \sum _{j=1}^D f_{w_j} \frac{\partial w_j}{\partial z_i} + \sum _{n=1}^N f_{\xi _n} \frac{\partial \xi _n}{\partial z_i} = \sum _{j=1}^D w_j \frac{\partial w_j}{\partial z_i} + \frac{C}{N}\sum _{n=1}^N \frac{\partial \xi _n}{\partial z_i}, \end{aligned} \end{aligned}$$
(11)

where D is the number of weights and N is the number of slack variables. To be able to calculate \(\frac{\partial \gamma }{\partial z_i}\) we need \(\frac{\partial w_j}{\partial z_i}\) and \(\frac{\partial \xi _n}{\partial z_i}\). These derivatives are found by creating and solving a system of equations from the optimality conditions of the problem. The Lagrangian for the constrained minimization problem in (10) is

$$\begin{aligned} L(\varvec{w},\varvec{\xi },\varvec{\lambda }) = f(\varvec{w},\varvec{\xi }) + \sum _{k=1}^{M} \lambda _k h_k(\varvec{w},\varvec{\xi }), \end{aligned}$$

where \(\varvec{\lambda }\) is the vector of Langrangian multipliers with elements \(\lambda _k\). At optimum, the first-order optimality conditions are satisfied:

$$\begin{aligned} \nabla _{\varvec{w}} L = \varvec{w} + \sum _{k=1}^{M} \lambda _k \nabla _{\varvec{w}} h_k = \varvec{0} \text{ and } \nabla _{\varvec{\xi }} L = \frac{C}{N} + \sum _{k=1}^{M} \lambda _k \nabla _{\varvec{\xi }} h_k = \varvec{0}. \end{aligned}$$
(12)

Now, the conditions for the implicit function theorem are satisfied and we also get that

$$\begin{aligned} \frac{\partial ( \nabla _{\varvec{w}}L)}{\partial z_i}&= \frac{\partial \varvec{w}}{\partial z_i} + \sum _{k=1}^{M} \left( \frac{\partial \lambda _k}{\partial z_i} \nabla _{\varvec{w}}h_k +\lambda _k \frac{\partial \nabla _{\varvec{w}} h_k}{\partial z_i} \right) = \varvec{0}, \end{aligned}$$
(13)
$$\begin{aligned} \frac{\partial ( \nabla _{\varvec{\xi }}L)}{\partial z_i}&= \sum _{k=1}^{M} \left( \frac{\partial \lambda _k}{\partial z_i} \nabla _{\varvec{\xi }}h_k +\lambda _k \frac{\partial \nabla _{\varvec{\xi }} h_k}{\partial z_i} \right) = \varvec{0}. \end{aligned}$$
(14)

Note that \(\lambda _k\) is a function of \(\varvec{z}\). For the active constraints, where \(h_k = 0\), it holds that \(\frac{\partial h_k}{\partial z_i} = 0\). For the passive constraints, \(h_k < 0\), we use the following identities:

$$\begin{aligned} \lambda _k = 0 \;\;\;\;\;\;\;\;\text {and} \;\;\;\;\;\;\;\; \frac{\lambda h_k}{\partial z_i} = 0. \end{aligned}$$
(15)

The equations in (12) to (15) give a linear system of equations with the unknowns \(\frac{\partial w_j}{\partial z_i}, \frac{\partial \xi _n}{\partial z_i},\lambda _k\) and \(\frac{\partial \lambda _k}{\partial z_i}\). Solving this enables us to calculate \(\frac{\partial \gamma }{\partial z_i}\) from (11) and finally \(\frac{\partial \gamma }{\partial \theta _j}\) according to (9). Having this derivative makes it possible to learn CNN weights that optimize the max-margin objective formulated in (8) using gradient based methods. For more details, see the supplementary material.

2.4 End-to-End Training in Batches

We have now derived all the theoretical tools needed to train our deep structured model in an end-to-end manner. The joint training is done in epochs, where all training samples are utilized in each epoch. In every training epoch, new CRF weights are computed and the CNN weights are updated using gradient descent: \(\theta _j \leftarrow \theta _j + \eta \frac{\partial \gamma }{\partial \theta _j}\) for all j.

To facilitate the process of learning deep image features for the CNN we first pretrain the weights \(\varvec{\theta }\) on the dataset without the CRF part of the model. Note that the CNN we used is based on a network pretrained on the ImageNet dataset [7]. The pretraining is done using stochastic gradient descent with a standard pixelwise log-likelihood error function.

The original learning method involves the entire training set when computing the CRF weights. However, since the linear equation system that needs to be solved grows with the number of training instances the learning process quickly becomes impractical with an increasing number of images. Hence we propose a method to compute the derivatives in batches. In batch mode we apply the CRF learning method from Algorithm 1 for each batch separately, We also calculate \(\frac{\partial \gamma _b}{\partial \theta _j}\) following the steps described in Sect. 2.3. Note that the objective \(\gamma _b\) that we actually minimize here is an approximation of the true objective since not all images are included. For each batch, the constraint set \(S_b^{(n)}\) is saved. These are, at the end of the epoch, merged to a set \(S^{(n)}\) containing the low-energy labelings for all training instances. Finally the optimization problem in (8) is solved with this \(S^{(n)}\) to get the CRF weights. When solving for the CRF weights we also get the current value of our objective \(\gamma \), which obviously should decrease during training. The algorithm is summarized in Algorithm 2.

figure b

3 Experiments and Results

Now, we present the performance of our method on three different segmentation tasks including comparisons to two baselines. For the first baseline, “CNN (only)”, the segmentation is created by thresholding the output of a pretrained CNN. For the second baseline, “CNN + CRF (piecewise)”, a CNN coupled with a CRF is trained in a piecewise manner, meaning that the network weights are kept fixed while learning the CRF weights. The results for the joint learning is denoted “CNN + CRF (joint)”. For all experiments the CNN had the same structure as the FCN-8 network introduced by Long et al. [16]. The parameter settings were the same for all three segmentation tasks (learning rate = \(10^{-4}\), batch size = 10 and \(C = 1\)). All routines for training and testing were implemented in Matlab on top of MatConvNet [25].

3.1 Weizmann Horse Dataset

The Weizmann Horse dataset [2] is widely used for benchmarking object segmentation algorithms. The dataset contains 328 images of horses in different environments, we divide these images into a training set of 150 images, a validation set of 50 images and a test set of 128 images.

Our algorithm is compared to the, to our knowledge, best previously published results on the data set; Reseg [26], CRF-Grad [14] and PatchCut [27]. There are a few variations of the Weizmann Horse dataset available, we used the same one as in PatchCut [27]. Our algorithm is also compared to the two baselines, “CNN (only)” and “CNN + CRF (piecewise)”. Quantative results (mean Jaccard index) are shown in Table 1 for the test images. In Fig. 1 some qualitative results are presented.

Table 1. Mean Jaccard index for the Weizmann Horse dataset (test set).
Fig. 1.
figure 1

Qualitative results on the Weizmann Horse dataset. “Piecewise” denotes “CNN + CRF (piecewise)” and “Joint” denotes “CNN + CRF (joint)”. The number shown in the upper right corner is the Jaccard index (%).

3.2 Cardiac Ultrasound Dataset

The second dataset we consider consists of 2D cardiac ultrasound images (2-chamber view, i.e. the left artrium and the left ventricle are visible). The ground truth consists of manual annotations of the left ventricle made by an experienced cardiologist according to the protocol in [13]. The dataset contains 66 images which are divided into a training set of 33 images, a validation set of 17 images and a test set of 16 images. See Fig. 2 and Table 2 for qualitative and quantitative results respectively.

Table 2. Quantitative results for the Cardiac Ultrasound dataset (US) and the Cardiac CTA dataset (CTA). For the CTA dataset, the different types of slices are evaluated separately (ax - axial, cor - coronal and sag - sagittal). The mean Jaccard index (%) for the test sets are reported.
Fig. 2.
figure 2

Qualitative results on the Cardiac ultrasound dataset (US) and the Cardiac CTA dataset (CTA). For the CTA dataset, the different types of slices are evaluated separately (ax - axial, cor - coronal and sag - sagittal). “Piecewise” denotes “CNN + CRF (piecewise)” and “Joint” denotes “CNN + CRF (joint)”. The red number shown in the upper right corner is the Jaccard index (%).

3.3 Cardiac CTA Dataset

The third dataset we consider consists of 2D slices of cardiac CTA volumes originating from the SCAPIS pilot study [1]. The ground truth consists of slice-wise manual annotations of the pericardium made by a specialist in thoracic radiology and according to the protocol in [18]. The dataset includes in total 1500 2D slices which are divided into three subsets of equal size to be evaluated separately representing three different views (i.e. axial, coronal and sagittal view). For each view the 2D slices were divided into a training set of 300 images, a validation set of 100 images and a test set of 100 images. Some of the 2D slices originate from regions where the pericardium is not visible. Thus, these images were excluded from the quantitative results since the Jaccard index is undefined if the ground truth and segmentation are both empty sets. Some qualitative results of the joint training process are visualized in Fig. 2 and quantitative results are presented in Table 2.

4 Conclusion and Future Work

In this paper, we have proposed a segmentation algorithm based on a deep structured model consisting of a CNN paired with a CRF. We also presented a method for jointly learning the parameters of the CNN and the CRF using a max-margin approach. Conveniently, the max-margin objective could be optimized with standard back-propagation thanks to the theoretical results derived in Sect. 2.3.

We achieve superior results on two smaller medical datasets when comparing to using a CNN only and using a CNN paired with a CRF trained separately. Note that the CNN we used is based on a network pretrained on the ImageNet dataset [7]. It has hence learnt image features for standard RGB images and for classification tasks, which of course makes it more challenging learning CNN weights well-adjusted for medical image segmentation. In spite of this, we still achieve good results on the two medical datasets. A future continuation of this work would be to combine the CRF with a CNN trained on a larger set of medical images. Also, implementing the framework for 3D would increase its usability when it comes to medical applications.

In addition, other types of CRFs could be used. The ones considered in this paper only include pairwise terms depending on neighbouring pixels. One possible extension would be to consider longer distance relationships or higher order energy terms. Also, the pairwise terms could be learned with a trainable CNN in the same way as for unary terms. A trainable regularization term would surely enable the model to learn even more sophisticated relationships for the output pixels.

We gratefully acknowledge funding from SSF (Semantic Mapping and Visual Navigation for Smart Robots) and VR (project no. 2016-04445).