java实现一个自己的ArrayList和LinkedList

前言

java中的ArrayList和LinkedList都是我们很常用的数据结构,了解它们的内部实现原理可以让我们更好的使用它们。

代码实现

ArrayList

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Objects;
import java.util.function.Predicate;

/**
 * 实现一个自己的ArrayList
 *
 * @param <E> 元素类型
 */
public class MyArrayList<E> implements List<E> {

  /**
   * 数据容器
   */
  private Object[] data;
  /**
   * 实际容量
   */
  private int size;

  public MyArrayList() {
    this(10);
  }

  public MyArrayList(int capacity) {
    data = new Object[capacity];
  }

  /**
   * 将一个元素添加到指定索引
   *
   * @param index 索引
   * @param e 元素
   */
  @Override
  public void add(int index, E e) {
    rangeCheckForAdd(index);
    int oldCapacity = data.length;
    if (oldCapacity == size) {
      resize(oldCapacity + (oldCapacity >> 1));
    }
    System.arraycopy(data, index, data, index + 1, size - index);
    data[index] = e;
    size++;
  }

  /**
   * 根据索引删除一个元素
   *
   * @param index 索引
   * @return 删除的元素
   */
  @Override
  public E remove(int index) {
    Objects.checkIndex(index, size);
    E oldValue = elementData(index);
    fastRemove(index);
    return oldValue;
  }

  /**
   * 查询元素在容器中索引(从前往后)
   *
   * @param o 元素
   * @return 索引 不存在返回-1
   */
  @Override
  public int indexOf(Object o) {
    for (int i = 0; i < size; i++) {
      if (Objects.equals(data[i], o)) {
        return i;
      }
    }
    return -1;
  }

  /**
   * 查询元素在容器中索引(总后往前)
   *
   * @param o 元素
   * @return 索引
   */
  @Override
  public int lastIndexOf(Object o) {
    for (int i = size - 1; i >= 0; i--) {
      if (Objects.equals(data[i], o)) {
        return i;
      }
    }
    return -1;
  }

  /**
   * 创建迭代器
   *
   * @return 迭代器
   */
  @Override
  public ListIterator<E> listIterator() {
    return new MyArrayListListIterator(0);
  }

  /**
   * 创建迭代器
   *
   * @param index 开始索引
   * @return 迭代器
   */
  @Override
  public ListIterator<E> listIterator(int index) {
    return new MyArrayListListIterator(index);
  }

  /**
   * 创建一个容器的视图
   *
   * @param fromIndex 开始索引
   * @param toIndex 结束索引
   * @return 容器视图
   */
  @Override
  public List<E> subList(int fromIndex, int toIndex) {
    subListRangeCheck(fromIndex, toIndex, size);
    List<E> subList = new MyArrayList<>();
    for (int i = fromIndex; i < toIndex; i++) {
      subList.add(elementData(i));
    }
    return subList;
  }

  /**
   * 添加元素
   *
   * @param e 元素
   * @return 添加是否成功
   */
  @Override
  public boolean add(E e) {
    add(size, e);
    return true;
  }

  /**
   * 删除元素
   *
   * @param o 元素
   * @return 是否成功
   */
  @Override
  public boolean remove(Object o) {
    int index = indexOf(o);
    if (index > -1) {
      fastRemove(index);
      return true;
    }
    return false;
  }

  /**
   * 是否包含指定容器中的所有元素
   *
   * @param c 容器
   * @return 是否包含
   */
  @Override
  public boolean containsAll(Collection<?> c) {
    for (Object e : c) {
      if (!contains(e)) {
        return false;
      }
    }
    return true;
  }

  /**
   * 将指定容器中元素全部添加到该容器中
   *
   * @param c 容器
   * @return 是否成功
   */
  @Override
  public boolean addAll(Collection<? extends E> c) {
    return addAll(size, c);
  }

  /**
   * 将指定容器元素添加到指定索引
   *
   * @param index 索引
   * @param c 容器
   * @return 是否成功
   */
  @Override
  public boolean addAll(int index, Collection<? extends E> c) {
    rangeCheckForAdd(index);
    int oldCapacity = data.length;
    int newSize = size + c.size();
    if (newSize > oldCapacity) {
      int newCapacity = oldCapacity;
      while (newSize > newCapacity) {
        newCapacity = newCapacity + (newCapacity >> 1);
      }
      resize(newCapacity);
    }
    System.arraycopy(data, index, data, index + c.size(), size - index);
    System.arraycopy(c.toArray(), 0, data, index, c.size());
    size += c.size();
    return true;
  }

  /**
   * 删除指定容器中的所有元素
   *
   * @param c 容器
   * @return 是否成功
   */
  @Override
  public boolean removeAll(Collection<?> c) {
    batchRemove(item -> !c.contains(item));
    return true;
  }

  /**
   * 保留指定容器中的所有元素,其余的删除
   *
   * @param c 容器
   * @return 是否成功
   */
  @Override
  public boolean retainAll(Collection<?> c) {
    batchRemove(c::contains);
    return true;
  }

  /**
   * 删除满足指定条件的元素
   *
   * @param filter 删除条件
   * @return 是否成功
   */
  @Override
  public boolean removeIf(Predicate<? super E> filter) {
    batchRemove(filter.negate());
    return true;
  }

  /**
   * 清空容器
   */
  @Override
  public void clear() {
    for (int i = 0; i < size; i++) {
      data[i] = null;
    }
    size = 0;
  }

  /**
   * 修改指定索引的元素
   *
   * @param index 索引
   * @param e 元素
   * @return 原来的元素
   */
  @Override
  public E set(int index, E e) {
    Objects.checkIndex(index, size);
    E oldValue = elementData(index);
    data[index] = e;
    return oldValue;
  }

  /**
   * 获取执行索引的元素
   *
   * @param index 索引
   * @return 元素
   */
  @Override
  public E get(int index) {
    Objects.checkIndex(index, size);
    return elementData(index);
  }

  /**
   * 查询容器容量
   *
   * @return 容量
   */
  @Override
  public int size() {
    return size;
  }

  /**
   * 容器是否为空
   *
   * @return 是否为空
   */
  @Override
  public boolean isEmpty() {
    return size == 0;
  }

  /**
   * 容器是否包含指定元素
   *
   * @param o 元素
   * @return 是否包含
   */
  @Override
  public boolean contains(Object o) {
    return indexOf(o) >= 0;
  }

  /**
   * 创建迭代器
   */
  @Override
  public Iterator<E> iterator() {
    return new MyArrayListIterator();
  }

  /**
   * 将容器转换成数组
   *
   * @return 数组
   */
  @Override
  public Object[] toArray() {
    return Arrays.copyOf(data, size);
  }

  /**
   * 将容器转换成指定类型的数组
   *
   * @param a 指定数组
   * @param <T> 数组元素类型
   */
  @Override
  public <T> T[] toArray(T[] a) {
    if (a.length < size) {
      return (T[]) Arrays.copyOf(data, size, a.getClass());
    }
    System.arraycopy(data, 0, a, 0, size);
    if (a.length > size) {
      a[size] = null;
    }
    return a;
  }

  @Override
  public String toString() {
    return Arrays.toString(toArray());
  }

  private void fastRemove(int index) {
    System.arraycopy(data, index + 1, data, index, size - index - 1);
    data[size] = null;
    size--;
  }

  private void batchRemove(Predicate<? super E> filter) {
    int low = 0;
    int high = 0;
    for (; high < size; high++) {
      if (filter.test(elementData(high))) {
        data[low++] = data[high];
      }
    }
    for (int i = low; i < high; i++) {
      data[i] = null;
    }
    size -= high - low;
  }

  private void resize(int newCapacity) {
    Object[] newData = new Object[newCapacity];
    System.arraycopy(data, 0, newData, 0, size);
    data = newData;
  }

  private E elementData(int index) {
    return (E) data[index];
  }

  private void rangeCheckForAdd(int index) {
    if (index > size || index < 0) {
      throw new IndexOutOfBoundsException(outOfBoundsMsg(index));
    }
  }

  private void subListRangeCheck(int fromIndex, int toIndex, int size) {
    if (fromIndex < 0) {
      throw new IndexOutOfBoundsException("fromIndex = " + fromIndex);
    }
    if (toIndex > size) {
      throw new IndexOutOfBoundsException("toIndex = " + toIndex);
    }
    if (fromIndex > toIndex) {
      throw new IllegalArgumentException("fromIndex(" + fromIndex +
          ") > toIndex(" + toIndex + ")");
    }
  }

  private String outOfBoundsMsg(int index) {
    return "Index: " + index + ", Size: " + size;
  }

  private class MyArrayListIterator implements Iterator<E> {

    int cursor;

    @Override
    public boolean hasNext() {
      return cursor != size;
    }

    @Override
    public E next() {
      return elementData(cursor++);
    }

    @Override
    public void remove() {
      MyArrayList.this.remove(cursor);
    }
  }

  private class MyArrayListListIterator extends MyArrayListIterator implements ListIterator<E> {

    MyArrayListListIterator(int index) {
      super();
      cursor = index;
    }

    @Override
    public boolean hasPrevious() {
      return cursor != 0;
    }

    @Override
    public E previous() {
      return elementData(--cursor);
    }

    @Override
    public int nextIndex() {
      return cursor;
    }

    @Override
    public int previousIndex() {
      return cursor - 1;
    }

    @Override
    public void set(E e) {
      MyArrayList.this.set(cursor, e);
    }

    @Override
    public void add(E e) {
      MyArrayList.this.add(cursor, e);
    }
  }
}

LinkedList

/**
 * 实现一个自己的LinkedList
 *
 * @param <E> 元素类型
 */
public class MyLinkedList<E> implements List<E> {

  /**
   * 虚拟头结点 实际头结点从下一个开始
   */
  private Node<E> dummyHead;
  /**
   * 尾节点
   */
  private Node<E> tail;
  /**
   * 实际容量
   */
  private int size;

  public MyLinkedList() {
    dummyHead = new Node<>(null, null, null);
  }

  @Override
  public void add(int index, E e) {
    rangeCheckForAdd(index);
    Node<E> prev = dummyHead;
    for (int i = 0; i < index; i++) {
      prev = prev.next;
    }
    Node<E> next = prev.next;
    Node<E> newNode = new Node<>(e, prev, next);
    prev.next = newNode;
    if (index == size) {
      tail = prev.next;
    } else {
      next.prev = newNode;
    }
    size++;
  }

  @Override
  public E remove(int index) {
    Objects.checkIndex(index, size);
    Node<E> node = node(index);
    E data = node.data;
    fastRemove(node);
    return data;
  }

  private void fastRemove(Node<E> node) {
    if (node == tail) {
      tail = node.prev;
    }
    node.prev.next = node.next;
    node.data = null;
    node.prev = null;
    node.next = null;
    size--;
  }

  @Override
  public int indexOf(Object o) {
    Node<E> cur = dummyHead.next;
    for (int index = 0; cur != null; index++, cur = cur.next) {
      if (Objects.equals(cur.data, o)) {
        return index;
      }
    }
    return -1;
  }

  @Override
  public int lastIndexOf(Object o) {
    Node<E> cur = tail;
    for (int index = size - 1; cur != null; index--, cur = cur.prev) {
      if (Objects.equals(cur.data, o)) {
        return index;
      }
    }
    return -1;
  }

  @Override
  public ListIterator<E> listIterator() {
    return listIterator(0);
  }

  @Override
  public ListIterator<E> listIterator(int index) {
    return new MyLinkedListListIterator(index);
  }

  @Override
  public List<E> subList(int fromIndex, int toIndex) {
    subListRangeCheck(fromIndex, toIndex, size);
    List<E> subList = new MyArrayList<>();
    for (Node<E> cur = node(fromIndex); fromIndex < toIndex; fromIndex++) {
      subList.add(cur.data);
      cur = cur.next;
    }
    return subList;
  }

  @Override
  public boolean add(E e) {
    add(size, e);
    return true;
  }

  @Override
  public boolean remove(Object o) {
    Node<E> cur = dummyHead.next;
    while (cur != null) {
      if (Objects.equals(cur.data, o)) {
        fastRemove(cur);
        return true;
      }
      cur = cur.next;
    }
    return false;
  }

  @Override
  public boolean containsAll(Collection<?> c) {
    for (Object e : c) {
      if (!contains(e)) {
        return false;
      }
    }
    return true;
  }

  @Override
  public boolean addAll(Collection<? extends E> c) {
    return addAll(size, c);
  }

  @Override
  public boolean addAll(int index, Collection<? extends E> c) {
    rangeCheckForAdd(index);
    Object[] objects = c.toArray();
    Node<E> prev = dummyHead;
    for (int i = 0; i < index; i++) {
      prev = prev.next;
    }
    Node<E> succ = prev.next;
    for (Object object : objects) {
      prev.next = new Node<>((E) object, prev, null);
      prev = prev.next;
    }
    prev.next = succ;
    if (index == size) {
      tail = prev.next;
    } else {
      succ.prev = prev;
    }
    size += c.size();
    return true;
  }

  @Override
  public boolean removeAll(Collection<?> c) {
    removeIf(c::contains);
    return true;
  }

  @Override
  public boolean retainAll(Collection<?> c) {
    removeIf(o -> !c.contains(o));
    return true;
  }

  @Override
  public void clear() {
    Node<E> cur = dummyHead;
    while (cur != null) {
      Node<E> next = cur.next;
      cur.data = null;
      cur.prev = null;
      cur.next = null;
      cur = next;
    }
    size = 0;
  }

  public E set(int index, E e) {
    Objects.checkIndex(index, size);
    Node<E> node = node(index);
    E oldValue = node.data;
    node.data = e;
    return oldValue;
  }

  public E get(int index) {
    Objects.checkIndex(index, size);
    return node(index).data;
  }

  public int size() {
    return size;
  }

  public boolean isEmpty() {
    return size == 0;
  }

  @Override
  public boolean contains(Object o) {
    return indexOf(o) >= 0;
  }

  @Override
  public Iterator<E> iterator() {
    return listIterator();
  }

  @Override
  public Object[] toArray() {
    Object[] res = new Object[size];
    Node<E> cur = dummyHead.next;
    for (int index = 0; index < size; index++) {
      res[index] = cur.data;
      cur = cur.next;
    }
    return res;
  }

  @Override
  public <T> T[] toArray(T[] a) {
    if (a.length < size) {
      a = (T[]) Array.newInstance(a.getClass().getComponentType(), size);
    }
    Object[] res = a;
    Node<E> cur = dummyHead.next;
    for (int index = 0; index < size; index++) {
      res[index] = cur.data;
      cur = cur.next;
    }
    if (a.length > size) {
      a[size] = null;
    }
    return a;
  }

  @Override
  public String toString() {
    return Arrays.toString(toArray());
  }

  private Node<E> node(int index) {
    if (index <= (size >> 1)) {
      Node<E> cur = dummyHead.next;
      for (int i = 0; i < index; i++) {
        cur = cur.next;
      }
      return cur;
    } else {
      Node<E> cur = tail;
      for (int i = size - 1; i > index; i--) {
        cur = cur.prev;
      }
      return cur;
    }
  }

  private String outOfBoundsMsg(int index) {
    return "Index: " + index + ", Size: " + size;
  }

  private void rangeCheckForAdd(int index) {
    if (index > size || index < 0) {
      throw new IndexOutOfBoundsException(outOfBoundsMsg(index));
    }
  }

  private void subListRangeCheck(int fromIndex, int toIndex, int size) {
    if (fromIndex < 0) {
      throw new IndexOutOfBoundsException("fromIndex = " + fromIndex);
    }
    if (toIndex > size) {
      throw new IndexOutOfBoundsException("toIndex = " + toIndex);
    }
    if (fromIndex > toIndex) {
      throw new IllegalArgumentException("fromIndex(" + fromIndex +
          ") > toIndex(" + toIndex + ")");
    }
  }

  private static class Node<E> {

    E data;
    Node<E> prev;
    Node<E> next;

    Node(E data, Node<E> prev, Node<E> next) {
      this.data = data;
      this.prev = prev;
      this.next = next;
    }
  }

  private class MyLinkedListListIterator implements ListIterator<E> {

    private int cursor;
    Node<E> cur;

    MyLinkedListListIterator(int index) {
      super();
      cursor = index;
      cur = node(index);
    }

    @Override
    public boolean hasNext() {
      return cur != null;
    }

    @Override
    public E next() {
      E data = cur.data;
      cur = cur.next;
      cursor++;
      return data;
    }

    @Override
    public void remove() {
      MyLinkedList.this.remove(cur);
    }

    @Override
    public boolean hasPrevious() {
      return cursor != 0;
    }

    @Override
    public E previous() {
      E data = cur.data;
      cur = cur.prev;
      cursor--;
      return data;
    }

    @Override
    public int nextIndex() {
      return cursor;
    }

    @Override
    public int previousIndex() {
      return cursor - 1;
    }

    @Override
    public void set(E e) {
      MyLinkedList.this.set(cursor, e);
    }

    @Override
    public void add(E e) {
      MyLinkedList.this.add(cursor, e);
    }
  }
}

总结

本实现参考了jdk的ArrayList和LinkedList的实现,主要实现了增删改查,扩容更功能,待完善的地方有

  • subList(),现在的实现和List接口要求不符
  • 迭代器不支持快速失败
    查看源码可以让我们使用工具更加得心应手。
原文地址:https://www.cnblogs.com/strongmore/p/14195487.html