Keywords

1 Introduction

Segmentation is one of the most important tasks in medical image analysis. Recent years, with the help of deep learning, there are many inspiring progresses are made in this field. However, most medical images are of high resolution and cannot be directly processed by mainstream graphics cards. Thus, many previous works are 2D networks, which only focus on the segmentation of one single slice at a time [4, 7, 8, 18, 20]. Nevertheless, such methods ignore the valuable information along the z-axis, which limits the improvement of model’s performance. To better capture all the information along the three dimensions, many algorithms such as 2D multiple views [23, 24] and 2.5D [19, 28] are developed, alleviating the problem to some extent. However, these methods still mainly use 2D convolution to extract features and cannot capture the overall 3D spatial information. Therefore, to thoroughly solve this problem, the better way is to use the intuitive 3D convolution [5, 10, 11, 14, 17]. Since training 3D segmentation models needs more computational cost, patch-sampling, that each time only crop a small part from the original medical image as the model’s input, becomes a necessity [15, 21, 25].

Fig. 1.
figure 1

(a) The Dice Similarity Coefficient, Jaccard Coefficient and (b) 95% Hausdorff Distance of different patch size using 3D ResUNet on two datasets. As is shown, a bigger patch size with more context tends to have a better performance.

Though widely used, patch-sampling also has some flaws. Firstly, the patch-based methods ignore the global context information, which is important for accurate segmentation [27]. As is shown in Fig. 1, our experiments illustrate that the size of a patch can greatly affect the model’s performance, and because bigger patches contain more context, they can usually achieve higher accuracy; Secondly, if the network is trained with patches, it also have to use patches (such as sliding window strategy) in inference stage, which may not only severely decrease the efficiency, but also reduce the accuracy due to inconsistencies introduced during the fusion process that takes place in areas where patches overlap [9].

To solve these problems, we need to design a patch-free segmentation method with moderate computational budgets. Motivated by SR technique, which can recover HR images from LR inputs, we concrete our idea by lowering the resolution of the input image. We propose a novel 3D patch-free segmentation method which can realize HR segmentation with LR input (i.e. the down-sampled 3D image). We call this kind of tasks as Upsample-Segmentation (US) tasks. Inspired by [22], we use SR as an auxiliary task for the US task to restore the HR details lost in the down-sampling procedure. In addition, we introduce a Self-Supervised Guidance Module (SGM), which uses a patch of the original HR image as the subsidiary guidance input. High-frequency features can be directly extracted from it, and then be concatenated with the features from the down-sampled input image. To further improve the model’s performance, we also propose a novel Task-Fusion Module (TFM) to exploit the connections between the US and SR tasks. It should be noted that TFM as well as the auxiliary SR branch are only used in training phase. They do not introduce extra computational costs when predicting.

Our contributions can be mainly concluded into three points: (1) We propose a patch-free 3D medical image segmentation method, which can realize HR segmentation with LR input. (2) We propose a Self-Supervised Guidance Module (SGM), which uses the original HR image patch as guidance, to keep the high-frequency representations for accurate segmentation. (3) We further design a Task-Fusion Module (TFM) to exploit the inter connections between the US and SR tasks, by which the two tasks can be optimized jointly.

2 Methodology

The proposed method is shown in Fig. 2. For a given image, we first down-sample it by 2\(\times \) to lower the resolution, and then use it as the framework’s main input. The encoder will process it into shared features which can be used for both SR and US tasks. In addition, we also crop a patch from the original HR image as guidance, using the features extracted from it to provide the network with more high-frequency information. In training phase, outputs of the US task and SR task will also be sent into Task-Fusion Module (TFM), where the two tasks are fused together to help each other optimize. Note that in the testing phase, only the main segmentation network is used for 3D segmentation, with no extra computational cost.

Fig. 2.
figure 2

A schematic view of the proposed framework.

2.1 Multi-task Learning

Multi-task Learning is the foundation of our framework, including a US task a SR task. Ground truth for US is the labeled segmentation mask of original high resolution, while that of SR is the HR image itself. Since our goal is about generating an accurate segmentation mask, here we will treat US as the main task and SR only as an auxiliary one, which can be removed in testing phase. The two branches are both designed on the basis of ResUNet3D [26] for better consistency, since they share an encoder and will be fused together afterwards. The details are shown in Fig. 3(a).

Fig. 3.
figure 3

Detailed introduction of the framework structure. (a) Multi-task learning as the foundation of the framework. (b) Residual Block used in the framework. (c) Network structure of Self-Supervised Guidance Module (SGM).

The loss functions for this part can be divided into segmentation loss \(L_{seg}\) and SR loss \(L_{sr}\) for each task. \(L_{seg}\) consists of Binary Cross Entropy (BCE) Loss and Dice Loss, while \(L_{sr}\) is a simple Mean Square Error (MSE) Loss.

2.2 Self-Supervised Guidance Module

To make proper use of the original HR images and further improve the framework’s performance, we propose the Self-Supervised Guidance Module (SGM). This module uses a typical patch cropped from the original HR image to extract some representative high-frequency features. Through the experiments, we found that simply cropping the central area performs even better than random cropping. Random cropping may cause instability since for every testing case the content of the guidance patch may vary a lot. In our experiment, the size of the guidance patch is set to be 1/64 of the original image.

SGM is designed according to the guidance patch size to make sure the features extracted from it can be correctly concatenated with the shared features. To avoid too much computational cost, SGM is built to be very concise, as is shown in Fig. 3(c). We also introduce a Self-Supervised Guidance Loss (SGL) to evaluate the distance between the guidance and its corresponding part of SR output. The loss function can be described as:

$$\begin{aligned} L_{sgl}=\frac{1}{N} \sum _{i=1}^{N}||SIG(i) \cdot SR(X\downarrow )_i-SIG(i)\cdot X_i||^2, \end{aligned}$$
(1)

where N refers to the total number of all voxels, and SIG(i) denotes the signal function that will output 1 if i-th voxel is in the cropping window and 0 otherwise. X and \(X\downarrow \) denote the original medical image and the one after down-sampling, while \(SR(\cdot )\) represents the SR output.

2.3 Task-Fusion Module

To better utilize the connections between US and SR, we design a Task-Fusion Module (TFM) combining the two tasks together to let them help each other. This module will first calculate the element-wise product of the two tasks’ outputs (the estimated HR mask and HR image), and then optimize it by two different streams. For the first stream, we propose a Target-Enhanced Loss (TEL), which calculates the average square Euclidean distance of target area voxels. It can be viewed as adding weight to the loss of segmentation target area. Thus, the US task will tend to segment more precisely, and the SR task will pay more attention on the target part. As to the second one, inspired by Spatial Attention Mechanism in [6], we propose a Spatial Similarity Loss (SSL) to make the internal differences between prediction voxels similar to that of the ground truth. SSL is calculated using Spatial Similarity Matrix, which mainly describes the pairwise relationship between voxels. For a \(D\times W\times H\times C\) image I (for medical images C usually equals 1), to compute its Spatial Similarity Matrix, first we need to reshape it into \(V\times C\), where \(V=D\times W\times H\). After that, by multiplying this matrix with its transpose, we can have the \(V\times V\) similarity matrix and calculate the loss of it with ground truth. The loss function for this module can be defined as follows.

$$\begin{aligned} L_{tfm}=L_{tel}+L_{ssl}, \end{aligned}$$
(2)
$$\begin{aligned} L_{tel}=\frac{1}{N} \sum _{i=1}^{N}||p_i \cdot SR(X\downarrow )_i-y_i\cdot X_i||^2, \end{aligned}$$
(3)
$$\begin{aligned} L_{ssl}=\frac{1}{D^2W^2H^2} \sum _{i=1}^{D\cdot W\cdot H}\sum _{j=1}^{D\cdot W\cdot H}||S_{ij}^{predict}-S_{ij}^{gt}||^2, \end{aligned}$$
(4)
$$\begin{aligned} S_{ij}=I_i\cdot I_j^\mathrm {T}, \end{aligned}$$
(5)

where \(p_i\) denotes the prediction of i-th voxel after binarization, and \(y_i\) represents its corresponding ground truth. \(S_{ij}\) refers to the correlation between i-th and j-th voxel of fusion image I, while \(I_i\) represents the i-th voxel of the image.

2.4 Overall Objective Function

The overall objective function L of the proposed framework is:

$$\begin{aligned} L=L_{seg}+\omega _{sr}L_{sr}+\omega _{tfm}L_{tfm}+\omega _{sgl}L_{sgl}, \end{aligned}$$
(6)

where \(\omega _{sr}\), \(\omega _{tfm}\) and \(\omega _{sgl}\) are hyper-parameters, and are all set to 0.5 by default. The whole objective function can be optimized end-to-end.

3 Experiments

3.1 Datasets

We used BRATS2020 dataset [2, 3, 16] and a privately-owned liver segmentation dataset in the experiment. BRATS2020 dataset contains a total number of 369 subjects, each with four-modality MRI images (T1, T2, T1ce and FLAIR) of size \(240\times 240\times 155\) and spacing \(1\times 1\times 1\) mm\(^3\). The ground truth includes masks of Tumor Core (TC), Enhanced Tumor (ET) and Whole Tumor (WT). For each image, we removed the edges without brain part by 24 voxels and resized the rest part to resolution \(192\times 192\times 128\). In our experiment, we used the down- sampled T2-weighted images as input, the original T2-weighted images as SR ground truth, and WT masks as US ground truth.

The privately-owned liver segmentation dataset contains 347 subjects. Each one has an MRI image and a segmentation ground truth labeled by experienced doctors. In our experiment, spacing of the images were all regulated to \(1.5\times 1.5\times 1.5\) mm\(^3\), and we then cropped the central \(192\times 192\times 128\) area. The cropped MRI image and its segmentation mask are used as ground truth of SR and US, while the input is the cropped image after down-sampling.

3.2 Implementation Details

We compared our framework with different patch-based 3D segmentation models. For those methods, we predict the test image using sliding window strategy with a stride of 48, 48 and 32 for x-axis, y-axis and z-axis, respectively. Besides, we also tested our method with other patch-free segmentation models (i.e., ResUNet3D\(\uparrow \) and HDResUNet). ResUNet3D\(\uparrow \) conducts ordinary segmentation with a down-sampled image, then enlarging the result by tricubic interpolation [13]; HDResUNet uses Holistic Decomposition [27] with ResUNet3D, and the down-shuffling factors of it are all set to 2.

We employed three quantitative evaluation indices in the experiments, which are Dice Similarity Coefficient, 95% Hausdorff Distance and Jaccard Coefficient. Dice and Jaccard mainly focus on the segmentation area. 95% Hausdorff pays more attention to the edges.

All the experiments run on a Nvidia GTX 1080Ti GPU with 11 GB video memory. For fair comparison, the input sizes were all set to \(96\times 96\times 64\), except HDResUNet, which uses the original HR image. Therefore, the patch size for patch-based methods and the input image size for patch-free methods are the same. For both datasets, we used 80% for training and the rest for testing. Data augmentation includes random cropping (only for patch-based methods), random flip, random rotation, and random shift. All the models are optimized by Adam [12] with the initial learning rate set to \(1e^{-4}\). The rate will be divided by 10 if the loss does not continuously reduce over 20 epochs, and the training phase will end when it reaches \(1e^{-7}\).

3.3 Ablation Study

We conduct an ablation study on BRATS2020 to investigate how the designed modules affect the framework’s performance. The framework is tested with two different backbones (i.e., UNet3D and ResUNet3D).

Table 1. Ablation study results on BRATS2020. HD95 refers to 95% Hausdorff Distance.

In Table 1, for both backbones, appending SR as the auxiliary task improves the segmentation performance, indicating that the framework successfully rebuild some high-frequency information with the help of SR. Moreover, the segmentation result after introducing TEL and SSL also proves the effectivity of TFM, showing that the inter connection between US and SR is very useful for joint optimization. At last, the increase of all the metrics after adding SGM demonstrate that the framework benefits from the self-supervised guidance for the high-frequency features it brings.

3.4 Experimental Results

The experimental results are summarized in Table 2. Our framework surpasses traditional 3D patch-based methods and also outperforms the other patch-free methods. Patch-free methods have the most obvious improvements in 95% Hausdorff Distance: with the global context, the model can more easily segment the target area as a whole, hence making the segmentation edges smoother and more accurate. Since our framework can directly output a complete segmentation mask at a time, it also has a faster inference speed than most of the other methods.

Table 2. Segmentation results on two datasets. HD95 refers to 95% Hausdorff Distance, and Time denotes the average inference time for each case.
Fig. 4.
figure 4

Typical segmentation results of the experiment. Case1 and Case2 are from BRATS2020, while Case3 is from the liver dataset. For the convenience of visualization, we only select one slice from every case.

Some typical segmentation results are listed in Fig. 4. As is shown, the patch-based results have many obvious flaws (labeled in red): in Case1, there is some segmentation noise. This problem mainly results from the limited context in patches. When conducting segmentation on the upper right corner patch, the model does not have the information of the real tumor area and it will be more likely to misdiagnose normal area as lesion. In Case2 and Case3, there are some failed segmentation in the corner area due to the padding technique. In [1], the authors pointed out that padding may result in artifacts on the edges of feature maps, and these artifacts may confuse the network. The problem of Case2 and Case3 is commonly seen when target area leaps over several patches. Under such circumstances, it is difficult for the patch-based models to correctly estimate the voxels on the edges, hence resulting in inconsistencies during the fusion process. Although patch-free segmentation can solve the above-mentioned problems, it may lead to significant performance degradation due to the loss of high-frequency information during down-scaling. In our method, we build a multi-task learning framework (US and SR) with two well-designed modules (TFM and SGM) to keep the HR representations, thus avoiding this issue. Therefore, our framework outperforms other existing patch-free methods.

4 Conclusion

In this work, we propose a novel framework for fast and accurate patch-free segmentation, which is capable of capturing global context while not introducing too much extra computational cost. We validate the framework’s performance on two datasets to demonstrate its effectiveness, and the result shows that it can efficiently generate better segmentation mask than other patch-based and patch-free methods.