Skip to content

Commit

Permalink
Rely on helper Ops to decrease amount of code
Browse files Browse the repository at this point in the history
  • Loading branch information
gselzer committed Jan 18, 2019
1 parent 2b44404 commit dae52ed
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 162 deletions.
157 changes: 76 additions & 81 deletions src/main/java/net/imagej/ops/filter/sharpen/DefaultSharpen.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
import net.imagej.Extents;
import net.imagej.Position;
import net.imagej.ops.Ops;
import net.imagej.ops.special.computer.Computers;
import net.imagej.ops.special.computer.UnaryComputerOp;
import net.imagej.ops.special.function.AbstractUnaryFunctionOp;
import net.imglib2.Cursor;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccess;
import net.imagej.ops.special.function.UnaryFunctionOp;
import net.imagej.ops.special.inplace.Inplaces;
import net.imagej.ops.special.inplace.UnaryInplaceOp;
import net.imglib2.IterableInterval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.neighborhood.Neighborhood;
import net.imglib2.algorithm.neighborhood.RectangleNeighborhood;
import net.imglib2.algorithm.neighborhood.RectangleNeighborhoodFactory;
import net.imglib2.algorithm.neighborhood.RectangleShape.NeighborhoodsAccessible;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.util.Util;
import net.imglib2.view.Views;

Expand All @@ -32,109 +33,103 @@ public class DefaultSharpen<T extends RealType<T>> extends
implements Ops.Filter.Sharpen
{

final double[] kernel = { -1, -1, -1, -1, 12, -1, -1, -1, -1 };
double scale;
final double[][] kernel = { { -1, -1, -1 }, { -1, 12, -1 }, { -1, -1, -1 } };

RandomAccessibleInterval<ByteType> kernelRAI;

/**
* The sum of all of the kernel values
*/
double scale = 4;

UnaryFunctionOp<double[][], RandomAccessibleInterval<T>> kernelCreator;
UnaryComputerOp<RandomAccessible<T>, RandomAccessibleInterval<DoubleType>> convolveOp;
UnaryInplaceOp<IterableInterval<DoubleType>, IterableInterval<DoubleType>> addConstOp;
UnaryInplaceOp<IterableInterval<DoubleType>, IterableInterval<DoubleType>> divConstOp;
UnaryComputerOp<DoubleType, T> clipTypesOp;
UnaryComputerOp<IterableInterval<DoubleType>, IterableInterval<T>> convertOp;

@SuppressWarnings({ "unchecked", "rawtypes" })
@Override
public void initialize() {
T inType = Util.getTypeFromInterval(in());

// convolution kernel
kernelRAI = ops().create().kernel(kernel, new ByteType());

IterableInterval<DoubleType> dummyDoubleType = ops().create().img(in(), new DoubleType());

convolveOp = (UnaryComputerOp) Computers.unary(ops(),
Ops.Filter.Convolve.class, RandomAccessibleInterval.class,
RandomAccessible.class, kernelRAI);
addConstOp = (UnaryInplaceOp) Inplaces.unary(ops(), Ops.Math.Add.class,
dummyDoubleType, new DoubleType(scale / 2));
divConstOp = (UnaryInplaceOp) Inplaces.unary(ops(), Ops.Math.Divide.class,
dummyDoubleType, new DoubleType(scale));
clipTypesOp = (UnaryComputerOp) Computers.unary(ops(),
Ops.Convert.Clip.class, new DoubleType(), inType);
convertOp = Computers.unary(ops(),
Ops.Convert.ImageType.class, Views.iterable(in()), dummyDoubleType, clipTypesOp);

}

@SuppressWarnings("unchecked")
@Override
public RandomAccessibleInterval<T> calculate(
final RandomAccessibleInterval<T> input)
{
final RandomAccessibleInterval<T> output = ops().copy().rai(input);
// intermediate image to hold the convolution data. Depending on the input
// type we have to be able create an intermediate of wide enough type.
final RandomAccessibleInterval<DoubleType> intermediate = ops().create().img(in(), new DoubleType());
// output image
final RandomAccessibleInterval<T> output = (RandomAccessibleInterval<T>) ops().create().img(in());

// compute the sharpening on 2D slices of the image
final long[] planeDims = new long[input.numDimensions() - 2];
for (int i = 0; i < planeDims.length; i++)
planeDims[i] = input.dimension(i + 2);
final Extents extents = new Extents(planeDims);
final Position planePos = extents.createPosition();
if (planeDims.length == 0) {
computePlanar(planePos, input, output);
computePlanar(planePos, input, intermediate);
}
else {
while (planePos.hasNext()) {
planePos.fwd();
computePlanar(planePos, input, output);
computePlanar(planePos, input, intermediate);
}

}

T inputType = Util.getTypeFromInterval(input);
IterableInterval<DoubleType> iterableIntermediate = Views.iterable(
intermediate);

// divide by the scale, if integerType input also add scale / 2.
if (inputType instanceof IntegerType) addConstOp.mutate(
iterableIntermediate);
divConstOp.mutate(iterableIntermediate);

//convert the result back to the input type.
convertOp.compute(iterableIntermediate, Views.iterable(output));

return output;
}

private void computePlanar(final Position planePos,
final RandomAccessibleInterval<T> input,
final RandomAccessibleInterval<T> output)
final RandomAccessibleInterval<DoubleType> intermediate)
{
// TODO can we just set scale to 4?
scale = 0;
for (final double d : kernel)
scale += d;

final T type = Util.getTypeFromInterval(input);

final long[] imageDims = new long[input.numDimensions()];
input.dimensions(imageDims);

// create all objects needed for NeighborhoodsAccessible
RandomAccessibleInterval<T> slicedInput = ops().copy().rai(input);
// create 2D slice in the case of a (N>2)-dimensional image.
RandomAccessible<T> slicedInput = Views.extendMirrorSingle(input);
RandomAccessibleInterval<DoubleType> slicedIntermediate = intermediate;
for (int i = planePos.numDimensions() - 1; i >= 0; i--) {
slicedInput = Views.hyperSlice(slicedInput, input.numDimensions() - 1 - i,
planePos.getLongPosition(i));
slicedIntermediate = Views.hyperSlice(slicedIntermediate, intermediate
.numDimensions() - 1 - i, planePos.getLongPosition(i));
}

final RandomAccessible<T> refactoredInput = Views.extendMirrorSingle(
slicedInput);
final RectangleNeighborhoodFactory<T> factory = RectangleNeighborhood
.factory();
final FinalInterval neighborhoodSpan = new FinalInterval(new long[] { -1,
-1 }, new long[] { 1, 1 });

final NeighborhoodsAccessible<T> neighborhoods =
new NeighborhoodsAccessible<>(refactoredInput, neighborhoodSpan, factory);

// create cursors and random accesses for loop.
final Cursor<T> cursor = Views.iterable(input).localizingCursor();
final RandomAccess<T> outputRA = output.randomAccess();
for (int i = 0; i < planePos.numDimensions(); i++) {
outputRA.setPosition(planePos.getLongPosition(i), i + 2);
}
final RandomAccess<Neighborhood<T>> neighborhoodsRA = neighborhoods
.randomAccess();

int algorithmIndex = 0;
double sum;
final double[] n = new double[9];
while (cursor.hasNext()) {
cursor.fwd();
if (cursor.getLongPosition(0) == 14 && cursor.getLongPosition(1) == 0)
System.out.println("Hit 14");
neighborhoodsRA.setPosition(cursor);
final Neighborhood<T> current = neighborhoodsRA.get();
final Cursor<T> neighborhoodCursor = current.cursor();

algorithmIndex = 0;
sum = 0;
while (algorithmIndex < n.length) {
neighborhoodCursor.fwd();
n[algorithmIndex++] = neighborhoodCursor.get().getRealDouble();
}

for (int i = 0; i < kernel.length; i++) {
sum += kernel[i] * n[i];
}

//find the value for the output
double value;
if(type instanceof IntegerType) {
value = (sum + scale / 2) / scale;
}
else {
value = sum / scale;
}

outputRA.setPosition(cursor.getLongPosition(0), 0);
outputRA.setPosition(cursor.getLongPosition(1), 1);
if (value > type.getMaxValue()) value = type.getMaxValue();
if (value < type.getMinValue()) value = type.getMinValue();
outputRA.get().setReal(value);
}
convolveOp.compute(slicedInput, slicedIntermediate);
}
}
Loading

0 comments on commit dae52ed

Please sign in to comment.