-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathlax.py
8625 lines (7248 loc) · 333 KB
/
lax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import builtins
from collections.abc import Callable, Sequence
import dataclasses
import enum
import functools
from functools import partial
import itertools
import math
import operator
from typing import Any, NamedTuple, TypeVar, Union, cast as type_cast, overload
import warnings
import numpy as np
from jax import tree_util
from jax.sharding import Sharding
from jax.tree_util import tree_map
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.core import (Primitive, UnshapedArray, ShapedArray,
abstract_token, canonicalize_shape)
from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src import mesh as mesh_lib
from jax._src.lax.utils import (
_input_dtype, dtype_to_string, standard_abstract_eval,
standard_multi_result_abstract_eval, standard_primitive)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
PartitionSpec as P, canonicalize_sharding)
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
safe_map, safe_zip, split_list, weakref_lru_cache,
foreach)
_max = builtins.max
_min = builtins.min
_reduce = functools.reduce
T = TypeVar("T")
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
export = util.set_module("jax.lax")
def _matrix_transpose(x: Array) -> Array:
assert x.ndim >= 2
return transpose(x, [*range(x.ndim - 2), x.ndim - 1, x.ndim - 2])
def _clip_int_to_valid_range(val: DimSize, dtype, where: str) -> int:
info = np.iinfo(dtype)
val = core.concrete_dim_or_error(val, where)
return core.max_dim(info.min, core.min_dim(val, info.max))
def _validate_shapes(shapes: Sequence[Shape]):
def _check_static_shape(shape: Shape):
checked = canonicalize_shape(shape)
if not all(idx >= 0 for idx in checked):
msg = f"Only non-negative indices are allowed when broadcasting" \
f" static shapes, but got shape {shape!r}."
raise TypeError(msg)
assert shapes
if config.dynamic_shapes.value:
# pass dynamic shapes through unchecked
return
else:
foreach(_check_static_shape, shapes)
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
"""
Attempt to broadcast shapes, raising a TypeError if broadcasting fails.
"""
if not shapes:
raise TypeError(f"{name}: At least one shape is required.")
ranks = {len(shape) for shape in shapes}
if len(ranks) != 1:
raise TypeError(f'{name}: arrays must have the same number of dimensions,'
f' got {ranks}')
result_shape = []
for ds in zip(*shapes):
if all(core.same_referent(d, ds[0]) for d in ds[1:]):
# if all axes are identical objects, the resulting size is the object
result_shape.append(ds[0])
else:
# if all dims are equal (or 1), the result is the non-1 size
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
if not non_1s:
result_shape.append(1)
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
result_shape.append(non_1s[0])
else:
raise TypeError(f'{name} got incompatible shapes for broadcasting: '
f'{", ".join(map(str, map(tuple, shapes)))}.')
return tuple(result_shape)
def asarray(x: ArrayLike) -> Array:
"""Lightweight conversion of ArrayLike input to Array output."""
if isinstance(x, Array):
return x
elif isinstance(x, (bool, np.ndarray, np.generic)):
return _convert_element_type(x, weak_type=False) # pytype: disable=bad-return-type
elif isinstance(x, (int, float, builtins.complex)):
return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True)
else:
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
@overload
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
@overload
def broadcast_shapes(*shapes: tuple[int | core.Tracer, ...]
) -> tuple[int | core.Tracer, ...]: ...
@export
def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`.
This follows the rules of `NumPy broadcasting`_.
Args:
shapes: one or more tuples of integers containing the shapes of arrays
to be broadcast.
Returns:
A tuple of integers representing the broadcasted shape.
Raises:
ValueError: if shapes are not broadcast-compatible.
See Also:
- :func:`jax.numpy.broadcast_shapes`: similar API in the JAX NumPy namespace
Examples:
Some examples of broadcasting compatible shapes:
>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)
Error when attempting to broadcast incompatible shapes:
>>> jnp.broadcast_shapes((3, 1), (4, 1)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]
.. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
"""
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
try:
return _broadcast_shapes_cached(*shapes)
except:
return _broadcast_shapes_uncached(*shapes)
@cache()
def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
return _broadcast_shapes_uncached(*shapes)
def _broadcast_shapes_uncached(*shapes):
_validate_shapes(shapes)
fst, *rst = shapes
if not rst: return fst
# First check if we need only rank promotion (and not singleton-broadcasting).
result_shape = _max(shapes, key=len)
ndim = len(result_shape)
if ndim == 0 or all(core.definitely_equal_shape(result_shape[ndim - len(s):], s) for s in shapes):
return result_shape
# Next try singleton-broadcasting, padding out ranks using singletons.
rank_promoted_shapes = tuple((*((1,) * (ndim - len(shape))), *shape) for shape in shapes)
try:
return _try_broadcast_shapes(*rank_promoted_shapes, name='broadcast_shapes')
except TypeError as err:
# Raise ValueError here for backward compatibility.
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
def broadcast_shardings(*avals):
fst, *rst = avals
if not rst:
return fst.sharding
# First check if we need only rank promotion (and not singleton-broadcasting).
res_aval = _max(avals, key=lambda a: a.ndim)
ndim = res_aval.ndim
if ndim == 0 or all(
res_aval.sharding.spec[ndim - a.ndim:] == a.sharding.spec for a in avals):
return res_aval.sharding
# Next try singleton-broadcasting, padding out ranks using singletons.
aval_list = []
for a in avals:
new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec)
new_shape = (1,) * (ndim - a.ndim) + a.shape
aval_list.append(a.update(shape=new_shape,
sharding=a.sharding.with_spec(new_spec)))
return broadcasting_sharding_rule('broadcast_shardings', *aval_list)
def _identity(x, **_): return x
def _extract_tracers_dyn_shape(
shape: Sequence[int | core.Tracer]
) -> tuple[list[core.Tracer], list[int | None]]:
# Given a sequence representing a shape, pull out Tracers, replacing with None
if config.dynamic_shapes.value:
# We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information).
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
return dyn_shape, static_shape
else:
return [], list(shape) # type: ignore
def _merge_dyn_shape(
static_shape: Sequence[int | None],
dyn_shape: Sequence[Any],
) -> tuple[int | mlir.Value | core.Tracer, ...]:
# Replace Nones in static_shape with elements of dyn_shape, in order
dyn_shape_it = iter(dyn_shape)
shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
assert next(dyn_shape_it, None) is None
return shape
def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params):
source_info = source_info_util.current()
out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info)
eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args],
[trace.makevar(out_tracer)],
prim, params, core.no_effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracer
### traceables
@export
def neg(x: ArrayLike) -> Array:
r"""Elementwise negation: :math:`-x`.
This function lowers directly to the `stablehlo.negate`_ operation.
Args:
x: input array
Returns:
Array of same shape and dtype as ``x``, containing the element-wise negative.
Notes:
For unsigned integer inputs, this function returns ``2 ** nbits - x``, where
``nbits`` is the number of bits in the integer representation.
.. _stablehlo.negate: https://openxla.org/stablehlo/spec#negate
"""
return neg_p.bind(x)
@export
def sign(x: ArrayLike) -> Array:
r"""Elementwise sign.
This function lowers directly to the `stablehlo.sign`_ operation.
Args:
x: input array
Returns:
Array of same shape and dtype as ``x``, containing the sign
of the value, as defined in Notes below.
Notes:
For floating-point inputs, returns
.. math::
\mathrm{sign}(x) = \begin{cases}
-1 & x < 0\\
-0 & x = -0\\
\mathit{NaN} & x = \mathit{NaN}\\
+0 & x = +0\\
1 & x > 0
\end{cases}
For signed integer inputs, returns
.. math::
\mathrm{sign}(x) = \begin{cases}
-1 & x < 0\\
0 & x = 0\\
1 & x > 0
\end{cases}
For complex inputs, returns the complex phase, i.e.
:math:`\mathrm{sign}(x) = x / |x|`.
.. _stablehlo.sign: https://openxla.org/stablehlo/spec#sign
"""
return sign_p.bind(x)
@export
def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
"""Returns the next representable value after ``x1`` in the direction of ``x2``.
This function lowers directly to the ``chlo.next_after`` operation.
Args:
x1, x2: input arrays. Must have a matching floating-point dtypes. If neither is
a scalar, must have the same number of dimensions and be broadcast-compatible.
Returns:
Array of the same dtype and broadcasted shape of the inputs, containing the
next representable floating-point value after ``x1`` in the direction of
``x2``.
Notes:
In some environments flush-denormal-to-zero semantics is used.
This means that, around zero, this function returns strictly non-zero
values which appear as zero in any operations. Consider this example::
>>> from jax import lax
>>> lax.nextafter(0.0, 1.0) # denormal numbers are representable
Array(1.e-45, dtype=float32, weak_type=True)
>>> lax.nextafter(0.0, 1.0) * 1 # but are flushed to zero
Array(0., dtype=float32, weak_type=True)
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
"""
return nextafter_p.bind(x1, x2)
@export
def floor(x: ArrayLike) -> Array:
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.
This function lowers directly to the `stablehlo.floor`_ operation.
Args:
x: input array. Must have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
to the next integer toward negative infinity.
See also:
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
- :func:`jax.lax.round`: round to the nearest integer
.. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor
"""
return floor_p.bind(x)
@export
def ceil(x: ArrayLike) -> Array:
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.
This function lowers directly to the `stablehlo.ceil`_ operation.
Args:
x: input array. Must have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
to the next integer toward positive infinity.
See also:
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
- :func:`jax.lax.round`: round to the nearest integer
.. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil
"""
return ceil_p.bind(x)
class RoundingMethod(enum.IntEnum):
"""Rounding strategies for handling halfway values (e.g., 0.5) in
:func:`jax.lax.round`.
"""
AWAY_FROM_ZERO = 0
"""Rounds halfway values away from zero (e.g., 0.5 -> 1, -0.5 -> -1)."""
TO_NEAREST_EVEN = 1
"""Rounds halfway values to the nearest even integer. This is also known
as “banker’s rounding” (e.g., 0.5 -> 0, 1.5 -> 2).
"""
@export
def round(x: ArrayLike,
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
) -> Array:
r"""Elementwise round.
Rounds values to the nearest integer. This function lowers directly to the
`stablehlo.round`_ operation.
Args:
x: an array or scalar value to round. Must have floating-point type.
rounding_method: the method to use when rounding halfway values
(e.g., ``0.5``). See :class:`jax.lax.RoundingMethod` for possible values.
Returns:
An array of the same shape and dtype as ``x``, containing the elementwise
rounding of ``x``.
See also:
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
Examples:
>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
>>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO
Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32)
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)
.. _stablehlo.round: https://openxla.org/stablehlo/spec#round
"""
rounding_method = RoundingMethod(rounding_method)
return round_p.bind(x, rounding_method=rounding_method)
@export
def is_finite(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{isfinite}`.
This function lowers directly to the `stablehlo.is_finite`_ operation.
Args:
x: input array. Must have floating-point type.
Returns:
Array of boolean dtype with the same shape as ``x``, containing ``False`` where
``x`` is :math:`\pm\infty` or :math:`\mathit{NaN}`, and ``True`` otherwise.
See also:
- :func:`jax.numpy.isinf`: return True where array is infinite.
- :func:`jax.numpy.isnan`: return True where array is NaN.
.. _stablehlo.is_finite: https://openxla.org/stablehlo/spec#is_finite
"""
return is_finite_p.bind(x)
@export
def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`.
This function lowers directly to the `stablehlo.exponential`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential.
See also:
- :func:`jax.lax.exp2`: elementwise base-2 exponentional: :math:`2^x`.
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
"""
return exp_p.bind(x)
@export
def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`.
This function is implemented in terms of the `stablehlo.exponential`_
and `stablehlo.multiply`_ operations.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
base-2 exponential.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
- :func:`jax.lax.log`: elementwise natural logarithm: :math:`\mathrm{log}(x)`.
.. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return exp2_p.bind(x)
@export
def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`.
This function lowers directly to the `stablehlo.exponential_minus_one`_
operation. Compared to the naive expression ``lax.exp(x) - 1``, it is
more accurate for ``x`` near zero.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
exponential minus 1.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
- :func:`jax.lax.log1p`: elementwise :math:`\mathrm{log}(1 + x)`.
.. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one
"""
return expm1_p.bind(x)
@export
def log(x: ArrayLike) -> Array:
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`.
This function lowers directly to the `stablehlo.log`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm.
See also:
- :func:`jax.lax.exp`: elementwise exponentional: :math:`e^x`.
.. _stablehlo.log: https://openxla.org/stablehlo/spec#log
"""
return log_p.bind(x)
@export
def log1p(x: ArrayLike) -> Array:
r"""Elementwise :math:`\mathrm{log}(1 + x)`.
This function lowers directly to the `stablehlo.log_plus_one`_ operation.
Compared to the naive expression ``lax.log(1 + x)``, it is more accurate
for ``x`` near zero.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
natural logarithm of ``x + 1``.
See also:
- :func:`jax.lax.expm1`: elementwise :math:`e^x - 1`.
- :func:`jax.lax.log`: elementwise natural logarithm :math:`\mathrm{log}(x)`.
.. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one
"""
return log1p_p.bind(x)
@export
def tanh(x: ArrayLike) -> Array:
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
This function lowers directly to the `stablehlo.tanh`_ operation.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
hyperbolic tangent.
See also:
- :func:`jax.lax.atanh`: elementwise inverse hyperbolic tangent.
- :func:`jax.lax.cosh`: elementwise hyperbolic cosine.
- :func:`jax.lax.sinh`: elementwise hyperbolic sine.
.. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh
"""
return tanh_p.bind(x)
@export
def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
of HLO arithmetic operations.
Args:
x: input array. Must have floating point or complex dtype.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
logistic/sigmoid function.
See also:
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
"""
return logistic_p.bind(x)
@export
def sin(x: ArrayLike) -> Array:
r"""Elementwise sine: :math:`\mathrm{sin}(x)`.
For floating-point inputs, this function lowers directly to the
`stablehlo.sine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex sine.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
sine.
See also:
- :func:`jax.lax.cos`: elementwise cosine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.asin`: elementwise arc sine.
.. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine
"""
return sin_p.bind(x)
@export
def cos(x: ArrayLike) -> Array:
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`.
For floating-point inputs, this function lowers directly to the
`stablehlo.cosine`_ operation. For complex inputs, it lowers to a
sequence of HLO operations implementing the complex cosine.
Args:
x: input array. Must have floating-point or complex type.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
cosine.
See also:
- :func:`jax.lax.sin`: elementwise sine.
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.acos`: elementwise arc cosine.
.. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine
"""
return cos_p.bind(x)
@export
def atan2(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise two-term arc tangent: :math:`\mathrm{atan}({x \over y})`.
This function lowers directly to the `stablehlo.atan2`_ operation.
Args:
x, y: input arrays. Must have a matching floating-point or complex dtypes. If
neither is a scalar, the two arrays must have the same number of dimensions
and be broadcast-compatible.
Returns:
Array of the same shape and dtype as ``x`` and ``y`` containing the element-wise
arc tangent of :math:`x \over y`, respecting the quadrant indicated by the sign
of each input.
See also:
- :func:`jax.lax.tan`: elementwise tangent.
- :func:`jax.lax.atan`: elementwise one-term arc tangent.
.. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2
"""
return atan2_p.bind(x, y)
@export
def real(x: ArrayLike) -> Array:
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
This function lowers directly to the `stablehlo.real`_ operation.
Args:
x: input array. Must have complex dtype.
Returns:
Array of the same shape as ``x`` containing its real part. Will have dtype
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.conj`: elementwise complex conjugate.
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
"""
return real_p.bind(x)
@export
def imag(x: ArrayLike) -> Array:
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
This function lowers directly to the `stablehlo.imag`_ operation.
Args:
x: input array. Must have complex dtype.
Returns:
Array of the same shape as ``x`` containing its imaginary part. Will have dtype
float32 if ``x.dtype == complex64``, or float64 if ``x.dtype == complex128``.
See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.conj`: elementwise complex conjugate.
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
"""
return imag_p.bind(x)
@export
def complex(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise make complex number: :math:`x + jy`.
This function lowers directly to the `stablehlo.complex`_ operation.
Args:
x, y: input arrays. Must have matching floating-point dtypes. If
neither is a scalar, the two arrays must have the same number
of dimensions and be broadcast-compatible.
Returns:
The complex array with the real part given by ``x``, and the imaginary
part given by ``y``. For inputs of dtype float32 or float64, the result
will have dtype complex64 or complex128 respectively.
See also:
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.conj`: elementwise complex conjugate.
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
"""
return complex_p.bind(x, y)
@export
def conj(x: ArrayLike) -> Array:
r"""Elementwise complex conjugate function: :math:`\overline{x}`.
This function lowers to a combination of `stablehlo.real`_, `stablehlo.imag`_,
and `stablehlo.complex`_.
Args:
x: input array. Must have complex dtype.
Returns:
Array of the same shape and dtype as ``x`` containing its complex conjugate.
See also:
- :func:`jax.lax.complex`: elementwise construct complex number.
- :func:`jax.lax.real`: elementwise extract real part.
- :func:`jax.lax.imag`: elementwise extract imaginary part.
- :func:`jax.lax.abs`: elementwise absolute value / complex magnitude.
.. _stablehlo.real: https://openxla.org/stablehlo/spec#real
.. _stablehlo.imag: https://openxla.org/stablehlo/spec#imag
.. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex
"""
# TODO(mattjj): remove input_dtype, not needed anymore
return conj_p.bind(x, input_dtype=_dtype(x))
@export
def abs(x: ArrayLike) -> Array:
r"""Elementwise absolute value: :math:`|x|`.
This function lowers directly to the `stablehlo.abs`_ operation.
Args:
x: Input array. Must have signed integer, floating, or complex dtype.
Returns:
An array of the same shape as ``x`` containing the elementwise absolute value.
For complex valued input, :math:`a + ib`, ``abs(x)`` returns :math:`\sqrt{a^2+b^2}`.
See also:
- :func:`jax.numpy.abs`: a more flexible NumPy-style ``abs`` implementation.
.. _stablehlo.abs: https://openxla.org/stablehlo/spec#abs
"""
return abs_p.bind(x)
@export
def pow(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise power: :math:`x^y`.
This function lowers directly to the `stablehlo.pow`_ operation, along with
a `stablehlo.convert`_ when the argument dtypes do not match.
Args:
x: Input array giving the base value. Must have floating or complex type.
y: Input array giving the exponent value. Must have integer, floating, or
complex type. Its dtype will be cast to that of ``x.dtype`` if necessary.
If neither ``x`` nor ``y`` is a scalar, then ``x`` and ``y`` must have
the same number of dimensions and be broadcast-compatible.
Returns:
An array of the same dtype as ``x`` containing the elementwise power.
See also:
:func:`jax.lax.integer_pow`: Elementwise power where ``y`` is a static integer.
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow
"""
return pow_p.bind(x, y)
@export
def integer_pow(x: ArrayLike, y: int) -> Array:
r"""Elementwise power: :math:`x^y`, where :math:`y` is a static integer.
This will lower to a sequence of :math:`O[\log_2(y)]` repetitions of
`stablehlo.multiply`_.
Args:
x: Input array giving the base value. Must have numerical dtype.
y: Static scalar integer giving the exponent.
Returns:
An array of the same shape and dtype as ``x`` containing the elementwise power.
See also:
:func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array.
.. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply
"""
return integer_pow_p.bind(x, y=y)
@export
def sqrt(x: ArrayLike) -> Array:
r"""Elementwise square root: :math:`\sqrt{x}`.
This function lowers directly to the `stablehlo.sqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the square root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.cbrt`: Elementwise cube root.
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
.. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt
"""
return sqrt_p.bind(x)
@export
def rsqrt(x: ArrayLike) -> Array:
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`.
This function lowers directly to the `stablehlo.rsqrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the
reciporical square root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.sqrt`: Elementwise square root.
:func:`jax.lax.cbrt`: Elementwise cube root.
.. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt
"""
return rsqrt_p.bind(x)
@export
def cbrt(x: ArrayLike) -> Array:
r"""Elementwise cube root: :math:`\sqrt[3]{x}`.
This function lowers directly to the `stablehlo.cbrt`_ operation.
Args:
x: Input array. Must have floating or complex dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the cube root.
See also:
:func:`jax.lax.pow`: Elementwise power.
:func:`jax.lax.sqrt`: Elementwise square root.
:func:`jax.lax.rsqrt`: Elementwise reciporical square root.
.. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt
"""
return cbrt_p.bind(x)
@export
def bitwise_not(x: ArrayLike) -> Array:
r"""Elementwise NOT: :math:`\neg x`.
This function lowers directly to the `stablehlo.not`_ operation.
Args:
x: Input array. Must have boolean or integer dtype.
Returns:
An array of the same shape and dtype as ``x`` containing the bitwise
inversion of each entry.
See also:
- :func:`jax.numpy.invert`: NumPy wrapper for this API, also accessible
via the ``~x`` operator on JAX arrays.
- :func:`jax.lax.bitwise_and`: Elementwise AND.
- :func:`jax.lax.bitwise_or`: Elementwise OR.
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
.. _stablehlo.not: https://openxla.org/stablehlo/spec#not
"""
return not_p.bind(x)
@export
def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise AND: :math:`x \wedge y`.
This function lowers directly to the `stablehlo.and`_ operation.
Args:
x, y: Input arrays. Must have matching boolean or integer dtypes.
If neither is a scalar, ``x`` and ``y`` must have the same number
of dimensions and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the bitwise
AND of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.bitwise_and`: NumPy wrapper for this API, also accessible
via the ``x & y`` operator on JAX arrays.
- :func:`jax.lax.bitwise_not`: Elementwise NOT.
- :func:`jax.lax.bitwise_or`: Elementwise OR.
- :func:`jax.lax.bitwise_xor`: Elementwise exclusive OR.
.. _stablehlo.and: https://openxla.org/stablehlo/spec#and
"""
return and_p.bind(x, y)
@export
def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise OR: :math:`x \vee y`.
This function lowers directly to the `stablehlo.or`_ operation.
Args:
x, y: Input arrays. Must have matching boolean or integer dtypes.
If neither is a scalar, ``x`` and ``y`` must have the same number
of dimensions and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the bitwise
OR of each pair of broadcasted entries.
See also:
- :func:`jax.numpy.invert`: NumPy wrapper for this API, also accessible
via the ``x | y`` operator on JAX arrays.