-
Notifications
You must be signed in to change notification settings - Fork 344
Functioning decode on multimodal Gemma3-4b #1689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cbf1907
to
c90b032
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, thanks Hengtao!
MaxText/layers/models.py
Outdated
@@ -703,5 +718,6 @@ def __call__( | |||
slot=slot, | |||
page_state=page_state, | |||
bidirectional_mask=bidirectional_mask, | |||
image_embeddings=image_embeddings if self.config.use_multimodal else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: line 702 set image_embeddings to None. We can keep one of them. And we can use similar way for bidirectional mask to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, just pass image_embeddings=image_embeddings
to align with bidirectional_mask.
@@ -34,6 +34,7 @@ | |||
from MaxText.layers import normalizations, quantizations | |||
from MaxText.layers import pipeline | |||
from MaxText import maxtext_utils | |||
from MaxText import multimodal_utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will cause circular dependency, since I added this in BUILD file (https://source.corp.google.com/piper///depot/google3/third_party/py/maxtext/BUILD;l=360?q=BUILD%20f:maxtext). Since I only used some image related variables in gemma3.py (https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/layers/gemma3.py#L48-L53), maybe better move them to multimodal_utils.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Aireen for the catch!
I have moved all the gemma-related static values from gemma3.py to multimodal_utils.py. Alongside this PR, I will amend in copybara to remove multimodal_utils.py's dependency on :layers. Let me know if this sounds good to you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, thank you!
cab335f
to
67ca20c
Compare
57ac3f5
to
f146e84
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks hengtao!!
aea520b
to
c32e302
Compare
5e67aec
to
76924ac
Compare
Description
Insert the vision embeddings into text embeddings and allow a fully functioning decode forward pass on multimodal Gemma3-4b model.
merge_mm_embeddings
inserts the image_embeddings into the text_embeddings based on the image placeholder token information in the bidirectional_mask:Tests
A full decode forward pass command line, using
prompt='Describe image <start_of_image>
and image:This yields the outcome image description logs:
Checklist
Before submitting this PR, please make sure (put X in square brackets):