5 Easy PyTorch Functions To Get You Started With PyTorch

Original Source Here

5 Easy PyTorch Functions To Get You Started With PyTorch

Learning PyTorch can feel like a daunting task but sometimes you just have to start with smaller steps to take bigger leaps.

Photo by Possessed Photography on Unsplash

Pytorch is an open source machine learning framework developed in 2016 by Facebook’s AI Research Lab (FAIR) to make deep learning algorithms more accessible and easier to implement. Pytorch makes solving the already complicated task of implementing deep learning models much easier and intuitive.

You can check out the amazing website for official documentation of PyTorch to dive in deeper. It has everything for beginners to advanced developers starting from tutorials to various resources to learn PyTorch from.

Here are the 5 Pytorch functions that we will learn here to get you started with this amazing framework:

  • torch.arange()
  • torch.gather()
  • torch.where()
  • torch.broadcast_to()
  • torch.sort()

Before we begin, let’s install and import PyTorch

Function 1 — torch.arange()

torch.arange returns a range of numbers with values from the interval [start, end) taken with common difference step beginning from start.

In the above example we passed a single value to torch.arange(), which returned a tensor of values in a range from 0 to the value passed, excluding that value.

Here we passed in three values where first is the *start* value and the second is the *end* value. The third value *step* gives the amount of steps the function has to jump in the range. Notice that the *end* value is not included in our tensor.

In the above example, the code breaks because the first value passed to torch.arange() was larger than the second value. We get a RuntimeError when we run the above code. Hence, the start value should always be less than the end value to avoid such errors.

The torch.arange function can be used whenever you need to generate a range of numbers. It is similar to range in python and can be used to define tensors in a particular range of numbers while also being able to jump some numbers by providing the step value.

Function 2 — torch.gather()

The function torch.gather() gathers values in a tensor along an axis specified by dim parameter. Syntax: torch.gather(input, dim, index)

torch.gather() takes in a tensor, an axis and the index values as input parameters and returns another tensor of values for those indices and the axis in the given tensor.

In the above example, we can see that the dim value has been changed to 0 and hence the output tensor has also changed in response.

This seems quite confusing but it is actually much simpler than it looks. Have a look at this image to understand what gather() does. Hint: Look at the numbers and trace them to each boxes.

Here to cause an error an index value greater than the dimensions of the tensor was passed. This returned an IndexError because 2 is out of range for the given tensor.

torch.gather() can be used whenever we want to gather values at certain indices and create a new tensor.

Function 3 — torch.where()

The function torch.where() returns a tensor of elements selected from either x or y tensors, depending on the given condition.

In the above example, the tensors x and y were defined. torch.where() function took values from x such that they are greater than 0. If no such value was found the output tensor included the values from y tensor.

Here, the replacement value was a single number 0., so every value less than 0 was replaced by 0..

Well here the replacement value 0 is an integer value and you should always use a float value such as 0. for the replacement or else the function will break.

torch.where() can be used whenever we want to select some values from a tensor given certain condition and create a new tensor of those selected values.

Function 4 — torch.broadcast_to()

torch.broadcast_to() broadcasts the input tensor passed as input to the dimensions of shape. It takes in an input tensor and creates a new tensor of a given shape by repeating the input tensor values.

In the above example, we passed in [1, 2, 3] as the input tensor and shape of (3, 3). torch.broadcast_to() broadcasted the list into a 3 X 3 matrix.

Here we tried it out with a 3 X 2 matrix. Notice that the shape must be according to the input tensor dimensions.

Here we passed an input tensor of the length three and expected a tensor of the shape 3 X 2. The input tensor can only be broadcasted as a matrix with number of columns equal to the length of the input tensor. Hence, we got an error as the shapes did not match.

This function can be useful when trying to create a huge tensor with a smaller tensor. The input tensor can be broadcasted to a required shape with this function.

Function 5 — torch.sort()

The function torch.sort() sorts the elements of the input tensor along a given dimension in ascending order by value. A namedtuple of (values, indices) is returned, where the values are the sorted values and indices are the indices of the elements in the original input tensor.

For the above example, start from sorted tensor. This is our tensor with the row values arranged in an ascending order. Then have a look at the indices tensor. These are the indices for our sorted values when they were in the original x tensor.

When we don’t do an unpacking for torch.sort() we get our indices tensor in the output of the function. We also used descending=False but that is the default value for sort() and can be set to True for descending order.

Here we passed in dim=2 which again caused an IndexError. x tensor is a matrix of 2-dimentions hence it can take only the dim values in the range of [-2, 1].

torch.sort() is an useful function when you require to arrange the values in an ascending or a descending order and create a new tensor. By specifying the dim value we can arrange these elements row wise or column wise. We also get a tensor for the indices to know the positions of the new arranged values in the original tensor.


So, we saw 5 Pytorch functions as promised in the beginning and tried to understand them with a hands-on practical approach. We saw how they are used and more importantly when do these functions crash.

Hope this helped you in getting started with the PyTorch library. Check out the official docmentation for more info.

Reference Links

Provide links to your references and other interesting articles about tensors


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: