Skip to content

Commit f02fdde

Browse files
committed
FIX: Update .dot() matrix-vector product to use maybe_uninit
Avoid Array::uninitialized and use maybe_uninit.
1 parent 262ef6f commit f02fdde

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

src/linalg/impl_linalg.rs

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2014-2016 bluss and ndarray developers.
1+
// Copyright 2014-2020 bluss and ndarray developers.
22
//
33
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -325,9 +325,9 @@ where
325325

326326
// Avoid initializing the memory in vec -- set it during iteration
327327
unsafe {
328-
let mut c = Array::uninitialized(m);
329-
general_mat_vec_mul(A::one(), self, rhs, A::zero(), &mut c);
330-
c
328+
let mut c = Array1::maybe_uninit(m);
329+
general_mat_vec_mul_impl(A::one(), self, rhs, A::zero(), c.raw_view_mut().cast::<A>());
330+
c.assume_init()
331331
}
332332
}
333333
}
@@ -598,6 +598,30 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
598598
S2: Data<Elem = A>,
599599
S3: DataMut<Elem = A>,
600600
A: LinalgScalar,
601+
{
602+
unsafe {
603+
general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
604+
}
605+
}
606+
607+
/// General matrix-vector multiplication
608+
///
609+
/// Use a raw view for the destination vector, so that it can be uninitalized.
610+
///
611+
/// ## Safety
612+
///
613+
/// The caller must ensure that the raw view is valid for writing.
614+
/// the destination may be uninitialized iff beta is zero.
615+
unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
616+
alpha: A,
617+
a: &ArrayBase<S1, Ix2>,
618+
x: &ArrayBase<S2, Ix1>,
619+
beta: A,
620+
y: RawArrayViewMut<A, Ix1>,
621+
) where
622+
S1: Data<Elem = A>,
623+
S2: Data<Elem = A>,
624+
A: LinalgScalar,
601625
{
602626
let ((m, k), k2) = (a.dim(), x.dim());
603627
let m2 = y.dim();
@@ -626,22 +650,20 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
626650
let x_stride = x.strides()[0] as blas_index;
627651
let y_stride = y.strides()[0] as blas_index;
628652

629-
unsafe {
630-
blas_sys::$gemv(
631-
layout,
632-
a_trans,
633-
m as blas_index, // m, rows of Op(a)
634-
k as blas_index, // n, cols of Op(a)
635-
cast_as(&alpha), // alpha
636-
a.ptr.as_ptr() as *const _, // a
637-
a_stride, // lda
638-
x.ptr.as_ptr() as *const _, // x
639-
x_stride,
640-
cast_as(&beta), // beta
641-
y.ptr.as_ptr() as *mut _, // x
642-
y_stride,
643-
);
644-
}
653+
blas_sys::$gemv(
654+
layout,
655+
a_trans,
656+
m as blas_index, // m, rows of Op(a)
657+
k as blas_index, // n, cols of Op(a)
658+
cast_as(&alpha), // alpha
659+
a.ptr.as_ptr() as *const _, // a
660+
a_stride, // lda
661+
x.ptr.as_ptr() as *const _, // x
662+
x_stride,
663+
cast_as(&beta), // beta
664+
y.ptr.as_ptr() as *mut _, // x
665+
y_stride,
666+
);
645667
return;
646668
}
647669
}
@@ -655,8 +677,9 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
655677
/* general */
656678

657679
if beta.is_zero() {
680+
// when beta is zero, c may be uninitialized
658681
Zip::from(a.outer_iter()).and(y).apply(|row, elt| {
659-
*elt = row.dot(x) * alpha;
682+
elt.write(row.dot(x) * alpha);
660683
});
661684
} else {
662685
Zip::from(a.outer_iter()).and(y).apply(|row, elt| {

0 commit comments

Comments
 (0)