uncertainty
# train LeNet network with expected mean square error loss def LeNet_EDL(logits2evidence=relu_evidence,loss_function=mse_loss, lmb=0.005): g = tf.Graph() with g.as_default(): X = tf.placeholder(shape=[None,28*28], dtype=tf.float32) Y = tf.placeholder(shape=[None,10], dtype=tf.float32) keep_prob = tf.placeholder(dtype=tf.float32) global_step = tf.Variable(initial_value=0, name='global_step', trainable=False) annealing_step = tf.placeholder(dtype=tf.int32) # first hidden layer - conv W1 = var('W1', [5,5,1,20]) b1 = var('b1', [20]) out1 = max_pool(tf.nn.relu(conv(tf.reshape(X, [-1, 28,28, 1]), W1, strides=[1, 1, 1, 1]) + b1)) # second hidden layer - conv W2 = var('W2', [5,5,20,50]) b2 = var('b2', [50]) out2 = max_pool(tf.nn.relu(conv(out1, W2, strides=[1, 1, 1, 1]) + b2)) # flatten the output Xflat = tf.contrib.layers.flatten(out2) # third hidden layer - fully connected W3 = var('W3', [Xflat.get_shape()[1].value, 500]) b3 = var('b3', [500]) out3 = tf.nn.relu(tf.matmul(Xflat, W3) + b3) out3 = tf.nn.dropout(out3, keep_prob=keep_prob) #output layer W4 = var('W4', [500,10]) b4 = var('b4',[10]) logits = tf.matmul(out3, W4) + b4 evidence = logits2evidence(logits) alpha = evidence + 1 u = K / tf.reduce_sum(alpha, axis=1, keep_dims=True) #uncertainty prob = alpha/tf.reduce_sum(alpha, 1, keepdims=True) loss = tf.reduce_mean(loss_function(Y, alpha, global_step, annealing_step)) l2_loss = (tf.nn.l2_loss(W3)+tf.nn.l2_loss(W4)) * lmb step = tf.train.AdamOptimizer().minimize(loss + l2_loss, global_step=global_step) # Calculate accuracy pred = tf.argmax(logits, 1) truth = tf.argmax(Y, 1) match = tf.reshape(tf.cast(tf.equal(pred, truth), tf.float32),(-1,1)) acc = tf.reduce_mean(match) total_evidence = tf.reduce_sum(evidence,1, keepdims=True) mean_ev = tf.reduce_mean(total_evidence) mean_ev_succ = tf.reduce_sum(tf.reduce_sum(evidence,1, keepdims=True)*match) / tf.reduce_sum(match+1e-20) mean_ev_fail = tf.reduce_sum(tf.reduce_sum(evidence,1, keepdims=True)*(1-match)) / (tf.reduce_sum(tf.abs(1-match))+1e-20) return g, step, X, Y, annealing_step, keep_prob, prob, acc, loss, u, evidence, mean_ev, mean_ev_succ, mean_ev_fail