diff --git a/src/matrix.rs b/src/dynamic_matrix.rs similarity index 99% rename from src/matrix.rs rename to src/dynamic_matrix.rs index 215be90..b499a05 100644 --- a/src/matrix.rs +++ b/src/dynamic_matrix.rs @@ -3,11 +3,11 @@ use std::ops::{Add, Deref, DerefMut, Div, Index, IndexMut, Mul, Sub}; use crate::axes::Axes; use crate::coord; use crate::coordinate::Coordinate; +use crate::dynamic_vector::DynamicVector; use crate::error::ShapeError; use crate::shape; use crate::shape::Shape; use crate::tensor::DynamicTensor; -use crate::vector::DynamicVector; use num::{Float, Num}; pub struct DynamicMatrix { diff --git a/src/vector.rs b/src/dynamic_vector.rs similarity index 99% rename from src/vector.rs rename to src/dynamic_vector.rs index ff2b939..ace4a7d 100644 --- a/src/vector.rs +++ b/src/dynamic_vector.rs @@ -1,8 +1,8 @@ use std::ops::{Add, Deref, DerefMut, Div, Index, IndexMut, Mul, Sub}; use crate::coord; +use crate::dynamic_matrix::DynamicMatrix; use crate::error::ShapeError; -use crate::matrix::DynamicMatrix; use crate::shape; use crate::shape::Shape; use crate::tensor::DynamicTensor; diff --git a/src/lib.rs b/src/lib.rs index 718f954..d03ddf2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ pub mod axes; pub mod coordinate; +pub mod dynamic_matrix; +pub mod dynamic_vector; pub mod error; pub mod iter; -pub mod matrix; pub mod shape; +pub mod static_matrix; +pub mod static_vector; pub mod storage; pub mod tensor; -pub mod vector; diff --git a/src/static_matrix.rs b/src/static_matrix.rs new file mode 100644 index 0000000..21baaa1 --- /dev/null +++ b/src/static_matrix.rs @@ -0,0 +1,673 @@ +use std::ops::{Add, Deref, DerefMut, Div, Index, IndexMut, Mul, Sub}; + +use crate::dynamic_vector::DynamicVector; +use crate::error::ShapeError; +use crate::shape; +use crate::static_vector::StaticVector; +use crate::tensor::Tensor; +use num::Float; +use num::Num; + +pub struct StaticMatrix { + data: [[T; N]; M], +} + +impl StaticMatrix { + pub fn new(data: [[T; N]; M]) -> Result, ShapeError> { + assert!(M > 0 && N > 0, "M and N must be greater than 0"); + if data.len() != M || data[0].len() != N { + return Err(ShapeError::new("Data dimensions must be equal to M and N")); + } + Ok(Self { data }) + } + + pub fn to_tensor(&self) -> Tensor { + Tensor::new(&shape![M, N].unwrap(), &self.data.concat()).unwrap() + } + + pub fn fill(value: T) -> StaticMatrix { + let data = [[value; N]; M]; + StaticMatrix::new(data).unwrap() + } + + pub fn ones() -> StaticMatrix { + StaticMatrix::fill(T::one()) + } + + pub fn zeros() -> StaticMatrix { + StaticMatrix::fill(T::zero()) + } + + pub fn vecmul(&self, rhs: &StaticVector) -> StaticVector { + let mut result = StaticVector::zeros(); + for i in 0..M { + for j in 0..N { + result[i] = result[i] + self.data[i][j] * rhs[j]; + } + } + result + } + + pub fn matmul(&self, rhs: &StaticMatrix) -> StaticMatrix { + let mut result = StaticMatrix::::zeros(); + for i in 0..M { + for j in 0..P { + for k in 0..N { + result.data[i][j] = result.data[i][j] + self.data[i][k] * rhs.data[k][j]; + } + } + } + result + } + + pub fn sum(&self, axis: Option) -> DynamicVector { + match axis { + Some(0) => { + let mut result = vec![T::zero(); N]; + for row in self.data.iter() { + for (j, &item) in row.iter().enumerate() { + result[j] = result[j] + item; + } + } + DynamicVector::new(&result).unwrap() + } + Some(1) => { + let mut result = vec![T::zero(); M]; + for (i, row) in self.data.iter().enumerate() { + for &item in row.iter() { + result[i] = result[i] + item; + } + } + DynamicVector::new(&result).unwrap() + } + None => { + let mut sum = T::zero(); + for row in self.data.iter() { + for &item in row.iter() { + sum = sum + item; + } + } + DynamicVector::new(&[sum]).unwrap() + } + _ => panic!("Axis out of bounds"), + } + } + + pub fn mean(&self, axis: Option) -> DynamicVector { + let sum = self.sum(axis); + match axis { + Some(0) => sum / self.size().0, + Some(1) => sum / self.size().1, + None => sum / (self.size().0 * self.size().1), + _ => panic!("Axis out of bounds"), + } + } + + pub fn variance(&self, axis: Option) -> DynamicVector { + let mean = self.mean(axis); + match axis { + Some(0) => { + let mut result = vec![T::zero(); N]; + for row in self.data.iter() { + for (j, &item) in row.iter().enumerate() { + let diff = item - mean[j]; + result[j] = result[j] + diff * diff; + } + } + DynamicVector::new(&result).unwrap() / self.size().0 + } + Some(1) => { + let mut result = vec![T::zero(); M]; + for (i, row) in self.data.iter().enumerate() { + for &item in row.iter() { + let diff = item - mean[i]; + result[i] = result[i] + diff * diff; + } + } + DynamicVector::new(&result).unwrap() / self.size().1 + } + None => { + let mut result = T::zero(); + for row in self.data.iter() { + for &item in row.iter() { + let diff = item - mean[0]; + result = result + diff * diff; + } + } + DynamicVector::new(&[result / (self.size().0 * self.size().1)]).unwrap() + } + _ => panic!("Axis out of bounds"), + } + } + + pub fn max(&self, axis: Option) -> DynamicVector { + let mut min_value = self.data[0][0]; + for row in self.data.iter() { + for &item in row.iter() { + if item < min_value { + min_value = item; + } + } + } + match axis { + Some(0) => { + let mut result = vec![min_value; N]; + for row in self.data.iter() { + for (j, &item) in row.iter().enumerate() { + if item > result[j] { + result[j] = item; + } + } + } + DynamicVector::new(&result).unwrap() + } + Some(1) => { + let mut result = vec![min_value; M]; + for (i, row) in self.data.iter().enumerate() { + for &item in row.iter() { + if item > result[i] { + result[i] = item; + } + } + } + DynamicVector::new(&result).unwrap() + } + None => { + let mut max_value = min_value; + for row in self.data.iter() { + for &item in row.iter() { + if item > max_value { + max_value = item; + } + } + } + DynamicVector::new(&[max_value]).unwrap() + } + _ => panic!("Axis out of bounds"), + } + } + + pub fn min(&self, axis: Option) -> DynamicVector { + let mut max_value = self.data[0][0]; + for row in self.data.iter() { + for &item in row.iter() { + if item > max_value { + max_value = item; + } + } + } + match axis { + Some(0) => { + let mut result = vec![max_value; N]; + for row in self.data.iter() { + for (j, &item) in row.iter().enumerate() { + if item < result[j] { + result[j] = item; + } + } + } + DynamicVector::new(&result).unwrap() + } + Some(1) => { + let mut result = vec![max_value; M]; + for (i, row) in self.data.iter().enumerate() { + for &item in row.iter() { + if item < result[i] { + result[i] = item; + } + } + } + DynamicVector::new(&result).unwrap() + } + None => { + let mut min_value = max_value; + for row in self.data.iter() { + for &item in row.iter() { + if item < min_value { + min_value = item; + } + } + } + DynamicVector::new(&[min_value]).unwrap() + } + _ => panic!("Axis out of bounds"), + } + } + + pub fn dims(&self) -> (usize, usize) { + (M, N) + } + + pub fn size(&self) -> (T, T) { + let mut n = T::zero(); + let mut m = T::zero(); + for _ in 0..M { + m = m + T::one(); + } + for _ in 0..N { + n = n + T::one(); + } + (m, n) + } +} + +impl StaticMatrix { + pub fn pow(&self, power: T) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item.powf(power); + } + } + StaticMatrix { data: result } + } +} + +// Scalar Addition +impl Add for StaticMatrix { + type Output = StaticMatrix; + + fn add(self, rhs: T) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item + rhs; + } + } + StaticMatrix { data: result } + } +} + +// Matrix Addition +impl Add> + for StaticMatrix +{ + type Output = StaticMatrix; + + fn add(self, rhs: StaticMatrix) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item + rhs.data[i][j]; + } + } + StaticMatrix { data: result } + } +} + +// Scalar Subtraction +impl Sub for StaticMatrix { + type Output = StaticMatrix; + + fn sub(self, rhs: T) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item - rhs; + } + } + StaticMatrix { data: result } + } +} + +// Matrix Subtraction +impl Sub> + for StaticMatrix +{ + type Output = StaticMatrix; + + fn sub(self, rhs: StaticMatrix) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item - rhs.data[i][j]; + } + } + StaticMatrix { data: result } + } +} + +// Scalar Multiplication +impl Mul for StaticMatrix { + type Output = StaticMatrix; + + fn mul(self, rhs: T) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item * rhs; + } + } + StaticMatrix { data: result } + } +} + +// Matrix Multiplication +impl Mul> + for StaticMatrix +{ + type Output = StaticMatrix; + + fn mul(self, rhs: StaticMatrix) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item * rhs.data[i][j]; + } + } + StaticMatrix { data: result } + } +} + +// Scalar Division +impl Div for StaticMatrix { + type Output = StaticMatrix; + + fn div(self, rhs: T) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item / rhs; + } + } + StaticMatrix { data: result } + } +} + +// Matrix Division +impl Div> + for StaticMatrix +{ + type Output = StaticMatrix; + + fn div(self, rhs: StaticMatrix) -> StaticMatrix { + let mut result = [[T::zero(); N]; M]; + for (i, row) in self.data.iter().enumerate() { + for (j, &item) in row.iter().enumerate() { + result[i][j] = item / rhs.data[i][j]; + } + } + StaticMatrix { data: result } + } +} + +impl Index<(usize, usize)> + for StaticMatrix +{ + type Output = T; + + fn index(&self, index: (usize, usize)) -> &Self::Output { + &self.data[index.0][index.1] + } +} + +impl IndexMut<(usize, usize)> + for StaticMatrix +{ + fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output { + &mut self.data[index.0][index.1] + } +} + +impl Deref for StaticMatrix { + type Target = [[T; N]; M]; + + fn deref(&self) -> &[[T; N]; M] { + &self.data + } +} + +impl DerefMut for StaticMatrix { + fn deref_mut(&mut self) -> &mut [[T; N]; M] { + &mut self.data + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() { + let data = [[1, 2], [3, 4]]; + let matrix = StaticMatrix::new(data).unwrap(); + assert_eq!(matrix.data, data); + } + + #[test] + #[should_panic(expected = "M and N must be greater than 0")] + fn test_new_failure_zero_length() { + let _ = StaticMatrix::::new([[]]).unwrap(); + } + + #[test] + fn test_fill() { + let matrix = StaticMatrix::fill(5); + assert_eq!(matrix.data, [[5; 2]; 2]); + } + + #[test] + fn test_ones() { + let matrix = StaticMatrix::::ones(); + assert_eq!(matrix.data, [[1; 2]; 2]); + } + + #[test] + fn test_zeros() { + let matrix = StaticMatrix::::zeros(); + assert_eq!(matrix.data, [[0; 2]; 2]); + } + + #[test] + fn test_vecmul() { + let matrix = StaticMatrix::new([[1, 2], [3, 4], [5, 6]]).unwrap(); + let vector = StaticVector::new([1, 2]).unwrap(); + let result = matrix.vecmul(&vector); + assert_eq!(*result, [5, 11, 17]); + } + + #[test] + fn test_matmul() { + let matrix1 = StaticMatrix::new([[1, 2, 3], [4, 5, 6]]).unwrap(); + let matrix2 = StaticMatrix::new([[7, 8], [9, 10], [11, 12]]).unwrap(); + let result = matrix1.matmul(&matrix2); + assert_eq!(result.data, [[58, 64], [139, 154]]); + } + + #[test] + fn test_sum() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let result = matrix.sum(None); + assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result[0], 10); + + let result = matrix.sum(Some(0)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 4); + assert_eq!(result[1], 6); + + let result = matrix.sum(Some(1)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 3); + assert_eq!(result[1], 7); + } + + #[test] + fn test_mean() { + let matrix = StaticMatrix::new([[1.0, 2.0], [3.0, 4.0]]).unwrap(); + let result = matrix.mean(None); + assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result[0], 2.5); + + let result = matrix.mean(Some(0)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 2.0); + assert_eq!(result[1], 3.0); + + let result = matrix.mean(Some(1)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 1.5); + assert_eq!(result[1], 3.5); + } + + #[test] + fn test_variance() { + let matrix = StaticMatrix::new([[1.0, 2.0], [3.0, 4.0]]).unwrap(); + let result = matrix.variance(None); + assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result[0], 1.25); + + let result = matrix.variance(Some(0)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 1.0); + assert_eq!(result[1], 1.0); + + let result = matrix.variance(Some(1)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 0.25); + assert_eq!(result[1], 0.25); + } + + #[test] + fn test_max() { + let matrix = StaticMatrix::new([[-1, -2], [-3, -4]]).unwrap(); + let result = matrix.max(None); + assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result[0], -1); + + let result = matrix.max(Some(0)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], -1); + assert_eq!(result[1], -2); + + let result = matrix.max(Some(1)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], -1); + assert_eq!(result[1], -3); + } + + #[test] + fn test_min() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let result = matrix.min(None); + assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result[0], 1); + + let result = matrix.min(Some(0)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 1); + assert_eq!(result[1], 2); + + let result = matrix.min(Some(1)); + assert_eq!(result.shape(), &shape![2].unwrap()); + assert_eq!(result[0], 1); + assert_eq!(result[1], 3); + } + + #[test] + fn test_dims() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let result = matrix.dims(); + assert_eq!(result, (2, 2)); + } + + #[test] + fn test_size() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let result = matrix.size(); + assert_eq!(result, (2, 2)); + } + + #[test] + fn test_add() { + let matrix1 = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let matrix2 = StaticMatrix::new([[5, 6], [7, 8]]).unwrap(); + let result = matrix1 + matrix2; + assert_eq!(result.data, [[6, 8], [10, 12]]); + } + + #[test] + fn test_add_scalar() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let result = matrix + 2; + assert_eq!(result.data, [[3, 4], [5, 6]]); + } + + #[test] + fn test_sub() { + let matrix1 = StaticMatrix::new([[5, 6], [7, 8]]).unwrap(); + let matrix2 = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let result = matrix1 - matrix2; + assert_eq!(result.data, [[4, 4], [4, 4]]); + } + + #[test] + fn test_sub_scalar() { + let matrix = StaticMatrix::new([[5, 6], [7, 8]]).unwrap(); + let result = matrix - 2; + assert_eq!(result.data, [[3, 4], [5, 6]]); + } + + #[test] + fn test_mul() { + let matrix1 = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let matrix2 = StaticMatrix::new([[2, 0], [1, 2]]).unwrap(); + let result = matrix1 * matrix2; + assert_eq!(result.data, [[2, 0], [3, 8]]); + } + + #[test] + fn test_mul_scalar() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + let result = matrix * 2; + assert_eq!(result.data, [[2, 4], [6, 8]]); + } + + #[test] + fn test_div() { + let matrix1 = StaticMatrix::new([[4, 8], [12, 16]]).unwrap(); + let matrix2 = StaticMatrix::new([[2, 2], [3, 4]]).unwrap(); + let result = matrix1 / matrix2; + assert_eq!(result.data, [[2, 4], [4, 4]]); + } + + #[test] + fn test_div_scalar() { + let matrix = StaticMatrix::new([[2, 4], [6, 8]]).unwrap(); + let result = matrix / 2; + assert_eq!(result.data, [[1, 2], [3, 4]]); + } + + #[test] + fn test_pow() { + let matrix = StaticMatrix::new([[1.0, 2.0], [3.0, 4.0]]).unwrap(); + let result = matrix.pow(2.0); + assert_eq!(result.data, [[1.0, 4.0], [9.0, 16.0]]); + } + + #[test] + fn test_index() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + assert_eq!(matrix[(0, 1)], 2); + } + + #[test] + fn test_index_mut() { + let mut matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + matrix[(0, 1)] = 5; + assert_eq!(matrix[(0, 1)], 5); + } + + #[test] + fn test_deref() { + let matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + assert_eq!(*matrix, [[1, 2], [3, 4]]); + } + + #[test] + fn test_deref_mut() { + let mut matrix = StaticMatrix::new([[1, 2], [3, 4]]).unwrap(); + matrix[(0, 1)] = 5; + assert_eq!(*matrix, [[1, 5], [3, 4]]); + } +} diff --git a/src/static_vector.rs b/src/static_vector.rs new file mode 100644 index 0000000..4a7ebf9 --- /dev/null +++ b/src/static_vector.rs @@ -0,0 +1,444 @@ +use std::ops::{Add, Deref, DerefMut, Div, Index, IndexMut, Mul, Sub}; + +use crate::error::ShapeError; +use crate::shape; +use crate::static_matrix::StaticMatrix; +use crate::tensor::Tensor; +use num::Float; +use num::Num; + +pub struct StaticVector { + data: [T; N], +} + +impl StaticVector { + pub fn new(data: [T; N]) -> Result, ShapeError> { + assert!(N > 0, "N must be greater than 0"); + Ok(Self { data }) + } + + pub fn to_tensor(&self) -> Tensor { + Tensor::new(&shape![N].unwrap(), &self.data).unwrap() + } + + pub fn fill(value: T) -> StaticVector { + let data = [value; N]; + StaticVector::new(data).unwrap() + } + + pub fn ones() -> StaticVector { + StaticVector::fill(T::one()) + } + + pub fn zeros() -> StaticVector { + StaticVector::fill(T::zero()) + } + + pub fn matmul(&self, rhs: &StaticMatrix) -> StaticVector { + let mut result = StaticVector::zeros(); + for i in 0..P { + for j in 0..N { + result.data[i] = result.data[i] + self.data[j] * rhs[(j, i)]; + } + } + result + } + + pub fn dot(&self, rhs: &Self) -> T { + let mut result = T::zero(); + for i in 0..N { + result = result + self.data[i] * rhs.data[i]; + } + result + } + + pub fn sum(&self) -> T { + let mut sum = T::zero(); + for &item in self.data.iter() { + sum = sum + item; + } + sum + } + + pub fn mean(&self) -> T { + self.sum() / self.size() + } + + pub fn variance(&self) -> T { + let mean = self.mean(); + let mut result = T::zero(); + for i in 0..N { + let diff = self.data[i] - mean; + result = result + diff * diff; + } + result / self.size() + } + + pub fn max(&self) -> T { + let mut max_value = self.data[0]; + for i in 1..N { + if self.data[i] > max_value { + max_value = self.data[i]; + } + } + max_value + } + + pub fn min(&self) -> T { + let mut min_value = self.data[0]; + for i in 1..N { + if self.data[i] < min_value { + min_value = self.data[i]; + } + } + min_value + } + + pub fn dims(&self) -> usize { + N + } + + pub fn size(&self) -> T { + let mut n = T::zero(); + for _ in 0..N { + n = n + T::one(); + } + n + } +} + +impl StaticVector { + pub fn pow(&self, power: T) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item.powf(power); + } + StaticVector { data: result } + } +} + +// Scalar Addition +impl Add for StaticVector { + type Output = StaticVector; + + fn add(self, rhs: T) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item + rhs; + } + StaticVector { data: result } + } +} + +// Vector Addition +impl Add> for StaticVector { + type Output = StaticVector; + + fn add(self, rhs: StaticVector) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item + rhs.data[i]; + } + StaticVector { data: result } + } +} + +// Scalar Subtraction +impl Sub for StaticVector { + type Output = StaticVector; + + fn sub(self, rhs: T) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item - rhs; + } + StaticVector { data: result } + } +} + +// Vector Subtraction +impl Sub> for StaticVector { + type Output = StaticVector; + + fn sub(self, rhs: StaticVector) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item - rhs.data[i]; + } + StaticVector { data: result } + } +} + +// Scalar Multiplication +impl Mul for StaticVector { + type Output = StaticVector; + + fn mul(self, rhs: T) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item * rhs; + } + StaticVector { data: result } + } +} + +// Vector Multiplication +impl Mul> for StaticVector { + type Output = StaticVector; + + fn mul(self, rhs: StaticVector) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item * rhs.data[i]; + } + StaticVector { data: result } + } +} + +// Scalar Division +impl Div for StaticVector { + type Output = StaticVector; + + fn div(self, rhs: T) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item / rhs; + } + StaticVector { data: result } + } +} + +// Vector Division +impl Div> for StaticVector { + type Output = StaticVector; + + fn div(self, rhs: StaticVector) -> StaticVector { + let mut result = [T::zero(); N]; + for (i, &item) in self.data.iter().enumerate() { + result[i] = item / rhs.data[i]; + } + StaticVector { data: result } + } +} + +impl Index for StaticVector { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + &self.data[index] + } +} + +impl IndexMut for StaticVector { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.data[index] + } +} + +impl Deref for StaticVector { + type Target = [T; N]; + + fn deref(&self) -> &[T; N] { + &self.data + } +} + +impl DerefMut for StaticVector { + fn deref_mut(&mut self) -> &mut [T; N] { + &mut self.data + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() { + let data = [1, 2, 3]; + let vector = StaticVector::new(data).unwrap(); + assert_eq!(vector.data, data); + } + + #[test] + #[should_panic(expected = "N must be greater than 0")] + fn test_new_failure_zero_length() { + let _ = StaticVector::::new([]).unwrap(); + } + + #[test] + fn test_fill() { + let vector = StaticVector::fill(5); + assert_eq!(vector.data, [5; 3]); + } + + #[test] + fn test_ones() { + let vector = StaticVector::::ones(); + assert_eq!(vector.data, [1; 3]); + } + + #[test] + fn test_zeros() { + let vector = StaticVector::::zeros(); + assert_eq!(vector.data, [0; 3]); + } + + #[test] + fn test_matmul() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let matrix = StaticMatrix::new([[1, 2], [3, 4], [5, 6]]).unwrap(); + let result = vector.matmul(&matrix); + assert_eq!(result.data, [22, 28]); + } + + #[test] + fn test_dot() { + let vector1 = StaticVector::new([1, 2, 3]).unwrap(); + let vector2 = StaticVector::new([4, 5, 6]).unwrap(); + let result = vector1.dot(&vector2); + assert_eq!(result, 32); + } + + #[test] + fn test_sum() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector.sum(); + assert_eq!(result, 6); + } + + #[test] + fn test_mean() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector.mean(); + assert_eq!(result, 2); + } + + #[test] + fn test_variance() { + let vector = StaticVector::new([1.0, 2.5, 4.0]).unwrap(); + let result = vector.variance(); + assert_eq!(result, 1.5); + } + + #[test] + fn test_max() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector.max(); + assert_eq!(result, 3); + } + + #[test] + fn test_min() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector.min(); + assert_eq!(result, 1); + } + + #[test] + fn test_dims() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector.dims(); + assert_eq!(result, 3); + } + + #[test] + fn test_len() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector.len(); + assert_eq!(result, 3); + } + + #[test] + fn test_pow() { + let vector = StaticVector::new([1.0, 2.0, 3.0]).unwrap(); + let result = vector.pow(2.0); + assert_eq!(result.data, [1.0, 4.0, 9.0]); + } + + #[test] + fn test_scalar_add() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector + 1; + assert_eq!(result.data, [2, 3, 4]); + } + + #[test] + fn test_vector_add() { + let vector1 = StaticVector::new([1, 2, 3]).unwrap(); + let vector2 = StaticVector::new([4, 5, 6]).unwrap(); + let result = vector1 + vector2; + assert_eq!(result.data, [5, 7, 9]); + } + + #[test] + fn test_scalar_sub() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector - 1; + assert_eq!(result.data, [0, 1, 2]); + } + + #[test] + fn test_vector_sub() { + let vector1 = StaticVector::new([4, 5, 6]).unwrap(); + let vector2 = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector1 - vector2; + assert_eq!(result.data, [3, 3, 3]); + } + + #[test] + fn test_scalar_mul() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector * 2; + assert_eq!(result.data, [2, 4, 6]); + } + + #[test] + fn test_vector_mul() { + let vector1 = StaticVector::new([1, 2, 3]).unwrap(); + let vector2 = StaticVector::new([4, 5, 6]).unwrap(); + let result = vector1 * vector2; + assert_eq!(result.data, [4, 10, 18]); + } + + #[test] + fn test_scalar_div() { + let vector = StaticVector::new([2, 4, 6]).unwrap(); + let result = vector / 2; + assert_eq!(result.data, [1, 2, 3]); + } + + #[test] + fn test_vector_div() { + let vector1 = StaticVector::new([2, 4, 6]).unwrap(); + let vector2 = StaticVector::new([1, 2, 3]).unwrap(); + let result = vector1 / vector2; + assert_eq!(result.data, [2, 2, 2]); + } + + #[test] + fn test_index() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + assert_eq!(vector[1], 2); + } + + #[test] + fn test_index_mut() { + let mut vector = StaticVector::new([1, 2, 3]).unwrap(); + vector[1] = 5; + assert_eq!(vector[1], 5); + } + + #[test] + fn test_deref() { + let vector = StaticVector::new([1, 2, 3]).unwrap(); + assert_eq!(*vector, [1, 2, 3]); + } + + #[test] + fn test_deref_mut() { + let mut vector = StaticVector::new([1, 2, 3]).unwrap(); + vector[1] = 5; + assert_eq!(*vector, [1, 5, 3]); + } +} diff --git a/src/tensor.rs b/src/tensor.rs index 595dc26..8962ada 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -3,13 +3,13 @@ use std::ops::{Add, Div, Mul, Sub}; use crate::axes::Axes; use crate::coordinate::Coordinate; +use crate::dynamic_matrix::DynamicMatrix; +use crate::dynamic_vector::DynamicVector; use crate::error::ShapeError; use crate::iter::IndexIterator; -use crate::matrix::DynamicMatrix; use crate::shape; use crate::shape::Shape; use crate::storage::DynamicStorage; -use crate::vector::DynamicVector; #[derive(Debug)] pub struct DynamicTensor {