Saxml is an experimental system that serves Paxml, JAX, and PyTorch models for inference. A Sax cell (aka Sax cluster) consists of an admin server and a group of model servers. The admin server keeps track of model servers, assigns published models to model servers to serve, and helps clients locate model servers serving specific published models.
In order to interact with the Sax cluster today, users can use the command line tool, saxutil, or interact directly with the Sax client.
This tutorial uses an HTTP Server to handle HTTP requests to Sax, supporting features such as model publishing, listing, updating, unpublishing, and generating predictions. The HTTP server uses the Python Sax client in order to communicate with the Sax cluster and handle routing within the Sax system. With an HTTP server, interaction with Sax can also expand to further than at the VM-level. For example, integration with GKE and load balancing will enable requests to Sax from inside and outside the GKE cluster.
This tutorial focuses on the deployment of the HTTP server and assumes you have already deployed a Sax Admin Server and Sax Model Server according to the OSS SAX Docker Guide
Build the HTTP Server image:
docker build -f Dockerfile -t sax-http .
If you haven't already, create a GCS Bucket to store Sax Cluster information:
GSBUCKET=${USER}-sax-data
gcloud storage buckets create gs://${GSBUCKET}
docker run -e SAX_ROOT=gs://${GSBUCKET}/sax-root -p 8888:8888 -it sax-http
In another terminal:
curl localhost:8888
You will see the output below:
{
"message": "HTTP Server for SAX Client"
}
curl --request POST \
--header "Content-type: application/json" \
--silent \
localhost:8888/publish \
--data \
'{
"model": "/sax/test/lm2b",
"model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest",
"checkpoint": "None",
"replicas": 1
}'
You will see the output below:
{
"model": "/sax/test/lm2b",
"path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest",
"checkpoint": "None",
"replicas": 1
}
curl --request GET \
--header "Content-type: application/json" \
--silent \
localhost:8888/listall \
--data \
'{
"sax_cell": "/sax/test"
}'
You will see the output below:
["/sax/test/lm2b"]
curl --request GET \
--header "Content-type: application/json" \
--silent \
localhost:8888/listcell \
--data \
'{
"model": "/sax/test/lm2b"
}'
You will see the output below:
{
"model": "/sax/test/lm2b",
"model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest",
"checkpoint": "None",
"max_replicas": 1,
"active_replicas": 1
}
curl --request POST \
--header "Content-type: application/json" \
--silent \
localhost:8888/generate \
--data \
'{
"model": "/sax/test/lm2b",
"query": "Q: Who is Harry Potter's mom? A: "
}'
The result should be printed in the terminal
$ curl --request POST \
--header "Content-type: application/json" \
--silent \
localhost:8888/unpublish \
--data '
{
"model": "/sax/test/lm2b"
}
'
You will see the output below:
{
"model": "/sax/test/lm2b"
}
The following are the APIs implemented in this HTTP Server, the complete Python client interface is available in the google/saxml repository
/generate
is use to generate a response from a specific model.
JSON object of the following format:
{
"model": <String>,
"query": <String>,
"extra_inputs": {
"temperature": <Number>,
"per_example_max_decode_steps": <Number>,
"per_example_top_k": <Number>,
"per_example_top_p": <Number>
}
}
model
is the name of the model to query.query
is the prompt to send to the model.extra_inputs
is an optional object that overrides the default decoding configuration of the model.temperature
: is a floating point number for the decoding temperature.per_example_max_decode_steps
: is an integer for the maximum decoding steps for each request. Needs to be smaller than maximum value of max_decode_steps configured for the published model.per_example_top_k
: is an integer for the topK used for decoding.per_example_top_p
: is a floating point number for the topP used for decoding.
JSON object with the following format:
[
[
<String>,
<Number>
],
...
]
[[<String>, <Number>]]
is an array of arrays
<String>
is the response from the model.<Number>
is a floating point number for the score of the response.
/listall
is used to list all model in a specific cell
JSON object of the following format:
{
"sax_cell": <String>
}
sax_cell
is the path to list.
JSON object with the following format:
[
<String>,
...
]
[<String>]
is an array Strings.<String>
is a model name.
/listcell
is used to list a specific model.
JSON object of the following format:
{
"model": <String>
}
model
is the name of the model.
JSON object with the following format:
{
"model": <String>,
"model_path": <String>,
"checkpoint": <String>,
"max_replicas": <Number>,
"active_replicas": <Number>
}
model
is the name of the model.model_path
is the path of the model in the Saxml model registry.checkpoint
is the location of the model checkpoint.max_replicas
is an integer for the maximum number of replicas the model be deployed on.active_replicas
is an integer for the number of replicas the model is currently deployed on.
/publish
is used to publish a new model.
JSON object of the following format:
{
"model": <String>,
"model_path": <String>,
"checkpoint": <String>,
"replicas": <Number>
}
model
is the name of the model.model_path
is the path of the model in the Saxml model registry.checkpoint
is the location of the model checkpoint.replicas
is an integer for the number of replicas of the model to deploy.
JSON object with the following format:
{
"model": <String>
}
model
is the name of the model published.
/unpublish
is used to unpublish a model.
JSON object of the following format:
{
"model": <String>
}
model
is the name of the model to unpublish.
JSON object with the following format:
{
"model": <String>
}
model
is the name of the model unpublished.
/update
is used to updated an existing model.
JSON object of the following format:
{
"model": <String>,
"model_path": <String>,
"checkpoint": <String>,
"replicas": <Number>
}
model
is the name of the model.model_path
is the path of the model in the Saxml model registry.checkpoint
is the location of the model checkpoint.replicas
is an integer for the number of replicas of the model to deploy.
JSON object with the following format:
{
"model": <String>,
"model_path": <String>,
"checkpoint": <String>,
"replicas": <Number>
}
model
is the name of the model.model_path
is the path of the model in the Saxml model registry.checkpoint
is the location of the model checkpoint.replicas
is an integer for the number of replicas of the model to deploy.