gemma4-pytorch-codex is a clean PyTorch implementation of the Gemma 4 model family.
It is meant to stay close to the original JAX structure while still looking like normal PyTorch code:
- text, vision, and audio towers are separate and reusable
- the top-level model is thin
- generation and KV-cache are built in
- multimodal prompt expansion is explicit instead of hidden in a large framework wrapper
- checkpoint conversion supports both Hugging Face safetensors and Orbax/JAX checkpoints
The install/import split is:
- distribution name:
gemma4-pytorch-codex - import package:
gemma4_pt_codex
Typical usage:
import gemma4_pt_codex as gemma4What is working today:
- text generation
- KV-cache decode
- image preprocessing and image-conditioned generation
- audio preprocessing and audio-conditioned generation
- native save/load
- conversion from local HF safetensors checkpoints
- conversion from local Orbax/JAX checkpoints
The most-tested public checkpoint path right now is the small instruction-tuned model:
google/gemma-4-e2b-it
That checkpoint has been exercised locally for:
- plain text generation
- image captioning
- audio transcription-style prompts
Editable install:
cd gemma4_pytorch_codex
python -m pip install -e .With PDM:
cd gemma4_pytorch_codex
pdm installConversion extras:
cd gemma4_pytorch_codex
python -m pip install -e ".[convert]"Dev/test extras:
cd gemma4_pytorch_codex
python -m pip install -e ".[dev]"Native checkpoints use:
config.jsonmodel.safetensors- tokenizer assets such as
tokenizer.modelortokenizer.json
Load a native checkpoint:
import torch
import gemma4_pt_codex as gemma4
tokenizer = gemma4.Gemma4Tokenizer.from_pretrained("/path/to/checkpoint_dir")
model = gemma4.Gemma4Model.from_pretrained(
"/path/to/checkpoint_dir",
dtype=torch.bfloat16,
)
model.eval()If you want a different attention backend at load time:
model = gemma4.Gemma4Model.from_pretrained(
"/path/to/checkpoint_dir",
attn_impl="sdpa",
)Supported attention implementations:
"eager""sdpa"
"sdpa" currently applies to text and vision. Audio remains eager.
The simplest path is generate_text():
import gemma4_pt_codex as gemma4
tokenizer = gemma4.Gemma4Tokenizer.from_pretrained("/path/to/checkpoint_dir")
model = gemma4.Gemma4Model.from_pretrained("/path/to/checkpoint_dir")
model.eval()
text = model.generate_text(
tokenizer,
"Write a haiku about cedar trees in coastal fog.",
max_new_tokens=32,
do_sample=False,
)
print(text)For instruction-tuned checkpoints, use the turn format explicitly:
prompt = (
"<|turn>user\n"
"Explain rotary position embeddings in plain English.\n"
"<turn|>\n"
"<|turn>model\n"
)
text = model.generate_text(
tokenizer,
prompt,
max_new_tokens=128,
do_sample=False,
)
print(text)If you want lower-level control, call prepare_inputs() and generate() directly:
prepared = model.prepare_inputs(
tokenizer,
prompt,
)
generated = model.generate(
prepared.input_ids,
attention_mask=prepared.attention_mask,
max_new_tokens=128,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
)
continuation = generated[:, prepared.input_ids.shape[1] :]
print(tokenizer.batch_decode(continuation, skip_special_tokens=True)[0])The package includes a Gemma4ImageProcessor for raw image preprocessing and a higher-level
Gemma4Processor that handles multimodal prompt expansion.
The user-facing rule is simple:
- put one visible
<|image|>token in the prompt for each image - pass the actual images through
images=...
Example captioning flow:
from PIL import Image
import gemma4_pt_codex as gemma4
tokenizer = gemma4.Gemma4Tokenizer.from_pretrained("/path/to/checkpoint_dir")
model = gemma4.Gemma4Model.from_pretrained("/path/to/checkpoint_dir")
image = Image.open("/path/to/image.jpg").convert("RGB")
prompt = (
"<|turn>user\n"
"<|image|>Caption this image in one short sentence.\n"
"<turn|>\n"
"<|turn>model\n"
)
text = model.generate_text(
tokenizer,
prompt,
images=image,
max_new_tokens=64,
do_sample=False,
)
print(text)Multiple images work the same way:
- include multiple visible
<|image|>tokens - pass the images in the same order
Internally, the processor expands each visible image token into the Gemma4-style image span:
\n\n<|image>- internal soft-image placeholder repeated
Ntimes <image|>\n\n
Those internal placeholder positions are later replaced with projected vision tokens before the text decoder runs.
The raw-image path is:
- Convert to RGB
- Resize with aspect ratio preserved
- No crop
- Scale pixels to
[0, 1] - Patchify
- Pad to the configured patch budget
- Build 2D patch positions
The vision stack then applies the model-side [0, 1] -> [-1, 1] patch normalization before the patch
projection, matching the effective JAX behavior while keeping the patch embed layer itself clean.
If you want the processor output directly:
processor = gemma4.Gemma4ImageProcessor.from_config(model.config.vision)
batch = processor.preprocess(image)
vision_tokens, vision_mask = model.encode_images_to_text(
batch.pixel_values,
batch.image_position_ids,
)Current practical note:
- general captioning works
- OCR-style prompts may work on simple large text
- dense text-heavy images are still a harder case
The package includes a Gemma4AudioProcessor and the same high-level Gemma4Processor handles audio
placeholder expansion.
The user-facing rule is the audio equivalent of image prompting:
- put one visible
<|audio|>token in the prompt for each audio clip - pass the actual clip through
audios=...
Example transcription-style flow:
import gemma4_pt_codex as gemma4
tokenizer = gemma4.Gemma4Tokenizer.from_pretrained("/path/to/checkpoint_dir")
model = gemma4.Gemma4Model.from_pretrained("/path/to/checkpoint_dir")
prompt = (
"<|turn>user\n"
"<|audio|>\n"
"Transcribe this audio clip exactly.\n"
"<turn|>\n"
"<|turn>model\n"
)
text = model.generate_text(
tokenizer,
prompt,
audios="/path/to/audio.wav",
max_new_tokens=128,
do_sample=False,
)
print(text)You can also pass waveform data directly:
import torch
waveform = torch.randn(16000 * 5)
text = model.generate_text(
tokenizer,
prompt,
audios=(waveform, 16000),
max_new_tokens=128,
do_sample=False,
)Important detail:
- if you pass a raw tensor or array without a sample rate, it is assumed to already be
16 kHz - if you pass
(waveform, sample_rate)or a file path, the audio processor will resample to16 kHz
Internally, the processor expands each visible audio token into:
<|audio>- internal soft-audio placeholder repeated
Ntimes <audio|>
The audio path is:
- Load the waveform
- Convert to mono
- Convert to
float32 - Resample to
16 kHzif needed - Compute log-mel features
- Build a valid-frame mask
- Compute the number of soft audio tokens implied by the waveform length
If you want the processor output directly:
audio_processor = gemma4.Gemma4AudioProcessor.from_config(model.config.audio)
audio_batch = audio_processor.preprocess("/path/to/audio.wav")
audio_tokens, audio_mask = model.encode_audio_to_text(
audio_batch.input_features,
audio_batch.input_features_mask,
)The current public-checkpoint path has been validated on transcription-style prompts against the small instruction-tuned checkpoint.
Gemma4Model.prepare_inputs() is the easiest way to see exactly what will be fed into the model:
prepared = model.prepare_inputs(
tokenizer,
prompt,
images=image,
audios="/path/to/audio.wav",
)
print(prepared.input_ids.shape)
print(None if prepared.vision_tokens is None else prepared.vision_tokens.shape)
print(None if prepared.audio_tokens is None else prepared.audio_tokens.shape)That returns a Gemma4PreparedInputs object with:
input_idsattention_maskvision_tokensvision_token_maskaudio_tokensaudio_token_mask
It can be moved and unpacked directly:
prepared = prepared.to("cuda", dtype=torch.bfloat16)
output = model(**prepared.as_forward_kwargs(), return_hidden_states=True)If you want the tokenizer plus multimodal preprocessing layer directly, build a processor:
processor = model.build_processor(tokenizer)
batch = processor(
prompt,
images=image,
audios="/path/to/audio.wav",
add_bos=True,
padding=True,
)Gemma4Tokenizer supports:
- SentencePiece tokenizer assets such as
tokenizer.model - HF fast-tokenizer assets such as
tokenizer.json
Example:
import gemma4_pt_codex as gemma4
tokenizer = gemma4.Gemma4Tokenizer.from_pretrained("/path/to/checkpoint_dir")
ids = tokenizer.encode("hello world", add_bos=True)
text = tokenizer.decode(ids)CLI:
gemma4-pt-codex-convert hf /path/to/hf_checkpoint /path/to/native_checkpointPython:
import gemma4_pt_codex as gemma4
gemma4.convert_hf_checkpoint(
"/path/to/hf_checkpoint",
"/path/to/native_checkpoint",
)Supported HF inputs:
- single-file safetensors checkpoints
- sharded safetensors checkpoints
- tokenizer assets from either SentencePiece or
tokenizer.json
CLI:
gemma4-pt-codex-convert orbax /path/to/orbax_checkpoint /path/to/native_checkpoint --variant gemma-4-e2bPython:
import gemma4_pt_codex as gemma4
gemma4.convert_orbax_checkpoint(
"/path/to/orbax_checkpoint",
"/path/to/native_checkpoint",
variant="gemma-4-e2b",
)The Orbax path is the closest route to the original JAX parameter tree.
The modalities are intentionally separated.
Text-only:
import gemma4_pt_codex as gemma4
config = gemma4.gemma4_e2b_config(text_only=True)
text = gemma4.Gemma4TextTower(config.text)Vision:
import gemma4_pt_codex as gemma4
config = gemma4.gemma4_e2b_config()
encoder = gemma4.Gemma4VisionEncoder(config.vision)
tower = gemma4.Gemma4VisionTower(config.vision, text_hidden_size=config.text.hidden_size)Audio:
import gemma4_pt_codex as gemma4
config = gemma4.gemma4_e2b_config()
tower = gemma4.Gemma4AudioTower(config.audio, text_hidden_size=config.text.hidden_size)Top-level model layout:
model.text
model.vision
model.audioThis package tries to stay close to the original JAX implementation where it matters:
- released preset configs
- alternating sliding/global attention patterns
- q/k/v normalization behavior
- RoPE choices for local and global attention
- trainable layer scaling
- multimodal placeholder expansion and merge behavior
- MoE support for
26B_A4B - cache-aware decode
It deliberately does not copy the style of the Transformers implementation:
- fewer wrappers
- flatter ownership
- modality towers live with their own embedders/projectors
- no
einsumin the main path
Current local coverage includes:
- text, vision, audio, save/load, and generation smoke tests
- config creation coverage for the main released variants
- HF conversion tests
- JAX parity tests for text and vision
- image preprocessing and multimodal expansion tests
- real conversion and generation smoke tests with
google/gemma-4-e2b-it
Run the tests with:
cd gemma4_pytorch_codex
pytest tests -qgemma4_pytorch_codex/
pyproject.toml
README.md
src/gemma4_pt_codex/
config.py
layers.py
image_processing.py
audio_processing.py
processing.py
text.py
vision.py
audio.py
model.py
tokenizer.py
convert.py
tests/
This is a practical local implementation, not a polished upstream distribution.
It is a good fit if you want to:
- study Gemma 4 without framework noise
- convert checkpoints and run local generation
- reuse the image or audio towers in another project
- work on a PyTorch codebase that stays relatively close to the original model structure
It is not trying to be:
- a drop-in replacement for the full Transformers ecosystem
- a large training framework
- a kitchen-sink inference server
Apache 2.0, matching the project metadata in pyproject.toml.