diff --git a/src/matrix.rs b/src/matrix.rs index 215be90..c28e800 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -29,6 +29,11 @@ impl DynamicMatrix { Ok(DynamicMatrix { tensor }) } + pub fn tile(tensor: &DynamicMatrix, reps: &Shape) -> Result, ShapeError> { + let result = DynamicTensor::tile(tensor, reps)?; + Ok(DynamicMatrix { tensor: result }) + } + pub fn fill(shape: &Shape, value: T) -> Result, ShapeError> { let data = vec![value; shape.size()]; DynamicMatrix::new(shape, &data) @@ -279,6 +284,34 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_tile() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let matrix = DynamicMatrix::new(&shape, &data).unwrap(); + let reps = shape![2, 2].unwrap(); + + let result = DynamicMatrix::tile(&matrix, &reps).unwrap(); + + assert_eq!(result.shape(), &shape![4, 4].unwrap()); + assert_eq!(result[coord![0, 0].unwrap()], 1.0); + assert_eq!(result[coord![0, 1].unwrap()], 2.0); + assert_eq!(result[coord![0, 2].unwrap()], 1.0); + assert_eq!(result[coord![0, 3].unwrap()], 2.0); + assert_eq!(result[coord![1, 0].unwrap()], 3.0); + assert_eq!(result[coord![1, 1].unwrap()], 4.0); + assert_eq!(result[coord![1, 2].unwrap()], 3.0); + assert_eq!(result[coord![1, 3].unwrap()], 4.0); + assert_eq!(result[coord![2, 0].unwrap()], 1.0); + assert_eq!(result[coord![2, 1].unwrap()], 2.0); + assert_eq!(result[coord![2, 2].unwrap()], 1.0); + assert_eq!(result[coord![2, 3].unwrap()], 2.0); + assert_eq!(result[coord![3, 0].unwrap()], 3.0); + assert_eq!(result[coord![3, 1].unwrap()], 4.0); + assert_eq!(result[coord![3, 2].unwrap()], 3.0); + assert_eq!(result[coord![3, 3].unwrap()], 4.0); + } + #[test] fn test_fill() { let shape = shape![2, 2].unwrap(); diff --git a/src/shape.rs b/src/shape.rs index 7901559..35f874c 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -24,6 +24,10 @@ impl Shape { self.dims.len() } + pub fn iter(&self) -> std::slice::Iter<'_, usize> { + self.dims.iter() + } + pub fn stack(&self, rhs: &Shape) -> Shape { let mut new_dims = self.dims.clone(); new_dims.extend(rhs.dims.iter()); diff --git a/src/tensor.rs b/src/tensor.rs index 595dc26..4faf738 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -29,6 +29,50 @@ impl Tensor { }) } + pub fn tile(tensor: &Tensor, reps: &Shape) -> Result, ShapeError> { + if tensor.shape.order() != reps.order() { + return Err(ShapeError::new("Tensor and reps must have the same order")); + } + + // Calculate the new shape by multiplying each dimension of the tensor by the corresponding rep + let new_shape = tensor + .shape + .iter() + .zip(reps.iter()) + .map(|(dim, &rep)| dim * rep) + .collect::>(); + let new_shape = Shape::new(new_shape).unwrap(); + let mut new_data = Vec::with_capacity(new_shape.size()); + + // Initialize indices to keep track of the current position in the new tensor + let mut indices = vec![0; tensor.shape.order()]; + for _ in 0..new_shape.size() { + let mut original_index = 0; + let mut stride = 1; + + // Calculate the corresponding index in the original tensor + for (i, &dim) in tensor.shape.iter().enumerate().rev() { + original_index += (indices[i] % dim) * stride; + stride *= dim; + } + + // Push the value from the original tensor to the new data + new_data.push(tensor.data[original_index]); + + // Update indices for the next position + for i in (0..indices.len()).rev() { + indices[i] += 1; + if indices[i] < new_shape[i] { + break; + } + indices[i] = 0; + } + } + + // Create and return the new tensor with the new shape and new data + Tensor::new(&new_shape, &new_data) + } + pub fn fill(shape: &Shape, value: T) -> Tensor { let mut vec = Vec::with_capacity(shape.size()); for _ in 0..shape.size() { @@ -594,6 +638,73 @@ mod tests { assert_eq!(tensor.data, DynamicStorage::new(vec![1.0; shape.size()])); } + #[test] + fn test_tile_tensor() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + let reps = shape![2, 2].unwrap(); + + let result = Tensor::tile(&tensor, &reps).unwrap(); + + assert_eq!(result.shape(), &shape![4, 4].unwrap()); + assert_eq!( + result.data, + DynamicStorage::new(vec![ + 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0 + ]) + ); + } + + #[test] + fn test_tile_tensor_mismatched_order() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + let reps = shape![2, 2, 2].unwrap(); // Mismatched order + + let result = Tensor::tile(&tensor, &reps); + + assert!(result.is_err()); + } + + #[test] + fn test_tile_tensor_1d() { + let shape = shape![3].unwrap(); + let data = vec![1.0, 2.0, 3.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + let reps = shape![2].unwrap(); + + let result = Tensor::tile(&tensor, &reps).unwrap(); + + assert_eq!(result.shape(), &shape![6].unwrap()); + assert_eq!( + result.data, + DynamicStorage::new(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]) + ); + } + + #[test] + fn test_tile_tensor_3d() { + let shape = shape![2, 2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + let reps = shape![2, 2, 2].unwrap(); + + let result = Tensor::tile(&tensor, &reps).unwrap(); + + assert_eq!(result.shape(), &shape![4, 4, 4].unwrap()); + assert_eq!( + result.data, + DynamicStorage::new(vec![ + 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, + 5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0, 5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0, + 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, + 5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0, 5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0, + ]) + ); + } + #[test] fn test_fill_tensor() { let shape = shape![2, 3].unwrap(); diff --git a/src/vector.rs b/src/vector.rs index ff2b939..01ddfde 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -28,6 +28,11 @@ impl DynamicVector { Ok(DynamicVector { tensor }) } + pub fn tile(tensor: &DynamicVector, reps: &Shape) -> Result, ShapeError> { + let result = DynamicTensor::tile(tensor, reps)?; + Ok(DynamicVector { tensor: result }) + } + pub fn fill(shape: &Shape, value: T) -> Result, ShapeError> { if shape.order() != 1 { return Err(ShapeError::new("Shape must have order of 1")); @@ -267,6 +272,23 @@ mod tests { assert!(result.is_err()); } + #[test] + fn test_tile() { + let data = vec![1.0, 2.0]; + let vector = DynamicVector::new(&data).unwrap(); + let reps = shape![3].unwrap(); + + let result = DynamicVector::tile(&vector, &reps).unwrap(); + + assert_eq!(result.shape(), &shape![6].unwrap()); + assert_eq!(result[0], 1.0); + assert_eq!(result[1], 2.0); + assert_eq!(result[2], 1.0); + assert_eq!(result[3], 2.0); + assert_eq!(result[4], 1.0); + assert_eq!(result[5], 2.0); + } + #[test] fn test_fill() { let shape = shape![4].unwrap();