-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathprint.py
79 lines (59 loc) · 2.83 KB
/
print.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Parse arguments
import argparse
NLAYER_DISPLAY = 5
parser = argparse.ArgumentParser()
parser.add_argument("--verbose",
action="store_true",
help="Print all.")
parser.add_argument("-n",
dest="nlayer_display",
type=int,
default=NLAYER_DISPLAY,
help="Number of layers to be printed.")
args = parser.parse_args()
# Load required modules
print("Loading the required Python modules ...")
from utils import get_predict_api
# Load model
print("\nLoading the model ...")
_, model = get_predict_api()
# Print model summary
class PrintFn:
def __init__(self, verbose=False, nlayer_display=NLAYER_DISPLAY):
self.nlayer = 0 # Number of layers
self.nlayer_display = nlayer_display # Number of layers to be displayed
self.layer_start = False # Index of the first layer starts
self.layer_end = False # Index of the last layer ends
self.outer_border = "====" # Keras use it to demarcate header from actual layer
self.inner_border = "____" # Keras use it to demarcate layers
self.verbose = verbose
self.summary = [] # All lines from Keras model.summary()
self.layer_index = [] # Indexes of all layers
def __mark_layer(self):
self.nlayer += 1
self.layer_index.append(len(self.summary)-1)
def __call__(self, line):
self.summary.append(line)
if not self.layer_start: # Header section
if line.startswith(self.outer_border):
self.layer_start = len(self.summary) - 1
self.__mark_layer()
elif not self.layer_end: # Layers section
if line.startswith(self.inner_border):
self.__mark_layer()
if line.startswith(self.outer_border): # Summary section
self.layer_end = len(self.summary) - 1
# Actual printing is carried out when the last line is feeded
elif line.startswith(self.inner_border): # Last line
# Add number of layers to summary
self.summary.insert(-1, "Different layers: {}".format(self.nlayer))
# If not verbose or model.summary() is short, print all
if self.verbose or self.nlayer < (self.nlayer_display * 2 + 3):
print("\n".join(self.summary))
else: # print just the head and tail nlayer_display of model.summary()
head = self.layer_index[self.nlayer_display] + 1
tail = self.layer_index[self.nlayer - self.nlayer_display]
print("\n".join(self.summary[:head]))
print("\n ...\n" * 2)
print("\n".join(self.summary[tail:]))
model.summary(print_fn=PrintFn(verbose=args.verbose, nlayer_display=args.nlayer_display))