Skip to content

math.Complex: genericize arithmetics with reals, replace mulbyi with comptime-generic i-shifted arithmetic #19202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 114 additions & 32 deletions lib/std/math/complex.zig
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,92 @@ pub fn Complex(comptime T: type) type {
}

/// Returns the sum of two complex numbers.
pub fn add(self: Self, other: Self) Self {
return Self{
pub fn add(self: Self, other: anytype) Self {
return if (Self == @TypeOf(other)) .{
.re = self.re + other.re,
.im = self.im + other.im,
} else .{
.re = self.re + other,
.im = self.im,
};
}

/// Returns the subtraction of two complex numbers.
pub fn sub(self: Self, other: Self) Self {
return Self{
pub fn sub(self: Self, other: anytype) Self {
return if (Self == @TypeOf(other)) .{
.re = self.re - other.re,
.im = self.im - other.im,
} else .{
.re = self.re - other,
.im = self.im,
};
}

/// Returns the product of two complex numbers.
pub fn mul(self: Self, other: Self) Self {
return Self{
pub fn mul(self: Self, other: anytype) Self {
return if (Self == @TypeOf(other)) .{
.re = self.re * other.re - self.im * other.im,
.im = self.im * other.re + self.re * other.im,
} else .{
.re = self.re * other,
.im = self.im * other,
};
}

/// Returns the quotient of two complex numbers.
pub fn div(self: Self, other: Self) Self {
const re_num = self.re * other.re + self.im * other.im;
const im_num = self.im * other.re - self.re * other.im;
const den = other.re * other.re + other.im * other.im;
pub fn div(self: Self, other: anytype) Self {
const abs2 = if (Self == @TypeOf(other)) other.re * other.re + other.im * other.im else other;
return if (Self == @TypeOf(other)) .{
.re = (self.re * other.re + self.im * other.im) / abs2,
.im = (self.im * other.re - self.re * other.im) / abs2,
} else .{
.re = self.re / other,
.im = self.im / other,
};
}

return Self{
.re = re_num / den,
.im = im_num / den,
/// Add self by a number rotated by the imaginary unit.
pub fn iadd(self: Self, other: anytype) Self {
return if (Self == @TypeOf(other)) .{
.re = self.re - other.im,
.im = self.im + other.re,
} else .{
.re = self.re,
.im = self.im + other,
};
}

/// Subtract self by a number rotated by the imaginary unit.
pub fn isub(self: Self, other: anytype) Self {
return if (Self == @TypeOf(other)) .{
.re = self.re + other.im,
.im = self.im - other.re,
} else .{
.re = self.re,
.im = self.im - other,
};
}

/// Multiply self by a number rotated by the imaginary unit.
pub fn imul(self: Self, other: anytype) Self {
return if (Self == @TypeOf(other)) .{
.re = -(self.im * other.re + self.re * other.im),
.im = self.re * other.re - self.im * other.im,
} else .{
.re = -self.im * other,
.im = self.re * other,
};
}

/// Divide self by a number rotated by the imaginary unit.
pub fn idiv(self: Self, other: anytype) Self {
const abs2 = if (Self == @TypeOf(other)) other.re * other.re + other.im * other.im else other;
return if (Self == @TypeOf(other)) .{
.re = (self.im * other.re - self.re * other.im) / abs2,
.im = -(self.re * other.re + self.im * other.im) / abs2,
} else .{
.re = self.im / other,
.im = -self.re / other,
};
}

Expand All @@ -94,14 +148,6 @@ pub fn Complex(comptime T: type) type {
};
}

/// Returns the product of complex number and i=sqrt(-1)
pub fn mulbyi(self: Self) Self {
return Self{
.re = -self.im,
.im = self.re,
};
}

/// Returns the reciprocal of a complex number.
pub fn reciprocal(self: Self) Self {
const m = self.re * self.re + self.im * self.im;
Expand All @@ -124,33 +170,76 @@ test "add" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.add(b);

const d = c.add(1);
try testing.expect(c.re == 7 and c.im == 10);
try testing.expect(d.re == 8 and d.im == 10);
}

test "sub" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.sub(b);

const d = c.sub(1);
try testing.expect(c.re == 3 and c.im == -4);
try testing.expect(d.re == 2 and d.im == -4);
}

test "mul" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.mul(b);

const d = c.mul(2);
try testing.expect(c.re == -11 and c.im == 41);
try testing.expect(d.re == -22 and d.im == 82);
}

test "div" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.div(b);

const d = c.div(2);
try testing.expect(math.approxEqAbs(f32, c.re, @as(f32, 31) / 53, epsilon) and
math.approxEqAbs(f32, c.im, @as(f32, -29) / 53, epsilon));
try testing.expect(math.approxEqAbs(f32, d.re, @as(f32, 31) / 106, epsilon) and
math.approxEqAbs(f32, d.im, @as(f32, -29) / 106, epsilon));
}

test "iadd" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.iadd(b);
const d = c.iadd(1);
try testing.expect(c.re == -2 and c.im == 5);
try testing.expect(d.re == -2 and d.im == 6);
}

test "isub" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.isub(b);
const d = c.isub(1);
try testing.expect(c.re == 12 and c.im == 1);
try testing.expect(d.re == 12 and d.im == 0);
}

test "imul" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.imul(b);
const d = c.imul(2);
try testing.expect(c.re == -41 and c.im == -11);
try testing.expect(d.re == 22 and d.im == -82);
}

test "idiv" {
const a = Complex(f32).init(5, 3);
const b = Complex(f32).init(2, 7);
const c = a.idiv(b);
const d = c.idiv(2);
try testing.expect(math.approxEqAbs(f32, c.re, @as(f32, -29) / 53, epsilon) and
math.approxEqAbs(f32, c.im, @as(f32, -31) / 53, epsilon));
try testing.expect(math.approxEqAbs(f32, d.re, @as(f32, -31) / 106, epsilon) and
math.approxEqAbs(f32, d.im, @as(f32, 29) / 106, epsilon));
}

test "conjugate" {
Expand All @@ -167,13 +256,6 @@ test "neg" {
try testing.expect(c.re == -5 and c.im == -3);
}

test "mulbyi" {
const a = Complex(f32).init(5, 3);
const c = a.mulbyi();

try testing.expect(c.re == -3 and c.im == 5);
}

test "reciprocal" {
const a = Complex(f32).init(5, 3);
const c = a.reciprocal();
Expand Down