def _my_model_fn(features, labels, mode):
my_head = tf.estimator.MultiClassHead(n_classes=3)
logits = tf.keras.Model(...)(features)
def _my_model_fn(features, labels, mode):
my_head = tf.estimator.MultiClassHead(n_classes=3)
logits = tf.keras.Model(...)(features)
return my_head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
optimizer=tf.keras.optimizers.Adagrad(lr=0.1),
logits=logits)
my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)