Skip to content

DistributedGradientBoostedTreesModel does not support Ranking task #209

Open
@JackGammack

Description

@JackGammack

The documentation shows that you can use the ranking task for this model, but there is no warning or failure until training time. This error message is not clear that the ranking task is actually not available for this model, and I couldn't find any documentation indicating this.

Are there plans to add support for distributed ranking models? I figure there may be limitations related to examples from the same ranking_group ending up on different workers when the ndcg needs to be calculated.

Minimal example

strategy = tf.distribute.experimental.ParameterServerStrategy(...)

with strategy.scope():
        model = tfdf.keras.DistributedGradientBoostedTreesModel(
            task=tfdf.keras.Task.RANKING,
            ranking_group="group",
        )

model.fit_on_dataset_path(
        train_path=train_input_pattern,
        label_key="label",
        weight_key="sample_weight",
        dataset_format="tfrecord+tfe",
)

Error message below. Changing the task to regression makes the model train successfully.

File "/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core.py", line 1942, in fit_on_dataset_path
    tf_core.train_on_file_dataset(
  File "/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/tensorflow/core.py", line 779, in train_on_file_dataset
    training_op.SimpleMLCheckStatus(process_id=process_id) == 1
  File "/opt/conda/lib/python3.10/site-packages/tensorflow/python/util/tf_export.py", line 403, in wrapper
    return f(**kwargs)
  File "<string>", line 1373, in simple_ml_check_status
  File "/opt/conda/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 5883, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.UnknownError: {{function_node __wrapped__SimpleMLCheckStatus_device_/job:chief/replica:0/task:0/device:CPU:0}} TensorFlow: INVALID_ARGUMENT: Worker #0: INVALID_ARGUMENT: Not supported task [Op:SimpleMLCheckStatus] name:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions