Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions cpp/modmesh/linalg/kalman_filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,49 @@ class KalmanFilter
check_dimensions();
}

/**
* @brief Construct a Kalman filter with explicit covariance matrices.
*
* @details
* Creates a Kalman filter with the specified system matrices and covariance
* matrices. This overload is useful for examples or physical models whose
* process noise, measurement noise, or initial state uncertainty are not
* scaled identity matrices.
*
* @param x Initial state vector.
* @param f State transition matrix F.
* @param b Control matrix B (empty to disable control input u).
* @param h Measurement matrix H.
* @param q Process noise covariance Q.
* @param r Measurement noise covariance R.
* @param p Initial state covariance P.
* @param jitter Numerical stability jitter.
*/
KalmanFilter(
array_type const & x, // FIXME: NOLINT(modernize-pass-by-value)
array_type const & f, // FIXME: NOLINT(modernize-pass-by-value)
array_type const & b, // FIXME: NOLINT(modernize-pass-by-value)
array_type const & h, // FIXME: NOLINT(modernize-pass-by-value)
array_type const & q, // FIXME: NOLINT(modernize-pass-by-value)
array_type const & r, // FIXME: NOLINT(modernize-pass-by-value)
array_type const & p, // FIXME: NOLINT(modernize-pass-by-value)
Comment on lines +181 to +183
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To determine the covariance matrix at the beginning, I create new constructor in KalmanFilter.

real_type jitter)
: m_state_size(x.shape(0))
, m_measurement_size(h.shape(0))
, m_control_size((b.ndim() == 2) ? b.shape(1) : 0)
, m_f(f)
, m_q(q)
, m_h(h)
, m_r(r)
, m_p(p)
, m_b(b)
, m_x(x)
, m_i(array_type::eye(m_state_size))
, m_jitter(jitter)
{
check_dimensions();
}

array_type const & state() const { return m_x; }
array_type const & covariance() const { return m_p; }

Expand Down
38 changes: 38 additions & 0 deletions cpp/modmesh/linalg/pymod/wrap_kalman_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,44 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapKalmanFilter
py::arg("process_noise"),
py::arg("measurement_noise"),
py::arg("jitter") = static_cast<real_type>(1e-9))
.def(
py::init(
[](array_type const & x,
array_type const & f,
py::object const & b,
array_type const & h,
array_type const & process_noise_covariance,
array_type const & measurement_noise_covariance,
array_type const & covariance,
real_type jitter)
{
array_type b_array;
if (b.is_none())
{
b_array = array_type(small_vector<size_t>{x.shape(0), 0});
}
else
{
b_array = b.cast<array_type>();
}
return wrapped_type(
x,
f,
b_array,
h,
process_noise_covariance,
measurement_noise_covariance,
covariance,
jitter);
}),
py::arg("x"),
py::arg("f"),
py::arg("b") = py::none(),
py::arg("h"),
py::arg("q"),
py::arg("r"),
py::arg("p"),
py::arg("jitter") = static_cast<real_type>(1e-9))
.def_property_readonly(
"state",
&wrapped_type::state)
Expand Down
55 changes: 55 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,61 @@ def sa_from_np(arr: np.ndarray, cls):
raise ValueError("sa_from_np supports only 1D or 2D arrays")


class KalmanFilterRadarExampleTC(unittest.TestCase):

def test_kalmanfilter_net_radar_example(self):
# Reference: https://kalmanfilter.net/
dt = 5.0
x0 = np.array([10000.0, 200.0])
f = np.array([[1.0, dt],
[0.0, 1.0]])
h = np.eye(2)
p0 = np.array([[16.0, 0.0],
[0.0, 0.25]])
q = np.array([[6.25, 2.5],
[2.5, 1.0]])
r = np.array([[36.0, 0.0],
[0.0, 2.25]])
z1 = np.array([11020.0, 202.0])
Comment on lines +411 to +424
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following testcase could be found in the reference.


kf = mm.KalmanFilterFp64(
x=sa_from_np(x0, mm.SimpleArrayFloat64),
f=sa_from_np(f, mm.SimpleArrayFloat64),
h=sa_from_np(h, mm.SimpleArrayFloat64),
q=sa_from_np(q, mm.SimpleArrayFloat64),
r=sa_from_np(r, mm.SimpleArrayFloat64),
p=sa_from_np(p0, mm.SimpleArrayFloat64),
jitter=0.0,
)

kf.predict()

x_pred_expected = np.array([11000.0, 200.0])
p_pred_expected = np.array([[28.5, 3.75],
[3.75, 1.25]])

np.testing.assert_allclose(
kf.state.ndarray, x_pred_expected, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(
kf.covariance.ndarray, p_pred_expected, atol=1e-12, rtol=0.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the atol here different from the one on line 454 ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error at line 454 is much bigger than line 445, so I set 1e-8 there. I would try to align the absolute difference (atol).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you have increased the precision of the answer to reduce the error.


kf.update(sa_from_np(z1, mm.SimpleArrayFloat64))

x_update_expected = np.array([
11009.371124889283,
201.42604074402126,
])
p_update_expected = np.array([
[14.572187776793623, 1.4348981399468559],
[1.4348981399468559, 0.7074844995571303],
])

np.testing.assert_allclose(
kf.state.ndarray, x_update_expected, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(
kf.covariance.ndarray, p_update_expected, atol=1e-12, rtol=0.0)


class TestKnownIssues603(unittest.TestCase):

@unittest.expectedFailure
Expand Down
Loading