Skip to content

Commit afb228d

Browse files
feat: add T5 implementation (#35)
* chore(dev): fix vscode code actions * style(ruff): fmt * style(mypy): fix * feat(t5): init * style(mypy): fix * feat: add t5 attn * feat: add container layers * feat: add causal lm * docs: update docstrings for module and config
1 parent 81dedd6 commit afb228d

File tree

11 files changed

+1141
-62
lines changed

11 files changed

+1141
-62
lines changed

.devcontainer/devcontainer.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
"python.linting.enabled": true,
2121
"editor.formatOnSave": true,
2222
"editor.codeActionsOnSave": {
23-
"source.organizeImports": "true",
24-
"source.fixAll": "true"
23+
"source.organizeImports": "always",
24+
"source.fixAll": "always"
2525
},
2626
"python.formatting.provider": "none",
2727
"[python]": {

examples/t5_inference_example.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from flax import nnx
2+
3+
from jaxgarden import T5Config, T5ForCausalLM, Tokenizer
4+
5+
if __name__ == "__main__":
6+
config = T5Config()
7+
model = T5ForCausalLM(config, rngs=nnx.Rngs(0))
8+
model_id = "google-t5/t5-base"
9+
10+
# download checkpoint from HuggingFace Hub
11+
model.from_hf(model_id, force_download=True)
12+
13+
tokenizer = Tokenizer.from_pretrained(model_id)
14+
15+
text = "The meaning of life is"
16+
model_inputs = tokenizer.encode(text)
17+
output = model.generate(**model_inputs, max_length=20, do_sample=True)
18+
output_text = tokenizer.decode(output)
19+
20+
print(output, output.shape)
21+
print(output_text)

jaxgarden/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,21 @@
2929
ModernBertLayer,
3030
ModernBertMLP,
3131
)
32-
from jaxgarden.tokenization import Tokenizer
32+
from jaxgarden.models.t5 import (
33+
T5MLP,
34+
T5Attention,
35+
T5Block,
36+
T5Config,
37+
T5CrossAttention,
38+
T5ForCausalLM,
39+
T5LayerNorm,
40+
T5SelfAttention,
41+
T5Stack,
42+
)
43+
from jaxgarden.tokenization import Tokenizer # type: ignore
3344

3445
__all__ = [
46+
"T5MLP",
3547
# Base classes
3648
"BaseConfig",
3749
"BaseModel",
@@ -60,6 +72,15 @@
6072
"ModernBertMLP",
6173
# Attention modules
6274
"MultiHeadAttention",
75+
# T5 Models
76+
"T5Attention",
77+
"T5Block",
78+
"T5Config",
79+
"T5CrossAttention",
80+
"T5ForCausalLM",
81+
"T5LayerNorm",
82+
"T5SelfAttention",
83+
"T5Stack",
6384
# tokenization
6485
"Tokenizer",
6586
# Functional interfaces

jaxgarden/attention/rope_multi_head_attention.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndar
4747
return (x * cos_emb) + (rotate_half(x) * sin_emb)
4848

4949

50-
def precompute_rotary_embeddings(seq_len: int, head_dim: int,
51-
base: float = 10000.0) -> tuple[jnp.ndarray, jnp.ndarray]:
50+
def precompute_rotary_embeddings(
51+
seq_len: int, head_dim: int, base: float = 10000.0
52+
) -> tuple[jnp.ndarray, jnp.ndarray]:
5253
"""Precomputes the RoPE cosine and sine embeddings.
5354
5455
Args:
@@ -91,11 +92,11 @@ class RoPEMultiHeadAttention(nn.Module):
9192
rope_base: float = 10000.0
9293
dtype: jnp.dtype = jnp.float32
9394

94-
def setup(self) -> None: # Added -> None return type
95+
def setup(self) -> None: # Added -> None return type
9596
"""Initializes the attention projections."""
9697
# Check head_dim validity early during setup
9798
if self.head_dim % 2 != 0:
98-
raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.")
99+
raise ValueError(f"head_dim ({self.head_dim}) must be even for RoPE.")
99100

100101
# Define layers here - they will be initialized when the module is first called
101102
total_head_dim = self.num_heads * self.head_dim
@@ -109,13 +110,12 @@ def setup(self) -> None: # Added -> None return type
109110
features=total_head_dim, use_bias=False, dtype=self.dtype, name="value_proj"
110111
)
111112
self.output_proj = nn.Dense(
112-
features=self.num_heads * self.head_dim, # Output should match embed_dim
113+
features=self.num_heads * self.head_dim, # Output should match embed_dim
113114
use_bias=False,
114115
dtype=self.dtype,
115-
name="output_proj"
116+
name="output_proj",
116117
)
117118

118-
119119
@nn.compact
120120
# Also using Optional for the mask type hint for clarity with None default
121121
def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarray:
@@ -136,8 +136,7 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
136136

137137
if embed_dim != total_head_dim:
138138
raise ValueError(
139-
f"embed_dim ({embed_dim}) must equal num_heads*head_dim"
140-
f" ({total_head_dim})"
139+
f"embed_dim ({embed_dim}) must equal num_heads*head_dim ({total_head_dim})"
141140
)
142141
# Note: head_dim even check moved to setup for earlier failure
143142

@@ -159,7 +158,6 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
159158
cos_emb = cos_emb.astype(self.dtype)
160159
sin_emb = sin_emb.astype(self.dtype)
161160

162-
163161
# 4. Apply RoPE to Query and Key
164162
query = apply_rotary_pos_emb(query, cos_emb, sin_emb)
165163
key = apply_rotary_pos_emb(key, cos_emb, sin_emb)
@@ -172,44 +170,46 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
172170
# 6. Scaled Dot-Product Attention
173171
# Attention scores: (batch, num_heads, seq_len, seq_len)
174172
attn_scores = jnp.matmul(query, key.transpose((0, 1, 3, 2))) / jnp.sqrt(
175-
self.head_dim).astype(self.dtype) # Ensure sqrt is correct dtype
173+
self.head_dim
174+
).astype(self.dtype) # Ensure sqrt is correct dtype
176175

177176
# Apply mask (if provided)
178177
if mask is not None:
179178
# Standard Flax causal mask is boolean (True means mask)
180179
# nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
181180
# Check if mask needs broadcasting or conversion
182-
if mask.ndim == 2: # Likely (seq_len, seq_len)
183-
mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len)
181+
if mask.ndim == 2: # Likely (seq_len, seq_len)
182+
mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len)
184183
elif mask.ndim == 3 and mask.shape[1] != self.num_heads:
185-
# Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
184+
# Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
186185
mask = mask[:, None, :, :]
187-
# Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)
186+
# Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)
188187

189188
# Ensure mask is broadcastable to attn_scores shape
190189
mask_shape_expected = (batch_size, self.num_heads, seq_len, seq_len)
191190
if mask.shape != mask_shape_expected:
192-
# Attempt broadcasting common causal mask shapes
193-
if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == (batch_size, 1,
194-
seq_len, seq_len): # Causal mask for all batches/heads
195-
mask = jnp.broadcast_to(mask, mask_shape_expected)
196-
# Add other broadcasting cases if needed
197-
else:
198-
raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}")
199-
191+
# Attempt broadcasting common causal mask shapes
192+
if mask.shape == (1, 1, seq_len, seq_len) or mask.shape == (
193+
batch_size,
194+
1,
195+
seq_len,
196+
seq_len,
197+
): # Causal mask for all batches/heads
198+
mask = jnp.broadcast_to(mask, mask_shape_expected)
199+
# Add other broadcasting cases if needed
200+
else:
201+
raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}")
200202

201203
# Apply mask: Use large negative number where mask is True
202204
# (or where mask value is 0 if using 0/-inf convention)
203205
# Assuming boolean mask convention (True = mask) common in Flax examples
204206
# If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask
205207
attn_scores = jnp.where(mask, jnp.finfo(self.dtype).min, attn_scores)
206208

207-
208209
# Softmax to get attention weights
209-
attn_weights = jax.nn.softmax(
210-
attn_scores, axis=-1
211-
).astype(self.dtype) # Shape: (batch, num_heads, seq_len, seq_len)
212-
210+
attn_weights = jax.nn.softmax(attn_scores, axis=-1).astype(
211+
self.dtype
212+
) # Shape: (batch, num_heads, seq_len, seq_len)
213213

214214
# Apply attention weights to Value
215215
# Output per head: (batch, num_heads, seq_len, head_dim)
@@ -222,6 +222,6 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
222222
attn_output = attn_output.reshape(batch_size, seq_len, total_head_dim)
223223

224224
# Final linear projection
225-
output = self.output_proj(attn_output) # Use self.output_proj defined in setup
225+
output = self.output_proj(attn_output) # Use self.output_proj defined in setup
226226

227227
return output

jaxgarden/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
@property
7272
def state(self) -> nnx.State:
7373
"""Splits state from the graph and returns it"""
74-
return nnx.split(self, nnx.Param, ...)[1]
74+
return nnx.split(self, nnx.Param, ...)[1] # type: ignore
7575

7676
@property
7777
def state_dict(self) -> dict[str, jnp.ndarray]:

jaxgarden/models/gemma2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def __call__(
422422

423423

424424
# 3. Main Model
425-
class Gemma2ForCausalLM(BaseModel, GenerationMixin):
425+
class Gemma2ForCausalLM(GenerationMixin, BaseModel):
426426
config: Gemma2Config # This helps to fix a mypy issue
427427

428428
def __init__(self, config: Gemma2Config, *, rngs: nnx.Rngs) -> None:

jaxgarden/models/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def __call__(
431431
return x
432432

433433

434-
class LlamaForCausalLM(BaseModel, GenerationMixin):
434+
class LlamaForCausalLM(GenerationMixin, BaseModel):
435435
"""LLama model for causal language modeling.
436436
437437
This implements the full LLama model for generating text.
@@ -511,7 +511,7 @@ def __call__(
511511
assert input_ids.shape[0] == 1, "Only batch size 1 is supported"
512512
print(input_ids.shape)
513513
position_ids = jnp.arange(input_ids.shape[-1])[None, :].astype(jnp.int32)
514-
attention_mask = jnp.where(attention_mask, 0.0, -jnp.inf)[None, None, ...]
514+
attention_mask = jnp.where(attention_mask, 0.0, -jnp.inf)[None, None, ...] # type: ignore
515515
x = self.token_embed(input_ids)
516516
for layer in self.layers:
517517
x = layer(x, position_ids, attention_mask)

0 commit comments

Comments
 (0)