Skip to content

Add tile constructor #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ impl<T: Num + PartialOrd + Copy> DynamicMatrix<T> {
Ok(DynamicMatrix { tensor })
}

pub fn tile(tensor: &DynamicMatrix<T>, reps: &Shape) -> Result<DynamicMatrix<T>, ShapeError> {
let result = DynamicTensor::tile(tensor, reps)?;
Ok(DynamicMatrix { tensor: result })
}

pub fn fill(shape: &Shape, value: T) -> Result<DynamicMatrix<T>, ShapeError> {
let data = vec![value; shape.size()];
DynamicMatrix::new(shape, &data)
Expand Down Expand Up @@ -279,6 +284,34 @@ mod tests {
assert!(result.is_err());
}

#[test]
fn test_tile() {
let shape = shape![2, 2].unwrap();
let data = vec![1.0, 2.0, 3.0, 4.0];
let matrix = DynamicMatrix::new(&shape, &data).unwrap();
let reps = shape![2, 2].unwrap();

let result = DynamicMatrix::tile(&matrix, &reps).unwrap();

assert_eq!(result.shape(), &shape![4, 4].unwrap());
assert_eq!(result[coord![0, 0].unwrap()], 1.0);
assert_eq!(result[coord![0, 1].unwrap()], 2.0);
assert_eq!(result[coord![0, 2].unwrap()], 1.0);
assert_eq!(result[coord![0, 3].unwrap()], 2.0);
assert_eq!(result[coord![1, 0].unwrap()], 3.0);
assert_eq!(result[coord![1, 1].unwrap()], 4.0);
assert_eq!(result[coord![1, 2].unwrap()], 3.0);
assert_eq!(result[coord![1, 3].unwrap()], 4.0);
assert_eq!(result[coord![2, 0].unwrap()], 1.0);
assert_eq!(result[coord![2, 1].unwrap()], 2.0);
assert_eq!(result[coord![2, 2].unwrap()], 1.0);
assert_eq!(result[coord![2, 3].unwrap()], 2.0);
assert_eq!(result[coord![3, 0].unwrap()], 3.0);
assert_eq!(result[coord![3, 1].unwrap()], 4.0);
assert_eq!(result[coord![3, 2].unwrap()], 3.0);
assert_eq!(result[coord![3, 3].unwrap()], 4.0);
}

#[test]
fn test_fill() {
let shape = shape![2, 2].unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ impl Shape {
self.dims.len()
}

pub fn iter(&self) -> std::slice::Iter<'_, usize> {
self.dims.iter()
}

pub fn stack(&self, rhs: &Shape) -> Shape {
let mut new_dims = self.dims.clone();
new_dims.extend(rhs.dims.iter());
Expand Down
111 changes: 111 additions & 0 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,50 @@ impl<T: Num + PartialOrd + Copy> Tensor<T> {
})
}

pub fn tile(tensor: &Tensor<T>, reps: &Shape) -> Result<Tensor<T>, ShapeError> {
if tensor.shape.order() != reps.order() {
return Err(ShapeError::new("Tensor and reps must have the same order"));
}

// Calculate the new shape by multiplying each dimension of the tensor by the corresponding rep
let new_shape = tensor
.shape
.iter()
.zip(reps.iter())
.map(|(dim, &rep)| dim * rep)
.collect::<Vec<_>>();
let new_shape = Shape::new(new_shape).unwrap();
let mut new_data = Vec::with_capacity(new_shape.size());

// Initialize indices to keep track of the current position in the new tensor
let mut indices = vec![0; tensor.shape.order()];
for _ in 0..new_shape.size() {
let mut original_index = 0;
let mut stride = 1;

// Calculate the corresponding index in the original tensor
for (i, &dim) in tensor.shape.iter().enumerate().rev() {
original_index += (indices[i] % dim) * stride;
stride *= dim;
}

// Push the value from the original tensor to the new data
new_data.push(tensor.data[original_index]);

// Update indices for the next position
for i in (0..indices.len()).rev() {
indices[i] += 1;
if indices[i] < new_shape[i] {
break;
}
indices[i] = 0;
}
}

// Create and return the new tensor with the new shape and new data
Tensor::new(&new_shape, &new_data)
}

pub fn fill(shape: &Shape, value: T) -> Tensor<T> {
let mut vec = Vec::with_capacity(shape.size());
for _ in 0..shape.size() {
Expand Down Expand Up @@ -594,6 +638,73 @@ mod tests {
assert_eq!(tensor.data, DynamicStorage::new(vec![1.0; shape.size()]));
}

#[test]
fn test_tile_tensor() {
let shape = shape![2, 2].unwrap();
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::new(&shape, &data).unwrap();
let reps = shape![2, 2].unwrap();

let result = Tensor::tile(&tensor, &reps).unwrap();

assert_eq!(result.shape(), &shape![4, 4].unwrap());
assert_eq!(
result.data,
DynamicStorage::new(vec![
1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0
])
);
}

#[test]
fn test_tile_tensor_mismatched_order() {
let shape = shape![2, 2].unwrap();
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::new(&shape, &data).unwrap();
let reps = shape![2, 2, 2].unwrap(); // Mismatched order

let result = Tensor::tile(&tensor, &reps);

assert!(result.is_err());
}

#[test]
fn test_tile_tensor_1d() {
let shape = shape![3].unwrap();
let data = vec![1.0, 2.0, 3.0];
let tensor = Tensor::new(&shape, &data).unwrap();
let reps = shape![2].unwrap();

let result = Tensor::tile(&tensor, &reps).unwrap();

assert_eq!(result.shape(), &shape![6].unwrap());
assert_eq!(
result.data,
DynamicStorage::new(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
);
}

#[test]
fn test_tile_tensor_3d() {
let shape = shape![2, 2, 2].unwrap();
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = Tensor::new(&shape, &data).unwrap();
let reps = shape![2, 2, 2].unwrap();

let result = Tensor::tile(&tensor, &reps).unwrap();

assert_eq!(result.shape(), &shape![4, 4, 4].unwrap());
assert_eq!(
result.data,
DynamicStorage::new(vec![
1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0,
5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0, 5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0,
1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0,
5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0, 5.0, 6.0, 5.0, 6.0, 7.0, 8.0, 7.0, 8.0,
])
);
}

#[test]
fn test_fill_tensor() {
let shape = shape![2, 3].unwrap();
Expand Down
22 changes: 22 additions & 0 deletions src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ impl<T: Num + PartialOrd + Copy> DynamicVector<T> {
Ok(DynamicVector { tensor })
}

pub fn tile(tensor: &DynamicVector<T>, reps: &Shape) -> Result<DynamicVector<T>, ShapeError> {
let result = DynamicTensor::tile(tensor, reps)?;
Ok(DynamicVector { tensor: result })
}

pub fn fill(shape: &Shape, value: T) -> Result<DynamicVector<T>, ShapeError> {
if shape.order() != 1 {
return Err(ShapeError::new("Shape must have order of 1"));
Expand Down Expand Up @@ -267,6 +272,23 @@ mod tests {
assert!(result.is_err());
}

#[test]
fn test_tile() {
let data = vec![1.0, 2.0];
let vector = DynamicVector::new(&data).unwrap();
let reps = shape![3].unwrap();

let result = DynamicVector::tile(&vector, &reps).unwrap();

assert_eq!(result.shape(), &shape![6].unwrap());
assert_eq!(result[0], 1.0);
assert_eq!(result[1], 2.0);
assert_eq!(result[2], 1.0);
assert_eq!(result[3], 2.0);
assert_eq!(result[4], 1.0);
assert_eq!(result[5], 2.0);
}

#[test]
fn test_fill() {
let shape = shape![4].unwrap();
Expand Down