-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathimage-tfrecord-builder.py
89 lines (73 loc) · 3.13 KB
/
image-tfrecord-builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import sys
import random
import cv2
import numpy as np
from tqdm import tqdm
import tensorflow as tf
from settings import app
def _load_image(path):
image = cv2.imread(path)
if image is not None:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image.astype(np.float32)
return None
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _build_examples_list(input_folder, seed):
examples = []
for classname in os.listdir(input_folder):
class_dir = os.path.join(input_folder, classname)
if (os.path.isdir(class_dir)):
for filename in os.listdir(class_dir):
filepath = os.path.join(class_dir, filename)
example = {
'classname': classname,
'path': filepath
}
examples.append(example)
random.seed(seed)
random.shuffle(examples)
return examples
def _split_list(alist, wanted_parts=1):
length = len(alist)
return [ alist[i*length // wanted_parts: (i+1)*length // wanted_parts]
for i in range(wanted_parts) ]
def _get_examples_share(examples, training_split):
examples_size = len(examples)
len_training_examples = int(examples_size * training_split)
return np.split(examples, [len_training_examples])
def _write_tfrecord(examples, output_filename):
writer = tf.python_io.TFRecordWriter(output_filename)
for example in tqdm(examples):
try:
image = _load_image(example['path'])
if image is not None:
encoded_image_string = cv2.imencode('.jpg', image)[1].tostring()
feature = {
'train/label': _bytes_feature(tf.compat.as_bytes(example['classname'])),
'train/image': _bytes_feature(tf.compat.as_bytes(encoded_image_string))
}
tf_example = tf.train.Example(features = tf.train.Features(feature=feature))
writer.write(tf_example.SerializeToString())
except Exception as inst:
print(inst)
pass
writer.close()
def _write_sharded_tfrecord(examples, number_of_shards, base_output_filename, is_training = True):
sharded_examples = _split_list(examples, number_of_shards)
for count, shard in tqdm(enumerate(sharded_examples, start = 1)):
output_filename = '{0}_{1}_{2:02d}of{3:02d}.tfrecord'.format(
base_output_filename,
'training' if is_training else 'test',
count,
number_of_shards
)
_write_tfrecord(shard, output_filename)
examples = _build_examples_list(app['IMAGES_INPUT_FOLDER'], app['SEED'])
training_examples, test_examples = _get_examples_share(examples, app['TRAINING_EXAMPLES_SPLIT']) # pylint: disable=unbalanced-tuple-unpacking
print("Creating training shards", flush = True)
_write_sharded_tfrecord(training_examples, app['NUMBER_OF_SHARDS'], app['OUTPUT_FILENAME'])
print("\nCreating test shards", flush = True)
_write_sharded_tfrecord(test_examples, app['NUMBER_OF_SHARDS'], app['OUTPUT_FILENAME'], False)
print("\n", flush = True)