{ "cells": [ { "cell_type": "markdown", "id": "1b68c7d8", "metadata": {}, "source": [ "# Kryssvalidering\n", "\n", "Vi vil brukte et datasett som inneholder informasjon om 4000 epler og hvor hvert eple er karakterisert som enten bra eller dårlig. Vi vil lage en modell basert på beslutningstrær som kan klassifisere et eple som bra eller dårlig basert på 7 forskjellige parametere. Hovedfokuset i denne notebooken vil derimot være mer på hvordan vi kan bruke de tilgjengelig dataene til å validere modellen og unngå overtrening.\n", "\n", "Først leser vi inn datasettet, fjerner ugyldige verdier, gjør om *god* og *dårlig* til 0 og 1 og plukker ut variablene/predikatorene samt target (god eller dårlig) i egne numpy-arrays. Dette bør være kjent stoff fra tidligere så vil ikke kommentere så mye på det som blir gjort. " ] }, { "cell_type": "code", "execution_count": 34, "id": "28b5d6cf", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "#%matplotlib notebook\n", "import pandas as pd\n", "import numpy as np\n", "from scipy.special import expit\n", "from scipy.stats import expon\n", "import math\n", "from sklearn import tree\n", "from sklearn.metrics import accuracy_score, \\\n", " precision_score, recall_score, f1_score, mean_squared_error, \\\n", " mean_absolute_error, ConfusionMatrixDisplay, confusion_matrix\n", "from IPython.display import Image\n", "from IPython.core.display import HTML " ] }, { "cell_type": "code", "execution_count": 35, "id": "33f4c11c", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(\"./apple_quality.csv\")" ] }, { "cell_type": "code", "execution_count": 36, "id": "94888b49", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | A_id | \n", "Size | \n", "Weight | \n", "Sweetness | \n", "Crunchiness | \n", "Juiciness | \n", "Ripeness | \n", "Acidity | \n", "Quality | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "-3.970049 | \n", "-2.512336 | \n", "5.346330 | \n", "-1.012009 | \n", "1.844900 | \n", "0.329840 | \n", "-0.491590483 | \n", "good | \n", "
1 | \n", "1.0 | \n", "-1.195217 | \n", "-2.839257 | \n", "3.664059 | \n", "1.588232 | \n", "0.853286 | \n", "0.867530 | \n", "-0.722809367 | \n", "good | \n", "
2 | \n", "2.0 | \n", "-0.292024 | \n", "-1.351282 | \n", "-1.738429 | \n", "-0.342616 | \n", "2.838636 | \n", "-0.038033 | \n", "2.621636473 | \n", "bad | \n", "
3 | \n", "3.0 | \n", "-0.657196 | \n", "-2.271627 | \n", "1.324874 | \n", "-0.097875 | \n", "3.637970 | \n", "-3.413761 | \n", "0.790723217 | \n", "good | \n", "
4 | \n", "4.0 | \n", "1.364217 | \n", "-1.296612 | \n", "-0.384658 | \n", "-0.553006 | \n", "3.030874 | \n", "-1.303849 | \n", "0.501984036 | \n", "good | \n", "
\n", " | A_id | \n", "Size | \n", "Weight | \n", "Sweetness | \n", "Crunchiness | \n", "Juiciness | \n", "Ripeness | \n", "Acidity | \n", "Quality | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "-3.970049 | \n", "-2.512336 | \n", "5.346330 | \n", "-1.012009 | \n", "1.844900 | \n", "0.329840 | \n", "-0.491590483 | \n", "good | \n", "
1 | \n", "1.0 | \n", "-1.195217 | \n", "-2.839257 | \n", "3.664059 | \n", "1.588232 | \n", "0.853286 | \n", "0.867530 | \n", "-0.722809367 | \n", "good | \n", "
2 | \n", "2.0 | \n", "-0.292024 | \n", "-1.351282 | \n", "-1.738429 | \n", "-0.342616 | \n", "2.838636 | \n", "-0.038033 | \n", "2.621636473 | \n", "bad | \n", "
3 | \n", "3.0 | \n", "-0.657196 | \n", "-2.271627 | \n", "1.324874 | \n", "-0.097875 | \n", "3.637970 | \n", "-3.413761 | \n", "0.790723217 | \n", "good | \n", "
4 | \n", "4.0 | \n", "1.364217 | \n", "-1.296612 | \n", "-0.384658 | \n", "-0.553006 | \n", "3.030874 | \n", "-1.303849 | \n", "0.501984036 | \n", "good | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
3995 | \n", "3995.0 | \n", "0.059386 | \n", "-1.067408 | \n", "-3.714549 | \n", "0.473052 | \n", "1.697986 | \n", "2.244055 | \n", "0.137784369 | \n", "bad | \n", "
3996 | \n", "3996.0 | \n", "-0.293118 | \n", "1.949253 | \n", "-0.204020 | \n", "-0.640196 | \n", "0.024523 | \n", "-1.087900 | \n", "1.854235285 | \n", "good | \n", "
3997 | \n", "3997.0 | \n", "-2.634515 | \n", "-2.138247 | \n", "-2.440461 | \n", "0.657223 | \n", "2.199709 | \n", "4.763859 | \n", "-1.334611391 | \n", "bad | \n", "
3998 | \n", "3998.0 | \n", "-4.008004 | \n", "-1.779337 | \n", "2.366397 | \n", "-0.200329 | \n", "2.161435 | \n", "0.214488 | \n", "-2.229719806 | \n", "good | \n", "
3999 | \n", "3999.0 | \n", "0.278540 | \n", "-1.715505 | \n", "0.121217 | \n", "-1.154075 | \n", "1.266677 | \n", "-0.776571 | \n", "1.599796456 | \n", "good | \n", "
4000 rows × 9 columns
\n", "