Skip to content

Commit 0341d82

Browse files
authored
Merge pull request #237 from mulimoen/bugfix/serde
Fix serde deserialisation
2 parents 7297305 + 61b3dd5 commit 0341d82

File tree

12 files changed

+287
-54
lines changed

12 files changed

+287
-54
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ members = [
6161
"suitesparse_bindings/suitesparse-src",
6262
"sprs-rand",
6363
"sprs-benches",
64+
"sprs-tests",
6465
]
6566

6667
[package.metadata.docs.rs]

sprs-tests/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[package]
2+
name = "sprs-tests"
3+
version = "0.1.0"
4+
authors = ["Magnus Ulimoen <[email protected]>"]
5+
edition = "2018"
6+
publish = false
7+
8+
[dev-dependencies]
9+
sprs = { path = "..", features = ["serde"], default-features = false }
10+
serde_json = "1.0.58"

sprs-tests/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# sprs-tests
2+
3+
Ancillary crate to test `sprs` by pulling in extra dependencies

sprs-tests/src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fn main() {}

sprs-tests/tests/tests.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
mod serde_tests {
2+
use sprs::*;
3+
#[test]
4+
fn valid_vectors() {
5+
let json_vec =
6+
r#"{ "dim": 100, "indices": [4, 6, 10], "data": [4, 1, 8] }"#;
7+
let _vec: CsVecI<u8, i32> = serde_json::from_str(&json_vec).unwrap();
8+
9+
let json_vec = r#"{ "dim": 200, "indices": [4, 6, 10, 120], "data": [4, 1, 8, 1] }"#;
10+
let _vec: CsVecI<i8, u16> = serde_json::from_str(&json_vec).unwrap();
11+
}
12+
13+
#[test]
14+
fn invalid_vectors() {
15+
// non-sorted indices
16+
let json_vec =
17+
r#"{ "dim": 100, "indices": [4, 6, 5], "data": [4, 1, 8] }"#;
18+
let e: Result<CsVecI<u8, i32>, _> = serde_json::from_str(&json_vec);
19+
assert!(e.is_err());
20+
21+
// max(indices) > dim
22+
let json_vec =
23+
r#"{ "dim": 2, "indices": [4, 6, 8], "data": [4, 1, 8] }"#;
24+
let e: Result<CsVecI<u8, i32>, _> = serde_json::from_str(&json_vec);
25+
assert!(e.is_err());
26+
27+
// indices.len != data.len
28+
let json_vec =
29+
r#"{ "dim": 100, "indices": [4, 6, 8, 10], "data": [4, 1, 8] }"#;
30+
let e: Result<CsVecI<u8, i32>, _> = serde_json::from_str(&json_vec);
31+
assert!(e.is_err());
32+
33+
// indice does not fit in datatype
34+
let json_vec =
35+
r#"{ "dim": 100000, "indices": [4, 6, 32768], "data": [4, 1, 8] }"#;
36+
let e: Result<CsVecI<u8, i16>, _> = serde_json::from_str(&json_vec);
37+
assert!(e.is_err());
38+
}
39+
40+
#[test]
41+
fn valid_matrices() {
42+
let json_mat = r#"{ "storage": "CSR", "ncols": 10, "nrows": 2, "indptr": [0, 2, 3], "indices": [4, 6, 9], "data": [4, 1, 8] }"#;
43+
let _mat: CsMatI<u8, i32, u16> =
44+
serde_json::from_str(&json_mat).unwrap();
45+
let _mat: CsMat<u8> = serde_json::from_str(&json_mat).unwrap();
46+
}
47+
48+
#[test]
49+
fn invalid_matrices() {
50+
// indices not sorted
51+
let json_mat = r#"{ "storage": "CSR", "ncols": 10, "nrows": 2, "indptr": [0, 3, 3], "indices": [4, 9, 6], "data": [4, 1, 8] }"#;
52+
let mat: Result<CsMatI<u8, i32, u16>, _> =
53+
serde_json::from_str(&json_mat);
54+
assert!(mat.is_err());
55+
56+
// data length != indices length
57+
let json_mat = r#"{ "storage": "CSR", "ncols": 10, "nrows": 2, "indptr": [0, 2, 3], "indices": [4, 9, 6], "data": [4, 1, 8, 10] }"#;
58+
let mat: Result<CsMatI<u8, i32, u16>, _> =
59+
serde_json::from_str(&json_mat);
60+
assert!(mat.is_err());
61+
}
62+
63+
#[test]
64+
fn valid_indptr() {
65+
let indptr = r#"{ "storage": [0, 0, 1, 2, 2, 6] }"#;
66+
let indptr: IndPtr<usize> = serde_json::from_str(&indptr).unwrap();
67+
assert_eq!(indptr.raw_storage(), &[0, 0, 1, 2, 2, 6]);
68+
69+
let indptr = r#"{ "storage": [5, 5, 8, 9] }"#;
70+
let indptr: IndPtr<usize> = serde_json::from_str(&indptr).unwrap();
71+
assert_eq!(indptr.raw_storage(), &[5, 5, 8, 9]);
72+
}
73+
74+
#[test]
75+
fn invalid_indptr() {
76+
let indptr = r#"{ "storage": [0, 0, 1, 2, 2, 1] }"#;
77+
let indptr: Result<IndPtr<usize>, _> = serde_json::from_str(&indptr);
78+
assert!(indptr.is_err());
79+
let indptr = r#"{ "storage": [2, 1, 2, 2, 2, 7] }"#;
80+
let indptr: Result<IndPtr<usize>, _> = serde_json::from_str(&indptr);
81+
assert!(indptr.is_err());
82+
// Larger than permitted by i16
83+
let indptr = r#"{ "storage": [0, 32768] }"#;
84+
let indptr: Result<IndPtr<i16>, _> = serde_json::from_str(&indptr);
85+
assert!(indptr.is_err());
86+
}
87+
}

src/sparse.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use crate::IndPtrBase;
55
use std::ops::Deref;
66

77
#[cfg(feature = "serde")]
8-
use serde::{Deserialize, Serialize};
8+
mod serde_traits;
9+
#[cfg(feature = "serde")]
10+
use serde_traits::{CsMatBaseShadow, CsVecBaseShadow, Deserialize, Serialize};
911

1012
pub use self::csmat::CompressedStorage;
1113

@@ -83,6 +85,12 @@ pub use self::csmat::CompressedStorage;
8385
8486
#[derive(Eq, PartialEq, Debug, Copy, Clone, Hash)]
8587
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
88+
#[cfg_attr(
89+
feature = "serde",
90+
serde(
91+
try_from = "CsMatBaseShadow<N, I, IptrStorage, IndStorage, DataStorage, Iptr>"
92+
)
93+
)]
8694
pub struct CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr = I>
8795
where
8896
I: SpIndex,
@@ -150,19 +158,28 @@ pub type CsStructure = CsStructureI<usize>;
150158
151159
#[derive(Eq, PartialEq, Debug, Copy, Clone, Hash)]
152160
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
153-
pub struct CsVecBase<IStorage, DStorage> {
161+
#[cfg_attr(
162+
feature = "serde",
163+
serde(try_from = "CsVecBaseShadow<IStorage, DStorage, N, I>")
164+
)]
165+
pub struct CsVecBase<IStorage, DStorage, N, I: SpIndex = usize>
166+
where
167+
IStorage: Deref<Target = [I]>,
168+
DStorage: Deref<Target = [N]>,
169+
{
154170
dim: usize,
155171
indices: IStorage,
156172
data: DStorage,
157173
}
158174

159-
pub type CsVecI<N, I> = CsVecBase<Vec<I>, Vec<N>>;
160-
pub type CsVecViewI<'a, N, I> = CsVecBase<&'a [I], &'a [N]>;
161-
pub type CsVecViewMutI<'a, N, I> = CsVecBase<&'a [I], &'a mut [N]>;
175+
pub type CsVecI<N, I = usize> = CsVecBase<Vec<I>, Vec<N>, N, I>;
176+
pub type CsVecViewI<'a, N, I = usize> = CsVecBase<&'a [I], &'a [N], N, I>;
177+
pub type CsVecViewMutI<'a, N, I = usize> =
178+
CsVecBase<&'a [I], &'a mut [N], N, I>;
162179

163-
pub type CsVecView<'a, N> = CsVecViewI<'a, N, usize>;
164-
pub type CsVecViewMut<'a, N> = CsVecViewMutI<'a, N, usize>;
165-
pub type CsVec<N> = CsVecI<N, usize>;
180+
pub type CsVecView<'a, N> = CsVecViewI<'a, N>;
181+
pub type CsVecViewMut<'a, N> = CsVecViewMutI<'a, N>;
182+
pub type CsVec<N> = CsVecI<N>;
166183

167184
/// Sparse matrix in the triplet format.
168185
///

src/sparse/compressed.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub trait SpVecView<N, I: SpIndex> {
3939
}
4040

4141
impl<N, I, IndStorage, DataStorage> SpVecView<N, I>
42-
for CsVecBase<IndStorage, DataStorage>
42+
for CsVecBase<IndStorage, DataStorage, N, I>
4343
where
4444
IndStorage: Deref<Target = [I]>,
4545
DataStorage: Deref<Target = [N]>,

src/sparse/csmat.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ where
144144
IStorage: Deref<Target = [I]>,
145145
DStorage: Deref<Target = [N]>,
146146
{
147-
fn new_checked(
147+
pub(crate) fn new_checked(
148148
storage: CompressedStorage,
149149
shape: (usize, usize),
150150
indptr: IptrStorage,
@@ -530,7 +530,7 @@ impl<N, I: SpIndex, Iptr: SpIndex> CsMatI<N, I, Iptr> {
530530
}
531531

532532
/// Append an outer dim to an existing matrix, provided by a sparse vector
533-
pub fn append_outer_csvec(mut self, vec: CsVecBase<&[I], &[N]>) -> Self
533+
pub fn append_outer_csvec(mut self, vec: CsVecViewI<N, I>) -> Self
534534
where
535535
N: Clone,
536536
{

src/sparse/indptr.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
//!
44
//! [`CsMatBase`]: type.CsMatBase.html
55
6+
#[cfg(feature = "serde")]
7+
use super::serde_traits::IndPtrBaseShadow;
68
use crate::errors::SprsError;
79
use crate::indexing::SpIndex;
810
#[cfg(feature = "serde")]
@@ -12,6 +14,10 @@ use std::ops::{Deref, DerefMut};
1214

1315
#[derive(Eq, PartialEq, Debug, Copy, Clone, Hash)]
1416
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17+
#[cfg_attr(
18+
feature = "serde",
19+
serde(try_from = "IndPtrBaseShadow<Iptr, Storage>")
20+
)]
1521
pub struct IndPtrBase<Iptr, Storage>
1622
where
1723
Iptr: SpIndex,

src/sparse/serde_traits.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use super::*;
2+
pub(crate) use serde::{Deserialize, Serialize};
3+
use std::convert::TryFrom;
4+
5+
#[derive(Deserialize)]
6+
pub(crate) struct CsVecBaseShadow<IStorage, DStorage, N, I: SpIndex = usize>
7+
where
8+
IStorage: Deref<Target = [I]>,
9+
DStorage: Deref<Target = [N]>,
10+
{
11+
dim: usize,
12+
indices: IStorage,
13+
data: DStorage,
14+
}
15+
16+
impl<IStorage, DStorage, N, I: SpIndex>
17+
TryFrom<CsVecBaseShadow<IStorage, DStorage, N, I>>
18+
for CsVecBase<IStorage, DStorage, N, I>
19+
where
20+
IStorage: Deref<Target = [I]>,
21+
DStorage: Deref<Target = [N]>,
22+
{
23+
type Error = SprsError;
24+
fn try_from(
25+
val: CsVecBaseShadow<IStorage, DStorage, N, I>,
26+
) -> Result<Self, Self::Error> {
27+
let CsVecBaseShadow { dim, indices, data } = val;
28+
Self::new_(dim, indices, data).map_err(|(_, _, e)| e)
29+
}
30+
}
31+
32+
#[derive(Deserialize)]
33+
pub struct CsMatBaseShadow<N, I, IptrStorage, IndStorage, DataStorage, Iptr = I>
34+
where
35+
I: SpIndex,
36+
Iptr: SpIndex,
37+
IptrStorage: Deref<Target = [Iptr]>,
38+
IndStorage: Deref<Target = [I]>,
39+
DataStorage: Deref<Target = [N]>,
40+
{
41+
storage: CompressedStorage,
42+
nrows: usize,
43+
ncols: usize,
44+
indptr: IptrStorage,
45+
indices: IndStorage,
46+
data: DataStorage,
47+
}
48+
49+
impl<IptrStorage, IndStorage, DStorage, N, I: SpIndex, Iptr: SpIndex>
50+
TryFrom<CsMatBaseShadow<N, I, IptrStorage, IndStorage, DStorage, Iptr>>
51+
for CsMatBase<N, I, IptrStorage, IndStorage, DStorage, Iptr>
52+
where
53+
IndStorage: Deref<Target = [I]>,
54+
IptrStorage: Deref<Target = [Iptr]>,
55+
DStorage: Deref<Target = [N]>,
56+
{
57+
type Error = SprsError;
58+
fn try_from(
59+
val: CsMatBaseShadow<N, I, IptrStorage, IndStorage, DStorage, Iptr>,
60+
) -> Result<Self, Self::Error> {
61+
let CsMatBaseShadow {
62+
storage,
63+
nrows,
64+
ncols,
65+
indptr,
66+
indices,
67+
data,
68+
} = val;
69+
let shape = (nrows, ncols);
70+
Self::new_checked(storage, shape, indptr, indices, data)
71+
.map_err(|(_, _, _, e)| e)
72+
}
73+
}
74+
75+
#[derive(Deserialize)]
76+
pub struct IndPtrBaseShadow<Iptr, Storage>
77+
where
78+
Iptr: SpIndex,
79+
Storage: Deref<Target = [Iptr]>,
80+
{
81+
storage: Storage,
82+
}
83+
84+
impl<Iptr: SpIndex, Storage> TryFrom<IndPtrBaseShadow<Iptr, Storage>>
85+
for IndPtrBase<Iptr, Storage>
86+
where
87+
Storage: Deref<Target = [Iptr]>,
88+
{
89+
type Error = SprsError;
90+
fn try_from(
91+
val: IndPtrBaseShadow<Iptr, Storage>,
92+
) -> Result<Self, Self::Error> {
93+
let IndPtrBaseShadow { storage } = val;
94+
Self::new(storage)
95+
}
96+
}

0 commit comments

Comments
 (0)