Refactor event phase sorting system for use with dynamic registries ()

This commit is contained in:
Technici4n 2023-07-07 18:02:34 +02:00 committed by GitHub
parent 0eb9cc1769
commit bb7c8b8790
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 325 additions and 192 deletions
fabric-api-base/src
main/java/net/fabricmc/fabric/impl/base
testmod/java/net/fabricmc/fabric/test/base
fabric-registry-sync-v0/src/main/java/net/fabricmc/fabric/impl/registry/sync

View file

@ -18,22 +18,19 @@ package net.fabricmc.fabric.impl.base.event;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import net.minecraft.util.Identifier;
import net.fabricmc.fabric.api.event.Event;
import net.fabricmc.fabric.impl.base.toposort.NodeSorting;
class ArrayBackedEvent<T> extends Event<T> {
static final Logger LOGGER = LoggerFactory.getLogger("fabric-api-base");
private final Function<T[], T> invokerFactory;
private final Object lock = new Object();
private T[] handlers;
@ -82,7 +79,7 @@ class ArrayBackedEvent<T> extends Event<T> {
sortedPhases.add(phase);
if (sortIfCreate) {
PhaseSorting.sortPhases(sortedPhases);
NodeSorting.sort(sortedPhases, "event phases", Comparator.comparing(data -> data.id));
}
}
@ -121,9 +118,9 @@ class ArrayBackedEvent<T> extends Event<T> {
synchronized (lock) {
EventPhaseData<T> first = getOrCreatePhase(firstPhase, false);
EventPhaseData<T> second = getOrCreatePhase(secondPhase, false);
first.subsequentPhases.add(second);
second.previousPhases.add(first);
PhaseSorting.sortPhases(this.sortedPhases);
first.subsequentNodes.add(second);
second.previousNodes.add(first);
NodeSorting.sort(this.sortedPhases, "event phases", Comparator.comparing(data -> data.id));
rebuildInvoker(handlers.length);
}
}

View file

@ -17,21 +17,18 @@
package net.fabricmc.fabric.impl.base.event;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import net.minecraft.util.Identifier;
import net.fabricmc.fabric.impl.base.toposort.SortableNode;
/**
* Data of an {@link ArrayBackedEvent} phase.
*/
class EventPhaseData<T> {
class EventPhaseData<T> extends SortableNode<EventPhaseData<T>> {
final Identifier id;
T[] listeners;
final List<EventPhaseData<T>> subsequentPhases = new ArrayList<>();
final List<EventPhaseData<T>> previousPhases = new ArrayList<>();
int visitStatus = 0; // 0: not visited, 1: visiting, 2: visited
@SuppressWarnings("unchecked")
EventPhaseData(Identifier id, Class<?> listenerClass) {
@ -44,4 +41,9 @@ class EventPhaseData<T> {
listeners = Arrays.copyOf(listeners, oldLength + 1);
listeners[oldLength] = listener;
}
@Override
protected String getDescription() {
return id.toString();
}
}

View file

@ -1,162 +0,0 @@
/*
* Copyright (c) 2016, 2017, 2018, 2019 FabricMC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.fabricmc.fabric.impl.base.event;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import com.google.common.annotations.VisibleForTesting;
/**
* Contains phase-sorting logic for {@link ArrayBackedEvent}.
*/
public class PhaseSorting {
@VisibleForTesting
public static boolean ENABLE_CYCLE_WARNING = true;
/**
* Deterministically sort a list of phases.
* 1) Compute phase SCCs (i.e. cycles).
* 2) Sort phases by id within SCCs.
* 3) Sort SCCs with respect to each other by respecting constraints, and by id in case of a tie.
*/
static <T> void sortPhases(List<EventPhaseData<T>> sortedPhases) {
// FIRST KOSARAJU SCC VISIT
List<EventPhaseData<T>> toposort = new ArrayList<>(sortedPhases.size());
for (EventPhaseData<T> phase : sortedPhases) {
forwardVisit(phase, null, toposort);
}
clearStatus(toposort);
Collections.reverse(toposort);
// SECOND KOSARAJU SCC VISIT
Map<EventPhaseData<T>, PhaseScc<T>> phaseToScc = new IdentityHashMap<>();
for (EventPhaseData<T> phase : toposort) {
if (phase.visitStatus == 0) {
List<EventPhaseData<T>> sccPhases = new ArrayList<>();
// Collect phases in SCC.
backwardVisit(phase, sccPhases);
// Sort phases by id.
sccPhases.sort(Comparator.comparing(p -> p.id));
// Mark phases as belonging to this SCC.
PhaseScc<T> scc = new PhaseScc<>(sccPhases);
for (EventPhaseData<T> phaseInScc : sccPhases) {
phaseToScc.put(phaseInScc, scc);
}
}
}
clearStatus(toposort);
// Build SCC graph
for (PhaseScc<T> scc : phaseToScc.values()) {
for (EventPhaseData<T> phase : scc.phases) {
for (EventPhaseData<T> subsequentPhase : phase.subsequentPhases) {
PhaseScc<T> subsequentScc = phaseToScc.get(subsequentPhase);
if (subsequentScc != scc) {
scc.subsequentSccs.add(subsequentScc);
subsequentScc.inDegree++;
}
}
}
}
// Order SCCs according to priorities. When there is a choice, use the SCC with the lowest id.
// The priority queue contains all SCCs that currently have 0 in-degree.
PriorityQueue<PhaseScc<T>> pq = new PriorityQueue<>(Comparator.comparing(scc -> scc.phases.get(0).id));
sortedPhases.clear();
for (PhaseScc<T> scc : phaseToScc.values()) {
if (scc.inDegree == 0) {
pq.add(scc);
// Prevent adding the same SCC multiple times, as phaseToScc may contain the same value multiple times.
scc.inDegree = -1;
}
}
while (!pq.isEmpty()) {
PhaseScc<T> scc = pq.poll();
sortedPhases.addAll(scc.phases);
for (PhaseScc<T> subsequentScc : scc.subsequentSccs) {
subsequentScc.inDegree--;
if (subsequentScc.inDegree == 0) {
pq.add(subsequentScc);
}
}
}
}
private static <T> void forwardVisit(EventPhaseData<T> phase, EventPhaseData<T> parent, List<EventPhaseData<T>> toposort) {
if (phase.visitStatus == 0) {
// Not yet visited.
phase.visitStatus = 1;
for (EventPhaseData<T> data : phase.subsequentPhases) {
forwardVisit(data, phase, toposort);
}
toposort.add(phase);
phase.visitStatus = 2;
} else if (phase.visitStatus == 1 && ENABLE_CYCLE_WARNING) {
// Already visiting, so we have found a cycle.
ArrayBackedEvent.LOGGER.warn(String.format(
"Event phase ordering conflict detected.%nEvent phase %s is ordered both before and after event phase %s.",
phase.id,
parent.id
));
}
}
private static <T> void clearStatus(List<EventPhaseData<T>> phases) {
for (EventPhaseData<T> phase : phases) {
phase.visitStatus = 0;
}
}
private static <T> void backwardVisit(EventPhaseData<T> phase, List<EventPhaseData<T>> sccPhases) {
if (phase.visitStatus == 0) {
phase.visitStatus = 1;
sccPhases.add(phase);
for (EventPhaseData<T> data : phase.previousPhases) {
backwardVisit(data, sccPhases);
}
}
}
private static class PhaseScc<T> {
final List<EventPhaseData<T>> phases;
final List<PhaseScc<T>> subsequentSccs = new ArrayList<>();
int inDegree = 0;
private PhaseScc(List<EventPhaseData<T>> phases) {
this.phases = phases;
}
}
}

View file

@ -0,0 +1,181 @@
/*
* Copyright (c) 2016, 2017, 2018, 2019 FabricMC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.fabricmc.fabric.impl.base.toposort;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Contains a topological sort implementation, with tie breaking using a {@link Comparator}.
*
* <p>The final order is always deterministic (i.e. doesn't change with the order of the input elements or the edges),
* assuming that they are all different according to the comparator. This also holds in the presence of cycles.
*
* <p>The steps are as follows:
* <ol>
* <li>Compute node SCCs (Strongly Connected Components, i.e. cycles).</li>
* <li>Sort nodes within SCCs using the comparator.</li>
* <li>Sort SCCs with respect to each other by respecting constraints, and using the comparator in case of a tie.</li>
* </ol>
*/
public class NodeSorting {
private static final Logger LOGGER = LoggerFactory.getLogger("fabric-api-base");
@VisibleForTesting
public static boolean ENABLE_CYCLE_WARNING = true;
/**
* Sort a list of nodes.
*
* @param sortedPhases The list of nodes to sort. Will be modified in-place.
* @param elementDescription A description of the elements, used for logging in the presence of cycles.
* @param comparator The comparator to break ties and to order elements within a cycle.
*/
public static <N extends SortableNode<N>> void sort(List<N> sortedPhases, String elementDescription, Comparator<N> comparator) {
// FIRST KOSARAJU SCC VISIT
List<N> toposort = new ArrayList<>(sortedPhases.size());
for (N phase : sortedPhases) {
forwardVisit(phase, null, toposort);
}
clearStatus(toposort);
Collections.reverse(toposort);
// SECOND KOSARAJU SCC VISIT
Map<N, PhaseScc<N>> phaseToScc = new IdentityHashMap<>();
for (N phase : toposort) {
if (!phase.visited) {
List<N> sccPhases = new ArrayList<>();
// Collect phases in SCC.
backwardVisit(phase, sccPhases);
// Sort phases by id.
sccPhases.sort(comparator);
// Mark phases as belonging to this SCC.
PhaseScc<N> scc = new PhaseScc<>(sccPhases);
for (N phaseInScc : sccPhases) {
phaseToScc.put(phaseInScc, scc);
}
}
}
clearStatus(toposort);
// Build SCC graph
for (PhaseScc<N> scc : phaseToScc.values()) {
for (N phase : scc.phases) {
for (N subsequentPhase : phase.subsequentNodes) {
PhaseScc<N> subsequentScc = phaseToScc.get(subsequentPhase);
if (subsequentScc != scc) {
scc.subsequentSccs.add(subsequentScc);
subsequentScc.inDegree++;
}
}
}
}
// Order SCCs according to priorities. When there is a choice, use the SCC with the lowest id.
// The priority queue contains all SCCs that currently have 0 in-degree.
PriorityQueue<PhaseScc<N>> pq = new PriorityQueue<>(Comparator.comparing(scc -> scc.phases.get(0), comparator));
sortedPhases.clear();
for (PhaseScc<N> scc : phaseToScc.values()) {
if (scc.inDegree == 0) {
pq.add(scc);
// Prevent adding the same SCC multiple times, as phaseToScc may contain the same value multiple times.
scc.inDegree = -1;
}
}
while (!pq.isEmpty()) {
PhaseScc<N> scc = pq.poll();
sortedPhases.addAll(scc.phases);
// Print cycle warning
if (ENABLE_CYCLE_WARNING && scc.phases.size() > 1) {
StringBuilder builder = new StringBuilder();
builder.append("Found cycle while sorting ").append(elementDescription).append(":\n");
for (N phase : scc.phases) {
builder.append("\t").append(phase.getDescription()).append("\n");
}
LOGGER.warn(builder.toString());
}
for (PhaseScc<N> subsequentScc : scc.subsequentSccs) {
subsequentScc.inDegree--;
if (subsequentScc.inDegree == 0) {
pq.add(subsequentScc);
}
}
}
}
private static <N extends SortableNode<N>> void forwardVisit(N phase, N parent, List<N> toposort) {
if (!phase.visited) {
// Not yet visited.
phase.visited = true;
for (N data : phase.subsequentNodes) {
forwardVisit(data, phase, toposort);
}
toposort.add(phase);
}
}
private static <N extends SortableNode<N>> void clearStatus(List<N> phases) {
for (N phase : phases) {
phase.visited = false;
}
}
private static <N extends SortableNode<N>> void backwardVisit(N phase, List<N> sccPhases) {
if (!phase.visited) {
phase.visited = true;
sccPhases.add(phase);
for (N data : phase.previousNodes) {
backwardVisit(data, sccPhases);
}
}
}
private static class PhaseScc<N extends SortableNode<N>> {
final List<N> phases;
final List<PhaseScc<N>> subsequentSccs = new ArrayList<>();
int inDegree = 0;
private PhaseScc(List<N> phases) {
this.phases = phases;
}
}
}

View file

@ -0,0 +1,31 @@
/*
* Copyright (c) 2016, 2017, 2018, 2019 FabricMC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.fabricmc.fabric.impl.base.toposort;
import java.util.ArrayList;
import java.util.List;
public abstract class SortableNode<N extends SortableNode<N>> {
public final List<N> subsequentNodes = new ArrayList<>();
public final List<N> previousNodes = new ArrayList<>();
boolean visited = false;
/**
* @return Description of this node, used to print the cycle warning.
*/
protected abstract String getDescription();
}

View file

@ -29,7 +29,7 @@ import net.minecraft.util.Identifier;
import net.fabricmc.fabric.api.event.Event;
import net.fabricmc.fabric.api.event.EventFactory;
import net.fabricmc.fabric.impl.base.event.PhaseSorting;
import net.fabricmc.fabric.impl.base.toposort.NodeSorting;
public class EventTests {
private static final Logger LOGGER = LoggerFactory.getLogger("fabric-api-base");
@ -41,10 +41,10 @@ public class EventTests {
testMultipleDefaultPhases();
testAddedPhases();
testCycle();
PhaseSorting.ENABLE_CYCLE_WARNING = false;
NodeSorting.ENABLE_CYCLE_WARNING = false;
testDeterministicOrdering();
testTwoCycles();
PhaseSorting.ENABLE_CYCLE_WARNING = true;
NodeSorting.ENABLE_CYCLE_WARNING = true;
long time2 = System.currentTimeMillis();
LOGGER.info("Event unit tests succeeded in {} milliseconds.", time2 - time1);

View file

@ -17,12 +17,14 @@
package net.fabricmc.fabric.impl.registry.sync;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import com.mojang.serialization.Codec;
import org.jetbrains.annotations.Unmodifiable;
@ -33,12 +35,13 @@ import net.minecraft.registry.RegistryLoader;
import net.minecraft.registry.SerializableRegistries;
import net.fabricmc.fabric.api.event.registry.DynamicRegistries;
import net.fabricmc.fabric.impl.base.toposort.NodeSorting;
import net.fabricmc.fabric.impl.base.toposort.SortableNode;
public final class DynamicRegistriesImpl {
private static final List<RegistryLoader.Entry<?>> DYNAMIC_REGISTRIES = new ArrayList<>(RegistryLoader.DYNAMIC_REGISTRIES);
private static final Set<RegistryKey<? extends Registry<?>>> DYNAMIC_REGISTRY_KEYS = new HashSet<>();
private static final Map<RegistryKey<? extends Registry<?>>, SettingsImpl<?>> SETTINGS = new HashMap<>();
private static boolean sorted = true;
private static volatile List<RegistryLoader.Entry<?>> sortedRegistries = null;
static {
for (RegistryLoader.Entry<?> vanillaEntry : RegistryLoader.DYNAMIC_REGISTRIES) {
@ -50,16 +53,78 @@ public final class DynamicRegistriesImpl {
}
public static @Unmodifiable List<RegistryLoader.Entry<?>> getDynamicRegistries() {
if (!sorted) {
sort();
sorted = true;
List<RegistryLoader.Entry<?>> ret = sortedRegistries;
if (ret == null) {
sortedRegistries = ret = sort();
}
return List.copyOf(DYNAMIC_REGISTRIES);
return ret;
}
private static void sort() {
// TODO
private static List<RegistryLoader.Entry<?>> sort() {
Map<RegistryKey<? extends Registry<?>>, RegistryNode> nodes = new HashMap<>(RegistryLoader.DYNAMIC_REGISTRIES.size() + SETTINGS.size());
// Add vanilla nodes with their ordering
int vanillaIndex = 0;
RegistryNode previousVanillaNode = null;
for (RegistryLoader.Entry<?> vanillaEntry : RegistryLoader.DYNAMIC_REGISTRIES) {
RegistryNode node = new RegistryNode(vanillaEntry, vanillaIndex++);
nodes.put(vanillaEntry.key(), node);
if (previousVanillaNode != null) {
link(previousVanillaNode, node);
}
previousVanillaNode = node;
}
// Add modded nodes
for (SettingsImpl<?> settings : SETTINGS.values()) {
RegistryNode node = nodes.computeIfAbsent(settings.owner.key(), k -> new RegistryNode(settings.owner, RegistryNode.MODDED_INDEX));
nodes.put(settings.owner.key(), node);
}
// Add modded ordering
for (SettingsImpl<?> settings : SETTINGS.values()) {
RegistryNode node = nodes.get(settings.owner.key());
for (RegistryKey<? extends Registry<?>> before : settings.before) {
RegistryNode other = nodes.get(before);
if (other == null) {
throw new IllegalStateException("Registry " + settings.owner.key() + " has a dependency on " + before + ", which does not exist!");
}
link(node, other);
}
for (RegistryKey<? extends Registry<?>> after : settings.after) {
RegistryNode other = nodes.get(after);
if (other == null) {
throw new IllegalStateException("Registry " + settings.owner.key() + " has a dependency on " + after + ", which does not exist!");
}
link(other, node);
}
}
// Sort everything
List<RegistryNode> nodesToSort = new ArrayList<>(nodes.values());
NodeSorting.sort(nodesToSort, "dynamic registries", RegistryNode.COMPARATOR);
for (RegistryNode node : nodesToSort) {
System.out.println("Sorted node: " + node.entry.key());
}
return nodesToSort.stream().map(node -> node.entry).collect(Collectors.toUnmodifiableList());
}
private static void link(RegistryNode node1, RegistryNode node2) {
node1.subsequentNodes.add(node2);
node2.previousNodes.add(node1);
}
public static <T> DynamicRegistries.Settings<T> register(RegistryKey<? extends Registry<T>> key, Codec<T> codec) {
@ -71,8 +136,8 @@ public final class DynamicRegistriesImpl {
}
var entry = new RegistryLoader.Entry<>(key, codec);
DYNAMIC_REGISTRIES.add(entry);
sorted = false;
// TODO: may not be thread-safe
sortedRegistries = null;
var settings = new SettingsImpl<>(entry);
SETTINGS.put(key, settings);
return settings;
@ -118,7 +183,7 @@ public final class DynamicRegistriesImpl {
public DynamicRegistries.Settings<T> sortBefore(RegistryKey<? extends Registry<?>> before) {
Objects.requireNonNull(before, "Registry key to sort before");
this.before.add(before);
sorted = false;
sortedRegistries = null;
return this;
}
@ -126,8 +191,27 @@ public final class DynamicRegistriesImpl {
public DynamicRegistries.Settings<T> sortAfter(RegistryKey<? extends Registry<?>> after) {
Objects.requireNonNull(after, "Registry key to sort after");
this.after.add(after);
sorted = false;
sortedRegistries = null;
return this;
}
}
private static final class RegistryNode extends SortableNode<RegistryNode> {
private static final Comparator<RegistryNode> COMPARATOR = Comparator.<RegistryNode>comparingInt(node -> node.vanillaIndex)
.thenComparing(node -> node.entry.key().getValue());
private static final int MODDED_INDEX = 1000; // modded registries go after vanilla by default
private final RegistryLoader.Entry<?> entry;
private final int vanillaIndex;
private RegistryNode(RegistryLoader.Entry<?> entry, int vanillaIndex) {
this.entry = entry;
this.vanillaIndex = vanillaIndex;
}
@Override
protected String getDescription() {
return entry.key().toString();
}
}
}