Skip to content

Include mypy in run_all_tests #2670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ jobs:
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
test-type: [doctest, pytest, pytype]
test-type: [doctest, pytest, pytype, mypy]
exclude:
- test-type: pytype
python-version: 3.8
- test-type: pytype
python-version: 3.10
- test-type: mypy
python-version: 3.8
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -119,11 +121,13 @@ jobs:
- name: Test with ${{ matrix.test-type }}
run: |
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then
tests/run_all_tests.sh --no-pytest --no-pytype --use-venv
tests/run_all_tests.sh --no-pytest --no-pytype --no-mypy --use-venv
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
tests/run_all_tests.sh --no-doctest --no-pytype --with-cov --use-venv
tests/run_all_tests.sh --no-doctest --no-pytype --no-mypy --with-cov --use-venv
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
tests/run_all_tests.sh --no-doctest --no-pytest --use-venv
tests/run_all_tests.sh --no-doctest --no-pytest --no-mypy --use-venv
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
tests/run_all_tests.sh --no-doctest --no-pytest --no-pytype --use-venv
else
echo "Unknown test type: ${{ matrix.test-type }}"
exit 1
Expand Down
2 changes: 1 addition & 1 deletion flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __hash__(self):

def copy(self, add_or_replace: Mapping[K, V]) -> 'FrozenDict[K, V]':
"""Create a new FrozenDict with additional or replaced entries."""
return type(self)({**self, **unfreeze(add_or_replace)})
return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]

def keys(self):
return FrozenKeysView(self)
Expand Down
6 changes: 3 additions & 3 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax.experimental import pjit


TAxisMetadata = TypeVar('TAxisMetadata', bound='AxisMetadata')
TAxisMetadata = Any # TypeVar('TAxisMetadata', bound='AxisMetadata')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy doesn't support bound type-vars?! is this intentional Any cast?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy has had lots of trouble dealing with type-vars in general. I've tried to preserve their use, and just define them as Any. That way, we can reintroduce stricter checking later where possible.



class AxisMetadata(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -252,8 +252,8 @@ def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
axis_name = self._get_partition_name(params)
names = list(self.names)
while len(names) < index:
names.append(None)
names.insert(index, axis_name)
names.append(None) # type: ignore
names.insert(index, axis_name) # type: ignore
return self.replace(names=tuple(names))

def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata:
Expand Down
4 changes: 2 additions & 2 deletions flax/core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def init_fn(shape, dtype=jnp.float32):
indices[attn_dim] = i // attn_size
i = i % attn_size

key = lax.dynamic_update_slice(cache_entry.key, key, indices)
value = lax.dynamic_update_slice(cache_entry.value, value, indices)
key = lax.dynamic_update_slice(cache_entry.key, key, indices) # type: ignore
value = lax.dynamic_update_slice(cache_entry.value, value, indices) # type: ignore
one = jnp.array(1, jnp.uint32)
cache_entry = cache_entry.replace(i=cache_entry.i + one,
key=key,
Expand Down
2 changes: 1 addition & 1 deletion flax/core/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,4 @@ def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=default_
Embedding dataclass with lookup and attend methods.
"""
table = scope.param('table', init_fn, (num_embeddings, features))
return Embedding(table)
return Embedding(table) # type: ignore
8 changes: 4 additions & 4 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey:
m.update(x.encode('utf-8'))
d = m.digest()
hash_int = int.from_bytes(d[:4], byteorder='big')
rng = random.fold_in(rng, jnp.uint32(hash_int))
rng = random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore
elif isinstance(x, int):
rng = random.fold_in(rng, x)
else:
Expand Down Expand Up @@ -148,7 +148,7 @@ def _fold_in_static(rng: PRNGKey,
raise ValueError(f'Expected int or string, got: {x}')
d = m.digest()
hash_int = int.from_bytes(d[:4], byteorder='big')
return random.fold_in(rng, jnp.uint32(hash_int))
return random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore


def is_filter_empty(filter_like: Filter) -> bool:
Expand Down Expand Up @@ -570,10 +570,10 @@ def push(self,
rngs = {key: LazyRng.create(rng, name) for key, rng in self.rngs.items()}
rng_key = (child_rng_token, name)
if rng_key in self.rng_counters:
rng_counters = self.rng_counters.get(rng_key)
rng_counters = self.rng_counters.get(rng_key) # type: ignore
else:
rng_counters = {key: 0 for key in rngs}
self.rng_counters[rng_key] = rng_counters
self.rng_counters[rng_key] = rng_counters # type: ignore
scope = Scope({},
name=name,
rngs=rngs,
Expand Down
4 changes: 2 additions & 2 deletions flax/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BackendMode(Enum):
gfile = None

if importlib.util.find_spec('tensorflow'):
from tensorflow.io import gfile # pytype: disable=import-error
from tensorflow.io import gfile # type: ignore
io_mode = BackendMode.TF
else:
logging.warning("Tensorflow library not found, tensorflow.io.gfile "
Expand Down Expand Up @@ -176,4 +176,4 @@ def rmtree(path):
elif io_mode == BackendMode.TF:
return gfile.rmtree(path)
else:
raise ValueError("Unknown IO Backend Mode.")
raise ValueError("Unknown IO Backend Mode.")
4 changes: 2 additions & 2 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def dot_product_attention_weights(query: Array,
if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
multiplier = (keep.astype(dtype) /
jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
Expand Down
4 changes: 2 additions & 2 deletions flax/linen/dotgetter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def is_leaf(x):
# We subclass dict so that freeze, unfreeze work transparently:
# i.e freeze(DotGetter(d)) == freeze(d)
# unfreeze(DotGetter(d)) == unfreeze(d)
class DotGetter(MutableMapping, dict): # pytype: disable=mro-error
class DotGetter(MutableMapping, dict): # type: ignore[misc] # pytype: disable=mro-error
"""Dot-notation helper for interactive access of variable trees."""
__slots__ = ('_data',)

Expand Down Expand Up @@ -94,7 +94,7 @@ def copy(self, **kwargs):
tree_util.register_pytree_node(
DotGetter,
lambda x: ((x._data,), ()), # pylint: disable=protected-access
lambda _, data: data[0])
lambda _, data: data[0]) # type: ignore

# Note: restores as raw dict, intentionally.
serialization.register_serialization_state(
Expand Down
9 changes: 5 additions & 4 deletions flax/linen/experimental/layers_with_named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@


# Type annotations
IntSequence = Any
Array = jnp.ndarray
Axes = Union[int, Iterable[int]]
DType = jnp.dtype
Axes = Union[int, IntSequence]
DType = Any
PRNGKey = jnp.ndarray
Shape = Iterable[int]
Shape = Any
Activation = Callable[..., Array]
# Parameter initializers.
Initializer = Callable[[PRNGKey, Shape, DType], Array]
Expand Down Expand Up @@ -62,7 +63,7 @@ class Dense(nn.Module):
"""
features: int
use_bias: bool = True
dtype: Any = jnp.float32
dtype: DType = jnp.float32
param_dtype: DType = jnp.float32
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, DType], Array] = default_kernel_init
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class _Conv(Module):
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros

@property
def shared_weights(self) -> bool:
def shared_weights(self) -> bool: # type: ignore
"""Defines whether weights are shared or not between different pixels.

Returns:
Expand Down
8 changes: 4 additions & 4 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def __getattr__(self, name: str) -> Any:
def __dir__(self) -> List[str]:
"""Call setup() before listing attributes."""
self._try_setup()
return object.__dir__(self) # pytype: disable=attribute-error
return object.__dir__(self) # type: ignore

def __post_init__(self) -> None:
# DO NOT REMOVE - Marker for internal logging.
Expand Down Expand Up @@ -974,7 +974,7 @@ def run_setup_only(x):

def _name_taken(self,
name: str,
module: 'Module' = None,
module: Optional['Module'] = None,
reuse_scopes: bool = False) -> bool:
if name in _all_names_on_object(self):
val = getattr(self, name, None)
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def variables(self) -> VariableDict:
raise ValueError("Can't access variables on unbound modules")
return self.scope.variables()

def get_variable(self, col: str, name: str, default: T = None) -> T:
def get_variable(self, col: str, name: str, default: Optional[T] = None) -> T:
"""Retrieves the value of a Variable.

Args:
Expand Down Expand Up @@ -1559,7 +1559,7 @@ def loss(params, perturbations, inputs, targets):
# [-1.456924 -0.44332537 0.02422847]]

"""
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we messing up internal type annotations in flax module methods to require this opt-out?

return value

def tabulate(
Expand Down
10 changes: 5 additions & 5 deletions flax/linen/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _logical_to_mesh_axes(
def logical_to_mesh_axes(
array_dim_names: Optional[Sequence[Optional[str]]],
rules: Optional[LogicalRules] = None,
) -> pjit.PartitionSpec:
) -> Optional[pjit.PartitionSpec]:
"""Compute layout for an array.

The rules are in order of precedence, and consist of pairs:
Expand Down Expand Up @@ -211,7 +211,7 @@ class RulesFallback(enum.Enum):
NO_CONSTRAINT = 'no_constraint'


def _with_sharding_constraint(x: Array, axis_resources: pjit.PartitionSpec):
def _with_sharding_constraint(x: Array, axis_resources: Optional[pjit.PartitionSpec]):
"""Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
if jax.devices()[0].platform == 'cpu' or not _global_mesh_defined():
return x
Expand Down Expand Up @@ -340,7 +340,7 @@ def param_with_axes(
pjit.PartitionSpec(*axes))
# record logical axis constraint for global axis metadata
module.sow(
'params_axes', f'{name}_axes', AxisMetadata(axes),
'params_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore
reduce_fn=_param_with_axes_sow_reduce_fn)
return module_param

Expand Down Expand Up @@ -459,7 +459,7 @@ def variable_with_axes(
if axes is not None:
# record logical axis constraint for global axis metadata
module.sow(
f'{collection}_axes', f'{name}_axes', AxisMetadata(axes),
f'{collection}_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore
reduce_fn=_param_with_axes_sow_reduce_fn)
return module_var

Expand Down Expand Up @@ -625,7 +625,7 @@ def vmap_with_axes(target: flax.linen.transforms.Target,
out_axes=0,
axis_size: Optional[int] = None,
axis_name: Optional[str] = None,
partitioning_axis_names: Mapping[str, str] = {},
partitioning_axis_names: Mapping[Any, str] = {},
spmd_axis_name: Optional[str] = None,
methods=None) -> flax.linen.transforms.Target:
"""Wrapped version of nn.vmap that handles logical axis metadata."""
Expand Down
2 changes: 2 additions & 0 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def _concat_dense(inputs: Array,
inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
y = jnp.dot(inputs, kernel)
if use_bias:
# This assert is here since mypy can't infer that bias cannot be None
assert bias is not None
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))

# Split the result back into individual (i, f, g, o) outputs.
Expand Down
2 changes: 1 addition & 1 deletion flax/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np


_STATE_DICT_REGISTRY = {}
_STATE_DICT_REGISTRY: Dict[Any, Any] = {}


class _ErrorContext(threading.local):
Expand Down
4 changes: 2 additions & 2 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def create(cls, kernel):
if '_flax_dataclass' in clz.__dict__:
return clz

data_clz = dataclasses.dataclass(frozen=True)(clz)
data_clz = dataclasses.dataclass(frozen=True)(clz) # type: ignore
meta_fields = []
data_fields = []
for field_info in dataclasses.fields(data_clz):
Expand Down Expand Up @@ -165,7 +165,7 @@ def from_state_dict(x, state):
# add a _flax_dataclass flag to distinguish from regular dataclasses
data_clz._flax_dataclass = True # type: ignore[attr-defined]

return data_clz
return data_clz # type: ignore


TNode = TypeVar('TNode', bound='PyTreeNode')
Expand Down
Loading