diff --git a/DEST/LICENSE b/DEST/LICENSE
new file mode 100644
index 00000000..20c1a5a1
--- /dev/null
+++ b/DEST/LICENSE
@@ -0,0 +1,62 @@
+NVIDIA Source Code License for DEST
+
+1. Definitions
+
+“Licensor” means any person or entity that distributes its Work.
+
+“Work” means (a) the original work of authorship made available under this license, which may include software,
+documentation, or other files, and (b) any additions to or derivative works thereof that are made available
+under this license.
+
+The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning
+as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works
+shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
+
+Works are “made available” under this license by including in or with the Work either (a) a copyright notice
+referencing the applicability of this license to the Work, or (b) a copy of this license.
+
+2. License Grant
+
+2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual,
+worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly
+display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
+
+3. Limitations
+
+3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b)
+you include a complete copy of this license with your distribution, and (c) you retain without modification any
+copyright, patent, trademark, or attribution notices that are present in the Work.
+
+3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and
+distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use
+limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that
+are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in
+Section 3.1) will continue to apply to the Work itself.
+
+3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use
+non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any
+derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
+
+3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim,
+cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then
+your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
+
+3.5 Trademarks. This license does not grant any rights to use any Licensor's or its affiliates' names, logos,
+or trademarks, except as necessary to reproduce the notices described in this license.
+
+3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant
+in Section 2.1) will terminate immediately.
+
+4. Disclaimer of Warranty.
+
+THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
+WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR
+THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
+
+5. Limitation of Liability.
+
+EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE),
+CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL,
+INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR
+MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
diff --git a/DEST/README.md b/DEST/README.md
new file mode 100644
index 00000000..89db2d89
--- /dev/null
+++ b/DEST/README.md
@@ -0,0 +1,178 @@
+# DEST: Depth Estimation with Simplified Transformer
+
+
+
+
+
+
+***[DEST: Depth Estimation with Simplified Transformer](https://arxiv.org/abs/2204.13791)***
+John Yang, Le An, Anurag Dixit, Jinkyu Koo, Su Inn Park
+CVPR Workshop on [Transformers For Vision](https://sites.google.com/view/t4v-cvpr22), 2022
+
+DEST leverages a simplified design of attention block in the transformer that is GPU friendly. Compared to state-of-the-art methods, our model achieves over 80% reduction in terms of model size and computation, while being more accurate and faster. The proposed model was validated on both depth esitimation and semantic segmentation tasks. This repository contains the official Pytorch model implementation and training configuration which can be adapted to your traing workflow.
+
+
+
+## Monocular Depth Estimation
+For depth estimation, we employ the same setup as that in [PackNet-sfm](https://github.com/TRI-ML/packnet-sfm). For details on environment preparation, data download, and training/evaluation scripts, please refer to the original repo for details.
+
+### Prerequistes
+
+Run the following commands
+
+```bash
+git clone https://github.com/TRI-ML/packnet-sfm.git
+cd packnet-sfm
+
+cp path/to/DEST/configs/train_kitti_dest.yaml configs/
+cp path/to/DEST/models/*_dest.py packnet_sfm/models/
+cp path/to/DEST/networks/DESTNet.py packnet_sfm/networks/depth/
+mkdir packnet_sfm/networks/DEST
+cp path/to/DEST/networks/DEST/*.py packnet_sfm/networks/DEST/
+```
+
+in order to place DEST and its config file within the [PackNet-sfm](https://github.com/TRI-ML/packnet-sfm) implementation as shown below:
+
+```yaml
+packnet-sfm
+ ├ configs
+ │ ...
+ │ └ train_kitti_dest.yaml
+ ├ packnet_sfm
+ │ ...
+ │ ├ models
+ │ │ ...
+ │ │ ├ SfmModel_dest.py
+ │ │ ├ SemiSupModel_dest.py
+ │ │ └ SelfSupModel_dest.py
+ │ ├ networks
+ │ │ ...
+ │ │ ├ depth
+ │ │ │ ...
+ │ │ │ └ DESTNet.py
+ │ │ └ DEST
+ │ │ ├ __init__.py
+ │ │ ├ DEST_EncDec.py
+ │ │ ├ simplified_attention.py
+ │ │ └ simplified_joint_attention.py
+...
+```
+
+### Modifications to make on PackNet repo
+Our work quires ```timm``` library, so please add the following line in `docker/Dockerfile`.
+
+```bash
+RUN pip install timm
+```
+
+Before building the docker image, we also need to adjust the Python version, CUDNN version, NCCL version, etc. in the Dockerfile according to our machine. Note that the minimum supported Python version is 3.7. Base images can be found from [dockerhub](https://hub.docker.com/r/nvidia/cuda/tags?page=1&ordering=last_updated):
+
+After properly configuring Dockerfile, please follow [the instructions](https://github.com/TRI-ML/packnet-sfm#install) to build your docker image.
+
+Also, due to [the issues from the PackNet repository](https://github.com/TRI-ML/packnet-sfm/issues/107) during evalution,
+you need to edit the lines of L295, L302 from the file `packnet-sfm/packnet_sfm/models/model_wrapper.py`.
+
+Change lines
+```
+[L295] depth = inv2depth(inv_depths[0])
+...
+[L301] inv_depth_pp = post_process_inv_depth(
+[L302] inv_depths[0], inv_depths_flipped[0], method='mean')
+```
+to
+```
+[L295] depth = inv2depth(inv_depths)
+...
+[L301] inv_depth_pp = post_process_inv_depth(
+[L302] inv_depths, inv_depths_flipped, method='mean')
+```
+
+
+### Training
+
+To train DEST from scratch on KITTI dataset, run the following command:
+```bash
+python scripts/train.py configs/train_kitti_dest.yaml
+```
+
+### Evaluation
+For the evaluation of DEST model on KITTI dataset, run the following:
+
+```bash
+python scripts/eval.py --checkpoint [--config ]
+```
+
+For inference on a single image or folder:
+You can also directly run inference on a single image or folder:
+
+```bash
+python scripts/infer.py --checkpoint --input --output [--image_shape ]
+```
+
+
+
+## Semantic Segmentation
+For semantic segmentation, our implementation can be readily integrated into [OpenMMLab Semantic Segmentation Toolbox and Benchmark](https://github.com/open-mmlab/mmsegmentation) implementation for training and evaluation.
+
+Please refer to their instruction for [installations](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/get_started.md#installation) and [dataset preparatation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets).
+Our DEST is trained/evaluated on [CityScapes Dataset](https://www.cityscapes-dataset.com/login/).
+
+### Prerequisites
+In order to follow MMSegmentation instructions for training, refer to the files that are located at ```DEST/semseg/``` and
+re-locate the files within the MMSegmentation repository by running the following commands:
+```bash
+git clone https://github.com/open-mmlab/mmsegmentation.git # first clone the MMSegmentation env
+cd mmsegmentation
+mkdir configs/dest/
+
+cp path/to/DEST/semseg/dest_simpatt-b0.py configs/_base_/models/
+cp path/to/DEST/semseg/schedule_160k_adamw.py configs/_base_/schedules/
+cp path/to/DEST/semseg/cityscapes_1024x1024_repeat.py configs/_base_/datasets/
+cp path/to/DEST/semseg/dest_simpatt-*_1024x1024_160k_cityscapes.py configs/dest/
+cp path/to/DEST/semseg/simplified_attention_mmseg.py mmseg/models/backbones/
+cp path/to/DEST/semseg/dest_head.py mmseg/models/decode_heads/
+```
+
+You now need to include DEST in their library
+```bash
+echo 'from .simplified_attention_mmseg import SimplifiedTransformer' >> mmseg/models/backbones/__init__.py
+echo 'from .dest_head import DestHead' >> mmseg/models/decode_heads/__init__.py
+```
+
+Then, you can start training/evaluating with a desired configuration of DEST.
+
+### Training
+Example: train DEST-B1 on CityScapes Dataset:
+
+```bash
+# Single-gpu training
+python tools/train.py configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py
+# Multi-gpu training
+./tools/dist_train.sh configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py
+```
+
+### Evaluation
+After training, you can evaluate the trained model (e.g. DEST-B1)
+
+```bash
+# Single-gpu testing
+python tools/test.py configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py /path/to/checkpoint_file
+# Multi-gpu testing
+./tools/dist_test.sh configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py /path/to/checkpoint_file
+# Multi-gpu, multi-scale testing
+tools/dist_test.sh configs/dest/dest_simpatt-b1_1024x1024_160k_cityscapes.py /path/to/checkpoint_file --aug-test
+```
+
+
+## License
+The provided code can be used for research or other non-commercial purposes. For details please check the [LICENSE](LICENSE) file.
+
+## Citation
+```
+@article{YangDEST,
+ title={Depth Estimation with Simplified Transformer},
+ author={Yang, John and An, Le and Dixit, Anurag and Koo, Jinkyu and Park, Su Inn},
+ journal={arXiv preprint arXiv:2204.13791},
+ year={2022}
+}
+```
diff --git a/DEST/configs/train_kitti_dest.yaml b/DEST/configs/train_kitti_dest.yaml
new file mode 100644
index 00000000..1719fcdc
--- /dev/null
+++ b/DEST/configs/train_kitti_dest.yaml
@@ -0,0 +1,43 @@
+model:
+ name: 'SelfSupModel_dest'
+ optimizer:
+ name: 'Adam'
+ depth:
+ lr: 0.000007
+ pose:
+ lr: 0.00001
+ scheduler:
+ name: 'StepLR'
+ step_size: 10
+ gamma: 0.5
+ depth_net:
+ name: 'DESTNet'
+ version: '1A'
+ pose_net:
+ name: 'PoseNet'
+ params:
+ crop: 'garg'
+ min_depth: 0.0
+ max_depth: 80.0
+datasets:
+ augmentation:
+ image_shape: (192, 640)
+ train:
+ batch_size: 10
+ num_workers: 12
+ dataset: ['KITTI']
+ path: ['data/datasets/KITTI_raw']
+ split: ['data_splits/eigen_zhou_files.txt']
+ depth_type: ['velodyne']
+ repeat: [5]
+ validation:
+ dataset: ['KITTI']
+ path: ['data/datasets/KITTI_raw']
+ split: ['data_splits/eigen_val_files.txt',
+ 'data_splits/eigen_test_files.txt']
+ depth_type: ['velodyne']
+ test:
+ dataset: ['KITTI']
+ path: ['data/datasets/KITTI_raw']
+ split: ['data_splits/eigen_test_files.txt']
+ depth_type: ['velodyne']
diff --git a/DEST/models/SelfSupModel_dest.py b/DEST/models/SelfSupModel_dest.py
new file mode 100644
index 00000000..a05664fa
--- /dev/null
+++ b/DEST/models/SelfSupModel_dest.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+import torch
+from packnet_sfm.models.SfmModel_dest import SfmModel_dest
+from packnet_sfm.losses.multiview_photometric_loss import MultiViewPhotometricLoss
+from packnet_sfm.models.model_utils import merge_outputs
+
+
+class SelfSupModel_dest(SfmModel_dest):
+ """
+ Model that inherits a depth and pose network from SfmModel and
+ includes the photometric loss for self-supervised training.
+
+ Parameters
+ ----------
+ kwargs : dict
+ Extra parameters
+ """
+ def __init__(self, **kwargs):
+ # Initializes SfmModel
+ super().__init__(**kwargs)
+ # Initializes the photometric loss
+ self._photometric_loss = MultiViewPhotometricLoss(**kwargs)
+
+ @property
+ def logs(self):
+ """Return logs."""
+ return {
+ **super().logs,
+ **self._photometric_loss.logs
+ }
+
+
+ def self_supervised_loss(self, image, ref_images, inv_depths, poses,
+ intrinsics, return_logs=False, progress=0.0):
+ """
+ Calculates the self-supervised photometric loss.
+
+ Parameters
+ ----------
+ image : torch.Tensor [B,3,H,W]
+ Original image
+ ref_images : list of torch.Tensor [B,3,H,W]
+ Reference images from context
+ inv_depths : torch.Tensor [B,1,H,W]
+ Predicted inverse depth maps from the original image
+ poses : list of Pose
+ List containing predicted poses between original and context images
+ intrinsics : torch.Tensor [B,3,3]
+ Camera intrinsics
+ return_logs : bool
+ True if logs are stored
+ progress :
+ Training progress percentage
+
+ Returns
+ -------
+ output : dict
+ Dictionary containing a "loss" scalar a "metrics" dictionary
+ """
+ return self._photometric_loss(
+ image, ref_images, inv_depths, intrinsics, intrinsics, poses,
+ return_logs=return_logs, progress=progress)
+
+
+ def forward(self, batch, return_logs=False, progress=0.0):
+ """
+ Processes a batch.
+
+ Parameters
+ ----------
+ batch : dict
+ Input batch
+ return_logs : bool
+ True if logs are stored
+ progress :
+ Training progress percentage
+
+ Returns
+ -------
+ output : dict
+ Dictionary containing a "loss" scalar and different metrics and predictions
+ for logging and downstream usage.
+ """
+ # Calculate predicted depth and pose output
+ output = super().forward(batch, return_logs=return_logs)
+ if not self.training:
+ # If not training, no need for self-supervised loss
+ return output
+ else:
+ # Otherwise, calculate self-supervised loss
+ self_sup_output = self.self_supervised_loss(
+ batch['rgb_original'], batch['rgb_context_original'],
+ output['inv_depths'], output['poses'], batch['intrinsics'],
+ return_logs=return_logs, progress=progress)
+ # Return loss and metrics
+ return {
+ 'loss': self_sup_output['loss'],
+ **merge_outputs(output, self_sup_output),
+ }
diff --git a/DEST/models/SemiSupModel_dest.py b/DEST/models/SemiSupModel_dest.py
new file mode 100644
index 00000000..5dfb092a
--- /dev/null
+++ b/DEST/models/SemiSupModel_dest.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+import torch
+
+from packnet_sfm.models.SelfSupModel_dest import SfmModel_dest, SelfSupModel_dest
+from packnet_sfm.losses.supervised_loss import SupervisedLoss
+from packnet_sfm.models.model_utils import merge_outputs
+from packnet_sfm.utils.depth import depth2inv
+
+
+class SemiSupModel_dest(SelfSupModel_dest):
+ """
+ Model that inherits a depth and pose networks, plus the self-supervised loss from
+ SelfSupModel and includes a supervised loss for semi-supervision.
+
+ Parameters
+ ----------
+ supervised_loss_weight : float
+ Weight for the supervised loss
+ kwargs : dict
+ Extra parameters
+ """
+ def __init__(self, supervised_loss_weight=0.9, **kwargs):
+ # Initializes SelfSupModel
+ super().__init__(**kwargs)
+ # If supervision weight is 0.0, use SelfSupModel directly
+ assert 0. < supervised_loss_weight <= 1., "Model requires (0, 1] supervision"
+ # Store weight and initializes supervised loss
+ self.supervised_loss_weight = supervised_loss_weight
+ self._supervised_loss = SupervisedLoss(**kwargs)
+
+ # Pose network is only required if there is self-supervision
+ if self.supervised_loss_weight == 1:
+ self._network_requirements.remove('pose_net')
+ # GT depth is only required if there is supervision
+ if self.supervised_loss_weight > 0:
+ self._train_requirements.append('gt_depth')
+
+ @property
+ def logs(self):
+ """Return logs."""
+ return {
+ **super().logs,
+ **self._supervised_loss.logs
+ }
+
+ def supervised_loss(self, inv_depths, gt_inv_depths,
+ return_logs=False, progress=0.0):
+ """
+ Calculates the supervised loss.
+
+ Parameters
+ ----------
+ inv_depths : torch.Tensor [B,1,H,W]
+ Predicted inverse depth maps from the original image
+ gt_inv_depths : torch.Tensor [B,1,H,W]
+ Ground-truth inverse depth maps from the original image
+ return_logs : bool
+ True if logs are stored
+ progress :
+ Training progress percentage
+
+ Returns
+ -------
+ output : dict
+ Dictionary containing a "loss" scalar a "metrics" dictionary
+ """
+ return self._supervised_loss(
+ inv_depths, gt_inv_depths,
+ return_logs=return_logs, progress=progress)
+
+ def forward(self, batch, return_logs=False, progress=0.0):
+ """
+ Processes a batch.
+
+ Parameters
+ ----------
+ batch : dict
+ Input batch
+ return_logs : bool
+ True if logs are stored
+ progress :
+ Training progress percentage
+
+ Returns
+ -------
+ output : dict
+ Dictionary containing a "loss" scalar and different metrics and predictions
+ for logging and downstream usage.
+ """
+ if not self.training:
+ # If not training, no need for self-supervised loss
+ return SfmModel_deptr.forward(self, batch)
+ else:
+ if self.supervised_loss_weight == 1.:
+ # If no self-supervision, no need to calculate loss
+ self_sup_output = SfmModel_deptr.forward(self, batch)
+ loss = torch.tensor([0.]).type_as(batch['rgb'])
+ else:
+ # Otherwise, calculate and weight self-supervised loss
+ self_sup_output = SelfSupModel_deptr.forward(self, batch)
+ loss = (1.0 - self.supervised_loss_weight) * self_sup_output['loss']
+ # Calculate and weight supervised loss
+ sup_output = self.supervised_loss(
+ self_sup_output['inv_depths'], depth2inv(batch['depth']),
+ return_logs=return_logs, progress=progress)
+ loss += self.supervised_loss_weight * sup_output['loss']
+ # Merge and return outputs
+ return {
+ 'loss': loss,
+ **merge_outputs(self_sup_output, sup_output),
+ }
diff --git a/DEST/models/SfmModel_dest.py b/DEST/models/SfmModel_dest.py
new file mode 100644
index 00000000..380f90bd
--- /dev/null
+++ b/DEST/models/SfmModel_dest.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+import random
+import torch.nn as nn
+from packnet_sfm.geometry.pose import Pose
+from packnet_sfm.models.base_model import BaseModel
+from packnet_sfm.models.model_utils import flip_batch_input, flip_output, upsample_output
+from packnet_sfm.utils.misc import filter_dict
+
+
+class SfmModel_dest(BaseModel):
+ """
+ Model class encapsulating a pose and depth networks.
+
+ Parameters
+ ----------
+ depth_net : nn.Module
+ Depth network to be used
+ pose_net : nn.Module
+ Pose network to be used
+ rotation_mode : str
+ Rotation mode for the pose network
+ flip_lr_prob : float
+ Probability of flipping when using the depth network
+ upsample_depth_maps : bool
+ True if depth map scales are upsampled to highest resolution
+ kwargs : dict
+ Extra parameters
+ """
+ def __init__(self, depth_net=None, pose_net=None,
+ rotation_mode='euler', flip_lr_prob=0.0,
+ upsample_depth_maps=False, **kwargs):
+ super().__init__()
+ self.depth_net = depth_net
+ self.pose_net = pose_net
+ self.rotation_mode = rotation_mode
+ self.flip_lr_prob = flip_lr_prob
+ self.upsample_depth_maps = upsample_depth_maps
+ self.mse_loss = nn.MSELoss(reduction='mean')
+ self._network_requirements = [
+ 'depth_net',
+ 'pose_net',
+ ]
+
+ def add_depth_net(self, depth_net):
+ """Add a depth network to the model"""
+ self.depth_net = depth_net
+
+ def add_pose_net(self, pose_net):
+ """Add a pose network to the model"""
+ self.pose_net = pose_net
+
+ def depth_net_flipping(self, batch, flip):
+ """
+ Runs depth net with the option of flipping
+
+ Parameters
+ ----------
+ batch : dict
+ Input batch
+ flip : bool
+ True if the flip is happening
+
+ Returns
+ -------
+ output : dict
+ Dictionary with depth network output (e.g. 'inv_depths' and 'uncertainty')
+ """
+ # Which keys are being passed to the depth network
+ batch_input = {key: batch[key] for key in filter_dict(batch, self._input_keys)}
+ if flip:
+ # Run depth network with flipped inputs
+ output = self.depth_net(**flip_batch_input(batch_input))
+ # Flip output back if training
+ output = flip_output(output)
+ else:
+ # Run depth network
+ output = self.depth_net(**batch_input)
+ return output
+
+ def compute_depth_net(self, batch, force_flip=False):
+ """Computes inverse depth maps from single images"""
+ # Randomly flip and estimate inverse depth maps
+ flag_flip_lr = random.random() < self.flip_lr_prob if self.training else force_flip
+ output = self.depth_net_flipping(batch, flag_flip_lr)
+ # If upsampling depth maps at training time
+ if self.training and self.upsample_depth_maps:
+ output = upsample_output(output, mode='nearest', align_corners=None)
+ # Return inverse depth maps
+ return output
+
+ def compute_pose_net(self, image, contexts):
+ """Compute poses from image and a sequence of context images"""
+ pose_vec = self.depth_net.pose(image, contexts)
+ # print('pose_vec.shape', pose_vec.size())
+ return [Pose.from_vec(pose_vec[:, i], self.rotation_mode)
+ for i in range(pose_vec.shape[1])]
+
+
+
+ def forward(self, batch, return_logs=False, force_flip=False):
+ """
+ Processes a batch.
+
+ Parameters
+ ----------
+ batch : dict
+ Input batch
+ return_logs : bool
+ True if logs are stored
+ force_flip : bool
+ If true, force batch flipping for inverse depth calculation
+
+ Returns
+ -------
+ output : dict
+ Dictionary containing the output of depth and pose networks
+ """
+ # Generate inverse depth predictions
+ depth_output = self.compute_depth_net(batch, force_flip=force_flip)
+
+ # Generate pose predictions if available
+ pose_output = None
+ if 'rgb_context' in batch :
+ pose_output = self.compute_pose_net(batch['rgb'], batch['rgb_context'])
+ # Return output dictionary
+ return {
+ **depth_output,
+ 'poses': pose_output,
+ }
diff --git a/DEST/models/__init__.py b/DEST/models/__init__.py
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/DEST/models/__init__.py
@@ -0,0 +1 @@
+
diff --git a/DEST/networks/DEST/DEST_EncDec.py b/DEST/networks/DEST/DEST_EncDec.py
new file mode 100644
index 00000000..3bcc954b
--- /dev/null
+++ b/DEST/networks/DEST/DEST_EncDec.py
@@ -0,0 +1,203 @@
+# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+from functools import partial
+import torch
+from torch import nn
+
+from packnet_sfm.networks.DEST.simplified_attention import SimplifiedTransformer as SimpTR
+from packnet_sfm.networks.DEST.simplified_joint_attention import SimplifiedJointTransformer as SimpTR_Joint
+
+
+def exists(val):
+ return val is not None
+
+def cast_tuple(val, depth):
+ return val if isinstance(val, tuple) else (val,) * depth
+
+
+class DEST_Encoder_Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ img_size=(192, 640),
+ dims=(32, 64, 160, 256),
+ heads=(1, 2, 4, 8),
+ ff_expansion=(8, 8, 4, 4),
+ reduction_ratio=(8, 4, 2, 1),
+ num_layers=(2, 2, 2, 2),
+ channels=3,
+ decoder_dim=128,
+ num_classes=64,
+ semseg=False
+ ):
+ super().__init__()
+ dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
+ assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
+
+ self.dest_encoder = SimpTR(
+ img_size=img_size, in_chans=channels, num_classes=num_classes,
+ embed_dims=dims, num_heads=heads, mlp_ratios=ff_expansion, qkv_bias=True, qk_scale=None, drop_rate=0,
+ drop_path_rate=0.1, attn_drop_rate=0., norm_layer=nn.LayerNorm, depths=num_layers, sr_ratios=reduction_ratio)
+
+ self.dims = dims
+ self.fuse_conv1 = nn.Sequential(nn.Conv2d(dims[-1], dims[-1], 1), nn.ReLU(inplace=True))
+ self.fuse_conv2 = nn.Sequential(nn.Conv2d(dims[-2], dims[-2], 1), nn.ReLU(inplace=True))
+ self.fuse_conv3 = nn.Sequential(nn.Conv2d(dims[-3], dims[-3], 1), nn.ReLU(inplace=True))
+ self.fuse_conv4 = nn.Sequential(nn.Conv2d(dims[-4], dims[-4], 1), nn.ReLU(inplace=True))
+
+ self.upsample = nn.ModuleList([nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'))]*len(dims))
+
+ self.fused_1 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(dims[-1], dims[-1], 3), nn.ReLU(inplace=True))
+ self.fused_2 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(dims[-2] + dims[-1], dims[-2], 3), nn.ReLU(inplace=True))
+ self.fused_3 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(dims[-3] + dims[-2], dims[-3], 3), nn.ReLU(inplace=True))
+ self.fused_4 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(dims[-4] + dims[-3], dims[-4], 3), nn.ReLU(inplace=True))
+ self.fused_5 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
+ nn.Conv2d(dims[-4], dims[-4], 1),
+ nn.ReLU(True))
+ self.semseg = semseg
+
+ def dest_decoder(self, lay_out):
+ fused_1 = self.fuse_conv1(lay_out[-1])
+ fused_1 = self.upsample[-1](fused_1)
+ fused_1 = self.fused_1(fused_1)
+ fused_2 = torch.cat([fused_1, self.fuse_conv2(lay_out[-2])], 1)
+
+ fused_2 = self.upsample[-2](fused_2)
+ fused_2 = self.fused_2(fused_2)
+ fused_3 = torch.cat([fused_2, self.fuse_conv3(lay_out[-3])], 1)
+
+ fused_3 = self.upsample[-3](fused_3)
+ fused_3 = self.fused_3(fused_3)
+ fused_4 = torch.cat([fused_3, self.fuse_conv4(lay_out[-4])], 1)
+
+ fused_4 = self.upsample[-4](fused_4)
+ fused_4 = self.fused_4(fused_4)
+
+ if self.semseg:
+ return fused_4
+
+ fused_5 = self.fused_5(fused_4)
+ return fused_5, fused_4, fused_3, fused_2
+
+ def forward(self, x):
+ layer_outputs, ref_feat = self.dest_encoder(x)
+
+ out = self.dest_decoder(layer_outputs)
+
+ return out, layer_outputs, ref_feat
+
+def DEST_Pose(
+ img_size=(192, 640),
+ dims = (32, 64, 160, 256),
+ heads = (1, 2, 5, 8),
+ ff_expansion = (8, 8, 8, 8),
+ reduction_ratio = (8, 4, 2, 1),
+ num_layers = (2, 2, 2, 2),
+ channels=3,
+ num_classes=512,
+ connectivity=True):
+
+ dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
+ assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
+
+
+ if connectivity :
+ model = SimpTR_Joint(
+ img_size=img_size, in_chans=channels, num_classes=num_classes,
+ embed_dims=dims, num_heads=heads, mlp_ratios=ff_expansion, qkv_bias=True, qk_scale=None, drop_rate=0.,
+ drop_path_rate=0.1, attn_drop_rate= 0., norm_layer=nn.LayerNorm, depths=num_layers, sr_ratios=reduction_ratio)
+ else:
+ model = SimpTR(
+ img_size=img_size, in_chans=channels, num_classes=num_classes,
+ embed_dims=dims, num_heads=heads, mlp_ratios=ff_expansion, qkv_bias=True, qk_scale=None, drop_rate=0.,
+ drop_path_rate=0.1, attn_drop_rate= 0., norm_layer=nn.LayerNorm, depths=num_layers, sr_ratios=reduction_ratio)
+
+ return num_classes, model
+
+
+
+def SimpleTR_B0(img_size=(192, 640), num_out_ch=64, semseg=False):
+ model = DEST_Encoder_Decoder(
+ img_size=img_size,
+ dims=(32, 64, 160, 256),
+ heads=(1, 2, 5, 8),
+ ff_expansion=(8, 8, 4, 4),
+ reduction_ratio=(8, 4, 2, 1),
+ num_layers=(2, 2, 2, 2),
+ channels=3,
+ decoder_dim=256,
+ num_classes=num_out_ch,
+ semseg=semseg)
+ return num_out_ch, model
+
+def SimpleTR_B1(img_size=(192, 640), num_out_ch=256, semseg=False):
+ model = DEST_Encoder_Decoder(
+ img_size=img_size,
+ dims=(64, 128, 250, 320),
+ heads=(1, 2, 5, 8),
+ ff_expansion=(8, 8, 4, 4),
+ reduction_ratio=(8, 4, 2, 1),
+ num_layers=(2, 2, 2, 2),
+ channels=3,
+ decoder_dim=num_out_ch,
+ num_classes=num_out_ch,
+ semseg=semseg)
+ return num_out_ch, model
+
+def SimpleTR_B2(img_size=(192, 640), num_out_ch=256, semseg=False):
+ model = DEST_Encoder_Decoder(
+ img_size=img_size,
+ dims=(64, 128, 250, 320),
+ heads=(1, 2, 5, 8),
+ ff_expansion=(8, 8, 4, 4),
+ reduction_ratio=(8, 4, 2, 1),
+ num_layers=(3, 3, 6, 3),
+ channels=3,
+ decoder_dim=num_out_ch,
+ num_classes=num_out_ch,
+ semseg=semseg)
+ return num_out_ch, model
+
+
+def SimpleTR_B3(img_size=(192, 640), num_out_ch=256, semseg=False):
+ model = DEST_Encoder_Decoder(
+ img_size=img_size,
+ dims=(64, 128, 250, 320),
+ heads=(1, 2, 5, 8),
+ ff_expansion=(8, 8, 4, 4),
+ reduction_ratio=(8, 4, 2, 1),
+ num_layers=(3, 6, 8, 3),
+ channels=3,
+ decoder_dim=512,
+ num_classes=256,
+ semseg=semseg)
+ return num_out_ch, model
+
+def SimpleTR_B4(img_size=(192, 640), num_out_ch=512, semseg=False):
+ model = DEST_Encoder_Decoder(
+ img_size=img_size,
+ dims=(64, 128, 250, 320),
+ heads=(1, 2, 5, 8),
+ ff_expansion=(8, 8, 4, 4),
+ reduction_ratio=(8, 4, 2, 1),
+ num_layers=(3, 8, 12, 5),
+ channels=3,
+ decoder_dim=num_out_ch,
+ num_classes=num_out_ch,
+ semseg=semseg)
+ return num_out_ch, model
+
+def SimpleTR_B5(img_size=(192, 640), num_out_ch=512, semseg=False):
+ model = DEST_Encoder_Decoder(
+ img_size=img_size,
+ dims=(64, 128, 250, 320),
+ heads=(1, 2, 5, 8),
+ ff_expansion=(8, 8, 4, 4),
+ reduction_ratio=(8, 4, 2, 1),
+ num_layers=(3, 10, 16, 5),
+ channels=3,
+ decoder_dim=num_out_ch,
+ num_classes=num_out_ch,
+ semseg=semseg)
+ return num_out_ch, model
+
diff --git a/DEST/networks/DEST/__init__.py b/DEST/networks/DEST/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/DEST/networks/DEST/simplified_attention.py b/DEST/networks/DEST/simplified_attention.py
new file mode 100644
index 00000000..f02d9e6c
--- /dev/null
+++ b/DEST/networks/DEST/simplified_attention.py
@@ -0,0 +1,321 @@
+# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+import math
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = nn.Conv1d(in_features, hidden_features, 1)
+ self.dwconv = DWConv(hidden_features)
+ self.act = nn.ReLU()
+ self.fc2 = nn.Conv1d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ self.norm1 = nn.BatchNorm1d(hidden_features)
+ self.norm2 = nn.BatchNorm1d(hidden_features)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Conv1d):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ x = self.norm1(x)
+ x = self.dwconv(x, H, W)
+ x = self.norm2(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention_MaxPool(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q = nn.Conv1d(dim, dim, 1, bias=qkv_bias)
+ self.k = nn.Conv1d(dim, dim, 1, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Conv1d(dim, dim, 1)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.BatchNorm1d(dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Conv1d):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, H, W):
+ B, C, N = x.shape
+ q = self.q(x)
+ q = q.reshape(B, self.num_heads, C // self.num_heads, N)
+ q = q.permute(0, 1, 3, 2)
+
+ if self.sr_ratio > 1:
+ x_ = x.reshape(B, C, H, W)
+ x_ = self.sr(x_).reshape(B, C, -1)
+ x_ = self.norm(x_)
+ k = self.k(x_).reshape(B, self.num_heads, C // self.num_heads, -1)
+ else:
+ k = self.k(x).reshape(B, self.num_heads, C // self.num_heads, -1)
+
+ v = torch.mean(x, 2, True).repeat(1, 1, self.num_heads).transpose(-2, -1)
+
+ attn = (q @ k) * self.scale
+ attn, _ = torch.max(attn, -1)
+
+ out = (attn.transpose(-2, -1) @ v)
+ out = out.transpose(-2, -1)
+ out = self.proj(out)
+ return out
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm, sr_ratio=1):
+ super().__init__()
+ self.norm1 = nn.BatchNorm1d(dim)
+ self.norm2 = nn.BatchNorm1d(dim)
+
+ self.attn = Attention_MaxPool(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+ return x
+
+
+class OverlapPatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=(224,224), patch_size=7, stride=4, in_chans=3, embed_dim=768):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
+
+ self.norm = nn.BatchNorm2d(embed_dim)
+
+ self.H = (img_size[0] - patch_size[0] + 2 * (patch_size[0] // 2)) / stride + 1
+ self.W = (img_size[1] - patch_size[1] + 2 * (patch_size[1] // 2)) / stride + 1
+ self.feat_shape = (int(self.H), int(self.W))
+ self.N = int(self.feat_shape[0] * self.feat_shape[1])
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = self.norm(x)
+ x = x.flatten(2)
+ return x, H, W
+
+
+class SimplifiedTransformer(nn.Module):
+ def __init__(self, img_size=(224,224), patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+ self.embed_dims = embed_dims
+ self.sr_ratios = sr_ratios
+ self.num_layers = depths
+
+ # patch_embed
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
+ embed_dim=embed_dims[0])
+ self.patch_embed2 = OverlapPatchEmbed(img_size=(img_size[0] // 4, img_size[1] // 4), patch_size=3, stride=2, in_chans=embed_dims[0],
+ embed_dim=embed_dims[1])
+ self.patch_embed3 = OverlapPatchEmbed(img_size=(img_size[0] // 8, img_size[1] // 8), patch_size=3, stride=2, in_chans=embed_dims[1],
+ embed_dim=embed_dims[2])
+ self.patch_embed4 = OverlapPatchEmbed(img_size=(img_size[0] // 16, img_size[1] // 16), patch_size=3, stride=2, in_chans=embed_dims[2],
+ embed_dim=embed_dims[3])
+
+ # encoder
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ cur = 0
+ self.block1 = nn.ModuleList([Block(
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[0])
+ for i in range(depths[0])])
+
+ cur += depths[0]
+ self.block2 = nn.ModuleList([Block(
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[1])
+ for i in range(depths[1])])
+
+ cur += depths[1]
+ self.block3 = nn.ModuleList([Block(
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[2])
+ for i in range(depths[2])])
+
+ cur += depths[2]
+ self.block4 = nn.ModuleList([Block(
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[3])
+ for i in range(depths[3])])
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+
+ B = x.shape[0]
+ outs = []
+ ref_feat = {'1': [], '2': [], '3': [], '4': [],}
+
+ # stage 1
+ x, H, W = self.patch_embed1(x)
+ for i, blk in enumerate(self.block1):
+ x = blk(x, H, W)
+ ref_feat['1'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ # stage 2
+ x, H, W = self.patch_embed2(x)
+ for i, blk in enumerate(self.block2):
+ x = blk(x, H, W)
+ ref_feat['2'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ # stage 3
+ x, H, W = self.patch_embed3(x)
+ for i, blk in enumerate(self.block3):
+ x = blk(x, H, W)
+ ref_feat['3'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ # stage 4
+ x, H, W = self.patch_embed4(x)
+ for i, blk in enumerate(self.block4):
+ x = blk(x, H, W)
+ ref_feat['4'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ return outs, ref_feat
+
+ def forward(self, x):
+ x, ref_feat = self.forward_features(x)
+ return x, ref_feat
+
+
+class DWConv(nn.Module):
+ def __init__(self, dim=768):
+ super(DWConv, self).__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x, H, W):
+ B, C, N = x.shape
+ x = x.reshape(B, C, H, W)
+ x = self.dwconv(x)
+ x = x.flatten(2)
+ return x
+
diff --git a/DEST/networks/DEST/simplified_joint_attention.py b/DEST/networks/DEST/simplified_joint_attention.py
new file mode 100644
index 00000000..680c2b18
--- /dev/null
+++ b/DEST/networks/DEST/simplified_joint_attention.py
@@ -0,0 +1,267 @@
+# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath, trunc_normal_
+
+import math
+
+from packnet_sfm.networks.DEST.simplified_attention import OverlapPatchEmbed, Mlp, Attention_MaxPool
+
+
+class Attention_Joint_MaxPool(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q = nn.Conv1d(dim, dim, 1, bias=qkv_bias)
+ self.k = nn.Conv1d(dim, dim, 1, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Conv1d(dim, dim, 1)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.norm2 = nn.BatchNorm1d(self.dim)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.BatchNorm1d(dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Conv1d):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, y, H, W):
+ B, C, N = x.shape
+
+ q = self.q(x)
+ q = q.reshape(B, self.num_heads, C // self.num_heads, N)
+ q = q.permute(0, 1, 3, 2)
+
+ if self.sr_ratio > 1:
+ x_ = x.reshape(B, C, H, W)
+ x_ = self.sr(x_).reshape(B, C, -1)
+ x_ = self.norm(x_)
+ k = self.k(x_).reshape(B, self.num_heads, C // self.num_heads, -1)
+ else:
+ k = self.k(x).reshape(B, self.num_heads, C // self.num_heads, -1)
+
+ v = torch.mean(x, 2, True).repeat(1, 1, self.num_heads).transpose(-2, -1)
+
+
+ attn = (q @ k) * self.scale
+
+ attn, _ = torch.max(attn, -1)
+
+ out = (attn.transpose(-2, -1) @ v)
+ out = out.transpose(-2, -1)
+
+ out = self.proj(out)
+
+ return out
+
+
+class JointBlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm, sr_ratio=1):
+ super().__init__()
+ self.norm0 = nn.BatchNorm1d(dim)
+ self.norm1_ref = nn.BatchNorm1d(dim)
+ self.norm1_src = nn.BatchNorm1d(dim)
+ self.norm2 = nn.BatchNorm1d(dim)
+
+ self.attn_joint = Attention_Joint_MaxPool(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+
+ def forward(self, ref_feat, src_feat, H, W):
+ src_feat = src_feat + self.drop_path(self.attn_joint(self.norm1_ref(ref_feat), self.norm1_src(src_feat), H, W))
+ src_feat = src_feat + self.drop_path(self.mlp(self.norm2(src_feat), H, W))
+ return src_feat
+
+
+
+class SimplifiedJointTransformer(nn.Module):
+ def __init__(self, img_size=(224,224), patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+ self.embed_dims = embed_dims
+ self.sr_ratios = sr_ratios
+
+ # patch_embed
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
+ embed_dim=embed_dims[0])
+ self.patch_embed2 = OverlapPatchEmbed(img_size=(img_size[0] // 4, img_size[1] // 4), patch_size=3, stride=2, in_chans=embed_dims[0],
+ embed_dim=embed_dims[1])
+ self.patch_embed3 = OverlapPatchEmbed(img_size=(img_size[0] // 8, img_size[1] // 8), patch_size=3, stride=2, in_chans=embed_dims[1],
+ embed_dim=embed_dims[2])
+ self.patch_embed4 = OverlapPatchEmbed(img_size=(img_size[0] // 16, img_size[1] // 16), patch_size=3, stride=2, in_chans=embed_dims[2],
+ embed_dim=embed_dims[3])
+
+ # encoder
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ cur = 0
+ self.block1 = nn.ModuleList([JointBlock(
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[0])
+ for i in range(depths[0])])
+ self.norm1 = nn.BatchNorm1d(self.patch_embed1.N)
+
+
+ cur += depths[0]
+ self.block2 = nn.ModuleList([JointBlock(
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[1])
+ for i in range(depths[1])])
+ self.norm2 = nn.BatchNorm1d(self.patch_embed2.N)
+
+ cur += depths[1]
+ self.block3 = nn.ModuleList([JointBlock(
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[2])
+ for i in range(depths[2])])
+ self.norm3 = nn.BatchNorm1d(self.patch_embed3.N)
+
+ cur += depths[2]
+ self.block4 = nn.ModuleList([JointBlock(
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[3])
+ for i in range(depths[3])])
+ self.norm4 = nn.BatchNorm1d(self.patch_embed4.N)
+
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+
+ def reset_drop_path(self, drop_path_rate):
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
+ cur = 0
+ for i in range(self.depths[0]):
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[0]
+ for i in range(self.depths[1]):
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[1]
+ for i in range(self.depths[2]):
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[2]
+ for i in range(self.depths[3]):
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
+
+ def freeze_patch_emb(self):
+ self.patch_embed1.requires_grad = False
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, ref_feat, x):
+ B = x.shape[0]
+
+ # stage 1
+ x, H, W = self.patch_embed1(x)
+ for i, blk in enumerate(self.block1):
+ if i > len(ref_feat['1']) - 1 : i = -1
+ x = blk(ref_feat['1'][i], x, H, W)
+ x = self.norm1(x.transpose(-2, -1)).transpose(-2, -1)
+ x = x.reshape(B, -1, H, W).contiguous()
+
+ # stage 2
+ x, H, W = self.patch_embed2(x)
+ for i, blk in enumerate(self.block2):
+ if i > len(ref_feat['2']) -1: i=-1
+ x = blk(ref_feat['2'][i], x, H, W)
+ x = self.norm2(x.transpose(-2, -1)).transpose(-2, -1)
+ x = x.reshape(B, -1, H, W).contiguous()
+
+ # stage 3
+ x, H, W = self.patch_embed3(x)
+ for i, blk in enumerate(self.block3):
+ if i > len(ref_feat['3']) -1: i = -1
+ x = blk(ref_feat['3'][i], x, H, W)
+ x = self.norm3(x.transpose(-2, -1)).transpose(-2, -1)
+ x = x.reshape(B, -1, H, W).contiguous()
+
+ # stage 4
+ x, H, W = self.patch_embed4(x)
+ for i, blk in enumerate(self.block4):
+ if i > len(ref_feat['4'])-1: i = -1
+ x = blk(ref_feat['4'][i], x, H, W)
+ x = self.norm4(x.transpose(-2, -1)).transpose(-2, -1)
+ x = x.reshape(B, -1, H, W).contiguous()
+ return x
+
+ def forward(self, ref_feat, x):
+ return self.forward_features(ref_feat, x)
+
diff --git a/DEST/networks/DESTNet.py b/DEST/networks/DESTNet.py
new file mode 100644
index 00000000..71f4b603
--- /dev/null
+++ b/DEST/networks/DESTNet.py
@@ -0,0 +1,138 @@
+import math
+import torch
+import torch.nn as nn
+from packnet_sfm.networks.DEST.DEST_EncDec import DEST_Pose, SimpleTR_B0, SimpleTR_B1, SimpleTR_B2, SimpleTR_B3, SimpleTR_B4, SimpleTR_B5
+
+
+class InvDepth(nn.Module):
+ """Inverse depth layer"""
+ def __init__(self, in_channels, out_channels=1, min_depth=0.5):
+ """
+ Initializes an InvDepth object.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ out_channels : int
+ Number of output channels
+ min_depth : float
+ Minimum depth value to calculate
+ """
+ super().__init__()
+ self.min_depth = min_depth
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1)
+ self.pad = nn.ConstantPad2d([1] * 4, value=0)
+ self.activ = nn.Sigmoid()
+ self.scale = torch.tensor([self.min_depth])
+ def forward(self, x):
+ """Runs the InvDepth layer."""
+ x = self.conv1(self.pad(x))
+ return self.activ(x) / self.min_depth
+
+
+class DESTNet(nn.Module):
+ def __init__(self, model='B3', nb_ref_imgs=2, img_size=(192, 640)):
+ """
+ Defines the size of DEST model
+
+ Parameters
+ ----------
+ model : string
+ The size of DEST can be selected: 'B0' | 'B1' | 'B2' | 'B3' | 'B4' | 'B5'
+ nb_ref_imgs : int
+ Number of reference images for Pose-Net
+ img_size : tuple
+ Input image size (H, W)
+ """
+ super().__init__()
+ self.nb_ref_imgs = nb_ref_imgs
+ self.connectivity = True
+
+ if model == 'B0':
+ self.num_out_ch, self.dest = SimpleTR_B0(img_size=img_size)
+ elif model == 'B1':
+ self.num_out_ch, self.dest = SimpleTR_B1(img_size=img_size)
+ elif model == 'B2':
+ self.num_out_ch, self.dest = SimpleTR_B2(img_size=img_size)
+ elif model == 'B3':
+ self.num_out_ch, self.dest = SimpleTR_B3(img_size=img_size)
+ elif model == 'B4':
+ self.num_out_ch, self.dest = SimpleTR_B4(img_size=img_size)
+ elif model == 'B5':
+ self.num_out_ch, self.dest = SimpleTR_B5(img_size=img_size)
+
+ self.disp1_layer = InvDepth(self.dest.dims[-4])
+ self.disp2_layer = InvDepth(self.dest.dims[-4])
+ self.disp3_layer = InvDepth(self.dest.dims[-3])
+ self.disp4_layer = InvDepth(self.dest.dims[-2])
+
+ num_out_ch, self.dest_pose = DEST_Pose(dims=self.dest.dest_encoder.embed_dims, channels=16,
+ num_layers=self.dest.dest_encoder.depths,
+ reduction_ratio=self.dest.dest_encoder.sr_ratios,
+ connectivity=self.connectivity)
+
+ self.pose_pred = nn.Sequential(nn.Conv2d(self.dest.dest_encoder.embed_dims[3], 6 * self.nb_ref_imgs, kernel_size=1, padding=0))
+ self.channel_reduction_pose = nn.Sequential(nn.Conv2d(9, 16, kernel_size=3, padding=0),
+ nn.BatchNorm2d(16),
+ nn.Tanh())
+
+ def measure_Complexity(self, input_size=(3, 192, 640), mode='Depth'):
+ input_shape = input_size
+
+ if mode == 'Depth':
+ model = Dummy_net_depth(self.dest, self.disp1_layer).eval()
+ macs, params = get_model_complexity_info(model.cpu(),
+ input_shape, as_strings=True,
+ print_per_layer_stat=False, verbose=False)
+ print('{:<30} {:<8}'.format('%sNet Computational complexity: ' % mode, macs))
+ print('{:<30} {:<8}'.format('%sNet Number of parameters: ' % mode, params))
+
+
+ def forward(self, rgb):
+ """
+ Runs the network and returns inverse depth maps
+ (4 scales if training and 1 if not).
+ """
+ out, _, self.ref_feat = self.dest(rgb)
+
+ x = self.disp1_layer(out[0])
+ if self.training:
+ x2 = self.disp2_layer(out[1])
+ x3 = self.disp3_layer(out[2])
+ x4 = self.disp4_layer(out[3])
+
+ if self.training:
+ return {'inv_depths': [x, x2, x3, x4]}
+ else:
+ return {'inv_depths': x }
+
+ def pose(self, image, context):
+ assert (len(context) == self.nb_ref_imgs)
+ input_ = [image]
+ input_.extend(context)
+ input_ = torch.cat(input_, 1)
+
+ return self._poseNet(input_)
+
+ def _poseNet(self, x_src):
+ x_src = self.channel_reduction_pose(x_src)
+
+ if self.connectivity:
+ x = self.dest_pose(self.ref_feat, x_src)
+ else :
+ x = self.dest_pose(x_src)
+
+ pose = self.pose_pred(x)
+ pose = pose.mean(3).mean(2)
+ pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6)
+ return pose
+
+
+ def reshape(self, ref_featss, b_):
+ for i, e_dim in enumerate(self.dest.dest_encoder.embed_dims):
+ ref_featss[i] = ref_featss[i].reshape(b_, -1, e_dim).repeat((1+self.nb_ref_imgs), 1, 1)
+ return ref_featss
+
+
+
diff --git a/DEST/networks/__init__.py b/DEST/networks/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/DEST/resources/attentions.png b/DEST/resources/attentions.png
new file mode 100644
index 00000000..5dd8814b
Binary files /dev/null and b/DEST/resources/attentions.png differ
diff --git a/DEST/semseg/cityscapes_1024x1024_repeat.py b/DEST/semseg/cityscapes_1024x1024_repeat.py
new file mode 100644
index 00000000..2f5ec4f6
--- /dev/null
+++ b/DEST/semseg/cityscapes_1024x1024_repeat.py
@@ -0,0 +1,57 @@
+# dataset settings
+dataset_type = 'CityscapesDataset'
+data_root = 'data/cityscapes/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (1024, 1024)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 1024),
+ # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type='RepeatDataset',
+ times=500,
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/train',
+ ann_dir='gtFine/train',
+ pipeline=train_pipeline)),
+ val=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ data_root=data_root,
+ img_dir='leftImg8bit/val',
+ ann_dir='gtFine/val',
+ pipeline=test_pipeline))
\ No newline at end of file
diff --git a/DEST/semseg/dest_head.py b/DEST/semseg/dest_head.py
new file mode 100644
index 00000000..b648a5b0
--- /dev/null
+++ b/DEST/semseg/dest_head.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+
+from mmseg.models.builder import HEADS
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class DestHead(BaseDecodeHead):
+ def __init__(self, segm=True, **kwargs):
+ super().__init__(input_transform='multiple_select', **kwargs)
+
+ num_inputs = len(self.in_channels)
+ assert num_inputs == len(self.in_channels)
+
+ self.fuse_conv1 = nn.Sequential(nn.Conv2d(self.in_channels[-1], self.in_channels[-1], 1), nn.ReLU(inplace=True))
+ self.fuse_conv2 = nn.Sequential(nn.Conv2d(self.in_channels[-2], self.in_channels[-2], 1), nn.ReLU(inplace=True))
+ self.fuse_conv3 = nn.Sequential(nn.Conv2d(self.in_channels[-3], self.in_channels[-3], 1), nn.ReLU(inplace=True))
+ self.fuse_conv4 = nn.Sequential(nn.Conv2d(self.in_channels[-4], self.in_channels[-4], 1), nn.ReLU(inplace=True))
+
+ self.upsample = nn.ModuleList([nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'))]*len(self.in_channels))
+
+ self.fused_1 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(self.in_channels[-1], self.in_channels[-1], 3), nn.ReLU(inplace=True))
+ self.fused_2 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(self.in_channels[-2] + self.in_channels[-1], self.in_channels[-2], 3), nn.ReLU(inplace=True))
+ self.fused_3 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(self.in_channels[-3] + self.in_channels[-2], self.in_channels[-3], 3), nn.ReLU(inplace=True))
+ self.fused_4 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(self.in_channels[-4] + self.in_channels[-3], self.in_channels[-4], 3), nn.ReLU(inplace=True))
+
+ self.conv_seg = nn.Conv2d(self.in_channels[-4], self.num_classes, kernel_size=1)
+
+ def dest_decoder(self, lay_out):
+ lay_out = lay_out[0]
+ fused_1 = self.fuse_conv1(lay_out[-1])
+ fused_1 = self.upsample[-1](fused_1)
+ fused_1 = self.fused_1(fused_1)
+ fused_2 = torch.cat([fused_1, self.fuse_conv2(lay_out[-2])], 1)
+
+ fused_2 = self.upsample[-2](fused_2)
+ fused_2 = self.fused_2(fused_2)
+ fused_3 = torch.cat([fused_2, self.fuse_conv3(lay_out[-3])], 1)
+
+ fused_3 = self.upsample[-3](fused_3)
+ fused_3 = self.fused_3(fused_3)
+ fused_4 = torch.cat([fused_3, self.fuse_conv4(lay_out[-4])], 1)
+
+ fused_4 = self.upsample[-4](fused_4)
+ fused_4 = self.fused_4(fused_4)
+
+ return self.conv_seg(fused_4)
+
+ def forward(self, x):
+ return self.dest_decoder(x)
+
+
diff --git a/DEST/semseg/dest_simpatt-b0.py b/DEST/semseg/dest_simpatt-b0.py
new file mode 100644
index 00000000..41169475
--- /dev/null
+++ b/DEST/semseg/dest_simpatt-b0.py
@@ -0,0 +1,36 @@
+# model settings
+embed_dims = [32, 64, 160, 256]
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='SimplifiedTransformer',
+ img_size=(1024, 1024),
+ in_chans=3,
+ num_classes=19,
+ embed_dims=embed_dims,
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ depths=[2, 2, 2, 2],
+ sr_ratios=[8, 4, 2, 1]),
+ decode_head=dict(
+ type='DestHead',
+ in_channels=[32, 64, 160, 256],
+ in_index=[0, 1, 2, 3],
+ channels=256,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ # test_cfg=dict(mode='whole'))
+ test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
diff --git a/DEST/semseg/dest_simpatt-b0_1024x1024_160k_cityscapes.py b/DEST/semseg/dest_simpatt-b0_1024x1024_160k_cityscapes.py
new file mode 100644
index 00000000..59edffa3
--- /dev/null
+++ b/DEST/semseg/dest_simpatt-b0_1024x1024_160k_cityscapes.py
@@ -0,0 +1,76 @@
+# _base_ = [
+# '../_base_/models/dest_simplemit-b0.py',
+# '../_base_/datasets/cityscapes_1024x1024.py',
+# '../_base_/default_runtime.py',
+# '../_base_/schedules/schedule_160k.py'
+# ]
+
+
+_base_ = [
+ '../_base_/models/dest_simpatt-b0.py',
+ '../_base_/datasets/cityscapes_1024x1024_repeat.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k_adamw.py'
+]
+
+
+
+evaluation = dict(interval=4000, metric='mIoU')
+# data = dict(samples_per_gpu=1)
+checkpoint_config = dict(by_epoch=False, interval=20000)
+
+
+# optimizer
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=1.)
+ }))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+embed_dims = [32, 64, 160, 256]
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='SimplifiedTransformer',
+ img_size=(1024,1024), #doesn't matter
+ in_chans=3,
+ num_classes=19,
+ embed_dims=embed_dims,
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ depths=[2, 2, 2, 2],
+ sr_ratios=[8, 4, 2, 1]),
+ decode_head=dict(
+ type='DestHead',
+ in_channels=embed_dims,
+ in_index=[0, 1, 2, 3],
+ channels=256,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=True,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ # test_cfg=dict(mode='whole'))
+ test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
+
+
+
+
diff --git a/DEST/semseg/dest_simpatt-b1_1024x1024_160k_cityscapes.py b/DEST/semseg/dest_simpatt-b1_1024x1024_160k_cityscapes.py
new file mode 100644
index 00000000..04be2510
--- /dev/null
+++ b/DEST/semseg/dest_simpatt-b1_1024x1024_160k_cityscapes.py
@@ -0,0 +1,76 @@
+# _base_ = [
+# '../_base_/models/dest_simplemit-b0.py',
+# '../_base_/datasets/cityscapes_1024x1024.py',
+# '../_base_/default_runtime.py',
+# '../_base_/schedules/schedule_160k.py'
+# ]
+
+
+_base_ = [
+ '../_base_/models/dest_simpatt-b0.py',
+ '../_base_/datasets/cityscapes_1024x1024_repeat.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k_adamw.py'
+]
+
+
+
+evaluation = dict(interval=4000, metric='mIoU')
+# data = dict(samples_per_gpu=1)
+checkpoint_config = dict(by_epoch=False, interval=20000)
+
+
+# optimizer
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=1.)
+ }))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+embed_dims = [64, 128, 250, 320]
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='SimplifiedTransformer',
+ img_size=(1024,1024), #doesn't matter
+ in_chans=3,
+ num_classes=19,
+ embed_dims=embed_dims,
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ depths=[2, 2, 2, 2],
+ sr_ratios=[8, 4, 2, 1]),
+ decode_head=dict(
+ type='DestHead',
+ in_channels=embed_dims,
+ in_index=[0, 1, 2, 3],
+ channels=256,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=True,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ # test_cfg=dict(mode='whole'))
+ test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
+
+
+
+
diff --git a/DEST/semseg/dest_simpatt-b2_1024x1024_160k_cityscapes.py b/DEST/semseg/dest_simpatt-b2_1024x1024_160k_cityscapes.py
new file mode 100644
index 00000000..c74b11fb
--- /dev/null
+++ b/DEST/semseg/dest_simpatt-b2_1024x1024_160k_cityscapes.py
@@ -0,0 +1,76 @@
+# _base_ = [
+# '../_base_/models/dest_simplemit-b0.py',
+# '../_base_/datasets/cityscapes_1024x1024.py',
+# '../_base_/default_runtime.py',
+# '../_base_/schedules/schedule_160k.py'
+# ]
+
+
+_base_ = [
+ '../_base_/models/dest_simpatt-b0.py',
+ '../_base_/datasets/cityscapes_1024x1024_repeat.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k_adamw.py'
+]
+
+
+
+evaluation = dict(interval=4000, metric='mIoU')
+# data = dict(samples_per_gpu=1)
+checkpoint_config = dict(by_epoch=False, interval=20000)
+
+
+# optimizer
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=1.)
+ }))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+embed_dims = [64, 128, 250, 320]
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='SimplifiedTransformer',
+ img_size=(1024,1024), #doesn't matter
+ in_chans=3,
+ num_classes=19,
+ embed_dims=embed_dims,
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ depths=[3, 3, 6, 3],
+ sr_ratios=[8, 4, 2, 1]),
+ decode_head=dict(
+ type='DestHead',
+ in_channels=embed_dims,
+ in_index=[0, 1, 2, 3],
+ channels=256,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=True,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ # test_cfg=dict(mode='whole'))
+ test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
+
+
+
+
diff --git a/DEST/semseg/dest_simpatt-b3_1024x1024_160k_cityscapes.py b/DEST/semseg/dest_simpatt-b3_1024x1024_160k_cityscapes.py
new file mode 100644
index 00000000..b5c65340
--- /dev/null
+++ b/DEST/semseg/dest_simpatt-b3_1024x1024_160k_cityscapes.py
@@ -0,0 +1,77 @@
+# _base_ = [
+# '../_base_/models/dest_simplemit-b0.py',
+# '../_base_/datasets/cityscapes_1024x1024.py',
+# '../_base_/default_runtime.py',
+# '../_base_/schedules/schedule_160k.py'
+# ]
+
+
+_base_ = [
+ '../_base_/models/dest_simpatt-b0.py',
+ '../_base_/datasets/cityscapes_1024x1024_repeat.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k_adamw.py'
+]
+
+
+
+evaluation = dict(interval=1000, metric='mIoU')
+data = dict(samples_per_gpu=4)
+checkpoint_config = dict(by_epoch=False, interval=20000)
+
+
+# optimizer
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=1.)
+ }))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+embed_dims = [64, 128, 250, 320]
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='SimplifiedTransformer',
+ img_size=(1024, 1024),
+ in_chans=3,
+ num_classes=19,
+ embed_dims=embed_dims,
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ depths=[3, 6, 8, 3],
+ sr_ratios=[8, 4, 2, 1]),
+ decode_head=dict(
+ type='DestHead',
+ in_channels=embed_dims,
+ in_index=[0, 1, 2, 3],
+ channels=512, #decoder param
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=True,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ # test_cfg=dict(mode='whole'))
+ test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
+
+
+
+
+
diff --git a/DEST/semseg/dest_simpatt-b4_1024x1024_160k_cityscapes.py b/DEST/semseg/dest_simpatt-b4_1024x1024_160k_cityscapes.py
new file mode 100644
index 00000000..767711ea
--- /dev/null
+++ b/DEST/semseg/dest_simpatt-b4_1024x1024_160k_cityscapes.py
@@ -0,0 +1,77 @@
+# _base_ = [
+# '../_base_/models/dest_simplemit-b0.py',
+# '../_base_/datasets/cityscapes_1024x1024.py',
+# '../_base_/default_runtime.py',
+# '../_base_/schedules/schedule_160k.py'
+# ]
+
+
+_base_ = [
+ '../_base_/models/dest_simpatt-b0.py',
+ '../_base_/datasets/cityscapes_1024x1024_repeat.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k_adamw.py'
+]
+
+
+
+evaluation = dict(interval=1000, metric='mIoU')
+data = dict(samples_per_gpu=4)
+checkpoint_config = dict(by_epoch=False, interval=20000)
+
+
+# optimizer
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=1.)
+ }))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+embed_dims = [64, 128, 250, 320]
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='SimplifiedTransformer',
+ img_size=(1024, 1024),
+ in_chans=3,
+ num_classes=19,
+ embed_dims=embed_dims,
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ depths=[3, 6, 8, 3],
+ sr_ratios=[8, 4, 2, 1]),
+ decode_head=dict(
+ type='DestHead',
+ in_channels=embed_dims,
+ in_index=[0, 1, 2, 3],
+ channels=512, #decoder param
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=True,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ # test_cfg=dict(mode='whole'))
+ test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
+
+
+
+
+
diff --git a/DEST/semseg/dest_simpatt-b5_1024x1024_160k_cityscapes.py b/DEST/semseg/dest_simpatt-b5_1024x1024_160k_cityscapes.py
new file mode 100644
index 00000000..c5a23ae0
--- /dev/null
+++ b/DEST/semseg/dest_simpatt-b5_1024x1024_160k_cityscapes.py
@@ -0,0 +1,77 @@
+# _base_ = [
+# '../_base_/models/dest_simplemit-b0.py',
+# '../_base_/datasets/cityscapes_1024x1024.py',
+# '../_base_/default_runtime.py',
+# '../_base_/schedules/schedule_160k.py'
+# ]
+
+
+_base_ = [
+ '../_base_/models/dest_simpatt-b0.py',
+ '../_base_/datasets/cityscapes_1024x1024_repeat.py',
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_160k_adamw.py'
+]
+
+
+
+evaluation = dict(interval=1000, metric='mIoU')
+data = dict(samples_per_gpu=3)
+checkpoint_config = dict(by_epoch=False, interval=20000)
+
+
+# optimizer
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+ paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=1.)
+ }))
+
+lr_config = dict(_delete_=True, policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0, min_lr=0.0, by_epoch=False)
+
+embed_dims = [64, 128, 250, 320]
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='SimplifiedTransformer',
+ img_size=(1024, 1024),
+ in_chans=3,
+ num_classes=19,
+ embed_dims=embed_dims,
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ depths=[3, 10, 16, 5],
+ sr_ratios=[8, 4, 2, 1]),
+ decode_head=dict(
+ type='DestHead',
+ in_channels=embed_dims,
+ in_index=[0, 1, 2, 3],
+ channels=512, #decoder param
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=True,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ # test_cfg=dict(mode='whole'))
+ test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
+
+
+
+
+
diff --git a/DEST/semseg/schedule_160k_adamw.py b/DEST/semseg/schedule_160k_adamw.py
new file mode 100644
index 00000000..3d45739f
--- /dev/null
+++ b/DEST/semseg/schedule_160k_adamw.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='AdamW', lr=0.0002, weight_decay=0.0001)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=160000)
+checkpoint_config = dict(by_epoch=False, interval=4000)
+evaluation = dict(interval=4000, metric='mIoU')
\ No newline at end of file
diff --git a/DEST/semseg/simplified_attention_mmseg.py b/DEST/semseg/simplified_attention_mmseg.py
new file mode 100644
index 00000000..02f69203
--- /dev/null
+++ b/DEST/semseg/simplified_attention_mmseg.py
@@ -0,0 +1,378 @@
+# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+import math
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from mmcv.cnn import build_norm_layer
+from mmcv.runner import BaseModule
+from ..builder import BACKBONES
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0., sync_norm=True):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = nn.Conv1d(in_features, hidden_features, 1)
+ self.dwconv = DWConv(hidden_features)
+ self.act = nn.ReLU()
+ self.fc2 = nn.Conv1d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ if sync_norm :
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, hidden_features, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, hidden_features, postfix=2)
+
+ self.add_module(self.norm1_name, norm1)
+ self.add_module(self.norm2_name, norm2)
+ else :
+ self.norm1 = nn.BatchNorm1d(hidden_features)
+ self.norm2 = nn.BatchNorm1d(hidden_features)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Conv1d):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ x = self.norm1(x)
+ x = self.dwconv(x, H, W)
+ x = self.norm2(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention_MaxPool(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, sync_norm=True):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q = nn.Conv1d(dim, dim, 1, bias=qkv_bias)
+ self.k = nn.Conv1d(dim, dim, 1, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Conv1d(dim, dim, 1)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ if sync_norm :
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, dim, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ else :
+ self.norm1 = nn.BatchNorm1d(dim)
+ self.apply(self._init_weights)
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Conv1d):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, H, W):
+ B, C, N = x.shape
+ q = self.q(x)
+ q = q.reshape(B, self.num_heads, C // self.num_heads, N)
+ q = q.permute(0, 1, 3, 2)
+
+ if self.sr_ratio > 1:
+ x_ = x.reshape(B, C, H, W)
+ x_ = self.sr(x_).reshape(B, C, -1)
+ x_ = self.norm1(x_)
+ k = self.k(x_).reshape(B, self.num_heads, C // self.num_heads, -1)
+ else:
+ k = self.k(x).reshape(B, self.num_heads, C // self.num_heads, -1)
+
+ v = torch.mean(x, 2, True).repeat(1, 1, self.num_heads).transpose(-2, -1)
+
+ attn = (q @ k) * self.scale
+ attn, _ = torch.max(attn, -1)
+
+ out = (attn.transpose(-2, -1) @ v)
+ out = out.transpose(-2, -1)
+
+ out = self.proj(out)
+
+ return out
+
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm, sr_ratio=1, sync_norm=True):
+ super().__init__()
+ if sync_norm :
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, dim, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, dim, postfix=2)
+ self.add_module(self.norm1_name, norm1)
+ self.add_module(self.norm2_name, norm2)
+ else :
+ self.norm1 = nn.BatchNorm1d(dim)
+ self.norm2 = nn.BatchNorm1d(dim)
+
+ self.attn = Attention_MaxPool(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+ return x
+
+
+class OverlapPatchEmbed(BaseModule):
+ """ Image to Patch Embedding """
+ def __init__(self, img_size=(224,224), patch_size=7, stride=4, in_chans=3, embed_dim=768, sync_norm=True):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
+
+ self.H = (img_size[0] - patch_size[0] + 2 * (patch_size[0] // 2)) / stride + 1
+ self.W = (img_size[1] - patch_size[1] + 2 * (patch_size[1] // 2)) / stride + 1
+ self.feat_shape = (int(self.H), int(self.W))
+ self.N = int(self.feat_shape[0] * self.feat_shape[1])
+
+ if sync_norm:
+ norm_cfg = dict(type='SyncBN', requires_grad=True)
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, embed_dim, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ else :
+ self.norm1 = nn.BatchNorm2d(embed_dim)
+
+ self.apply(self._init_weights)
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = self.norm1(x)
+ x = x.flatten(2)
+ return x, H, W
+
+
+@BACKBONES.register_module()
+class SimplifiedTransformer(nn.Module):
+ def __init__(self, img_size=(224,224), patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], sync_norm=True
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+ self.embed_dims = embed_dims
+ self.sr_ratios = sr_ratios
+ self.num_layers = depths
+
+ # patch_embed
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
+ embed_dim=embed_dims[0], sync_norm=sync_norm)
+ self.patch_embed2 = OverlapPatchEmbed(img_size=(img_size[0] // 4, img_size[1] // 4), patch_size=3, stride=2, in_chans=embed_dims[0],
+ embed_dim=embed_dims[1], sync_norm=sync_norm)
+ self.patch_embed3 = OverlapPatchEmbed(img_size=(img_size[0] // 8, img_size[1] // 8), patch_size=3, stride=2, in_chans=embed_dims[1],
+ embed_dim=embed_dims[2], sync_norm=sync_norm)
+ self.patch_embed4 = OverlapPatchEmbed(img_size=(img_size[0] // 16, img_size[1] // 16), patch_size=3, stride=2, in_chans=embed_dims[2],
+ embed_dim=embed_dims[3], sync_norm=sync_norm)
+
+ # encoder
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ cur = 0
+ self.block1 = nn.ModuleList([Block(
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[0], sync_norm=sync_norm)
+ for i in range(depths[0])])
+
+ cur += depths[0]
+ self.block2 = nn.ModuleList([Block(
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[1], sync_norm=sync_norm)
+ for i in range(depths[1])])
+
+ cur += depths[1]
+ self.block3 = nn.ModuleList([Block(
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[2], sync_norm=sync_norm)
+ for i in range(depths[2])])
+
+ cur += depths[2]
+ self.block4 = nn.ModuleList([Block(
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[3], sync_norm=sync_norm)
+ for i in range(depths[3])])
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+
+ B = x.shape[0]
+ outs = []
+ ref_feat = {'1': [], '2': [], '3': [], '4': [],}
+
+ # stage 1
+ x, H, W = self.patch_embed1(x)
+ for i, blk in enumerate(self.block1):
+ x = blk(x, H, W)
+ ref_feat['1'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ # stage 2
+ x, H, W = self.patch_embed2(x)
+ for i, blk in enumerate(self.block2):
+ x = blk(x, H, W)
+ ref_feat['2'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ # stage 3
+ x, H, W = self.patch_embed3(x)
+ for i, blk in enumerate(self.block3):
+ x = blk(x, H, W)
+ ref_feat['3'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ # stage 4
+ x, H, W = self.patch_embed4(x)
+ for i, blk in enumerate(self.block4):
+ x = blk(x, H, W)
+ ref_feat['4'].append(x)
+ x = x.reshape(B, -1, H, W).contiguous()
+ outs.append(x)
+
+ return outs, ref_feat
+
+ def forward(self, x):
+ x, ref_feat = self.forward_features(x)
+ return x, ref_feat
+
+
+class DWConv(nn.Module):
+ def __init__(self, dim=768):
+ super(DWConv, self).__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x, H, W):
+ B, C, N = x.shape
+ x = x.reshape(B, C, H, W)
+ x = self.dwconv(x)
+ x = x.flatten(2)
+ return x
+
diff --git a/LeNetWithS3Pooling/training/requirements.txt b/LeNetWithS3Pooling/training/requirements.txt
index ed194c68..56dc13c5 100644
--- a/LeNetWithS3Pooling/training/requirements.txt
+++ b/LeNetWithS3Pooling/training/requirements.txt
@@ -1,6 +1,6 @@
-numpy==1.16.5
+numpy==1.22.0
Pillow>=6.2.2
-protobuf==3.10.0
+protobuf==3.18.3
six==1.12.0
summary==0.2.0
termcolor==1.1.0