Skip to content

Commit a93bd9e

Browse files
committed
[+] linear: food_truck_profit
1 parent 93a9f50 commit a93bd9e

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

data/food_truck_profit.csv

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
population of City in 10000's,profit in 10000$
2+
6.1101,17.592
3+
5.5277,9.1302
4+
8.5186,13.662
5+
7.0032,11.854
6+
5.8598,6.8233
7+
8.3829,11.886
8+
7.4764,4.3483
9+
8.5781,12
10+
6.4862,6.5987
11+
5.0546,3.8166
12+
5.7107,3.2522
13+
14.164,15.505
14+
5.734,3.1551
15+
8.4084,7.2258
16+
5.6407,0.71618
17+
5.3794,3.5129
18+
6.3654,5.3048
19+
5.1301,0.56077
20+
6.4296,3.6518
21+
7.0708,5.3893
22+
6.1891,3.1386
23+
20.27,21.767
24+
5.4901,4.263
25+
6.3261,5.1875
26+
5.5649,3.0825
27+
18.945,22.638
28+
12.828,13.501
29+
10.957,7.0467
30+
13.176,14.692
31+
22.203,24.147
32+
5.2524,-1.22
33+
6.5894,5.9966
34+
9.2482,12.134
35+
5.8918,1.8495
36+
8.2111,6.5426
37+
7.9334,4.5623
38+
8.0959,4.1164
39+
5.6063,3.3928
40+
12.836,10.117
41+
6.3534,5.4974
42+
5.4069,0.55657
43+
6.8825,3.9115
44+
11.708,5.3854
45+
5.7737,2.4406
46+
7.8247,6.7318
47+
7.0931,1.0463
48+
5.0702,5.1337
49+
5.8014,1.844
50+
11.7,8.0043
51+
5.5416,1.0179
52+
7.5402,6.7504
53+
5.3077,1.8396
54+
7.4239,4.2885
55+
7.6031,4.9981
56+
6.3328,1.4233
57+
6.3589,-1.4211
58+
6.2742,2.4756
59+
5.6397,4.6042
60+
9.3102,3.9624
61+
9.4536,5.4141
62+
8.8254,5.1694
63+
5.1793,-0.74279
64+
21.279,17.929
65+
14.908,12.054
66+
18.959,17.054
67+
7.2182,4.8852
68+
8.2951,5.7442
69+
10.236,7.7754
70+
5.4994,1.0173
71+
20.341,20.992
72+
10.136,6.6799
73+
7.3345,4.0259
74+
6.0062,1.2784
75+
7.2259,3.3411
76+
5.0269,-2.6807
77+
6.5479,0.29678
78+
7.5386,3.8845
79+
5.0365,5.7014
80+
10.274,6.7526
81+
5.1077,2.0576
82+
5.7292,0.47953
83+
5.1884,0.20421
84+
6.3557,0.67861
85+
9.7687,7.5435
86+
6.5159,5.3436
87+
8.5172,4.2415
88+
9.1802,6.7981
89+
6.002,0.92695
90+
5.5204,0.152
91+
5.0594,2.8214
92+
5.7077,1.8451
93+
7.6366,4.2959
94+
5.8707,7.2029
95+
5.3054,1.9869
96+
8.2934,0.14454
97+
13.394,9.0551
98+
5.4369,0.61705

tests/food_truck_profit.rs

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
extern crate ml_algo;
2+
3+
use ml_algo::matrix::DMatrix;
4+
use ml_algo::linear::{LinearRegression, LinearRegressionOptions, Stepping};
5+
use ml_algo::utils::{rmse_error, mae_error};
6+
7+
// This is taken from Machine Learning course by Anderw Ng
8+
// https://www.coursera.org/learn/machine-learning
9+
//
10+
// Week 2, exercise 1
11+
12+
#[test]
13+
fn food_truck_profit() {
14+
let mut lr = LinearRegression::new( LinearRegressionOptions::new()
15+
.max_iter(1500)
16+
.stepping(Stepping::Constant(0.01))
17+
.x_eps(1.0e-15)
18+
.eps(1.0e-15)
19+
);
20+
21+
let train_x: DMatrix<f64> = DMatrix::from_csv("data/food_truck_profit.csv", 1, ',', Some(&[0])).unwrap();
22+
let train_y: DMatrix<f64> = DMatrix::from_csv("data/food_truck_profit.csv", 1, ',', Some(&[1])).unwrap();
23+
24+
lr.fit(&train_x, train_y.data()).unwrap();
25+
let train_py = lr.predict(&train_x).unwrap();
26+
27+
let rmse = rmse_error(train_y.data(), &train_py);
28+
let mae = mae_error(train_y.data(), &train_py);
29+
30+
println!("Bias = {}, Coefficients = {:?}", lr.bias().unwrap(), lr.coefficients().unwrap());
31+
assert!((lr.bias().unwrap() - (-3.630291)).abs() < 1.0e-5);
32+
assert!((lr.coefficients().unwrap()[0] - 1.166362).abs() < 1.0e-5);
33+
println!("Train: RMSE = {}, MAE = {}", rmse, mae);
34+
35+
let mut test_x: DMatrix<f64> = DMatrix::new_zeros(0, 1);
36+
test_x.append_row(&[3.5]);
37+
test_x.append_row(&[7.0]);
38+
let test_py = lr.predict(&test_x).unwrap();
39+
println!("For population = 35,000 we predict a profit of {}", test_py[0] * 10000.0);
40+
assert!((test_py[0] * 10000.0 - 4519.77).abs() < 0.1);
41+
println!("For population = 70,000 we predict a profit of {}", test_py[1] * 10000.0);
42+
assert!((test_py[1] * 10000.0 - 45342.45).abs() < 0.1);
43+
}

0 commit comments

Comments
 (0)