Testing Your PyTorch Models with Torcheck

https://miro.medium.com/max/1200/0*GcP-tp9RvCTpc-A5

Original Source Here

Suppose we have coded up a ConvNet model for classifying the MNIST dataset. The full training routine looks like this:

There is actually a subtle error in the model code. Some of you may have noticed: in line 16, we carelessly put x on the right hand side, which should be output instead.

Now let’s see how torcheck help you detect this hidden error!

Step 0: Installation

Before we start, first install the package in one line.

$ pip install torcheck

Step 1: Adding torcheck code

Next we will add in code. Torcheck code always resides right before the training for loop, after your model and optimizer instantiation, as is shown below:

Step 1.1: Registering your optimizer(s)

First, register your optimizer(s) with torcheck:

torcheck.register(optimizer)

Step 1.2: Adding sanity checks

Next, add all the checks you want to perform in the four categories.

1. Parameters change/not change

For our example, we want all the model parameters to change during the training procedure. Adding the check is simple:

# check all the model parameters will change
# module_name is optional, but it makes error messages more informative when checks fail
torcheck.add_module_changing_check(model, module_name="my_model")

Side Note

To demonstrate the full capability of torcheck, let’s say later you freeze the convolutional layers and only want to fine tune the linear layers. Adding checks in this situation would be like:

# check the first convolutional layer's parameters won't change
torcheck.add_module_unchanging_check(model.conv1, module_name="conv_layer_1")
# check the second convolutional layer's parameters won't change
torcheck.add_module_unchanging_check(model.conv2, module_name="conv_layer_2")
# check the third convolutional layer's parameters won't change
torcheck.add_module_unchanging_check(model.conv3, module_name="conv_layer_3")
# check the first linear layer's parameters will change
torcheck.add_module_changing_check(model.fc1, module_name="linear_layer_1")
# check the second linear layer's parameters will change
torcheck.add_module_changing_check(model.fc2, module_name="linear_layer_2")

2. Output range check

Since our model is a classification model, we want to add the check mentioned earlier: model outputs should not all be in the range (0, 1).

# check model outputs are not all within (0, 1)
# aka softmax hasn't been applied before loss calculation
torcheck.add_module_output_range_check(
model,
output_range=(0, 1),
negate_range=True,
)

The negate_range=True argument carries the meaning of “not all”. If you simply want to check model outputs are all within a certain range, just remove that argument.

Although not applicable to our example, torcheck enables you to check the intermediate outputs of submodules as well.

3. NaN check

We definitely want to make sure model parameters don’t become NaN during training, and model outputs don’t contain NaN. Adding the NaN check is simple:

# check whether model parameters become NaN or outputs contain NaN
torcheck.add_module_nan_check(model)

4. Inf check

Similarly, add the Inf check:

# check whether model parameters become infinite or outputs contain infinite value
torcheck.add_module_inf_check(model)

After adding all the checks of interest, the final training code looks like this:

Step 2: Training and fixing

Now let’s run the training as usual and see what happens:

$ python run.pyTraceback (most recent call last):
(stack trace information here)
RuntimeError: The following errors are detected while training:
Module my_model's conv1.weight should change.
Module my_model's conv1.bias should change.

Bang! We immediately get an error message saying that our model’s conv1.weight and conv1.bias don’t change. There must be something wrong with model.conv1 .

As expected, we head over to the model code, notice the error, fix it, and rerun the training. Now everything works like a charm šŸ™‚

(Optional) Step 3: Turning off checks

Yay! Our model has passed all the checks. To get rid of them, we can simply call

torcheck.disable()

This is useful when you want to run your model on a validation set, or you just want to remove the checking overhead from your model training.

If you ever want to turn on the checks again, just call

torcheck.enable()

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: