-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
74 lines (56 loc) · 2.29 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, MaxPooling2D, Flatten, Dropout, UpSampling2D
from tensorflow.keras.models import Model
class VGGBlock(Model):
"""
custom conv2d layer coupled with relu activation and max pooling
"""
def __init__(self, filters: int, kernel_size: int, repetitions: int,
pool_size: int=2, strides: int=2, block: int=0, dropout: float=None):
super(VGGBlock, self).__init__()
self.filters = filters
self.kernel_size = kernel_size
self.repetitions = repetitions
if dropout:
self.dropout = Dropout(dropout)
layers = []
for i in range(repetitions):
layers.append(Conv2D(filters, kernel_size, strides=1,
padding="same", activation="relu", name=f"conv{block}_{i+1}"))
self.rows = tf.keras.Sequential(layers)
self.pool = MaxPooling2D(pool_size=pool_size, strides=strides, name=f"pool{block}")
def call(self, inputs: tf.Tensor):
x = self.rows(inputs)
out = self.pool(x)
return out
class VGG16(Model):
"""
VGG16
"""
def __init__(self, num_classes: int=10, upsample_input: int=None):
super(VGG16, self).__init__()
self.num_classes = num_classes
self.upsample_input = upsample_input
self.block1 = VGGBlock(64, 3, 2, block=1)
self.block2 = VGGBlock(128, 3, 2, block=2)
self.block3 = VGGBlock(256, 3, 3, block=3)
self.block4 = VGGBlock(512, 3, 3, block=4)
self.block5 = VGGBlock(512, 3, 3, block=5)
self.flatten = Flatten()
self.fc1 = Dense(4096, activation="relu")
self.fc2 = Dense(4096, activation="relu")
self.dropout = Dropout(0.5)
self.classifier = Dense(self.num_classes, activation="softmax")
def call(self, inputs: tf.Tensor):
if self.upsample_input:
inputs = UpSampling2D(size=self.upsample_input)(inputs)
x = self.block1(inputs)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.flatten(x)
x = self.dropout(self.fc1(x))
x = self.dropout(self.fc2(x))
out = self.classifier(x)
return out