This demo shows how to do cross validation in bigframes.ml
1. Prepare Data#
import bigframes.pandas as bpd
# read and filter out unavailable data df = bpd.read_gbq("bigframes-dev.bqml_tutorial.penguins") df = df.dropna() df
Query job aa2b9845-0e66-4f42-a360-ffe03215caf6 is DONE. 0 Bytes processed. Open Job
Query job fe2bc354-672e-4d08-b969-bb2ede299fca is DONE. 28.9 kB processed. Open Job
Query job 8d16fa20-391f-4917-86fc-1a595dba3fc6 is DONE. 33.6 kB processed. Open Job
| species | island | culmen_length_mm | culmen_depth_mm | flipper_length_mm | body_mass_g | sex | |
|---|---|---|---|---|---|---|---|
| 0 | Gentoo penguin (Pygoscelis papua) | Biscoe | 45.2 | 16.4 | 223.0 | 5950.0 | MALE |
| 1 | Gentoo penguin (Pygoscelis papua) | Biscoe | 46.5 | 14.5 | 213.0 | 4400.0 | FEMALE |
| 2 | Adelie Penguin (Pygoscelis adeliae) | Biscoe | 37.7 | 16.0 | 183.0 | 3075.0 | FEMALE |
| 3 | Gentoo penguin (Pygoscelis papua) | Biscoe | 46.4 | 15.6 | 221.0 | 5000.0 | MALE |
| 4 | Gentoo penguin (Pygoscelis papua) | Biscoe | 46.1 | 13.2 | 211.0 | 4500.0 | FEMALE |
| 5 | Adelie Penguin (Pygoscelis adeliae) | Torgersen | 43.1 | 19.2 | 197.0 | 3500.0 | MALE |
| 6 | Gentoo penguin (Pygoscelis papua) | Biscoe | 45.2 | 15.8 | 215.0 | 5300.0 | MALE |
| 7 | Adelie Penguin (Pygoscelis adeliae) | Dream | 36.2 | 17.3 | 187.0 | 3300.0 | FEMALE |
| 8 | Chinstrap penguin (Pygoscelis antarctica) | Dream | 46.0 | 18.9 | 195.0 | 4150.0 | FEMALE |
| 9 | Gentoo penguin (Pygoscelis papua) | Biscoe | 54.3 | 15.7 | 231.0 | 5650.0 | MALE |
| 11 | Adelie Penguin (Pygoscelis adeliae) | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | FEMALE |
| 12 | Gentoo penguin (Pygoscelis papua) | Biscoe | 42.7 | 13.7 | 208.0 | 3950.0 | FEMALE |
| 13 | Adelie Penguin (Pygoscelis adeliae) | Biscoe | 41.0 | 20.0 | 203.0 | 4725.0 | MALE |
| 14 | Gentoo penguin (Pygoscelis papua) | Biscoe | 48.5 | 15.0 | 219.0 | 4850.0 | FEMALE |
| 15 | Chinstrap penguin (Pygoscelis antarctica) | Dream | 49.6 | 18.2 | 193.0 | 3775.0 | MALE |
| 16 | Gentoo penguin (Pygoscelis papua) | Biscoe | 50.8 | 17.3 | 228.0 | 5600.0 | MALE |
| 17 | Gentoo penguin (Pygoscelis papua) | Biscoe | 46.2 | 14.1 | 217.0 | 4375.0 | FEMALE |
| 18 | Adelie Penguin (Pygoscelis adeliae) | Biscoe | 38.8 | 17.2 | 180.0 | 3800.0 | MALE |
| 19 | Chinstrap penguin (Pygoscelis antarctica) | Dream | 51.0 | 18.8 | 203.0 | 4100.0 | MALE |
| 20 | Gentoo penguin (Pygoscelis papua) | Biscoe | 42.9 | 13.1 | 215.0 | 5000.0 | FEMALE |
| 21 | Gentoo penguin (Pygoscelis papua) | Biscoe | 50.4 | 15.3 | 224.0 | 5550.0 | MALE |
| 22 | Gentoo penguin (Pygoscelis papua) | Biscoe | 49.0 | 16.1 | 216.0 | 5550.0 | MALE |
| 23 | Gentoo penguin (Pygoscelis papua) | Biscoe | 43.4 | 14.4 | 218.0 | 4600.0 | FEMALE |
| 24 | Gentoo penguin (Pygoscelis papua) | Biscoe | 45.0 | 15.4 | 220.0 | 5050.0 | MALE |
| 25 | Gentoo penguin (Pygoscelis papua) | Biscoe | 47.5 | 14.0 | 212.0 | 4875.0 | FEMALE |
25 rows × 7 columns
[334 rows x 7 columns in total]
# Select X and y from the dataset X = df[ [ "species", "island", "culmen_length_mm", ] ] y = df["body_mass_g"]
2.1 Define KFold class and Train/Test for Each Fold (Manual Approach)#
from bigframes.ml import model_selection, linear_model
# Create KFold instance, n_splits defines how many folds the data will split. For example, n_split=5 will split the entire dataset into 5 pieces. # In each fold, 4 pieces will be used for training, and the other piece will be used for evaluation. kf = model_selection.KFold(n_splits=5)
for X_train, X_test, y_train, y_test in kf.split(X, y): model = linear_model.LinearRegression() model.fit(X_train, y_train) score = model.score(X_test, y_test) print(score)
Query job 9ce9fb43-306d-46e9-bbe5-d98ee55143bd is DONE. 37.0 kB processed. Open Job
Query job 8c86156d-ee97-4f66-9dc1-db15ff3d8e8e is DONE. 16.4 kB processed. Open Job
Query job b8f2b382-b938-4dff-8bdb-129703ade285 is DONE. 37.3 kB processed. Open Job
mean_absolute_error mean_squared_error mean_squared_log_error \ 0 297.36838 148892.914876 0.009057 median_absolute_error r2_score explained_variance 0 238.424052 0.814613 0.816053 [1 rows x 6 columns]
Query job ec2968f3-1713-4617-8a26-6fe4267f8061 is DONE. 37.0 kB processed. Open Job
Query job c7a1b80f-26f5-41b1-bcdc-b276af141671 is DONE. 16.4 kB processed. Open Job
Query job 82054991-c22f-41b3-9802-f16919949e26 is DONE. 37.3 kB processed. Open Job
mean_absolute_error mean_squared_error mean_squared_log_error \ 0 307.6149 139013.303482 0.007907 median_absolute_error r2_score explained_variance 0 266.589811 0.782835 0.794297 [1 rows x 6 columns]
Query job 3e5ae019-7c5b-44ea-8392-85145fdb6802 is DONE. 37.0 kB processed. Open Job
Query job c35dfd28-504d-4d12-b039-da890b9cb51d is DONE. 16.5 kB processed. Open Job
Query job 29ac1bb3-f864-400e-8cac-0b4c7f78ebcd is DONE. 37.3 kB processed. Open Job
mean_absolute_error mean_squared_error mean_squared_log_error \ 0 348.412701 180661.063512 0.01125 median_absolute_error r2_score explained_variance 0 313.29406 0.744053 0.74537 [1 rows x 6 columns]
Query job d90f5938-2894-4c93-8691-21162a2fca4c is DONE. 37.0 kB processed. Open Job
Query job 4c6328b3-2d3f-42bb-9f83-4f8c84773c95 is DONE. 16.4 kB processed. Open Job
Query job 8a885a6a-d3ad-4569-80ce-4f57d9b86105 is DONE. 37.3 kB processed. Open Job
mean_absolute_error mean_squared_error mean_squared_log_error \ 0 309.991882 151820.705254 0.008898 median_absolute_error r2_score explained_variance 0 212.758708 0.694001 0.694287 [1 rows x 6 columns]
Query job d1e60370-11c8-4f49-a8d5-85417662aa51 is DONE. 37.0 kB processed. Open Job
Query job d8e8712a-6347-4725-a27d-49810d4acc1c is DONE. 16.5 kB processed. Open Job
Query job 6a0ebaa6-5572-404f-a41d-b90e2c65d948 is DONE. 37.3 kB processed. Open Job
mean_absolute_error mean_squared_error mean_squared_log_error \ 0 256.569216 103495.042886 0.006605 median_absolute_error r2_score explained_variance 0 222.940815 0.818589 0.832344 [1 rows x 6 columns]
2.2 Use cross_validate Function to Do Cross Validation (Automatic Approach)#
# By using model_selection.cross_validate, the above 2.1 process is automated. The returned scores contains the evaluation results for each fold. model = linear_model.LinearRegression() scores = model_selection.cross_validate(model, X, y, cv=5) scores
Query job 5bdcd65d-7d72-4094-be3a-cf67a1787cf4 is DONE. 37.0 kB processed. Open Job
Query job bb0504b2-b656-4a08-9bf8-dcab0d188022 is DONE. 16.4 kB processed. Open Job
Query job 8c5c4b66-9a14-455a-a3f5-99f0f522713f is DONE. 37.3 kB processed. Open Job
Query job 9c9b81de-35b6-4561-8881-57da8b73cc7f is DONE. 37.0 kB processed. Open Job
Query job b781f1aa-6572-49e5-ab8d-f1908b497a1c is DONE. 16.4 kB processed. Open Job
Query job 41a2a58e-0289-4d58-8e39-de286f2a91fb is DONE. 37.3 kB processed. Open Job
Query job 7ee839a9-f77c-49b0-844e-8eecc1647b97 is DONE. 37.0 kB processed. Open Job
Query job a317d488-8589-4faa-940b-e59af91caf4d is DONE. 16.5 kB processed. Open Job
Query job 2de96ea8-519a-4976-a641-eb26a4bd38fb is DONE. 37.3 kB processed. Open Job
Query job 41a7d5a0-c76b-4ef3-a3da-d4d5a2ebbb0e is DONE. 37.0 kB processed. Open Job
Query job 9e82ddc9-8461-4644-ba34-957a7426ff8e is DONE. 16.4 kB processed. Open Job
Query job 0fa84d07-fdfa-41c9-b601-9326a94f3a09 is DONE. 37.3 kB processed. Open Job
Query job d4495568-f1b5-431b-b892-4fc7dcbccfd5 is DONE. 37.0 kB processed. Open Job
Query job af1e6460-3078-4a8b-8992-9e7df9dcfbb3 is DONE. 16.5 kB processed. Open Job
Query job f14401bf-fd80-401a-a61d-52614fba1ca7 is DONE. 37.3 kB processed. Open Job
{'test_score': [ mean_absolute_error mean_squared_error mean_squared_log_error \
0 322.341485 157616.627179 0.009137
median_absolute_error r2_score explained_variance
0 269.412639 0.705594 0.724882
[1 rows x 6 columns],
mean_absolute_error mean_squared_error mean_squared_log_error \
0 289.682121 136550.318797 0.00878
median_absolute_error r2_score explained_variance
0 212.874686 0.799363 0.81416
[1 rows x 6 columns],
mean_absolute_error mean_squared_error mean_squared_log_error \
0 325.358522 155218.752974 0.009606
median_absolute_error r2_score explained_variance
0 267.301671 0.777174 0.7782
[1 rows x 6 columns],
mean_absolute_error mean_squared_error mean_squared_log_error \
0 286.874056 120586.575364 0.007484
median_absolute_error r2_score explained_variance
0 247.656578 0.79281 0.796001
[1 rows x 6 columns],
mean_absolute_error mean_squared_error mean_squared_log_error \
0 287.989397 145947.465344 0.008447
median_absolute_error r2_score explained_variance
0 186.777549 0.791452 0.798825
[1 rows x 6 columns]],
'fit_time': [18.79181448201416,
19.092008439009078,
75.7446747609647,
17.520530884969048,
21.157033596013207],
'score_time': [4.247669544012751,
6.792615927988663,
4.502274781989399,
4.484583999030292,
4.224339194013737]}