2021-12-23

how to apply custom encoders to multiple clients at once? how to use custom encoders in run_one_round?

So my goal is basically implementing global top-k subsampling. Gradient sparsification is quite simple and I have already done this building on stateful clients example, but now I would like to use encoders as you have recommended here at page 28. Additionally I would like to average only the non-zero gradients, so say we have 10 clients but only 4 have nonzero gradients at a given position for a communication round then I would like to divide the sum of these gradients to 4, not 10. I am hoping to achieve this by summing gradients at numerator and masks, 1s and 0s, at denominator. Also moving forward I will add randomness to gradient selection so it is imperative that I create those masks concurrently with gradient selection. The code I have right now is

import tensorflow as tf

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te


@te.core.tf_style_adaptive_encoding_stage
class GrandienrSparsificationEncodingStage(te.core.AdaptiveEncodingStageInterface):
  """An example custom implementation of an `EncodingStageInterface`.
  Note: This is likely not what one would want to use in practice. Rather, this
  serves as an illustration of how a custom compression algorithm can be
  provided to `tff`.
  This encoding stage is expected to be run in an iterative manner, and
  alternatively zeroes out values corresponding to odd and even indices. Given
  the determinism of the non-zero indices selection, the encoded structure does
  not need to be represented as a sparse vector, but only the non-zero values
  are necessary. In the decode mehtod, the state (i.e., params derived from the
  state) is used to reconstruct the corresponding indices.
  Thus, this example encoding stage can realize representation saving of 2x.
  """

  ENCODED_VALUES_KEY = 'stateful_topk_values'
  INDICES_KEY = 'indices'
  SHAPES_KEY = 'shapes'
  ERROR_COMPENSATION_KEY = 'error_compensation'

  def encode(self, x, encode_params):
    shapes_list = [tf.shape(y) for y in x]
    flattened = tf.nest.map_structure(lambda y: tf.reshape(y, [-1]), x)
    gradients = tf.concat(flattened, axis=0)
    error_compensation = encode_params[self.ERROR_COMPENSATION_KEY]
    
    gradients_and_error_compensation = tf.math.add(gradients, error_compensation)

    percentage = tf.constant(0.1, dtype=tf.float32)
    k_float = tf.multiply(percentage, tf.cast(tf.size(gradients_and_error_compensation), tf.float32))
    k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)

    values, indices = tf.math.top_k(tf.math.abs(gradients_and_error_compensation), k = k_int, sorted = False)
    indices = tf.expand_dims(indices, 1)
    sparse_gradients_and_error_compensation = tf.scatter_nd(indices, values, tf.shape(gradients_and_error_compensation))

    new_error_compensation = tf.math.subtract(gradients_and_error_compensation, sparse_gradients_and_error_compensation)
    state_update_tensors = {self.ERROR_COMPENSATION_KEY: new_error_compensation}
    
    encoded_x = {self.ENCODED_VALUES_KEY: values,
                 self.INDICES_KEY: indices,
                 self.SHAPES_KEY: shapes_list}

    return encoded_x, state_update_tensors

  def decode(self,
             encoded_tensors,
             decode_params,
             num_summands=None,
             shape=None):
    del num_summands, decode_params, shape  # Unused.
    flat_shape = tf.math.reduce_sum([tf.math.reduce_prod(shape) for shape in encoded_tensors[self.SHAPES_KEY]])
    sizes_list = [tf.math.reduce_prod(shape) for shape in encoded_tensors[self.SHAPES_KEY]]
    scatter_tensor = tf.scatter_nd(
        indices=encoded_tensors[self.INDICES_KEY],
        updates=encoded_tensors[self.ENCODED_VALUES_KEY],
        shape=[flat_shape])
    nonzero_locations = tf.nest.map_structure(lambda x: tf.cast(tf.where(tf.math.greater(x, 0), 1, 0), tf.float32) , scatter_tensor)
    reshaped_tensor = [tf.reshape(flat_tensor, shape=shape) for flat_tensor, shape in
            zip(tf.split(scatter_tensor, sizes_list), encoded_tensors[self.SHAPES_KEY])]
    reshaped_nonzero = [tf.reshape(flat_tensor, shape=shape) for flat_tensor, shape in
            zip(tf.split(nonzero_locations, sizes_list), encoded_tensors[self.SHAPES_KEY])]
    return  reshaped_tensor, reshaped_nonzero


  def initial_state(self):
    return {self.ERROR_COMPENSATION_KEY: tf.constant(0, dtype=tf.float32)}

  def update_state(self, state, state_update_tensors):
    return {self.ERROR_COMPENSATION_KEY: state_update_tensors[self.ERROR_COMPENSATION_KEY]}

  def get_params(self, state):
    encode_params = {self.ERROR_COMPENSATION_KEY: state[self.ERROR_COMPENSATION_KEY]}
    decode_params = {}
    return encode_params, decode_params

  @property
  def name(self):
    return 'gradient_sparsification_encoding_stage'

  @property
  def compressible_tensors_keys(self):
    return False

  @property
  def commutes_with_sum(self):
    return False

  @property
  def decode_needs_input_shape(self):
    return False

  @property
  def state_update_aggregation_modes(self):
    return {}

I have run some simple tests manually following the steps you outlined here at page 45. It works but I have some questions/problems.

  1. When I use list of tensors of same shape (ex:2 2x25 tensors) as input,x, of encode it works without any issues but when I try to use list of tensors of different shapes (2x20 and 6x10) it gives and error saying

InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [2,20] != values1.shape = [6,10] [Op:Pack] name: packed

How can I resolve this issue? As i said I want to use global top-k so it is essential I encode entire trainable model weights at once. Take the cnn model used here, all the tensors have different shapes.

  1. How can I do the averaging I described at the beginning? For example here you have done

mean_factory = tff.aggregators.MeanFactory( tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator )

Is there a way to repeat this with one output of decode going to numerator and other going to denominator? How can I handle dividing 0 by 0? tensorflow has divide_no_nan function, can I use it somehow or do I need to add eps to each?

  1. How is partition handled when I use encoders? Does each client get a unique encoder holding a unique state for it? As you have discussed here at page 6 client states are used in cross-silo settings yet what happens if client ordering changes?

  2. Here you have recommended using stateful clients example. Can you explain this a bit further? I mean in the run_one_round where exactly encoders go and how are they used/combined with client update and aggregation?

  3. I have some additional information such as sparsity I want to pass to encode. What is the suggested method for doing that?



from Recent Questions - Stack Overflow https://ift.tt/32j2800
https://ift.tt/eA8V8J

No comments:

Post a Comment