jina_server.py


Overview

The jina_server.py file implements a Jina deployment to serve a transformer-based causal language model (such as GPT-style models) with two primary functionalities: batch text generation and token-by-token streaming generation. It leverages Hugging Face's transformers library for model loading and inference, and integrates with Jina for scalable, asynchronous serving.

The file defines data models for input/output documents, an executor class encapsulating the generation logic, and a CLI entry point to launch the deployment. It supports customizable generation parameters through the input prompt and uses device-aware model loading with PyTorch.


Detailed Description

Data Models

The file uses DocArray BaseDoc subclasses to define the input and output data structures for the service:

Prompt class

class Prompt(BaseDoc):
    message: list[dict]
    gen_conf: dict

Generation class

class Generation(BaseDoc):
    text: str

Global Variables

tokenizer = None
model_name = ""

TokenStreamingExecutor Class

class TokenStreamingExecutor(Executor):
    ...

Constructor: __init__

def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto", torch_dtype="auto"
    )

Method: generate

@requests(on="/chat")
async def generate(self, doc: Prompt, **kwargs) -> Generation:
    ...
    yield Generation(text=response)

Method: task

@requests(on="/stream")
async def task(self, doc: Prompt, **kwargs) -> Generation:
    ...
    yield Generation(text=...)

Main Execution Block

if __name__ == "__main__":
    ...

Important Implementation Details and Algorithms


Integration with Other System Components


Usage Examples

Running the Server

python jina_server.py --model_name gpt2 --port 12345

Sending a Batch Generation Request to /chat

from docarray import Document
from jina import Client

prompt = {
    "message": [{"role": "user", "content": "Hello, how are you?"}],
    "gen_conf": {"max_new_tokens": 50, "temperature": 0.7}
}

client = Client(host='grpc://localhost:12345')
response = client.post("/chat", inputs=Document(prompt), return_type=Document)
print(response.text)

Receiving Streaming Tokens from /stream

async for chunk in client.stream("/stream", inputs=Document(prompt)):
    print(chunk.text, end="", flush=True)

Mermaid Class Diagram

classDiagram
    class Prompt {
        +message: list[dict]
        +gen_conf: dict
    }
    class Generation {
        +text: str
    }
    class TokenStreamingExecutor {
        -model: AutoModelForCausalLM
        +__init__(**kwargs)
        +generate(doc: Prompt, **kwargs) async Generator
        +task(doc: Prompt, **kwargs) async Generator
    }
    TokenStreamingExecutor ..> Prompt : uses
    TokenStreamingExecutor ..> Generation : returns

Summary

jina_server.py is a specialized Jina executor deployment for serving transformer causal language models with both synchronous batch and asynchronous streaming text generation capabilities. It provides a flexible, scalable interface for chat-based AI applications, leveraging state-of-the-art NLP models from Hugging Face and the Jina ecosystem for deployment and request handling.