Skip to content

Commit 4181320

Browse files
amyerobertssguggeralaradirikNielsRoggeharrydrippin
authored
Add normalize to image transforms module (#19544)
* Adapt FE methods to transforms library * Mixin for saving the image processor * Base processor skeleton * BatchFeature for packaging image processor outputs * Initial image processor for GLPN * REmove accidental import * Fixup and docs * Mixin for saving the image processor * Fixup and docs * Import BatchFeature from feature_extraction_utils * Fixup and docs * Fixup and docs * Fixup and docs * Fixup and docs * BatchFeature for packaging image processor outputs * Import BatchFeature from feature_extraction_utils * Import BatchFeature from feature_extraction_utils * Fixup and docs * Fixup and docs * BatchFeature for packaging image processor outputs * Import BatchFeature from feature_extraction_utils * Fixup and docs * Mixin for saving the image processor * Fixup and docs * Add rescale back and remove ImageType * fix import mistake * Fix enum var reference * Can transform and specify image data format * Remove redundant function * Update reference * Data format flag for rescale * Fix typo * Fix dimension check * Fixes to make IP and FE outputs match * Add tests for transforms * Add test for utils * Update some docstrings * Make sure in channels last before converting to PIL * Remove default to numpy batching * Fix up * Add docstring and model_input_types * Use feature processor config from hub * Alias GLPN feature extractor to image processor * Alias feature extractor mixin * Add return_numpy=False flag for resize * Fix up * Fix up * Use different frameworks safely * Safely import PIL * Call function checking if PIL available * Only import if vision available * Address Sylvain PR comments Co-authored-by: [email protected] * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> * Update src/transformers/image_transforms.py Co-authored-by: Alara Dirik <[email protected]> * Update src/transformers/models/glpn/feature_extraction_glpn.py Co-authored-by: NielsRogge <[email protected]> * Add in docstrings * Fix TFSwinSelfAttention to have relative position index as non-trainable weight (#18226) Signed-off-by: Seunghwan Hong <[email protected]> * Refactor `TFSwinLayer` to increase serving compatibility (#18352) * Refactor `TFSwinLayer` to increase serving compatibility Signed-off-by: Seunghwan Hong <[email protected]> * Fix missed parameters while refactoring Signed-off-by: Seunghwan Hong <[email protected]> * Fix window_reverse to calculate batch size Signed-off-by: Seunghwan Hong <[email protected]> Co-Authored-By: amyeroberts <[email protected]> Co-authored-by: amyeroberts <[email protected]> * Add TF prefix to TF-Res test class (#18481) Co-authored-by: ydshieh <[email protected]> * Remove py.typed (#18485) * Fix pipeline tests (#18487) * Fix pipeline tests * Make sure all pipelines tests run with init changes * Use new huggingface_hub tools for download models (#18438) * Draft new cached_file * Initial draft for config and model * Small fixes * Fix first batch of tests * Look in cache when internet is down * Fix last tests * Bad black, not fixing all quality errors * Make diff less * Implement change for TF and Flax models * Add tokenizer and feature extractor * For compatibility with main * Add utils to move the cache and auto-do it at first use. * Quality * Deal with empty commit shas * Deal with empty etag * Address review comments * Fix `test_dbmdz_english` by updating expected values (#18482) Co-authored-by: ydshieh <[email protected]> * Move cache folder to huggingface/hub for consistency with hf_hub (#18492) * Move cache folder to just huggingface * Thank you VsCode for this needless import * Move to hub * Forgot one * Update some expected values in `quicktour.mdx` for `resampy 0.3.0` (#18484) Co-authored-by: ydshieh <[email protected]> * Forgot one new_ for cache migration * disable Onnx test for google/long-t5-tglobal-base (#18454) Co-authored-by: ydshieh <[email protected]> * Typo reported by Joel Grus on TWTR (#18493) * Just re-reading the whole doc every couple of months 😬 (#18489) * Delete valohai.yaml * NLP => ML * typo * website supports https * datasets * 60k + modalities * unrelated link fixing for accelerate * Ok those links were actually broken * Fix link * Make `AutoTokenizer` auto-link * wording tweak * add at least one non-nlp task * `transformers-cli login` => `huggingface-cli login` (#18490) * zero chance anyone's using that constant no? * `transformers-cli login` => `huggingface-cli login` * `transformers-cli repo create` => `huggingface-cli repo create` * `make style` * Add seed setting to image classification example (#18519) * [DX fix] Fixing QA pipeline streaming a dataset. (#18516) * [DX fix] Fixing QA pipeline streaming a dataset. QuestionAnsweringArgumentHandler would iterate over the whole dataset effectively killing all properties of the pipeline. This restores nice properties when using `Dataset` or `Generator` since those are meant to be consumed lazily. * Handling TF better. * Clean up hub (#18497) * Clean up utils.hub * Remove imports * More fixes * Last fix * update fsdp docs (#18521) * updating fsdp documentation * typo fix * Fix compatibility with 1.12 (#17925) * Fix compatibility with 1.12 * Remove pin from examples requirements * Update torch scatter version * Fix compatibility with 1.12 * Remove pin from examples requirements * Update torch scatter version * fix torch.onnx.symbolic_opset12 import * Reject bad version Co-authored-by: ydshieh <[email protected]> * Remove debug statement * Specify en in doc-builder README example (#18526) Co-authored-by: Ankur Goyal <[email protected]> * New cache fixes: add safeguard before looking in folders (#18522) * unpin resampy (#18527) Co-authored-by: ydshieh <[email protected]> * ✨ update to use interlibrary links instead of Markdown (#18500) * Add example of multimodal usage to pipeline tutorial (#18498) * 📝 add example of multimodal usage to pipeline tutorial * 🖍 apply feedbacks * 🖍 apply niels feedback * [VideoMAE] Add model to doc tests (#18523) * Add videomae to doc tests * Add pip install decord Co-authored-by: Niels Rogge <[email protected]> * Update perf_train_gpu_one.mdx (#18532) * Update no_trainer.py scripts to include accelerate gradient accumulation wrapper (#18473) * Added accelerate gradient accumulation wrapper to run_image_classification_no_trainer.py example script * make fixup changes * PR comments * changed input to Acceletor based on PR comment, ran make fixup * Added comment explaining the sync_gradients statement * Fixed lr scheduler max steps * Changed run_clm_no_trainer.py script to use accelerate gradient accum wrapper * Fixed all scripts except wav2vec2 pretraining to use accelerate gradient accum wrapper * Added accelerate gradient accum wrapper for wav2vec2_pretraining_no_trainer.py script * make fixup and lr_scheduler step inserted back into run_qa_beam_search_no_trainer.py * removed changes to run_wav2vec2_pretraining_no_trainer.py script and fixed using wrong constant in qa_beam_search_no_trainer.py script * Add Spanish translation of converting_tensorflow_models.mdx (#18512) * Add file in spanish docs to be translated * Finish translation to Spanish * Improve Spanish wording * Add suggested changes from review * Spanish translation of summarization.mdx (#15947) (#18477) * Add Spanish translation of summarization.mdx * Apply suggestions from code review Co-authored-by: Omar U. Espejel <[email protected]> Co-authored-by: Omar U. Espejel <[email protected]> * Let's not cast them all (#18471) * add correct dtypes when checking for params dtype * forward contrib credits * Update src/transformers/modeling_utils.py Co-authored-by: Thomas Wang <[email protected]> * more comments - added more comments on why we cast only floating point parameters * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: sgugger <[email protected]> Co-authored-by: Thomas Wang <[email protected]> * fix: data2vec-vision Onnx ready-made configuration. (#18427) * feat: add the data2vec conf that are missing https://huggingface.co/docs/transformers/serialization * fix: wrong config * Add mt5 onnx config (#18394) * update features * MT5OnnxConfig added with updated with tests and docs * fix imports * fix onnc_config_cls for mt5 Co-authored-by: Thomas Chaigneau <thomas.deeptools.ai> * Minor update of `run_call_with_unpacked_inputs` (#18541) Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: ydshieh <[email protected]> * BART - Fix attention mask device issue on copied models (#18540) * attempt to fix attn mask device * fix bart `_prepare_decoder_attention_mask` - add correct device - run `make fix-copies` to propagate the fix * Adding a new `align_to_words` param to qa pipeline. (#18010) * Adding a new `align_to_words` param to qa pipeline. * Update src/transformers/pipelines/question_answering.py Co-authored-by: Sylvain Gugger <[email protected]> * Import protection. Co-authored-by: Sylvain Gugger <[email protected]> * 📝 update metric with evaluate (#18535) * Restore _init_weights value in no_init_weights (#18504) * Recover _init_weights value in no_init_weights For potential nested use. In addition, users might modify private no_init_weights as well. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * Remove private variable change check Co-authored-by: Sylvain Gugger <[email protected]> * Clean up comment * 📝 update documentation build section (#18548) * `bitsandbytes` - `Linear8bitLt` integration into `transformers` models (#17901) * first commit * correct replace function * add final changes - works like charm! - cannot implement tests yet - tested * clean up a bit * add bitsandbytes dependencies * working version - added import function - added bitsandbytes utils file * small fix * small fix - fix import issue * fix import issues * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * refactor a bit - move bitsandbytes utils to utils - change comments on functions * reformat docstring - reformat docstring on init_empty_weights_8bit * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <[email protected]> * revert bad formatting * change to bitsandbytes * refactor a bit - remove init8bit since it is useless * more refactoring - fixed init empty weights issue - added threshold param * small hack to make it work * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * revmoe the small hack * modify utils file * make style + refactor a bit * create correctly device map * add correct dtype for device map creation * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * apply suggestions - remove with torch.grad - do not rely on Python bool magic! * add docstring - add docstring for new kwargs * add docstring - comment `replace_8bit_linear` function - fix weird formatting * - added more documentation - added new utility function for memory footprint tracking - colab demo to add * few modifs - typo doc - force cast into float16 when load_in_8bit is enabled * added colab link * add test architecture + docstring a bit * refactor a bit testing class * make style + refactor a bit * enhance checks - add more checks - start writing saving test * clean up a bit * male style * add more details on doc * add more tests - still needs to fix 2 tests * replace by "or" - could not fix it from GitHub GUI Co-authored-by: Sylvain Gugger <[email protected]> * refactor a bit testing code + add readme * make style * fix import issue * Update src/transformers/modeling_utils.py Co-authored-by: Michael Benayoun <[email protected]> * add few comments * add more doctring + make style * more docstring * raise error when loaded in 8bit * make style * add warning if loaded on CPU * add small sanity check * fix small comment * add bitsandbytes on dockerfile * Improve documentation - improve documentation from comments * add few comments * slow tests pass on the VM but not on the CI VM * Fix merge conflict * make style * another test should pass on a multi gpu setup * fix bad import in testing file * Fix slow tests - remove dummy batches - no more CUDA illegal memory errors * odify dockerfile * Update docs/source/en/main_classes/model.mdx * Update Dockerfile * Update model.mdx * Update Dockerfile * Apply suggestions from code review * few modifications - lm head can stay on disk/cpu - change model name so that test pass * change test value - change test value to the correct output - torch bmm changed to baddmm in bloom modeling when merging * modify installation guidelines * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * replace `n`by `name` * merge `load_in_8bit` and `low_cpu_mem_usage` * first try - keep the lm head in full precision * better check - check the attribute `base_model_prefix` instead of computing the number of parameters * added more tests * Update src/transformers/utils/bitsandbytes.py Co-authored-by: Sylvain Gugger <[email protected]> * Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit * improve documentation - fix typos for installation - change title in the documentation Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Michael Benayoun <[email protected]> * TF: XLA-trainable DeBERTa v2 (#18546) * fix deberta issues * add different code paths for gpu and tpu * shorter gpu take along axis * Stable Dropout without tf cond * variable must be float * Preserve hub-related kwargs in AutoModel.from_pretrained (#18545) * Preserve hub-related kwargs in AutoModel.from_pretrained * Fix tests * Remove debug statement * TF Examples Rewrite (#18451) * Finished QA example * Dodge a merge conflict * Update text classification and LM examples * Update NER example * New Keras metrics WIP, fix NER example * Update NER example * Update MC, summarization and translation examples * Add XLA warnings when shapes are variable * Make sure batch_size is consistently scaled by num_replicas * Add PushToHubCallback to all models * Add docs links for KerasMetricCallback * Add docs links for prepare_tf_dataset and jit_compile * Correct inferred model names * Don't assume the dataset has 'lang' * Don't assume the dataset has 'lang' * Write metrics in text classification * Add 'framework' to TrainingArguments and TFTrainingArguments * Export metrics in all examples and add tests * Fix training args for Flax * Update command line args for translation test * make fixup * Fix accidentally running other tests in fp16 * Remove do_train/do_eval from run_clm.py * Remove do_train/do_eval from run_mlm.py * Add tensorflow tests to circleci * Fix circleci * Update examples/tensorflow/language-modeling/run_mlm.py Co-authored-by: Joao Gante <[email protected]> * Update examples/tensorflow/test_tensorflow_examples.py Co-authored-by: Joao Gante <[email protected]> * Update examples/tensorflow/translation/run_translation.py Co-authored-by: Joao Gante <[email protected]> * Update examples/tensorflow/token-classification/run_ner.py Co-authored-by: Joao Gante <[email protected]> * Fix save path for tests * Fix some model card kwargs * Explain the magical -1000 * Actually enable tests this time * Skip text classification PR until we fix shape inference * make fixup Co-authored-by: Joao Gante <[email protected]> * Use commit hash to look in cache instead of calling head (#18534) * Use commit hash to look in cache instead of calling head * Add tests * Add attr for local configs too * Stupid typos * Fix tests * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond <[email protected]> * Address Julien's comments Co-authored-by: Julien Chaumond <[email protected]> * `pipeline` support for `device="mps"` (or any other string) (#18494) * `pipeline` support for `device="mps"` (or any other string) * Simplify `if` nesting * Update src/transformers/pipelines/base.py Co-authored-by: Sylvain Gugger <[email protected]> * Fix? @sgugger * passing `attr=None` is not the same as not passing `attr` 🤯 Co-authored-by: Sylvain Gugger <[email protected]> * Update philosophy to include other preprocessing classes (#18550) * 📝 update philosophy to include other preprocessing classes * 🖍 apply feedbacks * Properly move cache when it is not in default path (#18563) * Adds CLIP to models exportable with ONNX (#18515) * onnx config for clip * default opset as 14 * changes from the original repo * input values order fix * outputs fix * remove unused import * ran make fix-copies * black format * review comments: forward ref, import fix, model change revert, .to cleanup * make style * formatting fixes * revert groupvit * comment for cast to int32 * comment fix * make .T as .t() for onnx conversion * ran make fix-copies * remove unneeded comment Co-authored-by: Sylvain Gugger <[email protected]> * fix copies * remove comment Co-authored-by: Sylvain Gugger <[email protected]> * raise atol for MT5OnnxConfig (#18560) Co-authored-by: ydshieh <[email protected]> * fix string (#18568) * Segformer TF: fix output size in documentation (#18572) * Segformer TF: fix output size in doc * Segformer pytorch: fix output size in doc Co-authored-by: Maxime Gardoni <[email protected]> * Fix resizing bug in OWL-ViT (#18573) * Fixes resizing bug in OWL-ViT * Defaults to square resize if size is set to an int * Sets do_center_crop default value to False * Fix LayoutLMv3 documentation (#17932) * fix typos * fix sequence_length docs of LayoutLMv3Model * delete trailing white spaces * fix layoutlmv3 docs more * apply make fixup & quality * change to two versions of input docstring * apply make fixup & quality * Skip broken tests * Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training (#18486) * changing BartLearnedPositionalEmbedding forward signature and references to it * removing debugging dead code (thanks style checker) * blackened modeling_bart file * removing copy inconsistencies via make fix-copies * changing references to copied signatures in Bart variants * make fix-copies once more * using expand over repeat (thanks @michaelbenayoun) * expand instead of repeat for all model copies Co-authored-by: Daniel Jones <[email protected]> * german docs translation (#18544) * Create _config.py * Create _toctree.yml * Create index.mdx not sure about "du / ihr" oder "sie" * Create quicktour.mdx * Update _toctree.yml * Update build_documentation.yml * Update build_pr_documentation.yml * fix build * Update index.mdx * Update quicktour.mdx * Create installation.mdx * Update _toctree.yml * Deberta V2: Fix critical trace warnings to allow ONNX export (#18272) * Fix critical trace warnings to allow ONNX export * Force input to `sqrt` to be float type * Cleanup code * Remove unused import statement * Update model sew * Small refactor Co-authored-by: Michael Benayoun <[email protected]> * Use broadcasting instead of repeat * Implement suggestion Co-authored-by: Michael Benayoun <[email protected]> * Match deberta v2 changes in sew_d * Improve code quality * Update code quality * Consistency of small refactor * Match changes in sew_d Co-authored-by: Michael Benayoun <[email protected]> * [FX] _generate_dummy_input supports audio-classification models for labels (#18580) * Support audio classification architectures for labels generation, as well as provides a flag to print warnings or not * Use ENV_VARS_TRUE_VALUES * Fix docstrings with last version of hf-doc-builder styler (#18581) * Fix docstrings with last version of hf-doc-builder styler * Remove empty Parameter block * Bump nbconvert from 6.0.1 to 6.3.0 in /examples/research_projects/lxmert (#18565) Bumps [nbconvert](https://github.com/jupyter/nbconvert) from 6.0.1 to 6.3.0. - [Release notes](https://github.com/jupyter/nbconvert/releases) - [Commits](jupyter/nbconvert@6.0.1...6.3.0) --- updated-dependencies: - dependency-name: nbconvert dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump nbconvert in /examples/research_projects/visual_bert (#18566) Bumps [nbconvert](https://github.com/jupyter/nbconvert) from 6.0.1 to 6.3.0. - [Release notes](https://github.com/jupyter/nbconvert/releases) - [Commits](jupyter/nbconvert@6.0.1...6.3.0) --- updated-dependencies: - dependency-name: nbconvert dependency-type: direct:production ... Signed-off-by: dependabot[bot] <[email protected]> Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * fix owlvit tests, update docstring examples (#18586) * Return the permuted hidden states if return_dict=True (#18578) * Load sharded pt to flax (#18419) * initial commit * add small test * add cross pt tf flag to test * fix quality * style * update test with new repo * fix failing test * update * fix wrong param ordering * style * update based on review * update related to recent new caching mechanism * quality * Update based on review Co-authored-by: sgugger <[email protected]> * quality and style * Update src/transformers/modeling_flax_utils.py Co-authored-by: sgugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> * Add type hints for ViLT models (#18577) * Add type hints for Vilt models * Add missing return type for TokenClassification class * update doc for perf_train_cpu_many, add intel mpi introduction (#18576) * update doc for perf_train_cpu_many, add mpi introduction Signed-off-by: Wang, Yi A <[email protected]> * Update docs/source/en/perf_train_cpu_many.mdx Co-authored-by: Sylvain Gugger <[email protected]> * Update docs/source/en/perf_train_cpu_many.mdx Signed-off-by: Wang, Yi A <[email protected]> Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> * typos (#18594) * FSDP bug fix for `load_state_dict` (#18596) * Add `TFAutoModelForSemanticSegmentation` to the main `__init__.py` (#18600) Co-authored-by: ydshieh <[email protected]> * Generate: validate `model_kwargs` (and catch typos in generate arguments) (#18261) * validate generate model_kwargs * generate tests -- not all models have an attn mask * Supporting seq2seq models for `bitsandbytes` integration (#18579) * Supporting seq2seq models for `bitsandbytes` integration - `bitsandbytes` integration supports now seq2seq models - check if a model has tied weights as an additional check * small modification - tie the weights before looking at tied weights! * Add Donut (#18488) * First draft * Improve script * Update script * Make conversion work * Add final_layer_norm attribute to Swin's config * Add DonutProcessor * Convert more models * Improve feature extractor and convert base models * Fix bug * Improve integration tests * Improve integration tests and add model to README * Add doc test * Add feature extractor to docs * Fix integration tests * Remove register_buffer * Fix toctree and add missing attribute * Add DonutSwin * Make conversion script work * Improve conversion script * Address comment * Fix bug * Fix another bug * Remove deprecated method from docs * Make Swin and Swinv2 untouched * Fix code examples * Fix processor * Update model_type to donut-swin * Add feature extractor tests, add token2json method, improve feature extractor * Fix failing tests, remove integration test * Add do_thumbnail for consistency * Improve code examples * Add code example for document parsing * Add DonutSwin to MODEL_NAMES_MAPPING * Add model to appropriate place in toctree * Update namespace to appropriate organization Co-authored-by: Niels Rogge <[email protected]> * Fix URLs (#18604) Co-authored-by: Niels Rogge <[email protected]> * Update BLOOM parameter counts (#18531) * Update BLOOM parameter counts * Update BLOOM parameter counts * [doc] fix anchors (#18591) the manual anchors end up being duplicated with automatically added anchors and no longer work. * [fsmt] deal with -100 indices in decoder ids (#18592) * [fsmt] deal with -100 indices in decoder ids Fixes: #17945 decoder ids get the default index -100, which breaks the model - like t5 and many other models add a fix to replace -100 with the correct pad index. For some reason this use case hasn't been used with this model until recently - so this issue was there since the beginning it seems. Any suggestions to how to add a simple test here? or perhaps we have something similar already? user's script is quite massive. * style * small change (#18584) * Flax Remat for LongT5 (#17994) * [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * add gradient_checkpointing to examples * Add gradient_checkpointing to run_mlm_flax * Add remat to longt5 * Add gradient checkpointing test longt5 * Fix args errors * Fix remaining tests * Make fixup & quality fixes * replace kwargs * remove unecessary kwargs * Make fixup changes * revert long_t5_flax changes * Remove return_dict and copy to LongT5 * Remove test_gradient_checkpointing Co-authored-by: sanchit-gandhi <[email protected]> * mac m1 `mps` integration (#18598) * mac m1 `mps` integration * Update docs/source/en/main_classes/trainer.mdx Co-authored-by: Sylvain Gugger <[email protected]> * addressing comments * Apply suggestions from code review Co-authored-by: Dan Saattrup Nielsen <[email protected]> * resolve comment Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Dan Saattrup Nielsen <[email protected]> * Change scheduled CIs to use torch 1.12.1 (#18644) Co-authored-by: ydshieh <[email protected]> * Add checks for some workflow jobs (#18583) Co-authored-by: ydshieh <[email protected]> * TF: Fix generation repetition penalty with XLA (#18648) * Update longt5.mdx (#18634) * Update run_translation_no_trainer.py (#18637) * Update run_translation_no_trainer.py found an error in selecting `no_decay` parameters and some small modifications when the user continues to train from a checkpoint * fixs `no_decay` and `resume_step` issue 1. change `no_decay` list 2. if use continue to train their model from provided checkpoint, the `resume_step` will not be initialized properly if `args.gradient_accumulation_steps != 1` * [bnb] Minor modifications (#18631) * bnb minor modifications - refactor documentation - add troubleshooting README - add PyPi library on DockerFile * Apply suggestions from code review Co-authored-by: Stas Bekman <[email protected]> * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * put in one block - put bash instructions in one block * update readme - refactor a bit hardware requirements * change text a bit * Apply suggestions from code review Co-authored-by: Yih-Dar <[email protected]> * apply suggestions Co-authored-by: Yih-Dar <[email protected]> * add link to paper * Apply suggestions from code review Co-authored-by: Stas Bekman <[email protected]> * Update tests/mixed_int8/README.md * Apply suggestions from code review * refactor a bit * add instructions Turing & Amperer Co-authored-by: Stas Bekman <[email protected]> * add A6000 * clarify a bit * remove small part * Update tests/mixed_int8/README.md Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Yih-Dar <[email protected]> * Examples: add Bloom support for token classification (#18632) * examples: add Bloom support for token classification (FLAX, PyTorch and TensorFlow) * examples: remove support for Bloom in token classication (FLAX and TensorFlow currently have no support for it) * Fix Yolos ONNX export test (#18606) Co-authored-by: lewtun <[email protected]> Co-authored-by: ydshieh <[email protected]> * Fixup * Fix up * Move PIL default arguments inside function for safe imports * Add image utils to toctree * Update `rescale` method to reflect changes in #18677 * Update docs/source/en/internal/image_processing_utils.mdx Co-authored-by: NielsRogge <[email protected]> * Address Niels PR comments * Add normalize method to transforms library * Apply suggestions from code review - remove defaults to None Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> * Fix docstrings and revert to PIL.Image.XXX resampling Use PIL.Image.XXX resampling values instead of PIL.Image.Resampling.XXX enum as it's only in the recent version >= 9.10 and version is not yet pinned and older version support deprecated * Some more docstrings and PIL.Image tidy up * Reorganise arguments so flags by modifiers * Few last docstring fixes * Add normalize to docs * Accept PIL.Image inputs with deprecation warning * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <[email protected]> * Update warning to include version * Trigger CI - hash clash on doc build Signed-off-by: Seunghwan Hong <[email protected]> Signed-off-by: dependabot[bot] <[email protected]> Signed-off-by: Wang, Yi A <[email protected]> Co-authored-by: Amy Roberts <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Alara Dirik <[email protected]> Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Seunghwan Hong <[email protected]> Co-authored-by: Yih-Dar <[email protected]> Co-authored-by: ydshieh <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: regisss <[email protected]> Co-authored-by: Nicolas Patry <[email protected]> Co-authored-by: Sourab Mangrulkar <[email protected]> Co-authored-by: Ankur Goyal <[email protected]> Co-authored-by: Ankur Goyal <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Niels Rogge <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Rasmus Arpe Fogh Jensen <[email protected]> Co-authored-by: Ian Castillo <[email protected]> Co-authored-by: AguilaCudicio <[email protected]> Co-authored-by: Omar U. Espejel <[email protected]> Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Thomas Wang <[email protected]> Co-authored-by: Niklas Hansson <[email protected]> Co-authored-by: Thomas Chaigneau <[email protected]> Co-authored-by: YouJiacheng <[email protected]> Co-authored-by: Michael Benayoun <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Matt <[email protected]> Co-authored-by: Dhruv Karan <[email protected]> Co-authored-by: Michael Wyatt <[email protected]> Co-authored-by: Maxime G <[email protected]> Co-authored-by: Maxime Gardoni <[email protected]> Co-authored-by: Wonseok Lee (Jack) <[email protected]> Co-authored-by: Dan Jones <[email protected]> Co-authored-by: Daniel Jones <[email protected]> Co-authored-by: flozi00 <[email protected]> Co-authored-by: iiLaurens <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Arthur <[email protected]> Co-authored-by: Wang, Yi <[email protected]> Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Niklas Muennighoff <[email protected]> Co-authored-by: Karim Foda <[email protected]> Co-authored-by: sanchit-gandhi <[email protected]> Co-authored-by: Dan Saattrup Nielsen <[email protected]> Co-authored-by: zhoutang776 <[email protected]> Co-authored-by: Stefan Schweter <[email protected]> Co-authored-by: lewtun <[email protected]>
1 parent 82e360b commit 4181320

File tree

5 files changed

+128
-2
lines changed

5 files changed

+128
-2
lines changed

docs/source/en/internal/image_processing_utils.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Most of those are only useful if you are studying the code of the image processo
1919

2020
## Image Transformations
2121

22+
[[autodoc]] image_transforms.normalize
23+
2224
[[autodoc]] image_transforms.rescale
2325

2426
[[autodoc]] image_transforms.resize

src/transformers/image_transforms.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
16+
import warnings
17+
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
1718

1819
import numpy as np
1920

@@ -25,11 +26,13 @@
2526

2627
from .image_utils import (
2728
ChannelDimension,
29+
get_channel_dimension_axis,
2830
get_image_size,
2931
infer_channel_dimension_format,
3032
is_jax_tensor,
3133
is_tf_tensor,
3234
is_torch_tensor,
35+
to_numpy_array,
3336
)
3437

3538

@@ -257,3 +260,59 @@ def resize(
257260
resized_image = np.array(resized_image)
258261
resized_image = to_channel_dimension_format(resized_image, data_format)
259262
return resized_image
263+
264+
265+
def normalize(
266+
image: np.ndarray,
267+
mean: Union[float, Iterable[float]],
268+
std: Union[float, Iterable[float]],
269+
data_format: Optional[ChannelDimension] = None,
270+
) -> np.ndarray:
271+
"""
272+
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
273+
274+
image = (image - mean) / std
275+
276+
Args:
277+
image (`np.ndarray`):
278+
The image to normalize.
279+
mean (`float` or `Iterable[float]`):
280+
The mean to use for normalization.
281+
std (`float` or `Iterable[float]`):
282+
The standard deviation to use for normalization.
283+
data_format (`ChannelDimension`, *optional*):
284+
The channel dimension format of the output image. If `None`, will use the inferred format from the input.
285+
"""
286+
if isinstance(image, PIL.Image.Image):
287+
warnings.warn(
288+
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
289+
FutureWarning,
290+
)
291+
# Convert PIL image to numpy array with the same logic as in the previous feature extractor normalize -
292+
# casting to numpy array and dividing by 255.
293+
image = to_numpy_array(image)
294+
image = rescale(image, scale=1 / 255)
295+
296+
input_data_format = infer_channel_dimension_format(image)
297+
channel_axis = get_channel_dimension_axis(image)
298+
num_channels = image.shape[channel_axis]
299+
300+
if isinstance(mean, Iterable):
301+
if len(mean) != num_channels:
302+
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
303+
else:
304+
mean = [mean] * num_channels
305+
306+
if isinstance(std, Iterable):
307+
if len(std) != num_channels:
308+
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
309+
else:
310+
std = [std] * num_channels
311+
312+
if input_data_format == ChannelDimension.LAST:
313+
image = (image - mean) / std
314+
else:
315+
image = ((image.T - mean) / std).T
316+
317+
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
318+
return image

src/transformers/image_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,25 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
112112
raise ValueError("Unable to infer channel dimension format")
113113

114114

115+
def get_channel_dimension_axis(image: np.ndarray) -> int:
116+
"""
117+
Returns the channel dimension axis of the image.
118+
119+
Args:
120+
image (`np.ndarray`):
121+
The image to get the channel dimension axis of.
122+
123+
Returns:
124+
The channel dimension axis of the image.
125+
"""
126+
channel_dim = infer_channel_dimension_format(image)
127+
if channel_dim == ChannelDimension.FIRST:
128+
return image.ndim - 3
129+
elif channel_dim == ChannelDimension.LAST:
130+
return image.ndim - 1
131+
raise ValueError(f"Unsupported data format: {channel_dim}")
132+
133+
115134
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
116135
"""
117136
Returns the (height, width) dimensions of the image.

tests/test_image_transforms.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from transformers.image_transforms import (
3838
get_resize_output_image_size,
39+
normalize,
3940
resize,
4041
to_channel_dimension_format,
4142
to_pil_image,
@@ -172,3 +173,25 @@ def test_resize(self):
172173
self.assertIsInstance(resized_image, PIL.Image.Image)
173174
# PIL size is in (width, height) order
174175
self.assertEqual(resized_image.size, (40, 30))
176+
177+
def test_normalize(self):
178+
image = np.random.randint(0, 256, (224, 224, 3)) / 255
179+
180+
# Number of mean values != number of channels
181+
with self.assertRaises(ValueError):
182+
normalize(image, mean=(0.5, 0.6), std=1)
183+
184+
# Number of std values != number of channels
185+
with self.assertRaises(ValueError):
186+
normalize(image, mean=1, std=(0.5, 0.6))
187+
188+
# Test result is correct - output data format is channels_first and normalization
189+
# correctly computed
190+
mean = (0.5, 0.6, 0.7)
191+
std = (0.1, 0.2, 0.3)
192+
expected_image = ((image - mean) / std).transpose((2, 0, 1))
193+
194+
normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first")
195+
self.assertIsInstance(normalized_image, np.ndarray)
196+
self.assertEqual(normalized_image.shape, (3, 224, 224))
197+
self.assertTrue(np.allclose(normalized_image, expected_image))

tests/utils/test_image_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pytest
2121

2222
from transformers import is_torch_available, is_vision_available
23-
from transformers.image_utils import ChannelDimension
23+
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis
2424
from transformers.testing_utils import require_torch, require_vision
2525

2626

@@ -535,3 +535,26 @@ def test_infer_channel_dimension(self):
535535
image = np.random.randint(0, 256, (1, 3, 4, 5))
536536
inferred_dim = infer_channel_dimension_format(image)
537537
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
538+
539+
def test_get_channel_dimension_axis(self):
540+
# Test we correctly identify the channel dimension
541+
image = np.random.randint(0, 256, (3, 4, 5))
542+
inferred_axis = get_channel_dimension_axis(image)
543+
self.assertEqual(inferred_axis, 0)
544+
545+
image = np.random.randint(0, 256, (1, 4, 5))
546+
inferred_axis = get_channel_dimension_axis(image)
547+
self.assertEqual(inferred_axis, 0)
548+
549+
image = np.random.randint(0, 256, (4, 5, 3))
550+
inferred_axis = get_channel_dimension_axis(image)
551+
self.assertEqual(inferred_axis, 2)
552+
553+
image = np.random.randint(0, 256, (4, 5, 1))
554+
inferred_axis = get_channel_dimension_axis(image)
555+
self.assertEqual(inferred_axis, 2)
556+
557+
# We can take a batched array of images and find the dimension
558+
image = np.random.randint(0, 256, (1, 3, 4, 5))
559+
inferred_axis = get_channel_dimension_axis(image)
560+
self.assertEqual(inferred_axis, 1)

0 commit comments

Comments
 (0)