行为动作识别(一):TSM和TRN

更新时间:2023-05-16 03:50:35 阅读: 评论:0

⾏为动作识别(⼀):TSM和TRN
1 TSM和
2 参考
3 如今⾏为动作识别,都是在探讨如何更好的描述时域信息特征。该⽂章在TSN基础上,提出Temporal Shift Module (TSM),既能保持⾼效⼜能有⾼性能。TSM模块是参考《Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions》(该论⽂是探讨shift操作代替卷积操作,该论⽂还没看明⽩),提出了对时域进⾏shift操作,对于offline,对所有时域选择1/8进⾏从前到后shift和从后到前shift;对online,悬着对1/4全部进⾏从前到后shift,然后放到残差结构⾥⾯,即减少了数据移动操作,也提⾼了性能。
⽂章中提出的原因是,因为移动之后提⾼了时域的感受野,能进⾏更复杂的时域建模(For each inrted temporal shift module, the temporal receptive field will be enlarged by 2, as if running a convolution with the kernel size of 3 along the temporal dimension. Therefore, our TSM model has a very large temporal receptive field to conduct highly complicated temporal modeling.)
4 对于shift操作,第⼀个超参是移动多少,最终选定1/8left shift,然后1/8 right shift。其中shift操作选定是residual TSM,对于每个residual block,都⽤shift操作替代每个block中的conv1.
5 对于添加的Nonlocal操作,参照原⽂Nonlocal模块,对于resnet50在下图中前⾯4个block中,在第⼀个和第三个block后⾯增加了⼀个Nonlocal模块,然后对于后⾯6个block,在第⼀,三,五后⾯增加⼀个Nonlocal模块
6 代码中⼀些⼯作:
a. shift操作,其实就是将该帧特征,融⼊前后帧的特征信息,以增⼤时域感受野,当然对于shift操作,也是放在残差模块中。
class TemporalShift(nn.Module):
def __init__(lf, net, n_gment=3, n_div=8, inplace=True):
super(TemporalShift, lf).__init__()
< = net
lf.n_gment = n_gment
lf.fold_div = n_div
lf.inplace = inplace
灰姑娘英文版if inplace:
print('=> Using ')
print('=> Using fold div: {}'.format(lf.fold_div))
def forward(lf, x):
x = lf.shift(x, lf.n_gment, fold_div=lf.fold_div, inplace=lf.inplace)
(x)
@staticmethod
def shift(x, n_gment, fold_div=3, inplace=Fal):
nt, c, h, w = x.size()
n_batch = nt // n_gment
x = x.view(n_batch, n_gment, c, h, w)
fold = c // fold_div
if inplace:
out = InplaceShift.apply(x, fold)
el:
out = s_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift
return out.view(nt, c, h, w)
a. 对于数据的稀疏采样和密集采样
def _sample_indices(lf, record):
"""
:param record: VideoRecord
:
return: list
"""
if lf.den_sample:  # i3d den sample
sample_pos = max(1, 1 + record.num_frames - 64)
t_stride = 64 // lf.num_gments
start_idx = 0 if sample_pos == 1 el np.random.randint(0, sample_pos - 1)
offts = [(idx * t_stride + start_idx) % record.num_frames for idx in range(lf.num_gments)]
return np.array(offts) + 1
el:  # normal sample
average_duration = (record.num_frames - lf.new_length + 1) // lf.num_gments
学位英语
if average_duration > 0:
offts = np.multiply(list(range(lf.num_gments)), average_duration) + randint(average_duration,
size=lf.num_gments)
elif record.num_frames > lf.num_gments:
offts = np.sort(randint(record.num_frames - lf.new_length + 1, size=lf.num_gments))
el:
字母表 26个 大小写offts = np.zeros((lf.num_gments,))
return offts + 1
b. ⼀般图⽚的数据增强操作,对于训练集采⽤GroupMultiScaleCrop,对于测试集则是先scale在centercrop class GroupMultiScaleCrop(object):
def __init__(lf, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
lf.scales = scales if scales is not None el [1, .875, .75, .66]
lf.max_distort = max_distort
展览的意思lf.fix_crop = fix_crop
<_fix_crop = more_fix_crop
lf.input_size = input_size if not isinstance(input_size, int) el [input_size, input_size]
lf.interpolation = Image.BILINEAR
def __call__(lf, img_group):
im_size = img_group[0].size
crop_w, crop_h, offt_w, offt_h = lf._sample_crop_size(im_size)
crop_img_group = [p((offt_w, offt_h, offt_w + crop_w, offt_h + crop_h)) for img in img_group]        ret_img_group = [size((lf.input_size[0], lf.input_size[1]), lf.interpolation)
for img in crop_img_group]
return ret_img_group
def _sample_crop_size(lf, im_size):
image_w, image_h = im_size[0], im_size[1]
# find a crop size
ba_size = min(image_w, image_h)
crop_sizes = [int(ba_size * x) for x in lf.scales]
crop_h = [lf.input_size[1] if abs(x - lf.input_size[1]) < 3 el x for x in crop_sizes]
crop_w = [lf.input_size[0] if abs(x - lf.input_size[0]) < 3 el x for x in crop_sizes]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= lf.max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not lf.fix_crop:
w_offt = random.randint(0, image_w - crop_pair[0])
h_offt = random.randint(0, image_h - crop_pair[1])
oulmel:
w_offt, h_offt = lf._sample_fix_offt(image_w, image_h, crop_pair[0], crop_pair[1])
return crop_pair[0], crop_pair[1], w_offt, h_offt
def _sample_fix_offt(lf, image_w, image_h, crop_w, crop_h):
offts = lf.fill_fix__fix_crop, image_w, image_h, crop_w, crop_h)
return random.choice(offts)
@staticmethod
def fill_fix_offt(more_fix_crop, image_w, image_h, crop_w, crop_h):
w_step = (image_w - crop_w) // 4
h_step = (image_h - crop_h) // 4
ret = list()
ret.append((0, 0))  # upper left
ret.append((4 * w_step, 0))  # upper right
ret.append((0, 4 * h_step))  # lower left
ret.append((4 * w_step, 4 * h_step))  # lower right
ret.append((2 * w_step, 2 * h_step))  # center
if more_fix_crop:
ret.append((0, 2 * h_step))  # center left
电子词典
ret.append((4 * w_step, 2 * h_step))  # center right
ret.append((2 * w_step, 4 * h_step))  # lower center
ret.append((2 * w_step, 0 * h_step))  # upper center
ret.append((1 * w_step, 1 * h_step))  # upper left quarter
ret.append((3 * w_step, 1 * h_step))  # upper right quarter
ret.append((1 * w_step, 3 * h_step))  # lower left quarter
ret.append((3 * w_step, 3 * h_step))  # lower righ quarter
return ret
c. 对于预训练模型,采⽤partialBN,即第⼀层bn冻结,开放后⾯bn层参数
def train(lf, mode=True):
"""
Override the default train() to freeze the BN parameters
:return:
"""
super(TSN, lf).train(mode)
count = 0
if lf._enable_pbn and mode:
print("Freezing BatchNorm2D except the first one.")
for m in lf.dules():
if isinstance(m, nn.BatchNorm2d):
count += 1
if count >= (2 if lf._enable_pbn el 1):
m.eval()
# shutdown update in frozen mode
quires_grad = Fal
quires_grad = Fal
d. 对于不同层采⽤不同学习率进⾏训练,
def get_optim_policies(lf):
first_conv_weight = []
first_conv_bias = []
normal_weight = []
时尚女魔头插曲normal_bias = []
lr5_weight = []
英语句子成分分析
lr10_bias = []
bn = []
custom_ops = []
conv_cnt = 0
bn_cnt = 0
for m dules():
if isinstance(m, Conv2d) or isinstance(m, Conv1d) or isinstance(m, Conv3d):                ps = list(m.parameters())
conv_cnt += 1
if conv_cnt == 1:
first_conv_weight.append(ps[0])
if len(ps) == 2:
first_conv_bias.append(ps[1])
el:
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, Linear):
ps = list(m.parameters())
if lf.fc_lr5:
lr5_weight.append(ps[0])
el:
normal_weight.append(ps[0])
if len(ps) == 2:
if lf.fc_lr5:
lr10_bias.append(ps[1])
el:
normal_bias.append(ps[1])
elif isinstance(m, BatchNorm2d):
bn_cnt += 1
# later BN's are frozen
南昌新东方英语学校
if not lf._enable_pbn or bn_cnt == 1:
elif isinstance(m, BatchNorm3d):
elif isinstance(m, BatchNorm3d):
bn_cnt += 1
# later BN's are frozen
if not lf._enable_pbn or bn_cnt == 1:
elif len(m._modules) == 0:
if len(list(m.parameters())) > 0:
rai ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
return [
{'params': first_conv_weight, 'lr_mult': 5 dality == 'Flow' el 1, 'decay_mult': 1,
'name': "first_conv_weight"},
{'params': first_conv_bias, 'lr_mult': 10 dality == 'Flow' el 2, 'decay_mult': 0,
'name': "first_conv_bias"},
{'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
'name': "normal_weight"},
{'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
'name': "normal_bias"},
{'params': bn, 'lr_mult': 1, 'decay_mult': 0,
'name': "BN scale/shift"},
{'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1,
'name': "custom_ops"},
鲁迅 朝花夕拾# for fc
{'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1,
'name': "lr5_weight"},
{'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0,
'name': "lr10_bias"},
]
TRN模型
1 TRN模型的backbone也是参照TSN模型,以代码中举例说明,前⾯提取特征,⼀样以8帧代表⼀个clip,得到8帧⼀共8x256的特征,然后⽤TRN模块,会从8帧中,选取[8, 7,6,5,4,3,2]分别作为⼦模块,对于2就是将8帧随机按顺序取其中2帧作为⼦模块的输⼊,对于所有⼦模块特征,应⽤2个卷积(先将channel变成256,在变成最终num_class),得到最终num_class特征,例如最终分类10类,得到batchx9的特征,然后将所有的⼦模块特征相加得到最终分类特征。
2 但是该模块,扩展性不好,对于较⼤输⼊帧数假设输⼊64帧,那样⼦模块太多,⽆法训练

本文发布于:2023-05-16 03:50:35,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/90/110146.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:操作   模块   特征   时域   训练   移动   提出   数据
相关文章
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图