package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingResult;
import com.knuddels.jtokkit.api.GptBytePairEncodingParams;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: classes5.dex */
final class GptBytePairEncoding implements Encoding {
    private final TokenEncoder<ImmutableByteArray, Integer> encoder;
    private final String name;
    private final Pattern pattern;
    private final TokenEncoder<String, Integer> specialTokensEncoder;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes5.dex */
    public static class PieceIndexToRank {
        private final int index;
        private int rank;

        public PieceIndexToRank(int i, int i2) {
            this.index = i;
            this.rank = i2;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GptBytePairEncoding(GptBytePairEncodingParams gptBytePairEncodingParams) {
        this.name = gptBytePairEncodingParams.getName();
        this.pattern = gptBytePairEncodingParams.getPattern();
        this.encoder = new TokenEncoder<>(gptBytePairEncodingParams.getEncoder(), new Function() { // from class: com.knuddels.jtokkit.GptBytePairEncoding$$ExternalSyntheticLambda0
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return ImmutableByteArray.from((byte[]) obj);
            }
        });
        this.specialTokensEncoder = new TokenEncoder<>(gptBytePairEncodingParams.getSpecialTokensEncoder());
    }

    private int addTokens(List<Integer> list, List<Integer> list2, Integer num) {
        if (num == null) {
            list.addAll(list2);
            return list2.size();
        }
        List<Integer> subList = list2.subList(0, Math.min(num.intValue() - list.size(), list2.size()));
        list.addAll(subList);
        return subList.size();
    }

    private List<Integer> bytePairMerge(ImmutableByteArray immutableByteArray) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < immutableByteArray.length() + 1; i2++) {
            arrayList.add(new PieceIndexToRank(i2, Integer.MAX_VALUE));
        }
        for (int i3 = 0; i3 < arrayList.size() - 2; i3++) {
            Optional<Integer> rank = getRank(immutableByteArray, arrayList, i3, 0);
            if (rank.isPresent()) {
                arrayList.get(i3).rank = rank.get().intValue();
            }
        }
        while (arrayList.size() > 1) {
            int i4 = Integer.MAX_VALUE;
            int i5 = 0;
            for (int i6 = 0; i6 < arrayList.size() - 1; i6++) {
                int i7 = arrayList.get(i6).rank;
                if (i7 < i4) {
                    i5 = i6;
                    i4 = i7;
                }
            }
            if (i4 == Integer.MAX_VALUE) {
                break;
            }
            arrayList.get(i5).rank = getRank(immutableByteArray, arrayList, i5, 1).orElse(Integer.MAX_VALUE).intValue();
            if (i5 > 0) {
                int i8 = i5 - 1;
                arrayList.get(i8).rank = getRank(immutableByteArray, arrayList, i8, 1).orElse(Integer.MAX_VALUE).intValue();
            }
            arrayList.remove(i5 + 1);
        }
        ArrayList arrayList2 = new ArrayList();
        while (i < arrayList.size() - 1) {
            TokenEncoder<ImmutableByteArray, Integer> tokenEncoder = this.encoder;
            int i9 = arrayList.get(i).index;
            i++;
            arrayList2.add(tokenEncoder.encode(immutableByteArray.getBytesBetween(i9, arrayList.get(i).index)));
        }
        return arrayList2;
    }

    private byte[] decodeToken(int i) {
        Optional<ImmutableByteArray> decodeIfPresent = this.encoder.decodeIfPresent(Integer.valueOf(i));
        if (decodeIfPresent.isPresent()) {
            return decodeIfPresent.get().getRawArray();
        }
        Optional<String> decodeIfPresent2 = this.specialTokensEncoder.decodeIfPresent(Integer.valueOf(i));
        if (decodeIfPresent2.isPresent()) {
            return decodeIfPresent2.get().getBytes(StandardCharsets.UTF_8);
        }
        throw new IllegalArgumentException("Unknown token for decoding: " + i);
    }

    private EncodingResult encodeInternal(String str, Integer num) {
        if (str == null) {
            return new EncodingResult(Collections.emptyList(), false);
        }
        Iterator<String> it = this.specialTokensEncoder.getDecodedTokens().iterator();
        while (it.hasNext()) {
            if (str.contains(it.next())) {
                throw new UnsupportedOperationException("Encoding special tokens is not supported yet.");
            }
        }
        return encodeOrdinaryInternal(str, num);
    }

    private EncodingResult encodeOrdinaryInternal(String str, Integer num) {
        if (str == null) {
            return new EncodingResult(Collections.emptyList(), false);
        }
        List<Integer> arrayList = new ArrayList<>();
        Matcher matcher = this.pattern.matcher(str);
        int i = 0;
        while (matcher.find() && maxTokenCountNotReached(num, i)) {
            ImmutableByteArray from = ImmutableByteArray.from(matcher.group());
            if (this.encoder.containsDecodedToken(from)) {
                arrayList.add(this.encoder.encode(from));
                i++;
            } else {
                i += addTokens(arrayList, bytePairMerge(from), num);
            }
        }
        if (num != null) {
            for (int i2 = 0; i2 <= arrayList.size(); i2++) {
                List<Integer> subList = arrayList.subList(0, arrayList.size() - i2);
                String decode = decode(subList);
                if (str.startsWith(decode)) {
                    return new EncodingResult(subList, str.length() > decode.length());
                }
            }
        }
        return new EncodingResult(arrayList, false);
    }

    private Optional<Integer> getRank(ImmutableByteArray immutableByteArray, List<PieceIndexToRank> list, int i, int i2) {
        int i3 = i2 + i + 2;
        if (i3 >= list.size()) {
            return Optional.empty();
        }
        return this.encoder.encodeIfPresent(immutableByteArray.getBytesBetween(list.get(i).index, list.get(i3).index));
    }

    private boolean maxTokenCountNotReached(Integer num, int i) {
        return !maxTokenCountReached(num, i);
    }

    private boolean maxTokenCountReached(Integer num, int i) {
        return num != null && num.compareTo(Integer.valueOf(i)) <= 0;
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public int countTokens(String str) {
        return encode(str).size();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public int countTokensOrdinary(String str) {
        return encodeOrdinary(str).size();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public String decode(List<Integer> list) {
        return new String(decodeBytes(list), StandardCharsets.UTF_8);
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public byte[] decodeBytes(List<Integer> list) {
        int i;
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = list.iterator();
        while (true) {
            i = 0;
            if (!it.hasNext()) {
                break;
            }
            byte[] decodeToken = decodeToken(it.next().intValue());
            int length = decodeToken.length;
            while (i < length) {
                arrayList.add(Byte.valueOf(decodeToken[i]));
                i++;
            }
        }
        byte[] bArr = new byte[arrayList.size()];
        while (i < arrayList.size()) {
            bArr[i] = ((Byte) arrayList.get(i)).byteValue();
            i++;
        }
        return bArr;
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public EncodingResult encode(String str, int i) {
        return encodeInternal(str, Integer.valueOf(i));
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public List<Integer> encode(String str) {
        return encodeInternal(str, null).getTokens();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public EncodingResult encodeOrdinary(String str, int i) {
        return encodeOrdinaryInternal(str, Integer.valueOf(i));
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public List<Integer> encodeOrdinary(String str) {
        return encodeOrdinaryInternal(str, null).getTokens();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public String getName() {
        return this.name;
    }
}
