|
| 1 | +extern crate ml_algo; |
| 2 | + |
| 3 | +use ml_algo::matrix::DMatrix; |
| 4 | +use ml_algo::linear::{LinearRegression, LinearRegressionOptions, Stepping}; |
| 5 | +use ml_algo::utils::{rmse_error, mae_error}; |
| 6 | + |
| 7 | +// This is taken from Machine Learning course by Anderw Ng |
| 8 | +// https://www.coursera.org/learn/machine-learning |
| 9 | +// |
| 10 | +// Week 2, exercise 1 |
| 11 | + |
| 12 | +#[test] |
| 13 | +fn food_truck_profit() { |
| 14 | + let mut lr = LinearRegression::new( LinearRegressionOptions::new() |
| 15 | + .max_iter(1500) |
| 16 | + .stepping(Stepping::Constant(0.01)) |
| 17 | + .x_eps(1.0e-15) |
| 18 | + .eps(1.0e-15) |
| 19 | + ); |
| 20 | + |
| 21 | + let train_x: DMatrix<f64> = DMatrix::from_csv("data/food_truck_profit.csv", 1, ',', Some(&[0])).unwrap(); |
| 22 | + let train_y: DMatrix<f64> = DMatrix::from_csv("data/food_truck_profit.csv", 1, ',', Some(&[1])).unwrap(); |
| 23 | + |
| 24 | + lr.fit(&train_x, train_y.data()).unwrap(); |
| 25 | + let train_py = lr.predict(&train_x).unwrap(); |
| 26 | + |
| 27 | + let rmse = rmse_error(train_y.data(), &train_py); |
| 28 | + let mae = mae_error(train_y.data(), &train_py); |
| 29 | + |
| 30 | + println!("Bias = {}, Coefficients = {:?}", lr.bias().unwrap(), lr.coefficients().unwrap()); |
| 31 | + assert!((lr.bias().unwrap() - (-3.630291)).abs() < 1.0e-5); |
| 32 | + assert!((lr.coefficients().unwrap()[0] - 1.166362).abs() < 1.0e-5); |
| 33 | + println!("Train: RMSE = {}, MAE = {}", rmse, mae); |
| 34 | + |
| 35 | + let mut test_x: DMatrix<f64> = DMatrix::new_zeros(0, 1); |
| 36 | + test_x.append_row(&[3.5]); |
| 37 | + test_x.append_row(&[7.0]); |
| 38 | + let test_py = lr.predict(&test_x).unwrap(); |
| 39 | + println!("For population = 35,000 we predict a profit of {}", test_py[0] * 10000.0); |
| 40 | + assert!((test_py[0] * 10000.0 - 4519.77).abs() < 0.1); |
| 41 | + println!("For population = 70,000 we predict a profit of {}", test_py[1] * 10000.0); |
| 42 | + assert!((test_py[1] * 10000.0 - 45342.45).abs() < 0.1); |
| 43 | +} |
0 commit comments