// Definiton of TreeNode in Java
/*
public class TreeNode {
int val;
TreeNode left;
TreeNode right;
public TreeNode(int val) {
this.val = val;
left = null;
right = null;
}
}
*/
public class InsertAndRemove {
// Insert a new node and return the root of the BST.
public TreeNode insert(TreeNode root, int val) {
if (root == null) {
return new TreeNode(val);
}
if (val > root.val) {
root.right = insert(root.right, val);
} else if (val < root.val) {
root.left = insert(root.left, val);
}
return root;
}
// Return the minimum value node of the BST.
public TreeNode minValueNode(TreeNode root) {
TreeNode curr = root;
while(curr != null && curr.left != null) {
curr = curr.left;
}
return curr;
}
// Remove a node and return the root of the BST.
public TreeNode remove(TreeNode root, int val) {
if (root == null) {
return null;
}
if (val > root.val) {
root.right = remove(root.right, val);
} else if (val < root.val) {
root.left = remove(root.left, val) ;
} else {
if (root.left == null) {
return root.right;
} else if (root.right == null) {
return root.left;
} else {
TreeNode minNode = minValueNode(root.right);
root.val = minNode.val;;
root.right = remove(root.right, minNode.val);
}
}
return root;
}
}