The official PyTorch code implementation of "Human Trajectory Prediction via Counterfactual Analysis" in ICCV 2021.
Forecasting human trajectories in complex dynamic environments plays a critical role in autonomous vehicles and intelligent robots. Most existing methods learn to predict future trajectories by behavior clues from history trajectories and interaction clues from environments. However, the inherent bias between training and deployment environments is ignored. Hence, we propose a counterfactual analysis method for human trajectory prediction to investigate the causality between the predicted trajectories and input clues and alleviate the negative effects brought by environment bias. We first build a causal graph for trajectory forecasting with history trajectory, future trajectory, and the environment interactions. Then, we cut off the inference from environment to trajectory by constructing the counterfactual intervention on the trajectory itself. Finally, we compare the factual and counterfactual trajectory clues to alleviate the effects of environment bias and highlight the trajectory clues. Our counterfactual analysis is a plug-and-play module that can be applied to any baseline prediction methods including RNN- and CNN-based ones. We show that our method achieves consistent improvement for different baselines and obtains the state-of-the-art results on public pedestrian trajectory forecasting benchmarks.READ FULL TEXT VIEW PDF
The official PyTorch code implementation of "Human Trajectory Prediction via Counterfactual Analysis" in ICCV 2021.
Human trajectory prediction aims at forecasting the future trajectory of pedestrians based on their past positions in complex and crowd environments. It is a critical and fundamental task for many applications, including the planning and controlling of the autonomous vehicles, the robot navigation, and the tracking and re-identification in the crowd surveillance. Thanks to these significances, the human trajectory prediction task has attracted much attention over the past few years [1, 51, 54, 12, 23].
Despite the recent progress, trajectory prediction is still a challenging problem due to complex social or physical environment interactions. In the crowd environment, the pedestrian trajectory are always interacted with the social behavior from other pedestrians and the common scene context. For example, pedestrians may walk in parallel for social talking or stop for seconds to avoid collisions. Besides, the behaviors of pedestrians may be interacted by the traffic light, crosswalk, or just an obstacle like a tree.
Most existing methods concentrate on modeling the environment interactions and aggregate these interaction clues with history behavior clues for trajectory prediction. For example, Social LSTM  applies a social pooling module to extract clues from environment interaction. While Social-STGCNN  utilizes the spatial-temporal graph CNN to encode the environment interactions. However, these methods always ignore the inherent bias between training and deployment environments. As shown in part (a) of Figure 1, we statistically compare the interactions on training and testing environments of UNIV scene. We observe an obvious gap between the interactions of different environments. We also visualize the environment difference in part (b) of Figure 1
, where the training environment is a street scene while the testing environment is a more crowded public square. This environment bias causes the heavy overfitting of complex environment interaction modules in most of trajectory prediction methods. When the training data contains many examples who turns left in the crossroads, the prediction of left trajectory will be wrongly attributed to crossroads, which misleads the prediction in the new environment where prediction turn right in the crossroads. Besides, deployment environments are always unpredictable in the applications of autonomous vehicles and intelligent robots. Thus, it is difficult to apply transfer learning methods to reduce environment bias.
To address this problem, in this paper, we propose a counterfactual analysis method to alleviate the overdependence of environment bias and highlight the trajectory clues itself. Inspired by the causal inference methods [35, 34, 9], we propose to use the counterfactual intervention to investigate the causality between the observed clues and predicted trajectories. Different from conventional causal inference methods [32, 7], we furthermore extend these to the training process for the optimization of the prediction model. Specifically, we first construct a factual causal graph by the human prior knowledge, whose nodes include past trajectory, environment interaction, and future trajectory. As shown in the left part of Figure 1 (b), environment interaction may be a negative confounder due to the bias between training and deployment environments. Then, we conduct the counterfactual intervention on the history trajectory, which cuts off the dependence between the environment and trajectory. Motivated by [46, 43], we replace the history trajectory feature into counterfactual trajectory, such as uniform rectilinear motion, mean trajectory, or random trajectory. This counterfactual prediction indicates the effect of biased environment clues. Finally, we subtract the counterfactual prediction from original prediction as the causality-aware prediction since the negative effect of confounder is alleviated. We highlight that the proposed counterfactual analysis method is a plug-and-play module which can be applied to any baseline prediction method including RNN- and CNN-based ones. In the experiments, we apply our method for two baseline methods including RNN-based STGAT  and CNN-based Social-STGCNN . We show that our counterfactual analysis method achieves consistent improvement for both two baseline methods and obtains the state-of-the-art results on public pedestrian trajectory prediction datasets.
We summarize three advantages of the proposed counterfactual analysis method as follows:
We show that the environment bias may cause the heavy overfitting of the complex environment interaction modules.
We propose a counterfactual analysis method to alleviate the overdependence of environment bias and highlight the trajectory clues itself.
Our counterfactual analysis is a simply plug-and-play module which can be easily applied to any baseline predictor, and consistently improves the performance on many human trajectory prediction benchmarks.
Environment Interaction: The environment interaction consists of social environment interaction and physical environment interaction. To model the complex social environment interaction, previous methods always model human crowd motion with handcrafted rules or energy parameters, e.g. Social Force Model , continuum dynamics , Discrete Choice framework , Gaussian processes [44, 49] and crowd analysis [38, 52, 21]
. In deep learning methods, early studies[1, 12] employ the pooling module to capture the social interactions. While some methods [47, 8, 39, 53, 6]
apply the attention model to distinguish the importance of different neighbors. Many recent methods employ the graph model to extract the clues from interactions, e.g. the States Refinement module for message passing in SR-LSTM; the Graph Attention model for modeling interactions in [16, 19]; STR-GGRNN  uses data-driven adaptive online neighborhood recommendation. In addition, some methods [20, 39, 51, 19, 4] also incorporate physical environment interactions of the scene context for more reliable prediction. However, most of these methods ignore the environment bias between training and deployment environments, which causes the heavy overfitting of the environment interaction modules. To address this problem, we propose a counterfactual analysis method to encourage the model to focus more on trajectory itself, instead of the biased environment interaction.
Predictor: Deep learning based methods always regard the trajectory forecasting as the sequence prediction problem and apply the RNN model [1, 12, 19, 54, 39] or temporal CNN model [29, 27] as the predictor. For example, Social LSTM  employs the LSTM to encode human motion, then fuses the environment interaction clues extracted by a pooling module and feeds them into a LSTM decoder to sequentially predict the trajectory. While Nikhil et al.  first proposed to use the temporal CNN model to replace the RNN-based predictors. However, both RNN-based and CNN-based network are likelihood-based predictor whose causalities between observed clues and predicted results are still nontransparent. In our counterfactual analysis, we can investigate the effects of different clues including history trajectory and environment interactions by the intervention.
employ generative model to predict multiple plausible paths, instead of the deterministic one. These methods employ Generative Adversarial Network (GAN) to implicitly incorporates a latent variable into the encoded embedding [12, 19, 39, 55], or Variational Auto-Encoder (VAE)  to explicitly model the multi-modal distribution with a latent variable [20, 17, 40, 26]. While STGCNN 
directly predicts a bi-variate Gaussian distribution to replace some deterministic trajectories with multi-modal predictions. Besides, proposes to use Transformers for predicting the future trajectories.
The pedestrian trajectory prediction task can be defined as a sequential prediction problem, whose inputs include the history trajectory and environment interaction. Many method designs complex model to learn the clues from environment interactions. However, the inherent bias between training and deployment environments may cause the overfitting of interaction model. In the conventional trajectory prediction framework, given N pedestrians in a scene, we can define the history trajectory of th pedestrian as , where the is the 2D location at time . The ground-truth future trajectory of th pedestrian can be defined as . For the normalization of scale, many methods employ the relative locations even the relative speeds to replace the absolute locations. The environment interactions can be extracted from the trajectories of other pedestrians and the physical scene context as . The sequential trajectory prediction process is modeled as
where denotes the predicted trajectory.
Given a set of training pedestrian trajectory data
, the predictor is optimized by the L2 loss function as:
where denote the parameters of the pedestrian trajectory prediction system.
In this subsection, we reformulate above trajectory prediction framework with the causal graph with the prior knowledge. As shown in part (c) of Figure 1, the nodes in the graph denote the variables including history trajectory , environment interactions (including social and physical environment interactions), and future trajectory . While the causal links indicate the hidden causal relations and how these variables interact with each other. In a link , we can call the node as the parent node of , while is the child node of . The link from the parent node to the child node denotes a causality. indicates all the human trajectories are influenced by the environment. In the causal inference theory , when a variable () simultaneously affects two variables, it becomes a confounder in the causal analysis of these two variables. For example, if turning left is often observed with the environment interaction of crossroads, the prediction of left trajectory will be wrongly attributed to crossroads. While the real causality underlying the history trajectory may be ignored.
For conventional likelihood-based predictors, the causalities between different observed clues and predicted results are nontransparent. The prediction model is easy to be “cheat” by the short cut between the biased environment interaction and final future trajectory. Inspired by the causal inference methods [31, 32] which attempt to analyze the causalities among different clues, we propose to apply the counterfactual analysis for trajectory prediction system to mitigate the negative effect of biased environment and encourage the model to focus more on trajectory itself. Counterfactual intervention means to imagine a nonexistent observation to replace the original factual clues in the trajectory prediction system. In the causal inference methods , the intervention is formulated as . Once one variable is intervened, its all in-coming links in the causal graph are cut off and its value is independently given, while other variables who are not affected still maintain the original value. In our method, as shown in part (c) of Figure 1, we replace the history trajectory features in the system with the imagined trajectory or its embedding. We can obtain the prediction results under this intervention as:
indicates the counterfactual value. Specifically, we can apply the different counterfactual interventions such as the zero vector (which denotes uniform rectilinear motion when the relative speed is used as input), the mean vector of all history trajectories or the random trajectories as the counterfactual intervention.
The original factual prediction is dependent on both trajectories and environment interactions, while the counterfactual predicted result is only dependent with environment interactions since the trajectory is replaced by counterfactual intervention. To investigate the real effect of trajectory itself, we compute the difference between the counterfactual prediction and factual prediction and name it as the causal prediction:
Compared with original likelihood prediction, the causal prediction is more reliable by avoiding the biased affects from environment confounders. In the training process, we optimize the network to predict the causal prediction:
It is worth noting that our counterfactual analysis is robust to the slight change of trajectory prediction systems, such as different predictors, whether the physical environment content is used, or whether generative model is used. These changes cannot affect the structure of the causal graph. When using the generative methods, we will add the noise latent variable in the trajectory prediction system and predict the stochastic future trajectory. Taking the GAN as the example of generative model, the loss function in (2) is reformulated as
where indicates the trajectory is generated by a noise latent variable for multi-modal distributions. While is the discriminator in the GAN to judge whether the trajectory is generated by the noise latent variable. Our counterfactual analysis can work together with generative methods, which only needs to add the noise latent variable in the side of predicted trajectory with . The causal prediction with the latent variable is formulated as:
Then, we only need to replace the original prediction with causal prediction as:
As shown in Figure 2
, the framework of trajectory prediction system always contains 4 modules including history path encoder, scene perceptron, interactions analysis module, and trajectory predictor. Instead of the conventional likelihood-based trajectory predictor, in this paper, we conduct the counterfactual intervention for some clue and apply a causal predictor by computing the difference between the prediction results with factual and counterfactual clues. Our causal prediction is simple yet effective for the reliability of trajectory prediction system. To evaluate the generality of our counterfactual analysis method, we incorporate it as an plug-and-play module into two different baseline methods including RNN-based STGAT and CNN-based Social-STGCNN . We simply introduce these two implementations as follows and the details can be found in the supplementary materials. Note that though the whole counterfactual analysis method contains the features of environment scene, here we do not use the scene features in the implementations since our baseline models do not use scene images. Additionally, we conduct experiments with the scene features in the appendix.
|RNN-based Method||Performance (ADE/FDE)|
|MATF GAN ||1.01/1.75||0.43/0.80||0.26/0.45||0.26/0.57||0.44/0.91||0.48/0.90|
Causal-STGAT: The STGAT method employs two LSTM and a graph attention network (GAT) to encode the history trajectory and social interaction clues. Specifically, M-LSTM (LSTM for motion coding) focuses on the history trajectory features, while G-LSTM (LSTM for graph) and GAT models are applied to extract the interaction features. Then the features from M-LSTM, G-LSTM + GAT and latent noise are connected as the input of the LSTM predictor. For example about history trajectory, we replace the feature vector from M-LSTM with the counterfactual intervention (the mean vector of all trajectory features, or the zero vector). Then we respectively conduct predictions with original connected feature and counterfactual connected feature whose history trajectory part is changed. Finally, we adopt the difference between two results as the prediction of our Causal-STGAT.
Social-STGCNN employs the spatial-temporal graph convolution neural network (ST-GCNN) as the encoder to extract the features from history trajectory and social interactions. The trajectories in a scene are represented as a graph, where nodes denote the past trajectories, while the adjacency matrix in the GCN denotes the social interactions. Then a time-extrapolator CNN is applied as the trajectory predictor with the outputs from ST-GCNN. In this baseline, we replace the original nodes by counterfactual intervention and maintain the values of interaction clues . Then we respectively feed the outputs of ST-GCNN with factual and counterfactual nodes to the time-extrapolator CNN predictor and compute the difference to get the causal prediction.
Datasets: We conduct experiments on two publicly available human trajectory prediction datasets: ETH  and UCY . ( Besides, we provide experimental results on Stanford drone dataset  in the appendix. ) The human trajectories in both two datasets are captured in the real world scenes with rich social interactions. These datasets contain five unique scenes: Zara1, Zara2, Univ, ETH, and Hotel with 1536 detected pedestrians. All trajectories in datasets are sampled every 0.4 seconds. The experimental settings follow the previous methods like Social-GAN , Social BiGAT ,and STGAT . We also use the leave-one-out approach to train and validate our model on 4 sets and test on the remaining one. During evaluation, the first 3.2 seconds (8 frames) are observed and the next 4.8 seconds (12 frames) are to be predicted.
We adopt the same evaluation metrics as prior methods[12, 19, 16], including Average Displacement Error (ADE) and Final Displacement Error (FDE). ADE computes the mean square error (MSE) of predicted trajectory and ground-truth trajectory, while FDE computes the L2 distance between the final locations of predicted and ground-truth trajectories. Since both baselines Social-STGCNN  and STGAT  are generative methods, we also follow the evaluation method in Social-GAN . We generate 20 samples based on the predicted distribution and select the closest sample to the ground truth for ADE and FDE metrics.
Evaluation of Counterfactual Analysis: To verify the effectiveness of our method, we compare the performance of our Causal-STGAT, Causal-STGCNN with the RNN-based baseline STGAT  and CNN-based Social-STGCNN  methods. The performance comparison about our causal-based methods, baseline methods, and other SOTA methods is summarized as Table 1. Note that the in Table 1 means our reproduced results trained with the official released code. 222We reproduced the STGAT  methods with the released code from https://github.com/huang-xx/STGAT, while the code of Social-STGCNN  comes from https://github.com/abduallahmohamed/Social-STGCNN. Due to different implementation environments, the performance of our reproduced baselines are slightly lower then the performance reported in the paper. For the fair comparison, all hyper-parameters and environments of our Causal-STGAT and Causal-STGCNN methods are same with the reproduced baselines. As shown in Table 1, our causal prediction method can consistently improve the performance than the original likelihood-based prediction with different baselines on most of datasets. Specifically, Causal-STGAT obtains the +0.07/+0.16 ADE/FDE improvement than the baseline STGAT method on the average of 5 scenes, while Causal-STGCNN outperforms the reproduced Social-STGCNN method by over +0.02/+0.09 ADE/FDE results.
Besides, we also compare the proposed causal based method with other SOTA methods, such as IDL , Social-BiGAT  and MATF . As shown in Table 1, we achieves the SOTA performance with slight improvement over the RNN-based IDL  method and CNN-based Social-STGCNN  method. It also demonstrates the effectiveness of the proposed counterfactual analysis method. For both causal method in Table 1, we conducted the counterfactual intervention on the history trajectory by replacing the features with zero vector. Other attempts are introduced below.
Evaluation of Different Counterfactual Implementation:
We have attempted different implementations of counterfactual intervention, including using the zero vector, the mean of all feature vector or a random vector sampled from a uniform distribution with. These intervention implementations (“Zero”, “Mean”, “Random”) with limited computing cost are usually adopted  for causal model. For the zero vector and mean vector, the counterfactual intervention in the training and testing stage is invariant. While for random vector, we sampled the vector from a uniform distribution in the training stage and using the expectation of uniform distribution (zero vector) in the testing stage to avoid introducing the bias for testing data. As shown in bottom part of Table 2, all implementations obtain obvious improvements for the baselines, which demonstrates the generality of the proposed counterfactual analysis method. Furthermore, 1) The “Zero” obtains slightly better performance than the others. It might be because the “Zero” is a harder counterfactual intervention, which better emphasizes the causal effect from history behaviors to future trajectories. 2) The performance of these implementations is close, which demonstrates the robustness of the implementation of counterfactual intervention. All versions of counterfactual intervention are conducted on the human past trajectory clue.
Evaluation of Inference Speed and Model Size: Model size and inference speed are also critical for the deployment of the method in the real-world environment. We evaluated the speed of all models with all data in five scenes and compute the average inference time for one trajectory on one GTX 2080Ti GPU. For fairness, we fixed the batch-size as 1 for all methods and repeat the evaluation 3 times for average. The parameters cost and inference speed of our methods and baselines are summarized in Table 3. First, our counterfactual analysis method does not need any extra parameters since all parameters are shared with factual and counterfactual parts. Second, there is no such thing as a free lunch. The speed of our methods are lower than the baselines because of the extra computations for counterfactual analysis. However the extra speed cost is not heavy since the counterfactual analysis only uses part of network.
We qualitatively analyze how our causal predictions are more reliable than conventional likelihood-based predictions. As shown in Figure 3, we analyze our methods in 4 different scenarios. Taking the (b) as an example, given two trajectories gathering, the conventional likelihood-based methods always predict they will meet each other since the negative effect brought by the training bias of the environment interaction, (like “most trajectories getting closer finally meet”). These training biases mislead the predictor to learn the spurious correlations of environments and ignore the real clues from history trajectory, (like “these persons tend to gather and move forward together”). By counterfactual analysis, we remove the negative effect contained in the counterfactual prediction from original predictions. It is effective to overcome the training bias and encourage the model to highlight the real causation itself.
Besides, we also provide the visualization examples with the environment scenes. As shown in Figure 4 and Figure 5, we respectively compare our Causal-STGCNN and Causal-STGAT methods with corresponding baselines in different scenes. We observe that our counterfactual analysis method can effectively captures the real causal relations instead of the biased environment interactions. Taking the second scene in ETH as an example, both CNN-based and RNN-based counterfactual analysis methods significantly outperform the baseline methods and generate more reliable predicted trajectories.
In this paper, we have presented a counterfactual analysis method to investigate the causalities between the observed clues and predicted trajectories. We apply the counterfactual intervention by replacing the features of trajectory clues with the counterfactual one, and subtract this counterfactual prediction from original prediction. By this comparison, we encourage the model to learn the real causations of trajectory and alleviate the negative effects brought by the bias between training and deployment environments. Our counterfactual analysis is a plug-and-play module, which can be employed for off-the-shell trajectory prediction models. In the experiments, we demonstrate the effectiveness of our counterfactual analysis method for different scenes, analyze the effects of different counterfactual implementations, and evaluate the generalization ability for different baseline methods.
Multi-agent tensor fusion for contextual trajectory prediction.In CVPR, pages 12126–12134, 2019.
We conduct experiments for physical environment and evaluate whether our causal model can be used for physical environment. We added a visual feature branch to the original Social-STGCNN method as the baseline. The visual feature is extracted by ResNet34 , and we concatenate it with the node feature and position embedding after an MLP. Then, we applied our causal model to it to mitigate the effect of biased physical and social environment. As shown in Table 4, our method achieves very significant improvement by applying the causal model to the physical and social environment.
The Stanford drone dataset (SDD)  is a well established human trajectory prediction benchmark, consisting of 20 scenes and over 11, 000 unique pedestrians. It provides the scenes in bird’s eye view and the locations of agents in pixel co-ordinates. More than 40, 000 interactions between the agent and scene, and over 185, 000 interactions between agents are captured in the dataset. As used in [26, 12, 39], we use the standard test train split for the experiments on SDD.
For the SDD dataset, we use PECNet  as our baseline and apply our causal model to it. Similar to Causal-STGAT and Causal-STGCNN, we replace the trajectory feature vector from the past trajectory encoder with counterfactual intervention (the zero vector), after which we respectively use the original feature and counterfactual feature for both destination prediction and social pooling. Then, we adopt the difference between two outputs as the causal pooled feature, and finally use it to yield the prediction of our Causal-PECNet.
|Method||CF-VAE ||SimAug ||PECNet ||PECNet||Causal-PECNet|
The in Table 5 means our reproduced results trained with the official released code. In the table, we show the reported performance of PECNet, our reproduced performance of PECNet, and Causal-PECNet (applying our method to PECNet). Our method obtains improvement over the SOTA PECNet baseline on SDD, which demonstrates its effectiveness.