2  Grids & Warps

Here we show a few examples of the basic building blocks that are used in this work.

Code
import sys
sys.path.append("../src")
import warnings
warnings.filterwarnings('ignore')


from PIL import Image
import numpy as np
import torch as th
from einops import rearrange, repeat

from transmotion.utils import make_coordinate_grid
from transmotion.kp_detection import KPDetector, KPResult
from transmotion.configs import TPSConfig 
from transmotion.dense_motion import (
    DenseMotionNetwork,
    keypoints_to_heatmap_representation,
    warp_to_keypoints_with_background,
    DenseMotionConf,
)
import numpy as np
import altair as alt

def make_data_source(*grids):
    return alt.Data(values=[{"x":xi,"y":yi, "idx": idx} for (idx, grid) in enumerate(grids) for (xi, yi) in grid])


def draw_grid(grid):
    data = make_data_source(grid)
    return alt.Chart(data).mark_circle().encode(alt.X("x:Q"), alt.Y("y:Q"))

Uniform Coordinate Grid

We use coordinate grids to sample and deform images. Coordinate grids are well… grids, with each coordinates represented as [-1,1]\times[-1,1]

The coordinates are in [-1,1] range, where -1 means left/top and 1 means right/bottom. This is the same format that is used by PyTorch’s grid_sample function.

Code
grid = make_coordinate_grid(10,10).numpy().reshape(-1, 2)
draw_grid(grid)

Uniformal Coordinate Grid

Random Key-points

The key-point detector is untrained so, the key-points we get are from the random init of KPDetector weights. Also, the input images are random.

Code
K = 1
N = 15
tps_conf=TPSConfig(K, N)
bs = 1
kpd = KPDetector(tps_conf)


src_img = th.randn((bs, 3, 128, 128))
drv_img = th.randn((bs, 3, 128, 128))
src_res = kpd(src_img)
drv_res = kpd(drv_img)
data = make_data_source(src_res.foregroud_kp.detach().numpy().squeeze(), drv_res.foregroud_kp.detach().numpy().squeeze())

alt.Chart(data).mark_circle().encode(alt.X("x:Q"), alt.Y("y:Q"), color="idx:N").properties(title="Two Sets of Random Key-points")

Two Sets of Random Key-points

Flat grid & Warpped Grid

Code
GRID_HW = (128,128)
transform = warp_to_keypoints_with_background(GRID_HW, src_res, drv_res, None)
many_grids = transform.warpped_cords.detach().numpy()

# Regular grid, every coordinate matches grid location
flat_grid = many_grids[0,0].reshape(-1,2)
# Warpped Grid, coordinates get shifter according to TPS transform
warp_grid = many_grids[0,1].reshape(-1,2)

data = make_data_source(flat_grid, warp_grid)
alt.Chart(data).mark_circle().encode(alt.X("x:Q"), alt.Y("y:Q"), color="idx:N").configure_axis(grid=False).configure_view(strokeWidth=0).properties(title="Flat Grid (idx=0) & Warpped Grid (idx=1)")

The same key-points shown on two different grids

Code
alt.hconcat(
    draw_grid(flat_grid), 
    draw_grid(warp_grid)
).configure_axis(grid=False).configure_view(strokeWidth=0).properties(title="Flat Grid (Left) & Warpped Grid (Right)")

Warping Images

Here we warp an actual image with the transformation we fitted to our key-points. In this particular example we don’t see any occlusions

Code
GRID_HW = (480,480)
transform = warp_to_keypoints_with_background(GRID_HW, src_res, drv_res, None)
img_pil = Image.open("data/example.png").convert("RGB").resize(GRID_HW)
img = np.asarray(img_pil)
img_t = repeat(th.from_numpy(img), "h w d -> b d h w", b=1).type(th.float32)
warped_pil = Image.fromarray(transform.deform_frame(img_t).squeeze()[1].permute(1,2,0).detach().type(th.uint8).numpy())
D = GRID_HW[1]
canvas = np.empty((GRID_HW[0], 2*GRID_HW[1], 3), dtype=np.uint8)
canvas[:, :D, :] = np.asarray(img_pil)
canvas[:, D:, :] = np.asarray(warped_pil)
Image.fromarray(canvas)