Skip to content

Commit 0dce3c5

Browse files
authored
Correct minor details in image classification (#371)
* Corrected minor typos in image classification notebook * Small fixes
1 parent b6d6b22 commit 0dce3c5

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

31_image_classification.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,7 +1932,8 @@
19321932
"metadata": {},
19331933
"source": [
19341934
"After initialising the trainer instance, check whether a trained model already exists.\n",
1935-
"If so, load the weights using ```model_weights = torch.load(model_path, weights_only=True)```.\n",
1935+
"If so, load the weights using ```model_weights = torch.load(model_path, weights_only=True, map_location=torch.device('cpu'))```. \n",
1936+
"The ```map_location=torch.device('cpu')``` is only needed if you are running the code in a computer that does not have CUDA cores.\n",
19361937
"Then, load the weights into the model using (```model.load_state_dict(model_weights)```).\n",
19371938
"Finally, set the model to evaluation model (```model.eval()```).\n",
19381939
"This step is essential because certain layers, such as batch normalization and dropout, behave differently during training and evaluation.\n",
@@ -1974,7 +1975,7 @@
19741975
"Overfitting occurs when the model performs well on the training data but poorly on the validation data, usually indicated by a widening gap between the two curves.\n",
19751976
"Underfitting, on the other hand, is suggested when both the training and validation curves show poor performance and fail to improve. By monitoring these curves, we can adjust hyperparameters or modify the model architecture to address such issues. \n",
19761977
"\n",
1977-
"First, load the log file using ```pandas``` (```training_log = pd.read_csv(\"training_log.txt\")```).\n",
1978+
"First, load the log file using ```pandas``` (```training_log = pd.read_csv(\"training_log.txt\")``` or ```training_log = pd.read_csv(dataset_folder / \"training_log.txt\")``` in case you did not train the model by yourself).\n",
19781979
"Then, use the ```matplotlib``` library to plot the learning curves."
19791980
]
19801981
},
@@ -1988,7 +1989,7 @@
19881989
"import pandas as pd\n",
19891990
"from matplotlib import pyplot as plt\n",
19901991
"\n",
1991-
"# Load the training log file (In case you want to use the already trained model, replace this by model_path = dataset_folder / \"training_log.txt\")\n",
1992+
"# Load the training log file\n",
19921993
"training_log = None\n",
19931994
"\n",
19941995
"plt.figure()\n",

0 commit comments

Comments
 (0)