-
Notifications
You must be signed in to change notification settings - Fork 882
/
Copy pathllm_launcher.py
287 lines (245 loc) · 7.61 KB
/
llm_launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import argparse
import contextlib
import os
import shutil
import subprocess
from pathlib import Path
from signal import pause
import torch
import yaml
from model_archiver import ModelArchiverConfig
from model_archiver.model_packaging import generate_model_archive
from ts.launcher import start, stop
from ts.utils.hf_utils import download_model
def create_tensorrt_llm_engine(
model_store, model_name, dtype, snapshot_path, max_batch_size
):
if not Path("/tmp/TensorRT-LLM").exists():
subprocess.run(
[
"git",
"clone",
"https://github.com/NVIDIA/TensorRT-LLM.git",
"-b",
"v0.12.0",
"/tmp/TensorRT-LLM",
]
)
if not Path(f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16").exists():
subprocess.run(
[
"python",
"/tmp/TensorRT-LLM/examples/llama/convert_checkpoint.py",
"--model_dir",
snapshot_path,
"--output_dir",
f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16",
"--dtype",
dtype,
]
)
if not Path(f"{model_store}/{model_name}/{model_name}-engine").exists():
subprocess.run(
[
"trtllm-build",
"--checkpoint_dir",
f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16",
"--gemm_plugin",
dtype,
"--gpt_attention_plugin",
dtype,
"--max_batch_size",
f"{max_batch_size}",
"--output_dir",
f"{model_store}/{model_name}/{model_name}-engine",
]
)
def get_model_config(args, model_snapshot_path=None):
model_config = {
"minWorkers": 1,
"maxWorkers": 1,
"batchSize": 1,
"maxBatchDelay": 100,
"responseTimeout": 1200,
"startupTimeout": args.startup_timeout,
"deviceType": "gpu",
"asyncCommunication": True,
}
if args.engine == "vllm":
download_dir = getattr(args, "vllm_engine.download_dir")
download_dir = (
Path(download_dir).resolve().as_posix() if download_dir else download_dir
)
model_config.update(
{
"parallelLevel": torch.cuda.device_count()
if torch.cuda.is_available
else 1,
"handler": {
"model_path": args.model_id,
"vllm_engine_config": {
"max_num_seqs": getattr(args, "vllm_engine.max_num_seqs"),
"max_model_len": getattr(args, "vllm_engine.max_model_len"),
"download_dir": download_dir,
"tensor_parallel_size": torch.cuda.device_count()
if torch.cuda.is_available
else 1,
},
},
}
)
if hasattr(args, "lora_adapter_ids"):
raise NotImplementedError("Lora setting needs to be implemented")
lora_adapter_ids = args.lora_adapter_ids.split(";")
model_config["handler"]["vllm_engine_config"].update(
{
"enable_lora": True,
}
)
elif args.engine == "trt_llm":
model_config.update(
{
"handler": {
"tokenizer_dir": os.path.join(os.getcwd(), model_snapshot_path),
"engine_dir": f"{args.model_name}-engine",
"kv_cache_config": {
"free_gpu_memory_fraction": getattr(
args, "trt_llm_engine.kv_cache_free_gpu_memory_fraction"
),
},
}
}
)
else:
raise RuntimeError("Unsupported LLM Engine")
return model_config
@contextlib.contextmanager
def create_mar_file(args, model_snapshot_path=None):
mar_file_path = Path(args.model_store) / args.model_name
model_config_yaml = Path(args.model_store) / "model-config.yaml"
with model_config_yaml.open("w") as f:
yaml.dump(get_model_config(args, model_snapshot_path), f)
config = ModelArchiverConfig(
model_name=args.model_name,
version="1.0",
handler=f"{args.engine}_handler",
serialized_file=None,
export_path=args.model_store,
requirements_file=None,
runtime="python",
force=True,
config_file=model_config_yaml.as_posix(),
archive_format="no-archive",
)
if not mar_file_path.exists():
generate_model_archive(config)
model_config_yaml.unlink()
assert mar_file_path.exists()
yield mar_file_path.as_posix()
if args.engine == "vllm":
shutil.rmtree(mar_file_path)
def main(args):
"""
Register the model in torchserve
"""
model_store_path = Path(args.model_store)
model_store_path.mkdir(parents=True, exist_ok=True)
model_snapshot_path = (
download_model(args.model_id) if args.engine == "trt_llm" else None
)
with create_mar_file(args, model_snapshot_path):
if args.engine == "trt_llm":
create_tensorrt_llm_engine(
args.model_store,
args.model_name,
args.dtype,
model_snapshot_path,
getattr(args, "trt_llm_engine.max_batch_size"),
)
try:
start(
model_store=args.model_store,
no_config_snapshots=True,
models=args.model_name,
disable_token=args.disable_token_auth,
)
pause()
except KeyboardInterrupt:
pass
finally:
stop(wait=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="model",
help="Model name",
)
parser.add_argument(
"--model_store",
type=str,
default="model_store",
help="Model store",
)
parser.add_argument(
"--model_id",
type=str,
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
help="Model id",
)
parser.add_argument(
"--disable_token_auth",
action="store_true",
help="Disable token authentication",
)
parser.add_argument(
"--vllm_engine.max_num_seqs",
type=int,
default=256,
help="Max sequences in vllm engine",
)
parser.add_argument(
"--vllm_engine.max_model_len",
type=int,
default=None,
help="Model context length",
)
parser.add_argument(
"--vllm_engine.download_dir",
type=str,
default=None,
help="Cache dir",
)
parser.add_argument(
"--startup_timeout",
type=int,
default=1200,
help="Model startup timeout in seconds",
)
parser.add_argument(
"--engine",
type=str,
default="vllm",
help="LLM engine",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="Data type",
)
parser.add_argument(
"--trt_llm_engine.max_batch_size",
type=int,
default=4,
help="Max batch size",
)
parser.add_argument(
"--trt_llm_engine.kv_cache_free_gpu_memory_fraction",
type=int,
default=0.1,
help="KV Cache free gpu memory fraction",
)
args = parser.parse_args()
main(args)