1 Introduction

Gradient descent is an iterative algorithm to perform function optimization. Also, it is by far one of the most popular and common methods used in the training of neural networks. The gradient descent (GD) algorithm has three variants: Vanilla gradient descent (a.k.a Batch gradient descent), Stochastic gradient descent (SGD), and Mini-batch gradient descent. Which differs in the amount of training data used to compute the gradients (see [9] for further details). However, the GD algorithm and all its variants may present slow convergence time or heavy oscillations in the cost function [10]. As a result, there have been many proposals to improve the conventional gradient descent algorithms.

Momentum [8] is one of these methods, which accelerate the SGD algorithm in a relevant direction, even though its weakness is to show the behavior of a blind-rolling ball down the hill. Then AdaGrad [2] was designed to solve the blind-rolling problem by using non-constant learning rates. Unfortunately, it presents a radically diminishing learning rate, in which point the algorithm is no longer able to keep learning. After that, Adadelta [12] and RMSProp [11], both introduced the concept of adaptive learning rates, solving AdaGrad’s problem. On the other hand, Adam [4] (described by its authors as a combination of AdaGrad and RMSProp) is one of the most popular optimizers in nowadays neural network frameworks like [1, 7]. The Adam algorithm is commonly used because it presents high performance, is straightforward to implement, works well with sparse gradients and in online and non-stationary settings, and also it is very robust to the rescaling of the gradient. The above makes Adam the best choice to solve problems with non-stationary objectives, very noise gradients, and with large data inputs.

The methods mentioned early (see [13] for a detailed introduction) have shown great empirical results. However, we propose to enhance Adam algorithm using the Kalman filter [3] because we can obtain significant variations by using the estimated gradients instead of the computed ones. This change may help to explore and reach better solutions on the cost function, like other works have done by adding Gaussian noise to the gradients [6]. Hence, in this paper, it is introduced the KAdam algorithm, an extension of Adam using the Kalman filter.

The structure of this paper is described next. First, Sect. 2 describes the Adam algorithm. Then, Sect. 3 provides a brief introduction to the Kalman filter. After that, Sect. 4 describes the KAdam algorithm. Subsequently, Sect. 5 shows different carried out experiments and performance comparisons between the proposed method and other gradient-based optimizers. Finally, Sect. 6 shares conclusions from the authors and their future work.

2 Adam

The first step on the Adaptive Moment Estimation (Adam) algorithm [4], is to save the exponentially decaying averages of past gradients (first moment) and past squared gradients (second moment). This is done by computing the first moment estimate \(\upsilon _t\) (the mean) and the second moment estimate \(m_t\) (the uncentered variance) in the following equations:

$$\begin{aligned} m_t = \beta _1 m_{t-1} + (1 - \beta _1)g_t \end{aligned}$$
(1)
$$\begin{aligned} \upsilon _t = \beta _2 \upsilon _{t-1} + (1 - \beta _2)g_t^2 \end{aligned}$$
(2)

where \(\beta _1\) and \(\beta _2\) are the decay rates for the first and second moment (which the authors of Adam suggest to be set to 0.9 and 0.999 respectively), \(g_t \in \mathbb {R}^n\) is the computed gradient of the cost function and \(g_t^2\) is the squared (element-wise) gradients.

As \(\upsilon _t\) and \(m_t\) are initialized as zero vectors, the authors of Adam observe that they are biased towards zero, especially during the initial time steps and when the decay rates are small. Thus, they counteract these biases by computing bias-corrected first and second moment estimates.

$$\begin{aligned} \hat{m_t} = \dfrac{m_t}{1-\beta _1^t} \end{aligned}$$
(3)
$$\begin{aligned} \hat{\upsilon _t} = \dfrac{\upsilon _t}{1-\beta _2^t} \end{aligned}$$
(4)

Finally, the parameters update rule is given by:

$$\begin{aligned} \theta _{t+1} = \theta _t - \dfrac{\eta }{\sqrt{\hat{\upsilon _t}}+\epsilon } \hat{m_t} \end{aligned}$$
(5)

where, \(\eta \) is the learning rate and \(\epsilon \) is the smooth term (used to ensure algorithmic stability), which the authors of Adam suggested to be set to 0.001 and a value on the order \(10 \times 10^{-10}\) respectively.

3 Kalman Filter

The Kalman filter [3] is a recursive state estimator for linear systems. The algorithm consist in a group of equations that works in a two-steps process: prediction and update. The prediction phase is described by the following equations.

$$\begin{aligned} \hat{\mathbf {x}}_{k|k-1} = \mathbf {F}_k \hat{\mathbf {x}}_{k-1|k-1} + \mathbf {B}_k \mathbf {u}_k \end{aligned}$$
(6)
$$\begin{aligned} \mathbf {P}_{k|k-1} = \mathbf {F}_k \mathbf {P}_{k-1|k-1} \mathbf {F}_k^\mathsf {T} + \mathbf {Q}_k \end{aligned}$$
(7)

These equations gives a prediction of the state estimate and the covariance error but based only on information from the previous time step. In Eq. (6), the Kalman filter computes an a priori state estimate \(\hat{\mathbf {x}}_{k|k-1}\) where, \(\hat{\mathbf {x}}_{k-1|k-1}\) is the past predicted state, \(\mathbf {F}_k\) is the state transition model and \(\mathbf {B}_k\) is the control-input model with its respective input vector \(\mathbf {u}_k\). In Eq. (7), the predicted a priori error covariance \(\mathbf {P}_{k|k-1}\) is computed, where, \(\mathbf {P}_{k-1|k-1}\) is the previous covariance error, and \(\mathbf {Q}_{k}\) is the covariance of the process noise. On the other hand, the update phase is described by the following equations.

$$\begin{aligned} \mathbf {K}_k = \mathbf {P}_{k|k-1} \mathbf {H}_k^\mathsf {T} \ (\mathbf {H}_k \mathbf {P}_{k|k-1} \mathbf {H}_k^\mathsf {T} + \mathbf {R}_k)^{-1} \end{aligned}$$
(8)
$$\begin{aligned} \hat{\mathbf {x}}_{k|k} = \hat{\mathbf {x}}_{k|k-1} + \mathbf {K}_k (\mathbf {z}_k - \mathbf {H}_k \mathbf {\hat{x}}_{k|k-1}) \end{aligned}$$
(9)
$$\begin{aligned} \mathbf {P}_{k|k} = (\mathbf {I} - \mathbf {K}_k \mathbf {H}_k) \mathbf {P}_{k|k-1} \end{aligned}$$
(10)

These equations gives an updated prediction of the state estimate and the covariance error, computed with a correction based on observed information and measurements \(\mathbf {z}_k\) from the true state in the current time step. In Eq. (8) the optimal Kalman gain matrix \(\mathbf {K}_k\) is computed, where, \(\mathbf {H}_k\) is the measuring matrix and \(\mathbf {R}_k\) is the covariance of the observation noise. In Eqs. (9) and (10) the Kalman filter computes an updated (a posteriori) state estimate \(\hat{\mathbf {x}}_{k|k}\) and an updated (a posteriori) estimated covariance \(\mathbf {P}_{k|k}\), respectively.

4 KAdam

The KAdam algorithm uses a Kalman filter to estimate the gradients of the cost function. Considering the dynamics of the gradients as unknown, the matrices \(\mathbf {F}_k\), \(\mathbf {H}_k\), \(\mathbf {Q}_k\) and \(\mathbf {R}_k\) are used as identities and the state vector \(\hat{\mathbf {x}}_{k|k}\) initialized as a zero vector, with adequates dimensions according to the gradients vector. Moreover, the gradients \(g_t\) of the cost function are used as the measurements \(\mathbf {z}_k\) from the true state vector in the Kalman filter. Thus, the estimated gradients \(\hat{g}_t\) can be written as the post-fit measurements \(\mathbf {H}_k\mathbf {\hat{x}_{k|k}}\) from the filter. The steps to calculate the estimated gradients \(\hat{g}_t\) with the Kalman filter are summarized as a function \(K(\bullet )\).

$$\begin{aligned} \hat{g}_t = K(g_t) \end{aligned}$$
(11)

Hence, the equations to calculate the first and second moment are the following:

$$\begin{aligned} m_t = \beta _1 m_{t-1} + (1 - \beta _1) \hat{g_t} \end{aligned}$$
(12)
$$\begin{aligned} \upsilon _t = \beta _2 \upsilon _{t-1} + (1 - \beta _2) \hat{g}_t^2 \end{aligned}$$
(13)

The original equations from Adam to compute the bias-correction of the moments (see Eqs. (3) and (4)) and the update rule (see Eq. (5)) were not modified.

5 Experiments

To empirically evaluate the accuracy and efficiency of the proposal, two experiments (with two types trainings for each one) were carried out using feed-forward neural networks to solve some of the most popular benchmark problems in machine learning. For each experiment, there is a comparison between the proposed algorithm and the following algorithms: GD, Momentum, RMSProp, and Adam. We also include the stochastic and the batch experimentation, where stochastic implies that for every patron in the training set, we adapt the parameters of the model, while in the batch training the full training set is used to calculate one adaptation of the parameters. The comparison criterion is the cost reduction using the mean squared error (MSE) through the training phase and the test phase.

In the experiments, each neural network was configured with the same architecture (experimentally selected) and the same weights initialization. The settings for the hyper-parameters used in the experiments are listed in Table 1 except the learning-rate, which is \(\eta =0.01\) for all the experiments.

Table 1. Algorithms hyperparameters.

5.1 Experiment: Moons

The experiment deals with the classification problem of two interleaving half circles, using a dataset with 12, 000 samples (10, 000 for training and 2, 000 for test) generated by a functionFootnote 1 from the scikit-learn python package [7]. The architectures for the neural networks were fixed to: (10, 1) layers, with a \(\tanh (\bullet )\) function in the hidden layer and a sigmoid function for the output layer.

In the stochastic training, the parameters of the neural network are adapted with each patron. In Fig. 1, we show the error function in the training phase. Notice that GD and Momentum have different behavior than RMSprop, Adam, and KAdam due to the second-moment dynamics. The second moment allows the algorithms to speed-up in an early stage of the training, as is shown in the left image, where KAdam has the fastest descend. On the other hand, these algorithms show a noisy behavior in the long run, where stochastically some low cost can be achieved.

Fig. 1.
figure 1

Cost reduction comparison - Moons experiment (Stochastic training). Left: the first one hundred iterations, Right: Full experiment with \(10^6\) iterations

In Table 2, we show the results of this experiment, where Adam have the best performance. Notice that in the long run, all the algorithms have close results.

Table 2. Moons experiment - Stochastic Training results.

In the batch training, the algorithms performed the weights update using all the samples from the data-set for each iteration. Figure 2 shows how RMSprop, Adam, and KAdam do not present the stochastic behavior in the batch training. Moreover, the proposed method showed an improvement compared with Adam and the other presented methods. Table 3 shows the results of the experiment.

Fig. 2.
figure 2

Cost reduction comparison - Moons experiment (Batch training).

Table 3. Moons experiment - Batch training results.

5.2 Experiment: MNIST

This experiment deals with the MNIST classification problem. Before the training, the entire dataset was embedded into a 2D space (see Fig. 3) using a t-SNE [5] implementationFootnote 2.

The architectures for the neural networks were fixed to: (10, 10) layers, with a \(\tanh (\bullet )\) function in the hidden layer and a sigmoid function for the output layer.

Fig. 3.
figure 3

MNIST 2D visualization using the t-SNE implementation from scikit-learn.

Fig. 4.
figure 4

Cost reduction comparison - MNIST experiment (Stochastic training).

In Table 4, we show the results, where KAdam has the best result in the training phase.

Table 4. MNIST experiment - Stochastic Training results.
Fig. 5.
figure 5

Cost reduction comparison - MNIST experiment (Batch training).

Table 5. MNIST experiment - Batch training results.

In Fig. 4, we show the cost function stochastic training for the MNIST data-set. The KAdam algorithm has a performance comparable with RMSProp and Adam. The MNIST data-set has more noise than the moons experiment. Therefore the gradient-based algorithms in the stochastic training tend to oscillate around the local minimum.

We also present the batch training experiment for the MNIST data-set. In Fig. 5 and Table 5, We present the results for this experiment where Adam and KAdam have a tight competence and they overcome the other algorithms.

6 Conclusion

In this work, we presented a proposal to improve the performance of Adam optimizer. As we have shown, when the Kalman filter is used, the estimate gradients keep following the original ones but adding relevant enough variations, which allow exploring new and probably better solutions in the cost function.

We present two empirical results with two classical data-sets, the moons, and the MNIST, and with the stochastic and batch training. We have shown that our approach presents an excellent performance in both the training phase and the testing phase. On the other hand, we think this algorithm opens the door to new developments in the research of better optimization algorithms for artificial neural networks. In our future works, we will explore deeper the impact of varying the Kalman parameters used to estimate the gradients.