1. Introduction
1.1. Our Goal
Tensor shape mismatch
is a critical bug in deep neural network machine learning applications. Training a neural network is an expensive process that intends to terminate only when it finishes processing a huge amount of data through a sequence of tensor operations. In the middle of this timeconsuming training process, if the shape of an input datum failed to fit with a tensor operation, the whole process abruptly stops wasting the entire training cost spent thus far, losing the trained, if any, intermediate result.
Our goal is to automatically predict at compiletime such runtime tensorshape mismatch errors in PyTorch neural network training code.
1.2. Structure of PyTorch Programs
[PyTorch error samples]Sample codes of various types of tensor shape errors.
Figure 2(a) presents the typical type of tensor shape errors, which are slight modifications of Figure LABEL:fig:lbbasic. From the first example, the second Linear
layer (line 8), which multiplies the input with 8010matrix, requires a specific shape of a tensor as an input. The first layer (line 6), however, returns a wrongshaped tensor, and the overall pipeline will malfunction. This kind of error is called tensor shape mismatch error, simply, shape error.
Shape error is rather hard to manually find, only to be detected by running the program with an actual input. Indeed, the most notorious error for machine learning engineers is the error that can only be occurred after an immense amount of machinehours.
Section 1.2 shows another example. Its declaration of training data loader (line 14) hides a shape error. DataLoader
class slices the dataset sequentially by batch_size and passes it to the model. If the total length of the dataset is not divisible by batch_size, however, the size of the residual minibatch will be the nonzero remainder of the total length. See line 16: because the third parameter drop_last
is missing, the model assumes a consistent batch size (lines 10 and 6) hence the program will crash from the residual minibatch, losing the whole training hours. The recent massive networks like GPT3
(gpt3) require more than hundreds of machinehours to train. This type of error must be noticed before its run. Figure 2(a) illustrates another shape error that can be arisen from a dataset, not a structure of the model. It does not take input from the predefined MNIST dataset but reads an image from a file. If the read image is RGB, which has 3HW dimensions, it will not fit into the reshape method that requires a tensor of 2828elements. That means we have to convert it to a monochrome image before feeding it to the network. Even though it had been successfully tested with monochrome images, there can be a user who tests it with an RGB image, crashing the execution of the code. Though several works (ariadne; pythia; shapeflow; pytropos; semistatic) have reported tools to detect the shape mismatch errors of machine learning libraries, especially for TensorFlow (tensorflow), none of them have presented any static analysis tool that statically detects the shape errors for realistic Python ML applications. Realworld machine learning applications heavily utilize thirdparty libraries, external datasets, and configuration parameters, and handle their controls with subtle branch conditions and loops, but the existing tools still lack in supporting some of these elements and thus they fail to analyze even a simple ML application. To ensure that the shape error will not happen for any input data, we should statically infer a precise yet conservative range of each tensor shape and track its transformations through all possible execution paths.2. Overview of PyTea Analyzer
To find out shape errors before runtime, we present a static analyzer PyTea (PyTorch Tensor Error Analyzer). PyTea statically scans PyTorch applications and detects possible shape errors. PyTea analyzes full training and evaluation paths of the realworld Python/PyTorch applications with additional data processing and mixed usage of other libraries (e.g., Torchvision (torchvision), NumPy (numpy), PIL (pillow)) Figure 3 illustrates the overall architecture of PyTea analyzer. It first translates the original Python codes into a kernel language, PyTea Internal Representation (PyTea IR). Then, it tracks every possible execution path of the translated IR and collects the constraints regarding tensor shapes that dictate the conditions for the code to run without a shape error. The collected constraint sets are given to Satisfiability Modulo Theories (SMT) solver Z3 (z3) to judge that those constraints are satisfiable for every possible input shape. Following the result of the solver, PyTea concludes which path contains a shape error or not. If the constraintsolving by Z3 takes too much time, PyTea stops and tells ”don’t know”.
2.1. Assumptions
Given the typical structure of PyTorch neural network training code (Section 1.2), we assume for the PyTea’s input the followings about the PyTorch deep neural network training code:

Other than the training or evaluation dataset, every input value required to execute the code is injected by commandline arguments.

There is no infinite loop and recursion. We assume that every loop bound except for the datasets will be fixed to a constant.

The unknown loop bound for the datasets is only for the size of each dataset in an epoch, and every iteration is either with a fixedsized minibatch of the dataset or with a smaller, residual minibatch.

We assume that stringmanipulation expressions have no effect on tensor shapes.
These assumptions are based on our observations that most PyTorch networks and codes can be statically determined to fixed structures once we give precise commandline arguments. Realworld PyTorch applications mostly construct their structures by commandline arguments or external configuration files like JSON files. Therefore, PyTea chooses to analyze programs only with exact commandline arguments. For a few networks that are not resolved to a single fixed structure, we consider all possible structures. The number of the possible structures is to be controlled by our pathpruning technique, and sometimes, for an inevitable case, by timeout.
2.2. Handling path explosions
The number of possible paths is exponential to the number of branches in sequence. For some complex neural networks, such path explosion is possible. For example, Neural Architecture Search (nas) or Networks with Stochastic Depth (stochastic) have branches inside the network themselves. Figure 5 shows a representative path explosion case that utilizes a runtime random variable. We can notice that the feedforward function (forward(self, x)
) has two execution paths in its body. The final structure of the network is made with 24 same blocks (line 17), which makes 16M paths.
We handle this exponential cost blowup by means of conservative pathpruning and simpleminded timeouts. If we can find that the result of the binding scope of that feedforward function is pure (i.e., do not change any global value), and its bounded value is indeed equal for every path and not related with the branch conditions, we then safely ignore other paths except for one. If a path explosion arises even if using this method, we then use a timeout. See Section 3.2.3 for more details.
2.3. Handling Loops
For the loops in typical PyTorch neural network programs, as we discussed in Section 1.2 and accordingly assumed in Section 2.1, we do not need the full power of static analysis (staticanalysis). PyTea unrolls constantbound loops (Assumption A2 in Section 2.1) and analyzes their straightline code version. For the unknownbound loops for datasets, PyTea analyzes the loop body for just two cases with the aforementioned assumption A3. One is for the loop with a fixedsized regular minibatch of an epoch. The other is for the loop with the residual minibatch. For example, see code in Figure 4. For the third code box of Figure 4, we can unroll the loop expression to 3 same expressions. If we do not know the length of the dataset, such as the fourth code box of Figure 4, we use assumption A3 and consider only two cases for the two different sizes of minibatches.
3. Analysis Steps
3.1. PyTea IR
[Abstract syntax of PyTea IR]Formal definitions of PyTea IR expressions
[Abstract syntax of constraints]Formal definitions of PyTea IR constraints.
As the first step of the analysis, the input Python code is translated into the kernel language, PyTea IR. See Figure 6. PyTorch APIs are translated into tensor expressions that only define shape transformations, which PyTea IR focuses on. The second step of the analysis is to scan the PyTea IR code and generate constraints.
3.2. Constraint generation
Constraints are the conditions required by a PyTorch application so that it can be executed without any tensor shape error. For example, two operands of a matrix multiplication operation must share the same dimension. For each tensor operation (mm
, reshape
, readImage
, etc. of Figure 6), the shape of the input tensor must obey the requirement of the corresponding operation.
Figure 8 shows the abstract syntax of the constraints. Value expression represents the value of PyTea IR expressions, which can be used inside shape constraints. When PyTea analyzes a PyTea IR, it traces tensor shapes and primitive values of Python and constructs symbolic value expressions. Shape expression represents the shape of tensors, which is basically a tuple of integers . Figure 7 shows an example of a tensor with a shape . Each integer is a dimension size. We call the number of dimensions as a rank of a shape. We can slice () a shape expression or concatenate () two shape expressions. For example, suppose a PyTea IR variable t has shape . Expression t[0], which means the first subtensor of t along the first axis, can be represented inside constraints as , or simply . In case of expression t’s shape is unknown(), the shape of a subtensor will be represented as .
3.2.1. Constraint generation rules for PyTea IR
To capture Python semantics and PyTorch shape transformations, PyTea follows the static semantics () of PyTea IR. Judgment () means that the PyTea IR expression is statically approximated by a symbolic value expression under environment in case the constraint set () is satisfied. The environment () is a finite table that maps variables to symbolic value expressions. The introduction of constraints happens for branch expressions or PyTorch APIs (See Section 3.2.2). The other expressions will collect constraints from their subexpressions. For example, for an add expression (), see:
The result value is symbolically () where and are symbolic results of and respectively. The result constraint set will be a union of the result constraint sets of and . Every symbolic variable originates from external input, e.g., random function or a dataset. Every expression in the constraints is constructed by these variables and constant values.
3.2.2. Constraint types
In order to help the constraint resolution engine Z3 come up with a sensible counterexample that violates the derived constraints, we classify the constraints into two exclusive classes:
soft and hard constraints. For Z3 to generate counterexamples, soft constraints can be violated, while hard constraints should not. Thus hard constraints are, for example, those from branch conditions or about the value range of the input. See Figure 4 again. Python builtinrandom.randint
function generates an unknown random variable within a given range [0, 1]
. We mark that bound constraint as a hard constraint. On the other hand, torch.mm
API demands that two input tensors have to be rank2 () tensor and the second dimension (coordinate) of the first tensor have to be equal to the first dimension (coordinate) of the second tensor. This condition can be violated under the shape of the inputs, hence we mark it as a soft constraint.
Hard constraint generation
Hard constraints are those for inputs and branch conditions. Input conditions restrict the initial ranges of each input. Branch conditions split each path into two. Consider the following rule.
The readImage
API is an image fetching API that creates a new 3rank tensor which represents color channels, height, and width. The range of color channels is from 1 to 4, i.e., monochrome to RGBA, hence the constraint in the above rule. The symbolic value is a tensor of shape ().
As another case, consider the following rule.
The randInt
API generates a new random variable which is bound to given two numbers. This expression is used from the Python API random.randint
.
For branching case, see below:
The if
expression creates two paths depending on the branch condition . If the branch condition can be evaluated to a constant boolean, we can safely drop one branch.
Soft constraint generation
Soft constraints are the conditions with which PyTorch APIs must comply for them to run without a shape error. For instance, two operands of a matrix multiplication have to share the same middle dimension, and the reshape operation requires that the number of elements of the input tensor must be matched with the number of elements of the target shape. Each PyTorch API holds unique requirements of input conditions, and PyTea collects these requirements as soft constraints.
Following three rules, for example, PyTea collects such constraints from three representative APIs (mm
, reshape
and
transpose
):
The mm
API calculates a matrix multiplication of two 2rank matrices. The second dimension of the first matrix must be equal to the first dimension of the second matrix following the basic rules of linear algebra.
The reshape
API redefines the shape of a tensor. Reshaping a tensor does not change or drop the value of a tensor, so the target shape must have the exactly same number of values as the original shape ():
The transpose
API swaps two dimensions of the tensor along the axis and axis. Unlike the normal 2rank matrix transposition, transpose
slices a tensor with plane and transposes each matrix on each crosssection:
From this rule, we only consider the shape of the result, not the movement of the value inside the tensor.
3.2.3. Handling path explosion
Splitting execution paths whenever the analyzer encounters a branch can make the analysis cost grow exponentially. We can ignore some of them using the online constraint check, but we cannot for branches that use runtime input values. However, we can still avoid path split if both paths behave identically in terms of tensor shape. The conservative conditions are as follows:

Constraints collected from each path are not dependent on the branch condition, and

Each path has no global sideeffect, and

Two paths’ result symbolic values are the same.
PyTea checks the above conditions locally, within the boundary of the let expression containing each branch. When PyTea cannot statically decide on any of the three conditions, it safely assumes the conditions do not hold. Most branches in PyTorch neural network blocks satisfy the above conditions. Typically, network blocks should result in a tensor with a fixed shape that matches with a requirement of the next block or the training target tensor. Those blocks’ feedforward path will be translated into nested let blocks with branches that return the sameshaped tensor.
3.3. Constraint check
3.3.1. Online constraint check
To reduce the number of constraints and paths, our analyzer eagerly simplifies the symbolic expressions and constraints with primitive arithmetics and comparisons. By our eager, online constraint check, the ranges of each symbol can sometimes be known and be used to judge the subsequent constraints. If a branch condition can be simplified into constant true or false, we can trace only a single branch without splitting the path. If a constraint can be simplified to constant false, we can immediately report that the path is unsafe.
3.3.2. Offline Constraint check
PyTea feeds the collected constraints of each path to Z3. Algorithm 1 describes how we classify the Z3’s result. The final result of PyTea analyzer can be divided into four cases:

Valid: Soft constraints are always satisfied under the hard constraints. It guarantees that shape error will not occur from this path.

Invalid: A possible shape error is detected. There is a counterexample that makes soft constraints false under the hard constraints. We also report the generation position of the first broken constraint.

Don’t know: Z3 failed to decide whether constraints are satisfiable or not.

Unreachable: There is a conflict between hard constraints in this path. In other words, it is impossible to reach this path under the given conditions. This can happen if a path had passed two contradicted branches.
If every path results in either unreachable or valid path, we can conclude that the input program has no tensor shape error.
4. Evaluation
Our experiments show PyTea’s practical performance for realworld applications. To see the practicality of PyTea, we have collected several complete PyTorch applications and shaperelated PyTorch bugs. First, we analyzed the official PyTorch example projects from GitHub repository pytorch/examples(pytorch_example)
. This repository consists 11 complete PyTorch applications about major machine learning tasks from Generative Adversarial Network (GAN)
(dcgan)to Natural Language Processing. We also collected some PyTorch shape mismatch errors from StackOverflow and ran PyTea to statically detect them with PyTea. Finally, we conducted case analyses of several fullyfunctional, handmade PyTorch applications such as Stochastic ResNet
(stochastic).Experiment Settings
PyTea analyzer is written in mainly TypeScript (typescript), and communicates with Python scripts to run Z3. We also used Pyright (pyright) to parse and track Python syntax. The experiments were conducted on R7 5800X CPU, node.js 16.0.0 with TypeScript 4.2.4, and Python 3.8.8 with Z3Py 4.8.10.0. We fixed the epoch size to 1 from the commandline arguments, but used default values for the other settings. We measured the total elapsed time from the cold boot to the termination of PyTea. The full options and codes are written in the supplementary materials. ^{3}^{3}3Link: https://sf.snu.ac.kr/pytea/
PyTea commandline tool
Figure 9 shows an example snapshot of the analysis result of the PyTea commandline tool. It has analyzed one of the PyTorch example projects and prints the result of each phase of PyTea. It first prints out the online constraint check results and categorizes each path into three cases, potential success, potential unreachable, and immediate fail. The last one indicates that the online checker has found a constraint that can be false from that path. The potential unreachable path is the path which the online checker has found a false constraint, but there are certain unresolved branch conditions. That path will be checked at the next phase, and PyTea will examine whether the path has conflicted constraints only within the hard constraint set, which means that the path is unreachable from the beginning. From the second step, PyTea delivers the collected constraint set of each path to Z3 solver and runs the offline constraint checks. The offline check will report the first conflicted constraint and its position of creation, i.e., the exact tensor expression or PyTorch API that causes an error. If the solver does not found any conflicted constraint, PyTea concludes that all the paths are valid, hence no tensor shape error is possible.
4.1. Results
4.1.1. PyTea for PyTorch Examples
Network  LOC (main + lib)  PyTea  Hattori et al. (semistatic)  Total time (s) 
dcgan  3714 (214 + 3500)  1.75  
fast_neural_style  4394 (338 + 4056)  2.40  
imagenet  3820 (320 + 3500)  2.40  
mnist  3607 (116 + 3491)  1.59  
mnist_hogwild  3620 (129 + 3491)  1.94  
reinforcement_learning  180 (180 + )    
super_resolution  3886 (193 + 3693)  1.57  
snli  223 (223 + )    
time_sequence_prediction  3333 (88 + 3245)  1.88  
vae  3593 (102 + 3491)  1.70  
word_language_model  3278 (361 + 2912)  1.81  
Question  PyTea  Hattori et al. (semistatic) 
Case 1 (66995380)  
Case 2 (60121107)  
Case 3 (55124407)  
Case 4 (62157890)  
Case 5 (59108988)  
Case 6 (57534072)  
For the experiment, we pass each project twice to the analyzer. For the first pass, PyTea analyzed the main code unmodified, and we check that PyTea does not inform false positives. Then, we injected artificial shape errors, which we subtract one from the first dimension of the target tensor, right before the neural network’s loss calculation.
This simple method is decided on purpose. From this experiment, we focused on the speed of PyTea which shows the practicallity in order to be integrated to the code editor such as VSCode. This configuration can check the analysis time of the main network, and also confirm that PyTea tracks the tensor operations from the main network thoroughly, and we check PyTea does not report false negative results.
We have compared PyTea against another PyTorch analyzer of Hattori et al. (semistatic). Table 1 shows the overall results. Among the 11 projects, PyTea successfully analyzed 6 projects without any modification of the original source code. For three projects with a complex data preprocessing stage, PyTea needs a bypass (i.e., code modification) of that stage to infer the shapes of input tensors. PyTea has also succeeded in finding these injected errors. As these results show, PyTea is quick and effective enough to be integrated into code editors. Meanwhile, Hattori et al.’s analyzer failed for almost all benchmarks. Furthermore, since their semistatic approach requires an explicit shape of the input tensor, we needed to feed them an exact network model and input tensors to compare its performance with PyTea.
Although we have aimed to analyze the codes without any modification, two projects are heavily dependent on thirdparty data managing libraries like OpenAIGym
(gym)
. Because, at the moment, we are focusing on the analysis of PyTorchcentered applications, we decided not to support those libraries for now. Supporting more libraries is straightforward and is our future work.
4.1.2. PyTea for StackOverflow questions
To show that PyTea can identify yet another set of realworld shape mismatches, we collected some PyTorch shape errors from StackOverflow questions. Recent TensorFlow analyzers (pythia; shapeflow) used a TensorFlow error dataset collected by Zhang et al. (zhangbug), but we manually gathered PyTorch shape mismatch cases rather than using their dataset, because of the fundamental difference of the structures between TensorFlow and PyTorch. We also considered porting the TensorFlow error dataset into PyTorch codes, but we concluded that the ported codes are fairly old and artificial and do not reflect the standard method to build a PyTorch application.
Table 2 gives the analysis results of the 6 questions that we have collected. PyTea could detect every shape mismatch case from those questions. Following the analysis result, we could find the exact error positions and fix the shape mismatch cases. For example, the main code (Figure 10) of Case 2 does not satisfy the shape conditions for the inputs of NLLLoss
(line 9). The NLLLoss
module requires that the shape of the first input tensor without the second dimension is equal to the shape of the second input tensor. PyTea found out that NLLLoss
could generate a shape error from our experiment. We then fixed the code according to the StackOverflow answer, and PyTea checked that every path became valid.
4.2. Discovered Errors in PyTorch Applications
We applied PyTea to several realistic PyTorch applications which contain potential shape errors or path explosion. PyTeafound shape errors include the typical type of shape errors that we introduced at Section LABEL:sec:tensorerror. The complete projects and experiment scripts from this section will be in the supplementary material.
4.2.1. Detecting insufficient data preprocessing
We found a potential error at the data preprocessing stage from fast_neural_style application of pytorch/examples repository. As shown in Figure LABEL:lbfast2, Image.open
does not guarantee the loaded image has channel 3, i.e., RGB image. Therefore, any training or inference stage with a monochrome image will fail if we miss the channel converting method like line 4. This error was remained from the initial version and was fixed by the latest commit (a3f28a2) of the preprocessing script.
Comments
There are no comments yet.