-
Notifications
You must be signed in to change notification settings - Fork 882
/
Copy pathdensenet_handler.py
147 lines (119 loc) · 4.63 KB
/
densenet_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Module for image classification default handler
"""
import inspect
import logging
import os
import importlib.util
import time
import io
import torch
logger = logging.getLogger(__name__)
class DenseNetHandler:
"""
DenseNetHandler handler class. This handler takes an image
and returns the name of object in that image.
"""
def __init__(self):
self.model = None
self.device = None
self.initialized = False
self.context = None
self.manifest = None
self.map_location = None
def initialize(self, context):
"""First try to load torchscript else load eager mode state_dict based model"""
properties = context.system_properties
self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
self.manifest = context.manifest
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
# model def file
model_file = self.manifest["model"].get("modelFile", "")
if model_file:
logger.debug("Loading eager model")
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
else:
logger.debug("Loading torchscript model")
self.model = self._load_torchscript_model(model_pt_path)
self.model.to(self.device)
self.model.eval()
logger.debug("Model file %s loaded successfully", model_pt_path)
self.initialized = True
def _load_torchscript_model(self, model_pt_path):
return torch.jit.load(model_pt_path, map_location=self.map_location)
def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_def_path = os.path.join(model_dir, model_file)
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model.py file")
module = importlib.import_module(model_file.split(".")[0])
model_class_definitions = list_classes_from_module(module)
if len(model_class_definitions) != 1:
raise ValueError(
"Expected only one class as model definition. {}".format(
model_class_definitions
)
)
model_class = model_class_definitions[0]
state_dict = torch.load(model_pt_path, map_location=self.map_location)
model = model_class()
model.load_state_dict(state_dict)
return model
def inference(self, data, *args, **kwargs):
"""
Override to customize the inference
:param data: Torch tensor, matching the model input shape
:return: Prediction output as Torch tensor
"""
marshalled_data = data.to(self.device)
with torch.no_grad():
results = self.model(marshalled_data, *args, **kwargs)
return results
def handle(self, data, context):
"""
Entry point for default handler
"""
# It can be used for pre or post processing if needed as additional request
# information is available in context
start_time = time.time()
self.context = context
metrics = self.context.metrics
values = []
for row in data:
image = row.get("data") or row.get("body")
tensor = torch.load(io.BytesIO(image))
values.append(tensor)
data = self.inference(torch.stack(values))
stop_time = time.time()
metrics.add_time(
"HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms"
)
return [data]
def list_classes_from_module(module, parent_class=None):
"""
Parse user defined module to get all model service classes in it.
:param module:
:param parent_class:
:return: List of model service class definitions
"""
# Parsing the module to get all defined classes
classes = [
cls[1]
for cls in inspect.getmembers(
module,
lambda member: inspect.isclass(member)
and member.__module__ == module.__name__,
)
]
# filter classes that is subclass of parent_class
if parent_class is not None:
return [c for c in classes if issubclass(c, parent_class)]
return classes