Three different ways to use the PyTorch torch.max() function

Python has gained a solid reputation for flexibility and ease of use. And one of the most impressive things about the language is that most third-party libraries uphold the language’s design philosophy. This is even true for the major libraries which provide specific functionality for scientific, mathematical, and data analytic processing.

Even seemingly simple functions in a machine learning framework like PyTorch can be leveraged to achieve a wide variety of different effects. And this is exceedingly clear for something like PyTorch’s torch.max() function. In fact, you’ll soon discover three powerful ways to use this function in your own code. But before we move on to those examples we’ll need to take a closer look at PyTorch and its max function.

An Overview of PyTorch’s Max

In the simplest terms, PyTorch’s max function can be understood as a technique to find the maximum element of the library’s primary data container – tensors. This might seem like a somewhat useful if limited function. Many people’s first thought is that it’d be relatively easy to simply set up a loop to iterate through a tensor’s elements. However, that leads to the real power of the max function.

Torch max finds the maximum value of a wide variety of different elements supplied by PyTorch. This means that you can use max to find the largest numbers within tensors. But it also means you can work with various elements inside that tensor while specifying dimensional coordinates. Or you could work with multiple tensors with that same level of precision. But without further ado, we can take a deeper dive into the max function and tensors.

A Deeper Look Into Max

We can begin by considering exactly why it’s best to use max rather than iterating with a for loop. The reason comes down to how Python’s math and science libraries are developed. The most popular libraries have had years or even decades to mature. They typically benefit from an extremely high level of code optimization which improves processing time while decreasing memory requirements. This can range from leveraging a computer’s GPU to using highly specialized techniques that work in full pairity with the library’s new data types. All of these reasons and more translate into one solid fact.

You’ll typically find that a library’s own special functions work far more efficiently with its new data types than standard Python techniques are capable of. However, there’s one additional benefit to using a library’s native functionality. The people writing those functions are the undisputed experts in using the library. They’ve spent a lot of time considering the underlying structure of the language and how to creatively leverage functions. And this is why, for example, max can be used in multiple contexts. But to understand what that means we need to see it in action.

Three Ways To Use Max

We can begin by looking at the most common PyTorch max usage scenario. Take a look at this simple code sample.

import torch as pt

ourTensor = pt.Tensor([[1, 2, 3], [4, 5, 6]])
ourMaxValue1 = pt.max(ourTensor)

print(ourTensor)
print(ourMaxValue1)

We begin with PyTorch’s import assigned to pt. Next, we leverage that import to create a 2D tensor consisting of the numbers 1 to 6. On the following line, we use max on ourTensor and assign the results to ourMaxValue1. We finish things off by printing the contents of ourTensor and ourMaxValue1. You’ll note that this code prints out the original tensor’s data just as we’d expect while ourMaxValue1 consists of the number 6. This is, of course, the max value found in ourTensor. The only major caveat to keep in mind is that if you have more than one maximal value within a reduced row, then the return value will consist of the index of the first maximal item.

The previous example is the most bare-bones way to use max, but you’ll note that it’s effective and to the point. But let’s try working out some more complex usage scenarios to see what else we can do with max. Take a look at the following example.

import torch as pt

ourTensor = pt.Tensor([[1, 2, 3], [4, 5, 6]])
ourMaxElements1, ourMaxDimensions1 = pt.max(ourTensor, dim=0)
ourMaxElements2, ourMaxDimensions2 = pt.max(ourTensor, dim=1)

print(ourMaxElements1)
print(ourMaxDimensions1)

print(ourMaxElements2)
print(ourMaxDimensions2)

Things are similar this time around, right up until line 4 and 5. We are using the max function on ourTensor again. But note that this time around we’re passing the results to new variables and receiving multiple maximal values. The ourMaxElements and ourMaxDimensions variables will hold their corresponding data. And we’re able to funnel that information into them by passing an additional value to max. The new argument is dim. And by passing a dim value we can work over multiple dimensions.

Starting from line 7 we proceed to print out the results of our new approach to max. You’ll note that we have very different results this time around. This is because max adapts its behavior to the element passed to it and the arguments. In this case, max is examining ourTensor along separate dimensions. The ourMaxElements output tensors contain the maximum value of individual elements along the supplied dimension. For ourMaxElements1 this is 4,5,6. But with ourMaxElements2 we’re moving in a different axis. So we now have the elements of 3 and 6 as the maximal value. The ourMaxDimensions variables show which indices contain the respective max values. The max indices are automatically generated if we supply a dim value when calling max.

At this point, you’ve seen that we can dramatically change results by passing 1 fewer dimension values as dim. But you might wonder what would happen if we simply passed multiple input tensors to max rather than a dim argument. You can see the answer in the following code.

import torch as pt

ourTensor = pt.Tensor([[13, 2, 3], [4, 5, 6]])
ourTensor2 = pt.Tensor([[7, 8, 9], [10, 11, 12]])

ourMaxElements = pt.max(ourTensor, ourTensor2)

print(ourMaxElements)

In this example, we create and populate two tensors as ourTensor and ourTensor2. Note that ourTensor’s first value is 13 rather than 1. Next, we pass those two tensors to PyTorch’s max and assign the result to ourMaxElements. And, finally, we print out the result.

This time around you can see that ourMaxElements is, once again, a 2D tensor that correlates to the original structure found in ourTensor and ourTensor2. But note that the values begin with 13 and then jump to 8. This is because the first value found in ourTensor is now the largest. And, as such, it winds up as a maximal value within ourmaxElements. The remaining space in the 2D tensor’s structure consists of the other largest values from the original two tensors.

Three different ways to use the PyTorch torch.max() function
Scroll to top