!pip install -q iterative-stratification
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
import pandas as pd
from sklearn.model_selection import KFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
df = pd.read_csv('/root/.cache/data/train.csv')
dfx = pd.get_dummies(df, columns=["discourse_type"]).head(5)
dfx
|
id |
discourse_id |
discourse_start |
discourse_end |
discourse_text |
discourse_type_num |
predictionstring |
discourse_type_Claim |
discourse_type_Concluding Statement |
discourse_type_Counterclaim |
discourse_type_Evidence |
discourse_type_Lead |
discourse_type_Position |
discourse_type_Rebuttal |
0 |
423A1CA112E2 |
1.622628e+12 |
8.0 |
229.0 |
Modern humans today are always on their phone…. |
Lead 1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 1… |
0 |
0 |
0 |
0 |
1 |
0 |
0 |
1 |
423A1CA112E2 |
1.622628e+12 |
230.0 |
312.0 |
They are some really bad consequences when stu… |
Position 1 |
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
0 |
0 |
0 |
0 |
0 |
1 |
0 |
2 |
423A1CA112E2 |
1.622628e+12 |
313.0 |
401.0 |
Some certain areas in the United States ban ph… |
Evidence 1 |
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
0 |
0 |
0 |
1 |
0 |
0 |
0 |
3 |
423A1CA112E2 |
1.622628e+12 |
402.0 |
758.0 |
When people have phones, they know about certa… |
Evidence 2 |
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9… |
0 |
0 |
0 |
1 |
0 |
0 |
0 |
4 |
423A1CA112E2 |
1.622628e+12 |
759.0 |
886.0 |
Driving is one of the way how to get around. P… |
Claim 1 |
139 140 141 142 143 144 145 146 147 148 149 15… |
1 |
0 |
0 |
0 |
0 |
0 |
0 |
dfx = pd.get_dummies(df, columns=["discourse_type"]).groupby(["id"], as_index=False).sum()
dfx
|
id |
discourse_id |
discourse_start |
discourse_end |
discourse_type_Claim |
discourse_type_Concluding Statement |
discourse_type_Counterclaim |
discourse_type_Evidence |
discourse_type_Lead |
discourse_type_Position |
discourse_type_Rebuttal |
0 |
0000D23A521A |
1.294188e+13 |
4166.0 |
5506.0 |
1 |
1 |
1 |
3 |
0 |
1 |
1 |
1 |
00066EA9880D |
1.458994e+13 |
12618.0 |
16058.0 |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
2 |
000E6DE9E817 |
1.940756e+13 |
8760.0 |
10092.0 |
5 |
1 |
1 |
3 |
0 |
1 |
1 |
3 |
001552828BD0 |
1.622844e+13 |
12881.0 |
15580.0 |
4 |
0 |
0 |
4 |
1 |
1 |
0 |
4 |
0016926B079C |
1.783190e+13 |
5102.0 |
6414.0 |
7 |
0 |
0 |
3 |
0 |
1 |
0 |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
15589 |
FFF1442D6698 |
1.618644e+13 |
14374.0 |
17948.0 |
2 |
1 |
1 |
3 |
1 |
1 |
1 |
15590 |
FFF1ED4F8544 |
1.454313e+13 |
6944.0 |
9435.0 |
5 |
0 |
0 |
2 |
1 |
1 |
0 |
15591 |
FFF868E06176 |
1.456920e+13 |
8210.0 |
10507.0 |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
15592 |
FFFD0AF13501 |
1.295859e+13 |
4408.0 |
5395.0 |
4 |
1 |
0 |
2 |
0 |
1 |
0 |
15593 |
FFFF80B8CC2F |
1.617042e+12 |
0.0 |
990.0 |
0 |
0 |
0 |
1 |
0 |
0 |
0 |
15594 rows × 11 columns
Index(['id', 'discourse_id', 'discourse_start', 'discourse_end',
'discourse_type_Claim', 'discourse_type_Concluding Statement',
'discourse_type_Counterclaim', 'discourse_type_Evidence',
'discourse_type_Lead', 'discourse_type_Position',
'discourse_type_Rebuttal'],
dtype='object')
cols = [c for c in dfx.columns if c.startswith("discourse_type") or c =="id" and c != "discourse_type_num"]
cols
['id',
'discourse_type_Claim',
'discourse_type_Concluding Statement',
'discourse_type_Counterclaim',
'discourse_type_Evidence',
'discourse_type_Lead',
'discourse_type_Position',
'discourse_type_Rebuttal']
|
id |
discourse_type_Claim |
discourse_type_Concluding Statement |
discourse_type_Counterclaim |
discourse_type_Evidence |
discourse_type_Lead |
discourse_type_Position |
discourse_type_Rebuttal |
0 |
0000D23A521A |
1 |
1 |
1 |
3 |
0 |
1 |
1 |
1 |
00066EA9880D |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
2 |
000E6DE9E817 |
5 |
1 |
1 |
3 |
0 |
1 |
1 |
3 |
001552828BD0 |
4 |
0 |
0 |
4 |
1 |
1 |
0 |
4 |
0016926B079C |
7 |
0 |
0 |
3 |
0 |
1 |
0 |
… |
… |
… |
… |
… |
… |
… |
… |
… |
15589 |
FFF1442D6698 |
2 |
1 |
1 |
3 |
1 |
1 |
1 |
15590 |
FFF1ED4F8544 |
5 |
0 |
0 |
2 |
1 |
1 |
0 |
15591 |
FFF868E06176 |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
15592 |
FFFD0AF13501 |
4 |
1 |
0 |
2 |
0 |
1 |
0 |
15593 |
FFFF80B8CC2F |
0 |
0 |
0 |
1 |
0 |
0 |
0 |
15594 rows × 8 columns
mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
labels = [c for c in dfx.columns if c != "id"]
dfx_labels = dfx[labels]
dfx_labels
|
discourse_type_Claim |
discourse_type_Concluding Statement |
discourse_type_Counterclaim |
discourse_type_Evidence |
discourse_type_Lead |
discourse_type_Position |
discourse_type_Rebuttal |
0 |
1 |
1 |
1 |
3 |
0 |
1 |
1 |
1 |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
2 |
5 |
1 |
1 |
3 |
0 |
1 |
1 |
3 |
4 |
0 |
0 |
4 |
1 |
1 |
0 |
4 |
7 |
0 |
0 |
3 |
0 |
1 |
0 |
… |
… |
… |
… |
… |
… |
… |
… |
15589 |
2 |
1 |
1 |
3 |
1 |
1 |
1 |
15590 |
5 |
0 |
0 |
2 |
1 |
1 |
0 |
15591 |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
15592 |
4 |
1 |
0 |
2 |
0 |
1 |
0 |
15593 |
0 |
0 |
0 |
1 |
0 |
0 |
0 |
15594 rows × 7 columns
/tmp/ipykernel_27806/3539168384.py:1: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
dfx["kfold"] = -1
df = pd.read_csv('/root/.cache/data/train.csv')
dfx = pd.get_dummies(df, columns=["discourse_type"]).groupby(["id"], as_index=False).sum()
cols = [c for c in dfx.columns if c.startswith("discourse_type_") or c == "id" and c != "discourse_type_num"]
dfx = dfx[cols]
mskf = MultilabelStratifiedKFold(n_splits=10, shuffle=True, random_state=42)
labels = [c for c in dfx.columns if c != "id"]
dfx_labels = dfx[labels]
dfx["kfold"] = -1
dfx
|
id |
discourse_type_Claim |
discourse_type_Concluding Statement |
discourse_type_Counterclaim |
discourse_type_Evidence |
discourse_type_Lead |
discourse_type_Position |
discourse_type_Rebuttal |
kfold |
0 |
0000D23A521A |
1 |
1 |
1 |
3 |
0 |
1 |
1 |
-1 |
1 |
00066EA9880D |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
-1 |
2 |
000E6DE9E817 |
5 |
1 |
1 |
3 |
0 |
1 |
1 |
-1 |
3 |
001552828BD0 |
4 |
0 |
0 |
4 |
1 |
1 |
0 |
-1 |
4 |
0016926B079C |
7 |
0 |
0 |
3 |
0 |
1 |
0 |
-1 |
… |
… |
… |
… |
… |
… |
… |
… |
… |
… |
15589 |
FFF1442D6698 |
2 |
1 |
1 |
3 |
1 |
1 |
1 |
-1 |
15590 |
FFF1ED4F8544 |
5 |
0 |
0 |
2 |
1 |
1 |
0 |
-1 |
15591 |
FFF868E06176 |
3 |
1 |
0 |
3 |
1 |
1 |
0 |
-1 |
15592 |
FFFD0AF13501 |
4 |
1 |
0 |
2 |
0 |
1 |
0 |
-1 |
15593 |
FFFF80B8CC2F |
0 |
0 |
0 |
1 |
0 |
0 |
0 |
-1 |
15594 rows × 9 columns
for fold, (trn_, val_) in enumerate(mskf.split(dfx, dfx_labels)):
print(len(trn_), len(val_))
dfx.loc[val_, "kfold"] = fold
df = df.merge(dfx[["id", "kfold"]], on="id", how="left")
print(df.kfold.value_counts())
# df.to_csv("train_folds.csv", index=False)
14036 1558
14036 1558
14033 1561
14035 1559
14031 1563
14035 1559
14034 1560
14036 1558
14036 1558
14034 1560
6 14633
9 14532
7 14529
8 14509
5 14466
3 14431
4 14365
1 14358
2 14271
0 14199
Name: kfold, dtype: int64
df.groupby(["kfold"]).count()
|
id |
discourse_id |
discourse_start |
discourse_end |
discourse_text |
discourse_type |
discourse_type_num |
predictionstring |
kfold |
|
|
|
|
|
|
|
|
0 |
14199 |
14199 |
14199 |
14199 |
14199 |
14199 |
14199 |
14199 |
1 |
14358 |
14358 |
14358 |
14358 |
14358 |
14358 |
14358 |
14358 |
2 |
14271 |
14271 |
14271 |
14271 |
14271 |
14271 |
14271 |
14271 |
3 |
14431 |
14431 |
14431 |
14431 |
14431 |
14431 |
14431 |
14431 |
4 |
14365 |
14365 |
14365 |
14365 |
14365 |
14365 |
14365 |
14365 |
5 |
14466 |
14466 |
14466 |
14466 |
14466 |
14466 |
14466 |
14466 |
6 |
14633 |
14633 |
14633 |
14633 |
14633 |
14633 |
14633 |
14633 |
7 |
14529 |
14529 |
14529 |
14529 |
14529 |
14529 |
14529 |
14529 |
8 |
14509 |
14509 |
14509 |
14509 |
14509 |
14509 |
14509 |
14509 |
9 |
14532 |
14532 |
14532 |
14532 |
14532 |
14532 |
14532 |
14532 |
Share on:
Twitter
❄ Facebook
❄ Email