Skip to content

Commit 4498d5e

Browse files
committed
Added AddedToken class
1 parent 8931055 commit 4498d5e

File tree

6 files changed

+70
-6
lines changed

6 files changed

+70
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.5.3 (unreleased)
22

3+
- Added `AddedToken` class
34
- Added precompiled gem for Windows
45

56
## 0.5.2 (2024-08-26)

ext/tokenizers/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ mod utils;
1515

1616
use encoding::RbEncoding;
1717
use error::RbError;
18-
use tokenizer::RbTokenizer;
18+
use tokenizer::{RbAddedToken, RbTokenizer};
1919
use utils::RbRegex;
2020

2121
use magnus::{function, method, prelude::*, value::Lazy, Error, RModule, Ruby};
@@ -109,6 +109,10 @@ fn init(ruby: &Ruby) -> RbResult<()> {
109109
let class = module.define_class("Regex", ruby.class_object())?;
110110
class.define_singleton_method("new", function!(RbRegex::new, 1))?;
111111

112+
let class = module.define_class("AddedToken", ruby.class_object())?;
113+
class.define_singleton_method("_new", function!(RbAddedToken::new, 2))?;
114+
class.define_method("content", method!(RbAddedToken::get_content, 0))?;
115+
112116
let models = module.define_module("Models")?;
113117
let pre_tokenizers = module.define_module("PreTokenizers")?;
114118
let decoders = module.define_module("Decoders")?;

ext/tokenizers/src/tokenizer.rs

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@ use super::processors::RbPostProcessor;
2222
use super::trainers::RbTrainer;
2323
use super::{RbError, RbResult};
2424

25+
#[magnus::wrap(class = "Tokenizers::AddedToken")]
2526
pub struct RbAddedToken {
2627
pub content: String,
27-
pub is_special_token: bool,
28+
pub special: bool,
2829
pub single_word: Option<bool>,
2930
pub lstrip: Option<bool>,
3031
pub rstrip: Option<bool>,
3132
pub normalized: Option<bool>,
3233
}
3334

3435
impl RbAddedToken {
35-
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
36+
pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
3637
Self {
3738
content: content.into(),
38-
is_special_token: is_special_token.unwrap_or(false),
39+
special: special.unwrap_or(false),
3940
single_word: None,
4041
lstrip: None,
4142
rstrip: None,
@@ -44,7 +45,7 @@ impl RbAddedToken {
4445
}
4546

4647
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
47-
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
48+
let mut token = tk::AddedToken::from(&self.content, self.special);
4849

4950
if let Some(sw) = self.single_word {
5051
token = token.single_word(sw);
@@ -71,11 +72,53 @@ impl From<tk::AddedToken> for RbAddedToken {
7172
lstrip: Some(token.lstrip),
7273
rstrip: Some(token.rstrip),
7374
normalized: Some(token.normalized),
74-
is_special_token: !token.normalized,
75+
special: !token.normalized,
7576
}
7677
}
7778
}
7879

80+
impl RbAddedToken {
81+
pub fn new(content: Option<String>, kwargs: RHash) -> RbResult<Self> {
82+
let mut token = RbAddedToken::from(content.unwrap_or("".to_string()), None);
83+
84+
let value: Value = kwargs.delete(Symbol::new("single_word"))?;
85+
if !value.is_nil() {
86+
token.single_word = TryConvert::try_convert(value)?;
87+
}
88+
89+
let value: Value = kwargs.delete(Symbol::new("lstrip"))?;
90+
if !value.is_nil() {
91+
token.lstrip = TryConvert::try_convert(value)?;
92+
}
93+
94+
let value: Value = kwargs.delete(Symbol::new("rstrip"))?;
95+
if !value.is_nil() {
96+
token.rstrip = TryConvert::try_convert(value)?;
97+
}
98+
99+
let value: Value = kwargs.delete(Symbol::new("normalized"))?;
100+
if !value.is_nil() {
101+
token.normalized = TryConvert::try_convert(value)?;
102+
}
103+
104+
let value: Value = kwargs.delete(Symbol::new("special"))?;
105+
if !value.is_nil() {
106+
token.special = TryConvert::try_convert(value)?;
107+
}
108+
109+
if !kwargs.is_empty() {
110+
// TODO improve message
111+
return Err(Error::new(exception::arg_error(), "unknown keyword"));
112+
}
113+
114+
Ok(token)
115+
}
116+
117+
pub fn get_content(&self) -> String {
118+
self.content.to_string()
119+
}
120+
}
121+
79122
struct TextInputSequence<'s>(tk::InputSequence<'s>);
80123

81124
impl<'s> TryConvert for TextInputSequence<'s> {

lib/tokenizers.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
require_relative "tokenizers/trainers/word_piece_trainer"
4343

4444
# other
45+
require_relative "tokenizers/added_token"
4546
require_relative "tokenizers/char_bpe_tokenizer"
4647
require_relative "tokenizers/encoding"
4748
require_relative "tokenizers/from_pretrained"

lib/tokenizers/added_token.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module Tokenizers
2+
class AddedToken
3+
def self.new(content, **kwargs)
4+
_new(content, kwargs)
5+
end
6+
end
7+
end

test/added_token_test.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
require_relative "test_helper"
2+
3+
class AddedTokenTest < Minitest::Test
4+
def test_content
5+
token = Tokenizers::AddedToken.new("test")
6+
assert_equal "test", token.content
7+
end
8+
end

0 commit comments

Comments
 (0)