transmotion package#
Submodules#
transmotion.blocks module#
Building Block with trainable parameters
- class transmotion.blocks.AntiAliasInterpolation2d(dim: int, scale: float)[source]#
Bases:
torch.nn.modules.module.Module
You can get more info on anti-aliasing for image resize here: - [The dangers behind image resizing](https://blog.zuru.tech/machine-learning/2021/08/09/the-dangers-behind-image-resizing)
- class transmotion.blocks.ImagePyramid(scales: Sequence[float], in_dim: int)[source]#
Bases:
torch.nn.modules.module.Module
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
- forward(fmap: torch.Tensor) List[torch.Tensor] [source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
transmotion.configs module#
- class transmotion.configs.BackgroundLoss(loss_weight: float)[source]#
Bases:
object
- loss_weight: float#
- class transmotion.configs.DataLoadingConfig(batch_size: int, num_workers: int)[source]#
Bases:
object
- batch_size: int#
- num_workers: int#
- class transmotion.configs.DenseMotionConf(base_dim: int, tps: transmotion.configs.TPSConfig, num_blocks: int, in_features: int, max_features: int, kp_variance: float = 0.01, scale_factor: float = 0.25, num_occlusion_masks: int = 4)[source]#
Bases:
object
- base_dim: int#
Dimension that encoder and decoder networks start from. This dim multiplied in each block i by a factor of \(2^i\)
- in_features: int#
Input dimension to the dense motion network. Usually the input is an image so this value is 3
- kp_variance: float = 0.01#
Variance used in Gaussian heatmaps representation
- max_features: int#
- num_blocks: int#
- num_occlusion_masks: int = 4#
- scale_factor: float = 0.25#
- class transmotion.configs.DropoutConfig(start_epoch: int, init_prob: float, max_prob: float, prob_inc_epochs: int = 10)[source]#
Bases:
object
Drop-out is applied during training to the DenseMotion network. Its not a property of the network but an external param to the forward pass. So the config for that is in general train config and not the dense motion config.
- Parameters
prob_inc_epoch – Increment dropout probability over this amount of epochs
- init_prob: float#
- max_prob: float#
- prob_inc_epochs: int = 10#
- start_epoch: int#
- class transmotion.configs.EquivarianceLoss(sigma_tps: float, sigma_affine: float, points_per_tps: int, loss_weight: float)[source]#
Bases:
object
This loss requires a random TPS transformation, the parameters of this transformation are specified here.
- loss_weight: float#
- points_per_tps: int#
number of control points to use for the TPS
- sigma_affine: float#
standard deviation for affine transform params (generated randomly)
- sigma_tps: float#
standard deviation for TPS kernel mixing params (generated randomly)
- class transmotion.configs.InpaintingConfig(base_dim: int, in_features: int, num_down_blocks: int, num_occlusion_masks: int, max_features: int)[source]#
Bases:
object
- base_dim: int#
- in_features: int#
- max_features: int#
- num_down_blocks: int#
- num_occlusion_masks: int#
- class transmotion.configs.LossConfig(perceptual: transmotion.configs.PerceptualLoss, equivariance: transmotion.configs.EquivarianceLoss, warp: transmotion.configs.WarpLoss, background: transmotion.configs.BackgroundLoss)[source]#
Bases:
object
- background: transmotion.configs.BackgroundLoss#
- equivariance: transmotion.configs.EquivarianceLoss#
- perceptual: transmotion.configs.PerceptualLoss#
- class transmotion.configs.OptimizerConfig(initial_lr: float, lr_decay_epoch_sched: Sequence[int], lr_decay_gamma: float, adam_beta1: float = 0.5, adam_beta2: float = 0.999, weight_decay: float = 0.0001)[source]#
Bases:
object
Learning rates, decay parameters and scheduling
- adam_beta1: float = 0.5#
- adam_beta2: float = 0.999#
- initial_lr: float#
base learning rate for the optimizer
- lr_decay_epoch_sched: Sequence[int]#
Sequence of epoch numbers that represent learning rate decay steps. decay is by lr_decay_gamma
- lr_decay_gamma: float#
Learning rate decay parameter. Decay happens according to lr_decay_epoch_sched
- weight_decay: float = 0.0001#
- class transmotion.configs.PerceptualLoss(scales: Sequence[float], loss_weights: Sequence[float], in_dim: int)[source]#
Bases:
object
Get VGG features from multiple scales of the driving image and the generated image. Train on generated images with \(L_1\) loss.
- in_dim: int#
Perceptual loss multi-scale image pyramid, this is the in dimension of the pyramid input
- loss_weights: Sequence[float]#
Loss weights (for each scale) when aggreagted into total loss value
- scales: Sequence[float]#
List of float number that represent the scaling factors to be applied on the images. This needs to conform to the number of feature maps you get from the perceptual loss network. By default we use VGG19 here, this net has 5 feature map output. #TODO: Might need to make this more general and explicitly link number of feature maps to number of scales.
- class transmotion.configs.TPSConfig(num_tps: int, points_per_tps: int)[source]#
Bases:
object
- num_tps: int#
We use multiple TPS transforms. Here you specify the number to use
- points_per_tps: int#
Each TPS transform has a number of control points that will be matched exactly. Here you specify the number of those points
- class transmotion.configs.TrainConfig(data_loading: transmotion.configs.DataLoadingConfig, num_epochs: int, optimizer: transmotion.configs.OptimizerConfig, dropout: transmotion.configs.DropoutConfig, dense_motion: transmotion.configs.DenseMotionConf, inpainting: transmotion.configs.InpaintingConfig, tps: transmotion.configs.TPSConfig, loss: transmotion.configs.LossConfig, background_start_epoch: int, image_size: int)[source]#
Bases:
object
Specifies all the parameters needed for training
- background_start_epoch: int#
You might want to start background training later, this specifies what epoch bg training will start
- data_loading: transmotion.configs.DataLoadingConfig#
- dense_motion: transmotion.configs.DenseMotionConf#
- dropout: transmotion.configs.DropoutConfig#
- image_size: int#
TPS works on square images of size
image_size x image_size
- inpainting: transmotion.configs.InpaintingConfig#
- num_epochs: int#
- optimizer: transmotion.configs.OptimizerConfig#
- transmotion.configs.dummy_conf() transmotion.configs.TrainConfig [source]#
VOX model dummy config
transmotion.cord_warp module#
Thin-Plate Spline Transform implementation and some helper functions to work with points and grids.
- class transmotion.cord_warp.ThinPlateSpline(control_points: torch.Tensor, control_params: torch.Tensor, theta: torch.Tensor, _warp_cords: Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor])[source]#
Bases:
object
This Spline is equivalent ot a function that transforms a set of point to be close to some other set of point while keeping the transformation smooth. The other set of points are the control points and smoothness is measured as the second derivative of the transformation.
we have the following params in this transformation:
\(A_k, b_k\) = theta
\(W_k\) = control params
\(p_k\) = control points
To transform a point P:
\[\hat{P} = (A_k p + b_k) + W_k U(|| p_k - p ||)\]\(U\) is the Kernel function.
- control_params: torch.Tensor#
Weights that we apply on our kernel function
- control_points: torch.Tensor#
The points that our transformation needs to match exactly. Shape -
[b, K, 2]
Each new point we transform is am affine combination of control set point. each control point is weighted according to a kernel distance from the point being transformed
- classmethod fit(source_pts: torch.Tensor, destination_pts: torch.Tensor)[source]#
- Param
source_pts - [bs, num_tps, pts_per_tps, 2] source points for TPS transform
- Param
destination_pts - [bs, num_tps, pts_per_tps, 2] destination points for TPS transform
- classmethod random(sigma_tps: float, num_points: int, sigma_affine: float, batch_size: int = 1, num_transforms: int = 1)[source]#
Random Transform is a special case of the full TPS. This one does a random affine transform and jitter them in a diagonal (weighted by TPS params)
- theta: torch.Tensor#
Affine Transform Parameters [bs, K, 2, 3] a batch of K TPS transforms
- transmotion.cord_warp.conform_to_type_and_device(leader: torch.Tensor) Callable[[torch.Tensor], torch.Tensor] [source]#
- transmotion.cord_warp.deform_with_4d_deformation(frame: torch.Tensor, deformation: torch.Tensor) torch.Tensor [source]#
Apply pythorch’s
grid_sample
withalign_corners
.- Parameters
frame – shape [b, d, h, w]
deformation – shape [b, h, w, 2]
- transmotion.cord_warp.deform_with_5d_deformation(frame: torch.Tensor, deformation: torch.Tensor) torch.Tensor [source]#
Deform image frame according to wrapped coordinates. The deformation is done with PyTorch’s
grid_sample
that takes 4D tensors as deformation coordinates- Parameters
deformation ([bs, K, h, w, d]) – Batch of K coordinates grids that describe where each pixel should go
transmotion.data_loading module#
- transmotion.data_loading.map_numpy(*funcs: Callable[[numpy.ndarray], numpy.ndarray], it: Iterable[numpy.ndarray]) Callable[[numpy.ndarray], numpy.ndarray] [source]#
Map a sequence of function that take numpy array and return a numpy array over an iterator of numpy arrays.
This is a helper function to apply transformations on the video iterator.
transmotion.datasets module#
In training you sample a video, and sample two frames out of the video, one of the frames serves as the source and the other as driving.
we get a video array, in training this will be two frames out of the video, in reconstruction mode, we take all of the frames.
they load the whole video to get only two frames out of it. in demo run, the use all of the frames.
pretty simple actually
- class transmotion.datasets.SourceDrivingBatch(source: torch.Tensor, driving: torch.Tensor)[source]#
Bases:
object
- Parameters
source – source image, shape [bs, d (=3), h, w]
driving – source image, shape [bs, d (=3), h, w]
- driving: torch.Tensor#
- source: torch.Tensor#
- class transmotion.datasets.SourceDrivingSample(source: torch.Tensor, driving: torch.Tensor)[source]#
Bases:
object
- Parameters
source – source image, shape [d (=3), h, w]
driving – source image, shape [d (=3), h, w]
- driving: torch.Tensor#
- source: torch.Tensor#
- transmotion.datasets.collate_fn(batch: Sequence[transmotion.datasets.SourceDrivingSample]) transmotion.datasets.SourceDrivingBatch [source]#
transmotion.debug module#
This is an ad-hoc code used for debugging gaps in the dense motion generation network. Can safely ignore, its here “for my records”
transmotion.dense_motion module#
- class transmotion.dense_motion.BGMotionParam(bg_params: torch.Tensor)[source]#
Bases:
object
- bg_params: torch.Tensor#
[bs, 3, 3]
- Type
Affine transform parameters (3x3 shape, with last row as bias [0,0,1]) Shape
- class transmotion.dense_motion.BGMotionPredictor[source]#
Bases:
torch.nn.modules.module.Module
Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
This is a spearate module since we don’t train it from the start, it starts as None, and appears as the model has some training weights
- forward(source_image: torch.Tensor, driving_image: torch.Tensor) transmotion.dense_motion.BGMotionParam [source]#
- Parameters
source_imgae – [bs, d, h, w]
driving_imgae – [bs, d, h, w]
- class transmotion.dense_motion.DenseMotionNetwork(cfg: transmotion.configs.DenseMotionConf)[source]#
Bases:
torch.nn.modules.module.Module
Estimate optical flow from key-point transformatiopn,
Estimate multi-resolution occlusion masks
- Parameters
hg_to_num_mappings – Conv layer that translates HourGlass output dimension to the number of mappings (in this case num of TPS transforms + bkg transform)
- forward(source_img: torch.Tensor, src_kp: transmotion.kp_detection.KPResult, drv_kp: transmotion.kp_detection.KPResult, bg_param: Optional[transmotion.dense_motion.BGMotionParam] = None, dropout_prob: float = 0.0) transmotion.dense_motion.DenseMotionResult [source]#
Flow of the forward pass:
Downsample the source image
Generate heatmap representation out of the source and driving key-points
Fit Thin-Plate Spline to source and driving key-points
Deform source image with TPS transform (happens in
tps_warp_to_keypoints_with_background()
)Concat heatmaps and deformed source and pass that thorough an hour-glass neural-net (HG)
Infer optical-flow from last HG feature-map
Infer occlusion masks from the last HG feature-map
- hg_to_num_mappings: torch.nn.modules.module.Module#
- infer_optical_flow(fmap: torch.Tensor, warpped_cords: torch.Tensor, dropout_prob: float) Tuple[torch.Tensor, torch.Tensor] [source]#
Turn the output of the hour-glass architecture into a softmax weights tensor and use that reduce K+1 transformed cords to a single optical-flow tensor.
- Parameters
fmap (4-D Batched Tensor) – Last output of the HourGlass architecture used to infer contribution maps
warpped_cords – Shape: [b k h w d] coordinates after beign warped with a TPS transform
- Returns
Optical-flow Tensor (shape: [b h w d]) and the weights (shape: [b k h w]) assigned to each of the transformations.
- class transmotion.dense_motion.DenseMotionResult(deformed_source: torch.Tensor, contribution_maps: torch.Tensor, optical_flow: torch.Tensor, occlusion_masks: List[torch.Tensor])[source]#
Bases:
object
Output of the DenseMotionNetwork
- Parameters
deformed_source – Source image deformed by K TPS transformations (not by the optical flow!)
contribution_maps – Convex weights on each of the TPS transforms & background transform
optical_flow – a grid of coordinates \(\in [0,1]\), every coordinate tells us where grid pixel should be moved
- contribution_maps: torch.Tensor#
Shape -
[bs, K+1, h, w]
- deformed_source: torch.Tensor#
Shape -
[bs, K+1, h w]
- occlusion_masks: List[torch.Tensor]#
Shpae of element i -
[bs, d_i, h_i, w_i]
- optical_flow: torch.Tensor#
Shape -
[bs, h, w, d (=2)]
- class transmotion.dense_motion.LearnedUpsample(in_dim: int, num_upsamples: int)[source]#
Bases:
torch.nn.modules.module.Module
- Parameters
upsampled_dims – Dimension sizes of the upsampled feature maps
- class transmotion.dense_motion.OcclusionMasks(fmap_channels: List[int])[source]#
Bases:
torch.nn.modules.module.Module
- Parameters
mask_predicotrs – Conv layers followed by Sigmoid activation. kernel_size=7 and padding preserves spatial dimension
- forward(fmaps: List[torch.Tensor]) List[torch.Tensor] [source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class transmotion.dense_motion.WarppedCords(warpped_cords: torch.Tensor, num_tps: int)[source]#
Bases:
object
- Parameters
wrapped_cords – Shape
[bs, num_tps+1, gird_h, grid_w, 2]
- deform_frame(frame: torch.Tensor) torch.Tensor [source]#
Multi-deformation. One frames gets deformed by several different deformation coordinates Deform image frame according to the wrapped coordinates :return: shape
[b K+1 d h w]
- num_tps: int#
- warpped_cords: torch.Tensor#
- transmotion.dense_motion.keypoints_to_heatmap_representation(spatial_hw: Tuple[int, int], source: transmotion.kp_detection.KPResult, driving: transmotion.kp_detection.KPResult, kp_variance: float) torch.Tensor [source]#
create a heatmap representation from key-point results For motion estimation this should be the size of the source images as key-points heatmap is going to be concatenated with tohe source image
- Param
spatial_hw - spatial size of the produces heatmap.
- Source
Predicted Key-points for the source image
- Driving
Predicted Key-points for the driving image
- Kp_variance
Key-points representation is based on a Gaussian distribution. this is the variacne for the distribution we fit.
- Returns
Tensor with shpae of `[bs, K*N+1, h, w]
- transmotion.dense_motion.tps_dropout_softmax(fmap: torch.Tensor, drop_prob: float) torch.Tensor [source]#
We want to do some drop out on the K TPS transforms so that the model will learn to use all of them. We never drop the background features :param fmap: represents features of background and K TPS transforms :type fmap: [bs, K+1, h. w]
- transmotion.dense_motion.tps_warp_to_keypoints_with_background(grid_hw: Tuple[int, int], source: transmotion.kp_detection.KPResult, driving: transmotion.kp_detection.KPResult, bkg_affine_param: transmotion.dense_motion.BGMotionParam) transmotion.dense_motion.WarppedCords [source]#
Warp a grid of coordinates of shape
grid_hw
Predicted key-points are for foreground objects. This function warps a grid according to foreground transformation and adds a background Each point in the resulting grid says to what point we mapped the current point. Last dimension is the identity, that is the grid without any mapping- Parameters
bkg_affine_param – Shape [bs, 3, 3] if not None
- Returns
Tensor with shape `[bs, num_tps+1, gird_h, grid_w, 2]
transmotion.inpainting module#
In-painting network
- class transmotion.inpainting.InpaintingNetwork(cfg: transmotion.configs.InpaintingConfig)[source]#
Bases:
torch.nn.modules.module.Module
- forward(source_img: torch.Tensor, occlusion_masks: List[torch.Tensor], optical_flow: torch.Tensor) transmotion.inpainting.InpaintingResult [source]#
Forward pass has the following flow:
encode source image to produce encoder features
produce optical-flow deformed (
resize_deform()
) and mask occluded features (resize_occlude()
). we have two version of the encoder features, one with full gradients and another with gradient only on the optical-flow.use the deformed features as inputs to the decoder
deform and occlude the source image
final prediction is a convex combination of the deforemed-occluded source image and the inverse occluded decoder output
- Parameters
source_img (
[bs, 3, h, w]
) – Source image batchocclusion_masks (
[bs, d_i, h_i, w_i]
for i-th element of the list) – List of multi scale occlusion masks produced bytransmotion.dense_motion.DenseMotionNetwork
.optical_flow (
[bs, h, w, d (=2)]
) – a grid of coordinates, every coordinate tells us where grid pixel should be moved
Note
A Check that the number of occlusion masks is consistent with the number of down-blocks happens during config initx
- fwd_decoder(source_img: torch.Tensor, raw_encoding: Sequence[torch.Tensor], occlusion_masks: List[torch.Tensor], optical_flow: torch.Tensor) transmotion.inpainting.InpaintingResult [source]#
Decoder forward func, it takes Encoder result as input, deforms, occludes and decodes it.
- Parameters
source_img (
[bs, 3, h, w]
) – Source image batchraw_encoding – Output of the fwd_encoder func
occlusion_masks (
[bs, d_i, h_i, w_i]
for i-th element of the list) – List of multi scale occlusion masks produced bytransmotion.dense_motion.DenseMotionNetwork
.optical_flow (
[bs, h, w, d (=2)]
) – a grid of coordinates, every coordinate tells us where grid pixel should be moved
- fwd_encoder(source_img: torch.Tensor) Sequence[torch.Tensor] [source]#
Run the encoders forward pass. This function is here because you don’t need to re-run the encdoer during inference.
- Parameters
source_img (
[bs, 3, h, w]
) – Source image batch- Returns
Encodings and optical flow deformed image (same shape as input image)
- class transmotion.inpainting.InpaintingResult(deformed_src_fmaps: List[torch.Tensor], deformed_source: torch.Tensor, inpainted_img: torch.Tensor)[source]#
Bases:
object
- Parameters
deformed_src_fmaps – encoder feture map after deformation and occlusion, those have gradients on the optical flow tensor.
deformed_source – Source image deformed with optical flow
inpainted_img – generated image from source such that it matches driving pose
- deformed_source: torch.Tensor#
[bs, d, h, w]
- deformed_src_fmaps: List[torch.Tensor]#
[bs, d_i, h_i, w_i]
i-th element of the list
- inpainted_img: torch.Tensor#
[bs, d, h, w]
- transmotion.inpainting.resize_deform(frame: torch.Tensor, deformation: torch.Tensor) torch.Tensor [source]#
Deform frame according to coordinate deformation.
- Parameters
deformation (
[bs, h1, w1, d (=2)]
) – A grid of coordinates \(\in [0,1]\).frame (
[bs, d, h2 ,w2]
) – An image of a feature map
- Returns
Deformed feature map (
[bs, d, h2 ,w2]
)
- transmotion.inpainting.resize_occlude(frame: torch.Tensor, occlusion_mask: torch.Tensor) torch.Tensor [source]#
Occlude feature map :fmap: using occlusion_mask. If mask is not the same size as feature map, use bilinear interpolation on the mask to align the sizes.
- Parameters
frame – [bs, d, h1, w1]
occlusion_mask – [bs, 1, h2, w2]
- Returns
Occluded image
[bs, d, h1, w1]
transmotion.kp_detection module#
- class transmotion.kp_detection.KPDetector(cfg: transmotion.configs.TPSConfig)[source]#
Bases:
torch.nn.modules.module.Module
The network is ResNet-18 that produces num_tps*points_per_tps*spatial_dim dimensional feature. We later interpret the output as key-points of the shape:
[num_tps, points_per_tps, spatial_dim]
- forward(image: torch.Tensor) transmotion.kp_detection.KPResult [source]#
Runs on a btach of images
[b, c, h, w]
- spatial_dim: int = 2#
Spatial dimension of key-point coordinates
- class transmotion.kp_detection.KPResult(foreground_kp: torch.Tensor, num_tps: int, pts_per_tps: int, batch_size: int)[source]#
Bases:
object
The result of a KeyPoint estimation network. With shape of
[batch, num_tps_transforms, points_per_tps, 2]
The network produces parameters for a TPS transformation.Each TPS transformation has 5 parameters on every 2D spatial location :param foreground_kp: 2D coordinates in
[-1,1]
of shape[batch, num_tps_transforms, points_per_tps, 2]
- batch_size: int#
- detach_to_cpu() transmotion.kp_detection.KPResult [source]#
- foreground_kp: torch.Tensor#
2D coordinates in [-1,1] of shape [batch, num_tps_transforms, points_per_tps, 2]
- normalize_relative_to(init_: transmotion.kp_detection.KPResult, cur_: transmotion.kp_detection.KPResult) transmotion.kp_detection.KPResult [source]#
Normalize key-points relative to some initial key-point and current key-points.
Take the difference vectors of current key-points to some initial key-points. Use the difference vectors to translate current key points.
- Parameters
init – initial key-points
cur – current key-points
Note
This normalization is used during inference. We modify the source key-points according to the change to the driving key-points. This way no identity information is leaked from driving to source.
- num_tps: int#
Number of Thin-Plate Splines in the predicted key-points
- pts_per_tps: int#
Number of control points per TPS transform
- to_gaussians(hw_of_grid: Tuple[int, int], kp_variance: float) torch.Tensor [source]#
Transform a keypoints of shape
[batch, num_tps, pts_per_tps, 2]
into gaussian like representation of shape[bs, num_tps, pts_per_tps, grid_h, grid_w]
By “Gaussian Like” we mean find distance between predicted key-points and a uniform spread grid.
We get back a uniform spread grid with gaussian density centered around each key-point
We sometimes refer to the number of TPS transforms as K, number of points per TPS transform as N and number of points in a grid as P (= h*w) :param hw_of_grid: Height and Width of of the Gaussian representation.
Note
This is a convenience method, this is used when generating heatmaps for the optical-flow estimation
transmotion.network_bundle module#
- class transmotion.network_bundle.NetworkBundleResult(source_keypoints: transmotion.kp_detection.KPResult, driving_keypoints: transmotion.kp_detection.KPResult, background_param: transmotion.dense_motion.BGMotionParam, dense_motion: transmotion.dense_motion.DenseMotionResult, inpainting: transmotion.inpainting.InpaintingResult)[source]#
Bases:
object
The result of a forward pass. This includes traintime debugging information like source and target ke:50y-points Optical flow maps, deformed feature maps etc… Final generated frame is available as a property generated_image
- background_param: transmotion.dense_motion.BGMotionParam#
- dense_motion: transmotion.dense_motion.DenseMotionResult#
- driving_keypoints: transmotion.kp_detection.KPResult#
- property generated_image: torch.Tensor#
Final result. Generated source iamge that matches the pose of the drivign image
Note
Detached from autograd on copied to CPU
- inpainting: transmotion.inpainting.InpaintingResult#
- source_keypoints: transmotion.kp_detection.KPResult#
- class transmotion.network_bundle.NetworksBundle(key_points: transmotion.kp_detection.KPDetector, dense_motion: transmotion.dense_motion.DenseMotionNetwork, inpaint: transmotion.inpainting.InpaintingNetwork, background_motion: transmotion.dense_motion.BGMotionPredictor)[source]#
Bases:
object
- background_motion: transmotion.dense_motion.BGMotionPredictor#
- dense_motion: transmotion.dense_motion.DenseMotionNetwork#
- eval()[source]#
Turn all networks to eval mode and remove the background prediction network. Background prediction is not used in prediction
- forward(src_img: torch.Tensor, drv_img: torch.Tensor, dropout_prob: float = 0.0) transmotion.network_bundle.NetworkBundleResult [source]#
- key_points: transmotion.kp_detection.KPDetector#
- transmotion.network_bundle.build_nets(tps: transmotion.configs.TPSConfig, dense: transmotion.configs.DenseMotionConf, inpaint: transmotion.configs.InpaintingConfig) transmotion.network_bundle.NetworksBundle [source]#
- transmotion.network_bundle.get_kp_normalized_forward_func(src_img: torch.Tensor, intial_drv_img: torch.Tensor, nets: transmotion.network_bundle.NetworksBundle) Callable[[torch.Tensor], transmotion.network_bundle.NetworkBundleResult] [source]#
- transmotion.network_bundle.load_original_weights(fpath: str, nets: transmotion.network_bundle.NetworksBundle) transmotion.network_bundle.NetworksBundle [source]#
transmotion.nn_blocks module#
- class transmotion.nn_blocks.Decoder(block_expansion: int, in_dim: int, num_blocks: int = 3, max_features: int = 256)[source]#
Bases:
torch.nn.modules.module.Module
Hourglass Decoder
- forward(x: List[torch.Tensor]) torch.Tensor [source]#
Take a list of feature maps of shapes [ … [b, c_i, h_i, w_i]… ]. Upsample and add skip connections
TODO: - [ ] add network arch sketch
- Param
x - output of the src.nn_blocks.Encoder
- out_channels: List[int]#
- out_dim: int#
- class transmotion.nn_blocks.DownBlock2d(in_features, out_features, kernel_size=3, padding=1, groups=1)[source]#
Bases:
torch.nn.modules.module.Module
Downsampling block with Average Pooling
- class transmotion.nn_blocks.Encoder(block_expansion: int, in_dim: int, num_blocks: int = 3, max_features=256)[source]#
Bases:
torch.nn.modules.module.Module
Hourglass Encoder
- class transmotion.nn_blocks.Hourglass(block_expansion: int, in_features: int, num_blocks=3, max_features=256)[source]#
Bases:
torch.nn.modules.module.Module
Hourglass architecture. The input is downsampled with a CNN based encoder and upsampled with a CNN based decoder. A bit similar to a U-Net architecure where each upsampled features map has a residual connection to a downsampled feature map from the decoder.
- forward(x: torch.Tensor) List[torch.Tensor] [source]#
Encoder a feature map x and decode it back. :param: x - feature map of shape [b, c, h, w]
- Returns
List[th.Tensor] intermediate decoder feature maps, with increasing spatial dimensions h_i & w_i ande decreaseing dimension c_i
- last_out_dim: int#
- out_channels: List[int]#
- class transmotion.nn_blocks.ResBlock2d(in_dim: int, kernel_size: int, padding: int)[source]#
Bases:
torch.nn.modules.module.Module
Residual block that preserve spatial resolution
- class transmotion.nn_blocks.SameBlock2d(in_dim: int, out_dim: int, groups: int = 1, kernel_size: int = 3, padding: int = 1)[source]#
Bases:
torch.nn.modules.module.Module
Simple block, preserve spatial resolution.
- class transmotion.nn_blocks.UpBlock2d(in_dim: int, out_dim: int, kernel_size: int = 3, padding: int = 1, groups: int = 1)[source]#
Bases:
torch.nn.modules.module.Module
Upsampling block for use in decoder.
- class transmotion.nn_blocks.Vgg19(vgg_pretrained_features: torch.nn.modules.module.Module, requires_grad: bool = False)[source]#
Bases:
torch.nn.modules.module.Module
Vgg19 network for perceptual loss. See Sec 3.3.
- forward(X)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
transmotion.viz module#
- transmotion.viz.draw_points_on_tensors(img_tensor: torch.Tensor, key_points: transmotion.kp_detection.KPResult, radius: int = 2) Sequence[PIL.Image.Image] [source]#
Draw predicted key-points on images :param img_tensor: An image tenosr assumed to be in $[0,1]$ and shape of $(batch_size, dim, h, w)$ :param key_points: Result on predictoin from
- transmotion.viz.show_on_gird(*img_tnesors, **kwargs)[source]#
kwargs are passed to torchvision.utils.make_grid
- transmotion.viz.show_points_on_grid(img_tensor: torch.Tensor, key_points: transmotion.kp_detection.KPResult, radius: int = 2) PIL.Image.Image [source]#
transmotion.weights module#
Module for loading and saving module weights.
Key use cases:
Port pretrained weights from original repo to this one
- transmotion.weights.import_state_dict(old: Dict, new_: Dict) Dict [source]#
Use heuristics to load weights from original paper to current weights. Models haven’t changed much from the original, mostly ordered stayed the same.
The matching heuristics are as follows:
Match according to weight shapes (model must have the same shapes and the same number of weight for each shape)
In case more than one weight share shape, perform text similarity on weight name, make sure that the order of old weights matches the order of new weights