The Wayback Machine - https://web.archive.org/web/20220618205250/https://github.com/google/jax/commit/5782210174e3933f6c6c7c61591f7923efd029b9
Skip to content
Permalink
Browse files
CI: fix flake8 ignore declarations
  • Loading branch information
jakevdp committed Apr 21, 2022
1 parent f1104cf commit 5782210174e3933f6c6c7c61591f7923efd029b9
Showing 67 changed files with 30 additions and 159 deletions.
@@ -29,8 +29,6 @@
del _warn
del _cloud_tpu_init

# flake8: noqa: F401

# Confusingly there are two things named "config": the module and the class.
# We want the exported object to be the class, so we first import the module
# to make sure a later import doesn't overwrite the class.
@@ -21,7 +21,6 @@
arrays.
"""

# flake8: noqa: F401
import collections
import functools
from functools import partial
@@ -1540,7 +1539,7 @@ def pmap(
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None,
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
@@ -1876,7 +1875,7 @@ def _get_f_mapped(
in_axes=0,
out_axes=0,
static_broadcasted_tuple: Tuple[int],
devices: Optional[Sequence[xc.Device]],
devices: Optional[Sequence[xc.Device]], # noqa: F811
backend: Optional[str],
axis_size: Optional[int],
donate_tuple: Tuple[int],
@@ -1926,7 +1925,7 @@ def _python_pmap(
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None,
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
@@ -1985,7 +1984,7 @@ def _cpp_pmap(
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None,
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
@@ -2061,7 +2060,7 @@ def cache_miss(*args, **kwargs):


def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
devices, backend, axis_size, global_arg_shapes, donate_tuple):
devices, backend, axis_size, global_arg_shapes, donate_tuple): # noqa: F811
"""Make a ``lower`` method for pmapped functions."""
# If the function we returned from ``pmap`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and
@@ -2667,7 +2666,7 @@ def device_put(x, device: Optional[xc.Device] = None):
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)


def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]):
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array shards to specified devices and form ShardedDeviceArray(s).
Args:
@@ -2739,7 +2738,7 @@ def _device_put_sharded(*xs):
return tree_map(_device_put_sharded, *shards)


def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array(s) to each specified device and form ShardedDeviceArray(s).
Args:
@@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401

from contextlib import contextmanager
import inspect
import functools
from functools import partial
import re
import os
import textwrap
from typing import Dict, List, Generator, Sequence, Tuple, Union, NamedTuple
from typing import Dict, List, Generator, Sequence, Tuple, Union
import unittest
import warnings
import zlib
@@ -32,7 +30,6 @@
import numpy as np
import numpy.random as npr

from jax import stages
from jax._src import api
from jax import core
from jax._src import dtypes as _dtypes
@@ -42,14 +39,11 @@
from jax.tree_util import tree_map, tree_all
from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax._src.public_test_util import (
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.maps import Mesh
from jax.interpreters.pxla import PartitionSpec
from jax.experimental import pjit

# This submodule includes private test utilities that are not exported to
# jax.test_util. Functionality appearing here is for internal use only, and
@@ -14,7 +14,6 @@

# TODO(phawkins): fix users of these aliases and delete this file.

# flake8: noqa: F401
from jax._src.abstract_arrays import array_types
from jax.core import (
ShapedArray,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.ad_checkpoint import (
checkpoint,
checkpoint_policies,
@@ -13,7 +13,6 @@
# limitations under the License.


# flake8: noqa: F401
from jax._src.api_util import (
argnums_partial,
donation_vector,
@@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.cloud_tpu_init import cloud_tpu_init
@@ -14,5 +14,4 @@

# TODO(phawkins): fix users of this alias and delete this file.

# flake8: noqa: F401
from jax._src.config import config
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.custom_batching import (
custom_vmap,
sequential_vmap,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.custom_derivatives import (
_initial_style_jaxpr,
_sum_tangents,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.custom_transpose import (
custom_transpose,
)
@@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.distributed import initialize
@@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.dlpack import (to_dlpack, from_dlpack, SUPPORTED_DTYPES)
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.dtypes import (
_jax_types, # TODO(phawkins): fix users and remove?
bfloat16 as bfloat16,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.errors import (
JAXTypeError as JAXTypeError,
JAXIndexError as JAXIndexError,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax.interpreters.sharded_jit import (
sharded_jit as sharded_jit,
with_sharding_constraint as with_sharding_constraint,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.checkify import (
Error as Error,
ErrorCategory as ErrorCategory,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val,
split_to_logical_devices, PolyShape)
from jax.experimental.jax2tf.call_tf import call_tf
@@ -182,7 +182,6 @@
-0.670236 0.03132951 -0.05356663]
"""

# flake8: noqa: F401
from jax.experimental.sparse.ad import (
grad as grad,
value_and_grad as value_and_grad,
@@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.flatten_util import ravel_pytree
@@ -20,7 +20,6 @@
.. _PIX: https://github.com/deepmind/dm_pix
"""

# flake8: noqa: F401
from jax._src.image.scale import (
resize as resize,
ResizeMethod as ResizeMethod,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.lax.lax import (
DotDimensionNumbers as DotDimensionNumbers,
Precision as Precision,
@@ -24,11 +23,6 @@
acos_p as acos_p,
acosh as acosh,
acosh_p as acosh_p,
abs as abs,
abs_p as abs_p,
acos as acos,
acosh as acosh,
acosh_p as acosh_p,
add as add,
add_p as add_p,
after_all as after_all,
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from jax._src.lax.linalg import (
cholesky,
cholesky_p,
@@ -14,8 +14,6 @@

"""Common functions for neural network libraries."""

# flake8: noqa: F401

from jax.numpy import tanh as tanh
from jax.nn import initializers as initializers
from jax._src.nn.functions import (
@@ -17,7 +17,6 @@
used in Keras and Sonnet.
"""

# flake8: noqa: F401
from jax._src.nn.initializers import (
constant as constant,
delta_orthogonal as delta_orthogonal,

0 comments on commit 5782210

Please sign in to comment.