13
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
+ import enum
16
17
import fnmatch
17
18
import importlib
18
19
import inspect
@@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
811
812
# in this case they are already instantiated in `kwargs`
812
813
# extract them here
813
814
expected_modules , optional_kwargs = cls ._get_signature_keys (pipeline_class )
815
+ expected_types = pipeline_class ._get_signature_types ()
814
816
passed_class_obj = {k : kwargs .pop (k ) for k in expected_modules if k in kwargs }
815
817
passed_pipe_kwargs = {k : kwargs .pop (k ) for k in optional_kwargs if k in kwargs }
816
818
init_dict , unused_kwargs , _ = pipeline_class .extract_init_dict (config_dict , ** kwargs )
@@ -833,6 +835,26 @@ def load_module(name, value):
833
835
834
836
init_dict = {k : v for k , v in init_dict .items () if load_module (k , v )}
835
837
838
+ for key in init_dict .keys ():
839
+ if key not in passed_class_obj :
840
+ continue
841
+ if "scheduler" in key :
842
+ continue
843
+
844
+ class_obj = passed_class_obj [key ]
845
+ _expected_class_types = []
846
+ for expected_type in expected_types [key ]:
847
+ if isinstance (expected_type , enum .EnumMeta ):
848
+ _expected_class_types .extend (expected_type .__members__ .keys ())
849
+ else :
850
+ _expected_class_types .append (expected_type .__name__ )
851
+
852
+ _is_valid_type = class_obj .__class__ .__name__ in _expected_class_types
853
+ if not _is_valid_type :
854
+ logger .warning (
855
+ f"Expected types for { key } : { _expected_class_types } , got { class_obj .__class__ .__name__ } ."
856
+ )
857
+
836
858
# Special case: safety_checker must be loaded separately when using `from_flax`
837
859
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj :
838
860
raise NotImplementedError (
0 commit comments