Skip to content

Commit

Permalink
Union/IntersectionEvent: Freeze root realization
Browse files Browse the repository at this point in the history
  • Loading branch information
josephmure authored Nov 13, 2023
1 parent 83f8f9c commit 614b5d3
Show file tree
Hide file tree
Showing 26 changed files with 445 additions and 241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ AdaptiveDirectionalStratification::AdaptiveDirectionalStratification()
AdaptiveDirectionalStratification::AdaptiveDirectionalStratification(const RandomVector & event,
const RootStrategy & rootStrategy,
const SamplingStrategy & samplingStrategy)
: EventSimulation(event)
, standardEvent_(StandardEvent(event))
: EventSimulation(event.getImplementation()->asComposedEvent())
, standardEvent_(StandardEvent(getEvent()))
, rootStrategy_(rootStrategy)
, samplingStrategy_(samplingStrategy)
, gamma_(ResourceMap::GetAsUnsignedInteger("AdaptiveDirectionalStratification-DefaultNumberOfSteps"), ResourceMap::GetAsScalar("AdaptiveDirectionalStratification-DefaultGamma"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ CrossEntropyImportanceSampling::CrossEntropyImportanceSampling()
// Default constructor
CrossEntropyImportanceSampling::CrossEntropyImportanceSampling(const RandomVector & event,
const Scalar quantileLevel)
: EventSimulation(event)
, initialDistribution_(event.getAntecedent().getDistribution())
: EventSimulation(event.getImplementation()->asComposedEvent())
, initialDistribution_(getEvent().getAntecedent().getDistribution())
{
if (quantileLevel > 1.)
throw InvalidArgumentException(HERE) << "In CrossEntropyImportanceSampling::CrossEntropyImportanceSampling, quantileLevel parameter value should be between 0.0 and 1.0";
Expand All @@ -52,7 +52,7 @@ CrossEntropyImportanceSampling::CrossEntropyImportanceSampling(const RandomVecto
throw InvalidArgumentException(HERE) << "In CrossEntropyImportanceSampling::CrossEntropyImportanceSampling, quantileLevel parameter value should be between 0.0 and 1.0";


quantileLevel_ = (event.getOperator()(0, 1) ? quantileLevel : 1.0 - quantileLevel);
quantileLevel_ = (getEvent().getOperator()(0, 1) ? quantileLevel : 1.0 - quantileLevel);
}

/* Virtual constructor */
Expand Down Expand Up @@ -93,7 +93,7 @@ Point CrossEntropyImportanceSampling::optimizeAuxiliaryDistributionParameters(co


// Reset auxiliary distribution parameters
void CrossEntropyImportanceSampling::resetAuxiliaryDistribution()
void CrossEntropyImportanceSampling::resetAuxiliaryDistribution()
{
throw NotYetImplementedException(HERE) << "In CrossEntropyImportanceSampling::resetAuxiliaryDistribution()";
}
Expand All @@ -104,7 +104,7 @@ void CrossEntropyImportanceSampling::run()

// Initialization of auxiliary distribution (in case of multiple runs of algorithms)
resetAuxiliaryDistribution();

const UnsignedInteger sampleSize = getMaximumOuterSampling() * getBlockSize();

// Drawing of samples using initial density
Expand All @@ -117,7 +117,7 @@ void CrossEntropyImportanceSampling::run()
Scalar currentQuantile = auxiliaryOutputSample.computeQuantile(quantileLevel_)[0];

Point auxiliaryDistributionParameters;

const ComparisonOperator comparator(getEvent().getOperator());
const Scalar threshold = getEvent().getThreshold();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ DirectionalSampling::DirectionalSampling()

/* Constructor with parameters */
DirectionalSampling::DirectionalSampling(const RandomVector & event)
: EventSimulation(event)
: EventSimulation(event.getImplementation()->asComposedEvent())
{
if (!event.isEvent() || !event.isComposite()) throw InvalidArgumentException(HERE) << "DirectionalSampling requires a composite event";
standardEvent_ = StandardEvent(event);
standardEvent_ = StandardEvent(getEvent());
standardFunction_ = standardEvent_.getImplementation()->getFunction();
inputDistribution_ = standardEvent_.getImplementation()->getAntecedent().getDistribution();
samplingStrategy_ = SamplingStrategy(inputDistribution_.getDimension());
Expand All @@ -61,11 +61,11 @@ DirectionalSampling::DirectionalSampling(const RandomVector & event)
DirectionalSampling::DirectionalSampling(const RandomVector & event,
const RootStrategy & rootStrategy,
const SamplingStrategy & samplingStrategy)
: EventSimulation(event)
: EventSimulation(event.getImplementation()->asComposedEvent())
, rootStrategy_(rootStrategy)
{
if (!event.isEvent() || !event.isComposite()) throw InvalidArgumentException(HERE) << "DirectionalSampling requires a composite event";
standardEvent_ = StandardEvent(event);
standardEvent_ = StandardEvent(getEvent());
standardFunction_ = standardEvent_.getImplementation()->getFunction();
inputDistribution_ = standardEvent_.getImplementation()->getAntecedent().getDistribution();
setSamplingStrategy(samplingStrategy);
Expand Down
2 changes: 2 additions & 0 deletions lib/src/Uncertainty/Algorithm/Simulation/EventSimulation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "openturns/Uniform.hxx"
#include "openturns/IdentityFunction.hxx"
#include "openturns/CompositeRandomVector.hxx"
#include "openturns/IntersectionEvent.hxx"
#include "openturns/UnionEvent.hxx"

BEGIN_NAMESPACE_OPENTURNS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ PhysicalSpaceCrossEntropyImportanceSampling::PhysicalSpaceCrossEntropyImportance
, solver_(NLopt("LD_LBFGS"))
{
auxiliaryDistribution_ = auxiliaryDistribution;
quantileLevel_ = (event.getOperator()(0, 1) ? quantileLevel : 1.0 - quantileLevel);
quantileLevel_ = (getEvent().getOperator()(0, 1) ? quantileLevel : 1.0 - quantileLevel);
bounds_ = bounds;
initialAuxiliaryDistributionParameters_ = initialAuxiliaryDistributionParameters;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ Sample ProbabilitySimulationAlgorithm::computeBlockSampleComposite()
{
Point weights;
const Sample inputSample(experiment_.generateWithWeights(weights));
Sample blockSample(getEvent().getImplementation()->getFunction()(inputSample));
const DomainImplementation::BoolCollection isRealized(getEvent().getDomain().contains(blockSample));
Sample blockSample(blockSize_, 1);
const RandomVector event(getEvent());
for (UnsignedInteger i = 0; i < blockSize_; ++ i)
blockSample(i, 0) = isRealized[i];
blockSample[i] = event.getFrozenRealization(inputSample[i]);
if (!experiment_.hasUniformWeights())
for (UnsignedInteger i = 0; i < blockSize_; ++ i)
blockSample(i, 0) *= weights[i];
Expand Down
4 changes: 2 additions & 2 deletions lib/src/Uncertainty/Algorithm/Simulation/SubsetSampling.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ SubsetSampling::SubsetSampling()
SubsetSampling::SubsetSampling(const RandomVector & event,
const Scalar proposalRange,
const Scalar conditionalProbability)
: EventSimulation(event)
: EventSimulation(event.getImplementation()->asComposedEvent())
, proposalRange_(proposalRange)
, conditionalProbability_(conditionalProbability)
, iSubset_(false)
Expand All @@ -68,7 +68,7 @@ SubsetSampling::SubsetSampling(const RandomVector & event,
{
if (!event.isEvent() || !event.isComposite()) throw InvalidArgumentException(HERE) << "SubsetSampling requires a composite event";
setMaximumOuterSampling(ResourceMap::GetAsUnsignedInteger("SubsetSampling-DefaultMaximumOuterSampling"));// override simulation default outersampling
UnsignedInteger outputDimension = event.getFunction().getOutputDimension();
UnsignedInteger outputDimension = getEvent().getFunction().getOutputDimension();
if (outputDimension > 1)
throw InvalidArgumentException(HERE) << "Output dimension for SubsetSampling cannot be greater than 1, here output dimension=" << outputDimension;
}
Expand Down
16 changes: 16 additions & 0 deletions lib/src/Uncertainty/Model/CompositeRandomVector.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ Point CompositeRandomVector::getRealization() const
return function_(antecedent_.getRealization());
}

Point CompositeRandomVector::getFrozenRealization(const Point & fixedPoint) const
{
return function_(antecedent_.getFrozenRealization(fixedPoint));
}

/* Numerical sample accessor */
Sample CompositeRandomVector::getSample(const UnsignedInteger size) const
{
Expand All @@ -101,6 +106,17 @@ Sample CompositeRandomVector::getSample(const UnsignedInteger size) const
return sample;
}

Sample CompositeRandomVector::getFrozenSample(const Sample & fixedSample) const
{
Sample sample(function_(antecedent_.getFrozenSample(fixedSample)));
const Description description(getDescription());
// It may append that the description has been overloaded by a child class
// FIXME: change this ugly hack to something reasonable
if (description.getSize() == sample.getDimension()) sample.setDescription(description);
else sample.setDescription(function_.getOutputDescription());
return sample;
}

/* Get the random vector corresponding to the i-th marginal component */
RandomVector CompositeRandomVector::getMarginal(const UnsignedInteger i) const
{
Expand Down
26 changes: 26 additions & 0 deletions lib/src/Uncertainty/Model/DomainEvent.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ Point DomainEvent::getRealization() const
return Point(1, domain_.contains(CompositeRandomVector::getRealization()));
}

Point DomainEvent::getFrozenRealization(const Point & fixedPoint) const
{
return Point(1, domain_.contains(CompositeRandomVector::getFrozenRealization(fixedPoint)));
}

/* Numerical sample accessor */
Sample DomainEvent::getSample(const UnsignedInteger size) const
{
Expand All @@ -111,11 +116,32 @@ Sample DomainEvent::getSample(const UnsignedInteger size) const
return result;
}

Sample DomainEvent::getFrozenSample(const Sample & fixedSample) const
{
// First, compute the sample of the event antecedent that fits fixedSample
const Sample returnSample(CompositeRandomVector::getFrozenSample(fixedSample));
// Then, we loop over the sample to check each point in sequence
Sample result(fixedSample.getSize(), 1);
for (UnsignedInteger i = 0; i < fixedSample.getSize(); ++i)
result(i, 0) = domain_.contains(returnSample[i]);
result.setName("DomainEvent sample");
result.setDescription(getDescription());
return result;
}

Bool DomainEvent::isEvent() const
{
return true;
}

RandomVector DomainEvent::asComposedEvent() const
{
if (domain_.getImplementation()->getClassName() != "LevelSet")
throw NotYetImplementedException(HERE) << "DomainEvent is not based on a LevelSet.";

return RandomVector(clone());
}

/* Method save() stores the object through the StorageManager */
void DomainEvent::save(Advocate & adv) const
{
Expand Down
150 changes: 71 additions & 79 deletions lib/src/Uncertainty/Model/IntersectionEvent.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -76,83 +76,68 @@ void IntersectionEvent::setEventCollection(const RandomVectorCollection & collec
const UnsignedInteger size = collection.getSize();
if (!size) throw InvalidArgumentException(HERE) << "Empty collection";

// Explore the deepest leftmost node of the tree which is not Intersection/Union to get the root cause
// Also we initialize composedEvent_ if Intersection/Union use getComposedEvent from top node else take the ThresholdEvent
if (!collection[0].isEvent())
throw InvalidArgumentException(HERE) << "Element 0 is not an event";
UnsignedInteger depth = 0;
RandomVector current = collection[0];
String implementationName = current.getImplementation()->getClassName();
while ((implementationName == "IntersectionEvent") || (implementationName == "UnionEvent"))
for (UnsignedInteger i = 0; i < size; ++ i)
{
Collection<RandomVector> children;
if (implementationName == "IntersectionEvent")
{
IntersectionEvent *intersectionEvent = static_cast<IntersectionEvent*>(current.getImplementation().get());
if (depth == 0)
composedEvent_ = intersectionEvent->getComposedEvent();
children = intersectionEvent->getEventCollection();
}
else if (implementationName == "UnionEvent")
{
UnionEvent *unionEvent = static_cast<UnionEvent*>(current.getImplementation().get());
if (depth == 0)
composedEvent_ = unionEvent->getComposedEvent();
children = unionEvent->getEventCollection();
}
current = children[0];
++ depth;
implementationName = current.getImplementation()->getClassName();
if (!collection[i].isEvent())
throw InvalidArgumentException(HERE) << "Element " << i << " is not an event";
}
// store root cause
antecedent_ = current.getAntecedent();

antecedent_ = collection[0].getAntecedent();
const UnsignedInteger rootCauseId = antecedent_.getImplementation()->getId();
if (depth == 0) // no IntersectionEvent/Union was found, take the first node
composedEvent_ = collection[0];

// Explore the tree, check root cause, compose top-nodes
// Explore the tree, check root cause
for (UnsignedInteger i = 1; i < size; ++ i)
{
if (!collection[i].isEvent())
throw InvalidArgumentException(HERE) << "Element " << i << " is not an event";
if (collection[i].getImplementation()->getClassName() == "IntersectionEvent")
{
// IntersectionEvent
IntersectionEvent* intersectionEvent = static_cast<IntersectionEvent*>(collection[i].getImplementation().get());
if (intersectionEvent->getAntecedent().getImplementation()->getId() != rootCauseId)
throw InvalidArgumentException(HERE) << "Different root cause";
composedEvent_ = composedEvent_.intersect(intersectionEvent->getComposedEvent());
}
else if (collection[i].getImplementation()->getClassName() == "UnionEvent")
{
// UnionEvent
UnionEvent* unionEvent = static_cast<UnionEvent*>(collection[i].getImplementation().get());
if (unionEvent->getAntecedent().getImplementation()->getId() != rootCauseId)
throw InvalidArgumentException(HERE) << "Different root cause";
composedEvent_ = composedEvent_.intersect(unionEvent->getComposedEvent());
}
else
{
// ThresholdEvent
if (collection[i].getAntecedent().getImplementation()->getId() != rootCauseId)
throw NotYetImplementedException(HERE) << "Root cause not found";
composedEvent_ = composedEvent_.intersect(collection[i]);
}
if (collection[i].getAntecedent().getImplementation()->getId() != rootCauseId)
throw NotYetImplementedException(HERE) << "Root cause not found";
}
eventCollection_ = collection;
setDescription(composedEvent_.getDescription());
setDescription(collection[0].getDescription());
}

/* Realization accessor */
Point IntersectionEvent::getRealization() const
{
return composedEvent_.getRealization();
return getFrozenRealization(antecedent_.getRealization());
}

/* Fixed value accessor */
Point IntersectionEvent::getFrozenRealization(const Point & fixedPoint) const
{
LOGINFO(OSS() << "antecedent value = " << fixedPoint);
Point realization(1);
for (UnsignedInteger i = 0; i < eventCollection_.getSize(); ++ i)
if (eventCollection_[i].getFrozenRealization(fixedPoint)[0] == 0.0)
return realization;
realization[0] = 1.0;
return realization;
}

/* Sample accessor */
Sample IntersectionEvent::getSample(const UnsignedInteger size) const
{
return composedEvent_.getSample(size);
return getFrozenSample(antecedent_.getSample(size));
}

/* Fixed sample accessor */
Sample IntersectionEvent::getFrozenSample(const Sample & fixedSample) const
{
Indices stillInIntersection(fixedSample.getSize());
stillInIntersection.fill();
Indices noLongerInIntersection(0);

for (UnsignedInteger i = 0; i < eventCollection_.getSize(); ++ i)
{
const Sample currentEventSample(eventCollection_[i].getFrozenSample(fixedSample.select(stillInIntersection)));
for (UnsignedInteger j = 0; j < stillInIntersection.getSize(); ++ j)
if (currentEventSample(j, 0) == 0.0) noLongerInIntersection.add(stillInIntersection[j]);
stillInIntersection = noLongerInIntersection.complement(fixedSample.getSize());
}

Sample sample(fixedSample.getSize(), 1);
for (UnsignedInteger j = 0; j < stillInIntersection.getSize(); ++ j)
sample(stillInIntersection[j], 0) = 1.0;
return sample;
}

Bool IntersectionEvent::isEvent() const
Expand All @@ -170,24 +155,36 @@ RandomVector IntersectionEvent::getAntecedent() const
return antecedent_;
}

Function IntersectionEvent::getFunction() const
{
return composedEvent_.getFunction();
}

Domain IntersectionEvent::getDomain() const
RandomVector IntersectionEvent::asComposedEvent() const
{
return composedEvent_.getDomain();
}
const UnsignedInteger size = eventCollection_.getSize();
if (!size) throw InvalidArgumentException(HERE) << "Intersection has been improperly initialized: event collection is empty";

ComparisonOperator IntersectionEvent::getOperator() const
{
return composedEvent_.getOperator();
}
RandomVector composedEvent;
try
{
// We get the first event in the collection as a composed event if possible.
composedEvent = eventCollection_[0].getImplementation()->asComposedEvent();
}
catch (const NotYetImplementedException &)
{
throw NotYetImplementedException(HERE) << "Event #0 could not be rebuilt as a ThresholdEvent.";
}

Scalar IntersectionEvent::getThreshold() const
{
return composedEvent_.getThreshold();
// Further build composedEvent by composing with the other events in the eventCollection_
for (UnsignedInteger i = 1; i < size; ++ i)
{
try
{
// We try to compose with the next event in the collection.
composedEvent = composedEvent.intersect(eventCollection_[i].getImplementation()->asComposedEvent());
}
catch (const NotYetImplementedException &)
{
throw NotYetImplementedException(HERE) << "Event #" << i << " could not be rebuilt as a ThresholdEvent.";
}
}
return composedEvent;
}

/* Method save() stores the object through the StorageManager */
Expand All @@ -206,9 +203,4 @@ void IntersectionEvent::load(Advocate & adv)
setEventCollection(eventCollection);
}

RandomVector IntersectionEvent::getComposedEvent() const
{
return composedEvent_;
}

END_NAMESPACE_OPENTURNS
Loading

0 comments on commit 614b5d3

Please sign in to comment.