-
Notifications
You must be signed in to change notification settings - Fork 6.5k
/
Copy pathcustom_model_gemma.py
131 lines (108 loc) · 4.29 KB
/
custom_model_gemma.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Sequence
import logging
from typing import Any
from typing import Optional
import apache_beam as beam
from apache_beam.ml.inference import utils
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.options.pipeline_options import PipelineOptions
import keras_nlp
from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM
class GemmaModelHandler(ModelHandler[str, PredictionResult, GemmaCausalLM]):
def __init__(
self,
model_name: str = "gemma_2B",
):
""" Implementation of the ModelHandler interface for Gemma using text as input.
Example Usage::
pcoll | RunInference(GemmaModelHandler())
Args:
model_name: The Gemma model name. Default is gemma_2B.
"""
self._model_name = model_name
self._env_vars = {}
def share_model_across_processes(self) -> bool:
""" Indicates if the model should be loaded once-per-VM rather than
once-per-worker-process on a VM. Because Gemma is a large language model,
this will always return True to avoid OOM errors.
"""
return True
def load_model(self) -> GemmaCausalLM:
"""Loads and initializes a model for processing."""
return keras_nlp.models.GemmaCausalLM.from_preset(self._model_name)
def run_inference(
self,
batch: Sequence[str],
model: GemmaCausalLM,
inference_args: Optional[dict[str, Any]] = None
) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.
Args:
batch: A sequence of examples as text strings.
model: The Gemma model being used.
inference_args: Any additional arguments for an inference.
Returns:
An Iterable of type PredictionResult.
"""
# Loop each text string, and use a tuple to store the inference results.
predictions = []
for one_text in batch:
result = model.generate(one_text, max_length=64)
predictions.append(result)
return utils._convert_to_result(batch, predictions, self._model_name)
class FormatOutput(beam.DoFn):
def process(self, element, *args, **kwargs):
yield "Input: {input}, Output: {output}".format(
input=element.example, output=element.inference)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--messages_subscription",
required=True,
help="Pub/Sub subscription for input text messages",
)
parser.add_argument(
"--responses_topic",
required=True,
help="Pub/Sub topic for output text responses",
)
parser.add_argument(
"--model_path",
required=False,
default="gemma_2B",
help="path to the Gemma model in the custom worker container",
)
args, beam_args = parser.parse_known_args()
logging.getLogger().setLevel(logging.INFO)
beam_options = PipelineOptions(
beam_args,
streaming=True,
)
pipeline = beam.Pipeline(options=beam_options)
_ = (
pipeline | "Read Topic" >>
beam.io.ReadFromPubSub(subscription=args.messages_subscription)
| "Parse" >> beam.Map(lambda x: x.decode("utf-8"))
| "RunInference-Gemma" >> RunInference(
GemmaModelHandler(args.model_path)
) # Send the prompts to the model and get responses.
| "Format Output" >> beam.ParDo(FormatOutput()) # Format the output.
| "Publish Result" >>
beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic))
pipeline.run()