The Wayback Machine - https://web.archive.org/web/20220618205254/https://github.com/google/jax/commit/17de89b16ac5ee05aee03115d858e67489eab973
Skip to content
Permalink
Browse files
feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
  • Loading branch information
JeppeKlitgaard committed May 17, 2022
1 parent ea868b4 commit 17de89b16ac5ee05aee03115d858e67489eab973
Showing 112 changed files with 417 additions and 422 deletions.
@@ -101,7 +101,7 @@ def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],
times = []
for params in params_list:
f = prepare(**params)
subname = name + "".join("_%s=%s" % (n, _param_str(p))
subname = name + "".join(f"_{n}={_param_str(p)}"
for n, p in params.items())
times.append(benchmark(f, name=subname,
target_total_secs=target_total_secs))
@@ -126,7 +126,7 @@ def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],

if FLAGS.export_dir:
filename = _export_results(data_header, data, FLAGS.export_dir, name)
print("Wrote %s results to %s" % (name, filename))
print(f"Wrote {name} results to {filename}")
print()


@@ -135,7 +135,7 @@ def download_and_verify_bazel():

if not os.access(package.file, os.X_OK):
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
sys.stdout.write("Downloading bazel from: {}\n".format(uri))
sys.stdout.write(f"Downloading bazel from: {uri}\n")

def progress(block_count, block_size, total_size):
if total_size <= 0:
@@ -291,7 +291,7 @@ def _parse_string_as_bool(s):
elif lower == "false":
return False
else:
raise ValueError("Expected either 'true' or 'false'; got {}".format(s))
raise ValueError(f"Expected either 'true' or 'false'; got {s}")


def add_boolean_argument(parser, name, default=False, help_str=None):
@@ -438,42 +438,42 @@ def main():
print(f"Bazel version: {bazel_version}")

python_bin_path = get_python_bin_path(args.python_bin_path)
print("Python binary path: {}".format(python_bin_path))
print(f"Python binary path: {python_bin_path}")
python_version = get_python_version(python_bin_path)
print("Python version: {}".format(".".join(map(str, python_version))))
check_python_version(python_version)

numpy_version = check_numpy_version(python_bin_path)
print("NumPy version: {}".format(numpy_version))
print(f"NumPy version: {numpy_version}")

print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
print("Target CPU: {}".format(wheel_cpu))
print("Target CPU features: {}".format(args.target_cpu_features))
print(f"Target CPU: {wheel_cpu}")
print(f"Target CPU features: {args.target_cpu_features}")

cuda_toolkit_path = args.cuda_path
cudnn_install_path = args.cudnn_path
rocm_toolkit_path = args.rocm_path
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
if args.enable_cuda:
if cuda_toolkit_path:
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
print(f"CUDA toolkit path: {cuda_toolkit_path}")
if cudnn_install_path:
print("CUDNN library path: {}".format(cudnn_install_path))
print(f"CUDNN library path: {cudnn_install_path}")
if args.cuda_compute_capabilities is not None:
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}")
if args.cuda_version:
print("CUDA version: {}".format(args.cuda_version))
print(f"CUDA version: {args.cuda_version}")
if args.cudnn_version:
print("CUDNN version: {}".format(args.cudnn_version))
print(f"CUDNN version: {args.cudnn_version}")
print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no"))

print("TPU enabled: {}".format("yes" if args.enable_tpu else "no"))

print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no"))
if args.enable_rocm:
if rocm_toolkit_path:
print("ROCm toolkit path: {}".format(rocm_toolkit_path))
print("ROCm amdgpu targets: {}".format(args.rocm_amdgpu_targets))
print(f"ROCm toolkit path: {rocm_toolkit_path}")
print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}")

write_bazelrc(
python_bin_path=python_bin_path,
@@ -165,7 +165,7 @@ def verify_mac_libraries_dont_reference_chkstack():
["nm", "-g",
r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so")
],
stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True,
capture_output=True, text=True,
check=False)
if nm.returncode != 0:
raise RuntimeError(f"nm process failed: {nm.stdout} {nm.stderr}")
@@ -29,8 +29,7 @@ def run_shell_command(cmd, shell=False, env_vars={}):
env = {**env, **env_vars}
result = subprocess.run(cmd,
shell=shell,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
capture_output=True,
env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
@@ -3809,7 +3809,7 @@
"outputs": [],
"source": [
"def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):\n",
" undef_primals = tuple([type(x) is UndefPrimal for x in invals])\n",
" undef_primals = tuple(type(x) is UndefPrimal for x in invals)\n",
" true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)\n",
" false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)\n",
" true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
@@ -2983,7 +2983,7 @@ Transposition is a fairly straightforward application of `transpose_jaxpr`:

```{code-cell} ipython3
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple([type(x) is UndefPrimal for x in invals])
undef_primals = tuple(type(x) is UndefPrimal for x in invals)
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
@@ -2971,7 +2971,7 @@ def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,
# Transposition is a fairly straightforward application of `transpose_jaxpr`:

def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple([type(x) is UndefPrimal for x in invals])
undef_primals = tuple(type(x) is UndefPrimal for x in invals)
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
@@ -32,7 +32,7 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None,
if not text.isdigit():
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.")
options = {} if options is None else options
url = "https://github.com/google/jax/issues/{}".format(text)
url = f"https://github.com/google/jax/issues/{text}"
node = nodes.reference(rawtext, '#' + text, refuri=url, **options)
return [node], []

@@ -92,7 +92,7 @@ def objective(params, t):
approx_dist = lambda x, params: jnp.exp(diag_gaussian_logpdf(x, *params))

def callback(params, t):
print("Iteration {} lower bound {}".format(t, objective(params, t)))
print(f"Iteration {t} lower bound {objective(params, t)}")

plt.cla()
X, Y, Z = mesh_eval(target_dist, x_limits, y_limits, 1)
@@ -35,7 +35,7 @@ def _download(url, filename):
out_file = path.join(_DATA, filename)
if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)
print("downloaded {} to {}".format(url, _DATA))
print(f"downloaded {url} to {_DATA}")


def _partial_flatten(x):
@@ -230,7 +230,7 @@ def private_update(rng, i, opt_state, batch):
opt_state = update(
key, next(itercount), opt_state, shape_as_image(*next(batches)))
epoch_time = time.time() - start_time
print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time))
print(f'Epoch {epoch} in {epoch_time:0.2f} sec')

# evaluate test accuracy
params = get_params(opt_state)
@@ -245,7 +245,7 @@ def private_update(rng, i, opt_state, batch):
num_examples = 60000
eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
print(
'For delta={:.0e}, the current epsilon is: {:.2f}'.format(delta, eps))
f'For delta={delta:.0e}, the current epsilon is: {eps:.2f}')
else:
print('Trained with vanilla non-private SGD optimizer')

@@ -50,7 +50,7 @@ def setUp(self):
self.rng = np.random.default_rng(zlib.adler32(self.__class__.__name__.encode()))

@parameterized.named_parameters(
{"testcase_name": "_input_shape={}".format(input_shape),
{"testcase_name": f"_input_shape={input_shape}",
"input_shape": input_shape}
for input_shape in [(2, 20, 25, 2)])
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
@@ -59,7 +59,7 @@ def testIdentityBlockShape(self, input_shape):
_CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

@parameterized.named_parameters(
{"testcase_name": "_input_shape={}".format(input_shape),
{"testcase_name": f"_input_shape={input_shape}",
"input_shape": input_shape}
for input_shape in [(2, 20, 25, 3)])
@unittest.skipIf(config.x64_enabled, "skip in x64 mode")
@@ -95,8 +95,8 @@ def gp(params, x, y, xtest=None, compute_marginal_likelihood=False):
params = {"amplitude": jnp.zeros((1, 1)),
"noise": jnp.zeros((1, 1)) - 5.,
"lengthscale": jnp.zeros((1, 1))}
momentums = dict([(k, p * 0.) for k, p in params.items()])
scales = dict([(k, p * 0. + 1.) for k, p in params.items()])
momentums = {k: p * 0. for k, p in params.items()}
scales = {k: p * 0. + 1. for k, p in params.items()}

lr = 0.01 # Learning rate
def train_step(params, momentums, scales, x, y):
@@ -92,6 +92,6 @@ def update(i, opt_state, batch):
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")
@@ -90,6 +90,6 @@ def update(params, batch):

train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")
@@ -133,5 +133,5 @@ def evaluate(opt_state, images):
tic = time.time()
opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
test_elbo, sampled_images = evaluate(opt_state, test_images)
print("{: 3d} {} ({:.3f} sec)".format(epoch, test_elbo, time.time() - tic))
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
plt.imsave(imfile.format(epoch), sampled_images, cmap=plt.cm.gray)
@@ -123,6 +123,6 @@ def spmd_update(params, batch):
params = tree_map(lambda x: x[0], replicated_params)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training set accuracy {train_acc}")
print(f"Test set accuracy {test_acc}")
@@ -66,7 +66,7 @@ class Zero:
def __init__(self, aval):
self.aval = aval
def __repr__(self):
return 'Zero({})'.format(self.aval)
return f'Zero({self.aval})'
@staticmethod
def from_value(val):
return Zero(raise_to_shaped(get_aval(val)))
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -68,7 +68,7 @@ def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
ans = fun(*args)
return tree_unflatten(out_tree, ans)

@@ -82,7 +82,7 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
raise TypeError(f"Expected {in_tree_expected}, got {in_tree}")
ans = fun(*args)
return tree_unflatten(out_tree, ans)

@@ -46,7 +46,7 @@ def bool_env(varname: str, default: bool) -> bool:
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
else:
raise ValueError("invalid truth value %r for environment %r" % (val, varname))
raise ValueError(f"invalid truth value {val!r} for environment {varname!r}")

def int_env(varname: str, default: int) -> int:
"""Read an environment variable and interpret it as an integer."""
@@ -81,7 +81,7 @@ def update(self, name, val):
else:
self.check_exists(name)
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))
raise Exception(f"Unrecognized config option: {name}")
self.values[name] = val

hook = self._update_hooks.get(name, None)
@@ -105,7 +105,7 @@ def _read(self, name):
def add_option(self, name, default, opt_type, meta_args, meta_kwargs,
update_hook=None):
if name in self.values:
raise Exception("Config option {} already defined".format(name))
raise Exception(f"Config option {name} already defined")
self.values[name] = default
self.meta[name] = (opt_type, meta_args, meta_kwargs)
if update_hook:
@@ -114,7 +114,7 @@ def add_option(self, name, default, opt_type, meta_args, meta_kwargs,

def check_exists(self, name):
if name not in self.values:
raise AttributeError("Unrecognized config option: {}".format(name))
raise AttributeError(f"Unrecognized config option: {name}")

def DEFINE_bool(self, name, default, *args, **kwargs):
update_hook = kwargs.pop("update_hook", None)
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -233,7 +232,7 @@ def _add_args(f, extra_args):

@lu.transformation
def _add_args_(extra_args, *args, **kwargs):
extra_args = tuple([arg.val for arg in extra_args])
extra_args = tuple(arg.val for arg in extra_args)
all_args = (extra_args + args)
yield (yield all_args, kwargs)

@@ -271,7 +270,7 @@ def _flatten_jvp(in_tree, *args):
msg = ("Custom JVP rule must produce primal and tangent outputs with "
"equal shapes and dtypes, but got:\n{}")
disagreements = (
" primal {} for tangent {}".format(av1.str_short(), av2.str_short())
f" primal {av1.str_short()} for tangent {av2.str_short()}"
for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, out_tree
@@ -212,7 +212,7 @@ def __repr__(self):
sep = ' '
if last_line_len + len(dtype_str) + 1 > line_width:
sep = ' ' * len(prefix)
return "{}{},{}{}".format(prefix, s, sep, dtype_str)
return f"{prefix}{s},{sep}{dtype_str}"

setattr(device_array, "__repr__", __repr__)

@@ -298,7 +298,7 @@ def raise_not_implemented():
# pylint: enable=protected-access


class DeletedBuffer(object): pass
class DeletedBuffer: pass
deleted_buffer = DeletedBuffer()


@@ -545,7 +545,7 @@ def elaborate_and_pad(explicit_args):
if args[i] != args[j].shape[k]:
raise Exception("inconsistent argument axis sizes for type")
if needs_padding:
args = tuple([pad(x) if pad else x for x, pad in zip(args, padders)])
args = tuple(pad(x) if pad else x for x, pad in zip(args, padders))
return args
return elaborate_and_pad

@@ -92,7 +92,7 @@ def scalar_type_of(x):
elif np.issubdtype(typ, np.complexfloating):
return complex
else:
raise TypeError("Invalid scalar value {}".format(x))
raise TypeError(f"Invalid scalar value {x}")


def _scalar_type_to_dtype(typ: type, value: Any = None):

0 comments on commit 17de89b

Please sign in to comment.