-
Notifications
You must be signed in to change notification settings - Fork 654
[WIP] Proper tool calling support in the torchtune #2794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2794
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hey @krammnic, does this support tool calls for all formats (like openai, sharegpt etc)? |
It still WIP, but yes, it will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! I left a few small comments. We should also add a test to ensure that this actually works and generates the expected outputs on a tool-calling dataset
@@ -136,9 +136,16 @@ def encode( | |||
list[int]: The list of token ids. | |||
""" | |||
token_ids = self.tokenizer.encode(text).ids | |||
if add_bos and not self.hf_adds_bos and self.bos_token not in text: | |||
|
|||
# Both bos_id and eos_id might be None (null). Therefore, we need an additional check. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this related to tool-calling? Or a separate issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is caused by separate issue in HuggingfaceBaseTokenizer.
try: | ||
self.bos_token = self._get_token_from_config(self.config, "bos_token") | ||
self.eos_token = self._get_token_from_config(self.config, "eos_token") | ||
except ValueError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case I wonder whether we should just modify _get_token_from_config
to directly return None (possibly logging a warning) rather than use this try/except
@@ -52,6 +52,7 @@ class Message: | |||
masked (bool): whether the message is masked in the sample. If True, do not use | |||
in loss calculation. Default: False | |||
ipython (bool): whether the message is a tool call. Default: False | |||
tool_calls (Optional[list]): list of tool calls related to this message. Default: None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also update the role "ipython" to "tool" to match what's done by Hugging Face?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch, that argument seemed to me weird.
@@ -48,7 +48,7 @@ tune = "torchtune._cli.tune:main" | |||
|
|||
[project.optional-dependencies] | |||
dev = [ | |||
"bitsandbytes>=0.43.0", | |||
# "bitsandbytes>=0.43.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this an intentional removal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope :/ I don't like to do it on Mac (in another case it will not install) and then remove the comment. Will open separate PR to address this
@@ -108,7 +111,8 @@ def from_dict(cls, d: dict) -> "Message": | |||
role=d["role"], | |||
content=d["content"], | |||
masked=d.get("masked", False), | |||
ipython=d.get("ipython", False), | |||
tool_calls=d.get("tool_calls", []), | |||
tool=d.get("tool", False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I agree with this change, it currently breaks existing tokenizers. Repro:
from torchtune.datasets import alpaca_cleaned_dataset
from torchtune.models.qwen2_5 import qwen2_5_tokenizer
vocab_path = "/tmp/Qwen2.5-14B-Instruct/vocab.json"
merges_path = "/tmp/Qwen2.5-14B-Instruct/merges.txt"
tokenizer_json_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer.json"
tokenizer_config_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer_config.json"
tokenizer_qwen = qwen2_5_tokenizer(
path=vocab_path,
merges_file=merges_path,
max_seq_len=512
)
dataset_qwen = alpaca_cleaned_dataset(tokenizer=tokenizer_qwen, packed=False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yep. We need to remove ipython everywhere.
@krammnic This is great progress. It looks like a bit of work needs to be done on BC with existing tokenizers (unless the plan is to fully deprecate them). In addition I'm seeing some issues with the jinja rendering - I think it may require explicitly passing tools to the renderer (hf ref). Repro: from torchtune.datasets import alpaca_cleaned_dataset
from torchtune.modules.transforms.tokenizers import HuggingFaceModelTokenizer
tokenizer_json_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer.json"
tokenizer_config_path = "/tmp/Qwen2.5-14B-Instruct/tokenizer_config.json"
tokenizer_hf = HuggingFaceModelTokenizer(
tokenizer_json_path=tokenizer_json_path,
tokenizer_config_json_path=tokenizer_config_path,
max_seq_len=512,
)
dataset_hf = alpaca_cleaned_dataset(tokenizer=tokenizer_hf, packed=True) Basically optionally propagating |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example