Original Source Here
5 Easy PyTorch Functions To Get You Started With PyTorch
Learning PyTorch can feel like a daunting task but sometimes you just have to start with smaller steps to take bigger leaps.
Pytorch is an open source machine learning framework developed in 2016 by Facebook’s AI Research Lab (FAIR) to make deep learning algorithms more accessible and easier to implement. Pytorch makes solving the already complicated task of implementing deep learning models much easier and intuitive.
You can check out the amazing website for official documentation of PyTorch to dive in deeper. It has everything for beginners to advanced developers starting from tutorials to various resources to learn PyTorch from.
Here are the 5 Pytorch functions that we will learn here to get you started with this amazing framework:
Before we begin, let’s install and import PyTorch
Function 1 —
torch.arange returns a range of numbers with values from the interval
[start, end) taken with common difference
step beginning from start.
In the above example we passed a single value to
torch.arange(), which returned a tensor of values in a range from 0 to the value passed, excluding that value.
Here we passed in three values where first is the *start* value and the second is the *end* value. The third value *step* gives the amount of steps the function has to jump in the range. Notice that the *end* value is not included in our tensor.
In the above example, the code breaks because the first value passed to
torch.arange() was larger than the second value. We get a
RuntimeError when we run the above code. Hence, the start value should always be less than the end value to avoid such errors.
torch.arange function can be used whenever you need to generate a range of numbers. It is similar to
range in python and can be used to define tensors in a particular range of numbers while also being able to jump some numbers by providing the
Function 2 —
torch.gather() gathers values in a tensor along an axis specified by dim parameter. Syntax:
torch.gather(input, dim, index)
torch.gather() takes in a tensor, an axis and the index values as input parameters and returns another tensor of values for those indices and the axis in the given tensor.
In the above example, we can see that the
dim value has been changed to
0 and hence the output tensor has also changed in response.
This seems quite confusing but it is actually much simpler than it looks. Have a look at this image to understand what
gather() does. Hint: Look at the numbers and trace them to each boxes.
Here to cause an error an index value greater than the dimensions of the tensor was passed. This returned an
2 is out of range for the given tensor.
torch.gather() can be used whenever we want to gather values at certain indices and create a new tensor.
Function 3 —
torch.where() returns a tensor of elements selected from either x or y tensors, depending on the given condition.
In the above example, the tensors x and y were defined.
torch.where() function took values from x such that they are greater than 0. If no such value was found the output tensor included the values from y tensor.
Here, the replacement value was a single number
0., so every value less than 0 was replaced by
Well here the replacement value
0 is an integer value and you should always use a float value such as
0. for the replacement or else the function will break.
torch.where() can be used whenever we want to select some values from a tensor given certain condition and create a new tensor of those selected values.
Function 4 —
torch.broadcast_to() broadcasts the input tensor passed as
input to the dimensions of
shape. It takes in an input tensor and creates a new tensor of a given shape by repeating the input tensor values.
In the above example, we passed in
[1, 2, 3] as the input tensor and shape of (3, 3).
torch.broadcast_to() broadcasted the list into a 3 X 3 matrix.
Here we tried it out with a 3 X 2 matrix. Notice that the shape must be according to the input tensor dimensions.
Here we passed an input tensor of the length three and expected a tensor of the shape 3 X 2. The input tensor can only be broadcasted as a matrix with number of columns equal to the length of the input tensor. Hence, we got an error as the shapes did not match.
This function can be useful when trying to create a huge tensor with a smaller tensor. The input tensor can be broadcasted to a required shape with this function.
Function 5 —
torch.sort() sorts the elements of the
input tensor along a given dimension in ascending order by value. A namedtuple of (values, indices) is returned, where the values are the sorted values and indices are the indices of the elements in the original input tensor.
For the above example, start from
sorted tensor. This is our tensor with the row values arranged in an ascending order. Then have a look at the
indices tensor. These are the indices for our sorted values when they were in the original
When we don’t do an unpacking for
torch.sort() we get our
indices tensor in the output of the function. We also used
descending=False but that is the default value for
sort() and can be set to
True for descending order.
Here we passed in
dim=2 which again caused an IndexError.
x tensor is a matrix of 2-dimentions hence it can take only the
dim values in the range of [-2, 1].
torch.sort() is an useful function when you require to arrange the values in an ascending or a descending order and create a new tensor. By specifying the
dim value we can arrange these elements row wise or column wise. We also get a tensor for the indices to know the positions of the new arranged values in the original tensor.
So, we saw 5 Pytorch functions as promised in the beginning and tried to understand them with a hands-on practical approach. We saw how they are used and more importantly when do these functions crash.
Hope this helped you in getting started with the PyTorch library. Check out the official docmentation for more info.
Provide links to your references and other interesting articles about tensors
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot