Full training set used by dask_lightgbm?

173 views Asked by At

I'm reading over the implementation of the dask-lightgbm estimators (specifically, the _train_part function in dask_lightgb.core.py), and I'm failing to see how the entirety of the training set gets used to fit the final estimator?

The _train_part function accepts the boolean argument return_model, and in the implementation of the train function (which uses client.submit to call _train_part on each worker), return_model is only true when the worker is the "master_worker" (which itself appears to be a randomly chosen Dask worker). Logically, each worker gets dispatched 1/n chunks of the overall model training set - where n = total number of workers - then each worker trains its own independent model on its own subset of the training set. The return_model parameter controls whether each worker's model gets returned by _train_part, so it returns None for all workers - and therefore, models - except for one worker.

Code:

def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400,
                time_out=120, **kwargs):

    network_params = build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out)
    params.update(network_params)

    # Concatenate many parts into one
    parts = tuple(zip(*list_of_parts))
    data = concat(parts[0])
    label = concat(parts[1])
    weight = concat(parts[2]) if len(parts) == 3 else None

    try:
        model = model_factory(**params)
        model.fit(data, label, sample_weight=weight)
    finally:
        _safe_call(_LIB.LGBM_NetworkFree())

    return model if return_model else None

Is this not equivalent to training a non-distributed version of a lightgbm estimator on a 1/n subsample of the training set? Am I missing something? I feel like I am missing a part where either the workers' independent models get combined into one, or where a single estimator is getting updated with the individual trees learned by separate workers.

Thank you!

1

There are 1 answers

0
Frank Fineis On

Ah the answer is yes - dask_lightgbm uses all available training samples. Dask's responsibility is only to distribute data across workers. LightGBM handles all distributed learning once its network parameters are set. It's not that each worker is training its own independent model - LightGBM is training a single model - but each worker will get a copy of it. For this reason, only the chosen worker returns the fitted estimator, and everyone else returns None.