Skip to content

Commit 5ca01a1

Browse files
committed
add flow module
1 parent 6794bd5 commit 5ca01a1

File tree

8 files changed

+14012
-75
lines changed

8 files changed

+14012
-75
lines changed

datasets/ucf101.py

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,20 @@
66
import numpy as np
77
import cv2
88

9-
IMG_EXTENSIONS = [
10-
'.jpg', '.JPG', '.jpeg', '.JPEG',
11-
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12-
]
13-
14-
def is_image_file(filename):
15-
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16-
179
def find_classes(dir):
1810
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
1911
classes.sort()
2012
class_to_idx = {classes[i]: i for i in range(len(classes))}
2113
return classes, class_to_idx
2214

23-
def cv2_loader(path):
24-
return cv2.imread(path)
25-
26-
def read_split_file(root, split_file):
15+
def make_dataset(root, source):
2716

28-
if not os.path.exists(split_file):
29-
print("Split file for ucf101 dataset doesn't exist.")
17+
if not os.path.exists(source):
18+
print("Setting file %s for ucf101 dataset doesn't exist." % (source))
3019
sys.exit()
3120
else:
3221
clips = []
33-
with open(split_file) as split_f:
22+
with open(source) as split_f:
3423
data = split_f.readlines()
3524
for line in data:
3625
line_info = line.split()
@@ -40,53 +29,160 @@ def read_split_file(root, split_file):
4029
item = (clip_path, duration, target)
4130
clips.append(item)
4231
return clips
43-
32+
33+
def ReadSegmentRGB(path, offsets, new_height, new_width, new_length, is_color, name_pattern):
34+
if is_color:
35+
cv_read_flag = cv2.IMREAD_COLOR # > 0
36+
else:
37+
cv_read_flag = cv2.IMREAD_GRAYSCALE # = 0
38+
interpolation = cv2.INTER_LINEAR
39+
40+
sampled_list = []
41+
for offset_id in range(len(offsets)):
42+
offset = offsets[offset_id]
43+
for length_id in range(1, new_length+1):
44+
frame_name = name_pattern % (length_id + offset)
45+
frame_path = path + "/" + frame_name
46+
cv_img_origin = cv2.imread(frame_path, cv_read_flag)
47+
if cv_img_origin is None:
48+
print("Could not load file %s" % (frame_path))
49+
sys.exit()
50+
# TODO: error handling here
51+
if new_width > 0 and new_height > 0:
52+
cv_img = cv2.resize(cv_img_origin, (new_width, new_height), interpolation)
53+
else:
54+
cv_img = cv_img_origin
55+
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
56+
sampled_list.append(cv_img)
57+
clip_input = np.concatenate(sampled_list, axis=2)
58+
return clip_input
59+
60+
def ReadSegmentFlow(path, offsets, new_height, new_width, new_length, is_color, name_pattern):
61+
if is_color:
62+
cv_read_flag = cv2.IMREAD_COLOR # > 0
63+
else:
64+
cv_read_flag = cv2.IMREAD_GRAYSCALE # = 0
65+
interpolation = cv2.INTER_LINEAR
66+
67+
sampled_list = []
68+
for offset_id in range(len(offsets)):
69+
offset = offsets[offset_id]
70+
for length_id in range(1, new_length+1):
71+
frame_name_x = name_pattern % ("x", length_id + offset)
72+
frame_path_x = path + "/" + frame_name_x
73+
cv_img_origin_x = cv2.imread(frame_path_x, cv_read_flag)
74+
frame_name_y = name_pattern % ("y", length_id + offset)
75+
frame_path_y = path + "/" + frame_name_y
76+
cv_img_origin_y = cv2.imread(frame_path_y, cv_read_flag)
77+
if cv_img_origin_x is None or cv_img_origin_y is None:
78+
print("Could not load file %s or %s" % (frame_path_x, frame_path_y))
79+
sys.exit()
80+
# TODO: error handling here
81+
if new_width > 0 and new_height > 0:
82+
cv_img_x = cv2.resize(cv_img_origin_x, (new_width, new_height), interpolation)
83+
cv_img_y = cv2.resize(cv_img_origin_y, (new_width, new_height), interpolation)
84+
else:
85+
cv_img_x = cv_img_origin_x
86+
cv_img_y = cv_img_origin_y
87+
sampled_list.append(np.expand_dims(cv_img_x, 2))
88+
sampled_list.append(np.expand_dims(cv_img_y, 2))
89+
90+
clip_input = np.concatenate(sampled_list, axis=2)
91+
return clip_input
92+
4493

4594
class ucf101(data.Dataset):
4695

47-
def __init__(self, root, split_file, phase, new_length=1, transform=None, target_transform=None,
48-
video_transform=None, loader=cv2_loader):
96+
def __init__(self,
97+
root,
98+
source,
99+
phase,
100+
modality,
101+
name_pattern=None,
102+
is_color=True,
103+
num_segments=1,
104+
new_length=1,
105+
new_width=0,
106+
new_height=0,
107+
transform=None,
108+
target_transform=None,
109+
video_transform=None):
110+
49111
classes, class_to_idx = find_classes(root)
50-
clips = read_split_file(root, split_file)
112+
clips = make_dataset(root, source)
51113

52114
if len(clips) == 0:
53115
raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
54116
"Check your data directory."))
55117

56118
self.root = root
57-
self.split_file = split_file
119+
self.source = source
58120
self.phase = phase
59-
self.clips = clips
121+
self.modality = modality
122+
60123
self.classes = classes
61124
self.class_to_idx = class_to_idx
125+
self.clips = clips
126+
127+
if name_pattern:
128+
self.name_pattern = name_pattern
129+
else:
130+
if self.modality == "rgb":
131+
self.name_pattern = "image_%04d.jpg"
132+
elif self.modality == "flow":
133+
self.name_pattern = "flow_%s_%04d.jpg"
134+
135+
self.is_color = is_color
136+
self.num_segments = num_segments
62137
self.new_length = new_length
138+
self.new_width = new_width
139+
self.new_height = new_height
140+
63141
self.transform = transform
64142
self.target_transform = target_transform
65143
self.video_transform = video_transform
66-
self.loader = loader
67144

68145
def __getitem__(self, index):
69146
path, duration, target = self.clips[index]
70-
frame_list = os.listdir(path)
71-
frame_list.sort()
72-
if self.phase == "train":
73-
sampled_frameID = random.randint(0, duration-self.new_length)
74-
elif self.phase == "val":
75-
if duration >= self.new_length:
76-
sampled_frameID = int((duration - self.new_length + 1)/2)
147+
average_duration = int(duration / self.num_segments)
148+
offsets = []
149+
for seg_id in range(self.num_segments):
150+
if self.phase == "train":
151+
if average_duration >= self.new_length:
152+
offset = random.randint(0, average_duration - self.new_length)
153+
# No +1 because randint(a,b) return a random integer N such that a <= N <= b.
154+
offsets.append(offset + seg_id * average_duration)
155+
else:
156+
offsets.append(0)
157+
elif self.phase == "val":
158+
if average_duration >= self.new_length:
159+
offsets.append(int((average_duration - self.new_length + 1)/2 + seg_id * average_duration))
160+
else:
161+
offsets.append(0)
77162
else:
78-
sampled_frameID = 0
163+
print("Only phase train and val are supported.")
164+
165+
166+
if self.modality == "rgb":
167+
clip_input = ReadSegmentRGB(path,
168+
offsets,
169+
self.new_height,
170+
self.new_width,
171+
self.new_length,
172+
self.is_color,
173+
self.name_pattern
174+
)
175+
elif self.modality == "flow":
176+
clip_input = ReadSegmentFlow(path,
177+
offsets,
178+
self.new_height,
179+
self.new_width,
180+
self.new_length,
181+
self.is_color,
182+
self.name_pattern
183+
)
79184
else:
80-
print("No such phase. Only train and val are supported.")
81-
82-
sampled_list = []
83-
for frame_id in range(self.new_length):
84-
fname = os.path.join(path, frame_list[sampled_frameID+frame_id])
85-
if is_image_file(fname):
86-
img = self.loader(fname)
87-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
88-
sampled_list.append(img)
89-
clip_input = np.concatenate(sampled_list, axis=2)
185+
print("No such modality %s" % (self.modality))
90186

91187
if self.transform is not None:
92188
clip_input = self.transform(clip_input)
@@ -97,5 +193,6 @@ def __getitem__(self, index):
97193

98194
return clip_input, target
99195

196+
100197
def __len__(self):
101198
return len(self.clips)

main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@
4545
help='number of total epochs to run')
4646
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
4747
help='manual epoch number (useful on restarts)')
48-
parser.add_argument('-b', '--batch-size', default=50, type=int,
48+
parser.add_argument('-b', '--batch-size', default=32, type=int,
4949
metavar='N', help='mini-batch size (default: 50)')
50+
parser.add_argument('--iter-size', default=4, type=int,
51+
metavar='I', help='iter size as in Caffe to reduce memory usage (default: 8)')
5052
parser.add_argument('--new_length', default=1, type=int,
5153
metavar='N', help='length of sampled video frames (default: 1)')
5254
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,

0 commit comments

Comments
 (0)