Skip to content

Commit 38edcc1

Browse files
hlkysayakpaulDN6
authored andcommitted
Check correct model type is passed to from_pretrained (huggingface#10189)
* Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul <[email protected]> * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair <[email protected]> * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 380398f commit 38edcc1

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import enum
1617
import fnmatch
1718
import importlib
1819
import inspect
@@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
811812
# in this case they are already instantiated in `kwargs`
812813
# extract them here
813814
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
815+
expected_types = pipeline_class._get_signature_types()
814816
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
815817
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
816818
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -833,6 +835,26 @@ def load_module(name, value):
833835

834836
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
835837

838+
for key in init_dict.keys():
839+
if key not in passed_class_obj:
840+
continue
841+
if "scheduler" in key:
842+
continue
843+
844+
class_obj = passed_class_obj[key]
845+
_expected_class_types = []
846+
for expected_type in expected_types[key]:
847+
if isinstance(expected_type, enum.EnumMeta):
848+
_expected_class_types.extend(expected_type.__members__.keys())
849+
else:
850+
_expected_class_types.append(expected_type.__name__)
851+
852+
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
853+
if not _is_valid_type:
854+
logger.warning(
855+
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
856+
)
857+
836858
# Special case: safety_checker must be loaded separately when using `from_flax`
837859
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
838860
raise NotImplementedError(

tests/pipelines/test_pipelines.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,16 @@ def test_pipe_same_device_id_offload(self):
18021802
sd.maybe_free_model_hooks()
18031803
assert sd._offload_gpu_id == 5
18041804

1805+
def test_wrong_model(self):
1806+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
1807+
with self.assertRaises(ValueError) as error_context:
1808+
_ = StableDiffusionPipeline.from_pretrained(
1809+
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
1810+
)
1811+
1812+
assert "is of type" in str(error_context.exception)
1813+
assert "but should be" in str(error_context.exception)
1814+
18051815

18061816
@slow
18071817
@require_torch_gpu

0 commit comments

Comments
 (0)