24
24
from absl import logging
25
25
import jax
26
26
from jax_tpu_embedding .sparsecore .lib .fdo import fdo_client
27
+ from jax_tpu_embedding .sparsecore .lib .nn import embedding
27
28
import numpy as np
28
29
29
30
30
31
_FILE_NAME = 'fdo_stats'
31
32
_FILE_EXTENSION = 'npz'
32
- _MAX_ID_STATS_KEY = '_max_ids'
33
- _MAX_UNIQUE_ID_STATS_KEY = '_max_unique_ids'
33
+ _MAX_ID_STATS_SUFFIX = '_max_ids'
34
+ _MAX_UNIQUE_ID_STATS_SUFFIX = '_max_unique_ids'
34
35
35
36
36
37
class NPZFileFDOClient (fdo_client .FDOClient ):
@@ -57,7 +58,7 @@ def __init__(self, base_dir: str):
57
58
self ._max_ids_per_partition = collections .defaultdict (np .ndarray )
58
59
self ._max_unique_ids_per_partition = collections .defaultdict (np .ndarray )
59
60
60
- def record (self , data : Mapping [ str , Mapping [ str , np . ndarray ]] ) -> None :
61
+ def record (self , data : embedding . SparseDenseMatmulInputStats ) -> None :
61
62
"""Records stats per process.
62
63
63
64
Accumulates the max ids observed per process per sparsecore per device for
@@ -67,9 +68,7 @@ def record(self, data: Mapping[str, Mapping[str, np.ndarray]]) -> None:
67
68
Args:
68
69
data: A mapping representing data to be recorded.
69
70
"""
70
- if _MAX_ID_STATS_KEY [1 :] not in data :
71
- raise ValueError (f'Expected stat ({ _MAX_ID_STATS_KEY [1 :]} ) not found.' )
72
- max_ids_per_process = data [_MAX_ID_STATS_KEY [1 :]]
71
+ max_ids_per_process = data .max_ids_per_partition
73
72
for table_name , stats in max_ids_per_process .items ():
74
73
logging .vlog (
75
74
2 , 'Recording observed max ids for table: %s -> %s' , table_name , stats
@@ -80,11 +79,7 @@ def record(self, data: Mapping[str, Mapping[str, np.ndarray]]) -> None:
80
79
self ._max_ids_per_partition [table_name ] = np .vstack (
81
80
(self ._max_ids_per_partition [table_name ], stats )
82
81
)
83
- if _MAX_UNIQUE_ID_STATS_KEY [1 :] not in data :
84
- raise ValueError (
85
- f'Expected stats ({ _MAX_UNIQUE_ID_STATS_KEY [1 :]} ) not found.'
86
- )
87
- max_uniques_per_process = data [_MAX_UNIQUE_ID_STATS_KEY [1 :]]
82
+ max_uniques_per_process = data .max_unique_ids_per_partition
88
83
for table_name , stats in max_uniques_per_process .items ():
89
84
logging .vlog (
90
85
2 ,
@@ -107,7 +102,7 @@ def _generate_file_name(self) -> str:
107
102
_FILE_NAME , jax .process_index (), time .time_ns (), _FILE_EXTENSION
108
103
)
109
104
return os .path .join (self ._base_dir , filename )
110
- # LINT.ThenChange(:_get_latest_files_by_process)
105
+ # LINT.ThenChange(:_get_latest_files_by_process)
111
106
112
107
def _get_latest_files_by_process (self , files : list [str ]) -> list [str ]:
113
108
"""Returns the latest file for each process."""
@@ -150,11 +145,11 @@ def publish(self) -> None:
150
145
processes.
151
146
"""
152
147
merged_stats = {
153
- f'{ table_name } { _MAX_ID_STATS_KEY } ' : stats
148
+ f'{ table_name } { _MAX_ID_STATS_SUFFIX } ' : stats
154
149
for table_name , stats in self ._max_ids_per_partition .items ()
155
150
}
156
151
merged_stats .update ({
157
- f'{ table_name } { _MAX_UNIQUE_ID_STATS_KEY } ' : stats
152
+ f'{ table_name } { _MAX_UNIQUE_ID_STATS_SUFFIX } ' : stats
158
153
for table_name , stats in self ._max_unique_ids_per_partition .items ()
159
154
})
160
155
self ._write_to_file (merged_stats )
@@ -197,16 +192,16 @@ def load(
197
192
stats = self ._read_from_file (files_glob )
198
193
max_id_stats , max_unique_id_stats = {}, {}
199
194
for table_name , stats in stats .items ():
200
- if table_name .endswith (f'{ _MAX_ID_STATS_KEY } ' ):
201
- max_id_stats [table_name [: - len (_MAX_ID_STATS_KEY )]] = stats
202
- elif table_name .endswith (f'{ _MAX_UNIQUE_ID_STATS_KEY } ' ):
203
- max_unique_id_stats [table_name [: - len (_MAX_UNIQUE_ID_STATS_KEY )]] = (
195
+ if table_name .endswith (f'{ _MAX_ID_STATS_SUFFIX } ' ):
196
+ max_id_stats [table_name [: - len (_MAX_ID_STATS_SUFFIX )]] = stats
197
+ elif table_name .endswith (f'{ _MAX_UNIQUE_ID_STATS_SUFFIX } ' ):
198
+ max_unique_id_stats [table_name [: - len (_MAX_UNIQUE_ID_STATS_SUFFIX )]] = (
204
199
stats
205
200
)
206
201
else :
207
202
raise ValueError (
208
203
f'Unexpected table name and stats key: { table_name } , expected to'
209
- f' end with { _MAX_ID_STATS_KEY } or { _MAX_UNIQUE_ID_STATS_KEY } '
204
+ f' end with { _MAX_ID_STATS_SUFFIX } or { _MAX_UNIQUE_ID_STATS_SUFFIX } '
210
205
)
211
206
self ._max_ids_per_partition = max_id_stats
212
207
self ._max_unique_ids_per_partition = max_unique_id_stats
0 commit comments