Skip to content

Commit 713a34a

Browse files
committed
remove broadcast for bit and ring binary ops
1 parent b983c0a commit 713a34a

File tree

2 files changed

+5
-112
lines changed

2 files changed

+5
-112
lines changed

rust/moose/src/bit.rs

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,14 @@ impl From<BitTensor> for ArrayD<u8> {
6161
impl BitXor for BitTensor {
6262
type Output = BitTensor;
6363
fn bitxor(self, other: Self) -> Self::Output {
64-
match self.0.broadcast(other.0.dim()) {
65-
Some(self_broadcasted) => BitTensor(self_broadcasted.to_owned() ^ other.0),
66-
None => BitTensor(self.0 ^ other.0),
67-
}
64+
BitTensor(self.0 ^ other.0)
6865
}
6966
}
7067

7168
impl BitAnd for BitTensor {
7269
type Output = BitTensor;
7370
fn bitand(self, other: Self) -> Self::Output {
74-
match self.0.broadcast(other.0.dim()) {
75-
Some(self_broadcasted) => BitTensor(self_broadcasted.to_owned() & other.0),
76-
None => BitTensor(self.0 & other.0),
77-
}
71+
BitTensor(self.0 & other.0)
7872
}
7973
}
8074

@@ -137,27 +131,4 @@ mod tests {
137131
BitTensor::fill(&shape, 0)
138132
);
139133
}
140-
141-
#[test]
142-
fn test_bit_ops_broadcasting() {
143-
// test xor
144-
assert_eq!(
145-
BitTensor::fill(&Shape(vec![5]), 0) ^ BitTensor::fill(&Shape(vec![1]), 1),
146-
BitTensor::fill(&Shape(vec![5]), 1)
147-
);
148-
assert_eq!(
149-
BitTensor::fill(&Shape(vec![1]), 0) ^ BitTensor::fill(&Shape(vec![5]), 1),
150-
BitTensor::fill(&Shape(vec![5]), 1)
151-
);
152-
153-
// test and
154-
assert_eq!(
155-
BitTensor::fill(&Shape(vec![5]), 0) & BitTensor::fill(&Shape(vec![1]), 1),
156-
BitTensor::fill(&Shape(vec![5]), 0)
157-
);
158-
assert_eq!(
159-
BitTensor::fill(&Shape(vec![1]), 0) & BitTensor::fill(&Shape(vec![5]), 1),
160-
BitTensor::fill(&Shape(vec![5]), 0)
161-
);
162-
}
163134
}

rust/moose/src/ring.rs

Lines changed: 3 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,7 @@ where
164164
{
165165
type Output = ConcreteRingTensor<T>;
166166
fn add(self, other: ConcreteRingTensor<T>) -> Self::Output {
167-
match self.0.broadcast(other.0.dim()) {
168-
Some(self_broadcasted) => {
169-
ConcreteRingTensor::<T>(self_broadcasted.to_owned() + other.0)
170-
}
171-
None => ConcreteRingTensor::<T>(self.0 + other.0),
172-
}
167+
ConcreteRingTensor::<T>(self.0 + other.0)
173168
}
174169
}
175170

@@ -180,12 +175,7 @@ where
180175
{
181176
type Output = ConcreteRingTensor<T>;
182177
fn mul(self, other: ConcreteRingTensor<T>) -> Self::Output {
183-
match self.0.broadcast(other.0.dim()) {
184-
Some(self_broadcasted) => {
185-
ConcreteRingTensor::<T>(self_broadcasted.to_owned() * other.0)
186-
}
187-
None => ConcreteRingTensor::<T>(self.0 * other.0),
188-
}
178+
ConcreteRingTensor::<T>(self.0 * other.0)
189179
}
190180
}
191181

@@ -196,12 +186,7 @@ where
196186
{
197187
type Output = ConcreteRingTensor<T>;
198188
fn sub(self, other: ConcreteRingTensor<T>) -> Self::Output {
199-
match self.0.broadcast(other.0.dim()) {
200-
Some(self_broadcasted) => {
201-
ConcreteRingTensor::<T>(self_broadcasted.to_owned() - other.0)
202-
}
203-
None => ConcreteRingTensor::<T>(self.0 - other.0),
204-
}
189+
ConcreteRingTensor::<T>(self.0 - other.0)
205190
}
206191
}
207192

@@ -412,69 +397,6 @@ mod tests {
412397
assert_eq!(out, exp)
413398
}
414399

415-
#[test]
416-
fn test_add_broadcasting() {
417-
let x_1_backing: ArrayD<i64> = array![2].into_dimensionality::<IxDyn>().unwrap();
418-
let x_1 = Ring64Tensor::from(x_1_backing);
419-
let y_1_backing: ArrayD<i64> = array![1, 2].into_dimensionality::<IxDyn>().unwrap();
420-
let y_1 = Ring64Tensor::from(y_1_backing);
421-
let z_1 = x_1.add(y_1);
422-
let z_1_exp_backing: ArrayD<i64> = array![3, 4].into_dimensionality::<IxDyn>().unwrap();
423-
let z_1_exp = Ring64Tensor::from(z_1_exp_backing);
424-
let x_2_backing: ArrayD<i64> = array![1, 2].into_dimensionality::<IxDyn>().unwrap();
425-
let x_2 = Ring64Tensor::from(x_2_backing);
426-
let y_2_backing: ArrayD<i64> = array![2].into_dimensionality::<IxDyn>().unwrap();
427-
let y_2 = Ring64Tensor::from(y_2_backing);
428-
let z_2 = x_2.add(y_2);
429-
let z_2_exp_backing: ArrayD<i64> = array![3, 4].into_dimensionality::<IxDyn>().unwrap();
430-
let z_2_exp = Ring64Tensor::from(z_2_exp_backing);
431-
432-
assert_eq!(z_1, z_1_exp);
433-
assert_eq!(z_2, z_2_exp);
434-
}
435-
436-
#[test]
437-
fn test_sub_broadcasting() {
438-
let x_1_backing: ArrayD<i64> = array![2].into_dimensionality::<IxDyn>().unwrap();
439-
let x_1 = Ring64Tensor::from(x_1_backing);
440-
let y_1_backing: ArrayD<i64> = array![1, 2].into_dimensionality::<IxDyn>().unwrap();
441-
let y_1 = Ring64Tensor::from(y_1_backing);
442-
let z_1 = x_1.sub(y_1);
443-
let z_1_exp_backing: ArrayD<i64> = array![1, 0].into_dimensionality::<IxDyn>().unwrap();
444-
let z_1_exp = Ring64Tensor::from(z_1_exp_backing);
445-
let x_2_backing: ArrayD<i64> = array![1, 2].into_dimensionality::<IxDyn>().unwrap();
446-
let x_2 = Ring64Tensor::from(x_2_backing);
447-
let y_2_backing: ArrayD<i64> = array![2].into_dimensionality::<IxDyn>().unwrap();
448-
let y_2 = Ring64Tensor::from(y_2_backing);
449-
let z_2 = x_2.sub(y_2);
450-
let z_2_exp_backing: ArrayD<i64> = array![-1, 0].into_dimensionality::<IxDyn>().unwrap();
451-
let z_2_exp = Ring64Tensor::from(z_2_exp_backing);
452-
453-
assert_eq!(z_1, z_1_exp);
454-
assert_eq!(z_2, z_2_exp);
455-
}
456-
457-
#[test]
458-
fn test_mul_broadcasting() {
459-
let x_1_backing: ArrayD<i64> = array![2].into_dimensionality::<IxDyn>().unwrap();
460-
let x_1 = Ring64Tensor::from(x_1_backing);
461-
let y_1_backing: ArrayD<i64> = array![1, 2].into_dimensionality::<IxDyn>().unwrap();
462-
let y_1 = Ring64Tensor::from(y_1_backing);
463-
let z_1 = x_1.mul(y_1);
464-
let z_1_exp_backing: ArrayD<i64> = array![2, 4].into_dimensionality::<IxDyn>().unwrap();
465-
let z_1_exp = Ring64Tensor::from(z_1_exp_backing);
466-
let x_2_backing: ArrayD<i64> = array![1, 2].into_dimensionality::<IxDyn>().unwrap();
467-
let x_2 = Ring64Tensor::from(x_2_backing);
468-
let y_2_backing: ArrayD<i64> = array![2].into_dimensionality::<IxDyn>().unwrap();
469-
let y_2 = Ring64Tensor::from(y_2_backing);
470-
let z_2 = x_2.mul(y_2);
471-
let z_2_exp_backing: ArrayD<i64> = array![2, 4].into_dimensionality::<IxDyn>().unwrap();
472-
let z_2_exp = Ring64Tensor::from(z_2_exp_backing);
473-
474-
assert_eq!(z_1, z_1_exp);
475-
assert_eq!(z_2, z_2_exp);
476-
}
477-
478400
#[test]
479401
fn bit_extract() {
480402
let shape = Shape(vec![5]);

0 commit comments

Comments
 (0)