jina_server.py
Overview
The jina_server.py file implements a Jina deployment to serve a transformer-based causal language model (such as GPT-style models) with two primary functionalities: batch text generation and token-by-token streaming generation. It leverages Hugging Face's transformers library for model loading and inference, and integrates with Jina for scalable, asynchronous serving.
The file defines data models for input/output documents, an executor class encapsulating the generation logic, and a CLI entry point to launch the deployment. It supports customizable generation parameters through the input prompt and uses device-aware model loading with PyTorch.
Detailed Description
Data Models
The file uses DocArray BaseDoc subclasses to define the input and output data structures for the service:
Prompt class
class Prompt(BaseDoc):
message: list[dict]
gen_conf: dict
Purpose: Represents the input prompt sent to the service.
Attributes:
message: A list of dictionaries representing chat messages or prompt components. Each dictionary typically encodes role/content pairs for chat-based generation.gen_conf: A dictionary of generation configuration parameters passed to the model'sGenerationConfig(e.g.,max_new_tokens,temperature,top_k).
Generation class
class Generation(BaseDoc):
text: str
Purpose: Represents the output generated text from the model.
Attributes:
text: The generated string text output.
Global Variables
tokenizer = None
model_name = ""
tokenizer: Global variable for the tokenizer instance, loaded at runtime.model_name: Global string storing the model identifier or path.
TokenStreamingExecutor Class
class TokenStreamingExecutor(Executor):
...
Purpose: Jina executor implementing methods to generate text from prompts, supporting both full generation and streaming token-by-token output.
Inheritance: Inherits from
jina.Executor.
Constructor: __init__
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype="auto"
)
Loads the causal language model specified by the global
model_name.Uses
device_map="auto"to automatically place model layers on available devices (CPU/GPU).Uses
torch_dtype="auto"to infer optimal tensor data types for performance.
Method: generate
@requests(on="/chat")
async def generate(self, doc: Prompt, **kwargs) -> Generation:
...
yield Generation(text=response)
Purpose: Handles synchronous batch generation requests at the
/chatendpoint.Parameters:
doc(Prompt): Input document containing chat messages and generation configuration.**kwargs: Additional parameters (ignored).
Returns: An asynchronous generator yielding a single
Generationdocument containing the full generated text.Workflow:
Applies a chat template to the input messages to produce a single prompt string (via
tokenizer.apply_chat_template).Tokenizes the prompt into model inputs.
Creates a
GenerationConfigobject from the provided generation parameters, setting EOS and PAD tokens.Invokes the model's
.generate()method to produce generated token IDs.Extracts newly generated tokens by removing the prompt tokens from output.
Decodes the generated tokens into a string.
Yields the generated text wrapped in a
Generationdocument.
Method: task
@requests(on="/stream")
async def task(self, doc: Prompt, **kwargs) -> Generation:
...
yield Generation(text=...)
Purpose: Handles streaming token-by-token generation requests at the
/streamendpoint.Parameters:
doc(Prompt): Input prompt with generation configuration.**kwargs: Additional parameters (ignored).
Returns: Asynchronous generator yielding incremental
Generationdocuments for each generated token/chunk.Workflow:
Applies chat template and tokenizes input prompt.
Determines maximum tokens to generate (
max_new_tokens), defaulting to 512 if not specified.Iteratively calls
model.generate()withmax_new_tokens=1to produce one token at a time.Stops if the generated token is the EOS token.
Yields each newly generated token as a decoded string incrementally.
Updates input tokens to include newly generated tokens for next iteration.
Usage: Enables real-time streaming of generated text, useful for interactive chat or UI feedback.
Main Execution Block
if __name__ == "__main__":
...
Parses command-line arguments:
--model_name: Hugging Face model name or path (required).--port: Port number for Jina gRPC server (default 12345).
Loads the tokenizer globally using
AutoTokenizer.from_pretrained.Launches a Jina
Deploymentserving theTokenStreamingExecutoron the specified port and using gRPC protocol.Blocks indefinitely to serve requests.
Important Implementation Details and Algorithms
Model Loading: Uses Hugging Face's
AutoModelForCausalLMandAutoTokenizerto dynamically load any causal language model compatible with transformers.Device Mapping: The model is automatically placed on available hardware (CPU/GPU) using
device_map="auto".GenerationConfig: Uses Hugging Face's
GenerationConfigto flexibly configure generation parameters such as temperature, top-k, max tokens, etc., passed from the client.Token Streaming: Implements a token-by-token streaming approach by generating one token at a time and yielding it immediately, allowing clients to receive partial results without waiting for full generation.
Prompt Handling: Uses
tokenizer.apply_chat_templateto convert structured chat messages into prompt strings suitable for model input, facilitating chat-style generation.
Integration with Other System Components
Jina Framework: This file is designed as a Jina executor and deployment script, integrating with the Jina ecosystem for scalable serving.
DocArray: Uses DocArray's
BaseDocfor input/output schema enforcement and serialization.Hugging Face Transformers: Depends heavily on transformers models and tokenizers for natural language processing.
CLI Usage: The file is intended to be executed as a standalone server exposing gRPC endpoints
/chatand/stream.This executor can be integrated into larger pipelines or microservices using Jina's orchestration.
Usage Examples
Running the Server
python jina_server.py --model_name gpt2 --port 12345
Sending a Batch Generation Request to /chat
from docarray import Document
from jina import Client
prompt = {
"message": [{"role": "user", "content": "Hello, how are you?"}],
"gen_conf": {"max_new_tokens": 50, "temperature": 0.7}
}
client = Client(host='grpc://localhost:12345')
response = client.post("/chat", inputs=Document(prompt), return_type=Document)
print(response.text)
Receiving Streaming Tokens from /stream
async for chunk in client.stream("/stream", inputs=Document(prompt)):
print(chunk.text, end="", flush=True)
Mermaid Class Diagram
classDiagram
class Prompt {
+message: list[dict]
+gen_conf: dict
}
class Generation {
+text: str
}
class TokenStreamingExecutor {
-model: AutoModelForCausalLM
+__init__(**kwargs)
+generate(doc: Prompt, **kwargs) async Generator
+task(doc: Prompt, **kwargs) async Generator
}
TokenStreamingExecutor ..> Prompt : uses
TokenStreamingExecutor ..> Generation : returns
Summary
jina_server.py is a specialized Jina executor deployment for serving transformer causal language models with both synchronous batch and asynchronous streaming text generation capabilities. It provides a flexible, scalable interface for chat-based AI applications, leveraging state-of-the-art NLP models from Hugging Face and the Jina ecosystem for deployment and request handling.