// Definiton of TreeNode in Java
/*
public class TreeNode {
    int val;
    TreeNode left;
    TreeNode right;
 
    public TreeNode(int val) {
        this.val = val; 
        left = null;
        right = null; 
    }
}
*/
 
public class InsertAndRemove {
 
    // Insert a new node and return the root of the BST.
    public TreeNode insert(TreeNode root, int val) {
        if (root == null) {
            return new TreeNode(val);
        }
 
        if (val > root.val) {
            root.right = insert(root.right, val);
        } else  if (val < root.val) {
            root.left = insert(root.left, val);
        }
        return root;
    }
 
    // Return the minimum value node of the BST.
    public TreeNode minValueNode(TreeNode root) {
        TreeNode curr = root;
        while(curr != null && curr.left != null) {
            curr = curr.left;
        }
        return curr;
    }
 
    // Remove a node and return the root of the BST.
    public TreeNode remove(TreeNode root, int val) {
        if (root == null) {
            return null;
        }
        if (val > root.val) {
            root.right = remove(root.right, val);
        } else if (val < root.val) {
            root.left = remove(root.left, val) ;
        } else {
            if (root.left == null) {
                return root.right;
            } else if (root.right == null) {
                return root.left;
            } else {
                TreeNode minNode = minValueNode(root.right);
                root.val = minNode.val;;
                root.right = remove(root.right, minNode.val);
            }
        }
        return root;
    }    
}