Skip to content

Commit

Permalink
Merge pull request #165 from SAYANTANDE/dev
Browse files Browse the repository at this point in the history
Fit data extract function from log and test cases
  • Loading branch information
akabla authored Jul 23, 2024
2 parents c8839cc + 744e758 commit cc1690e
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 0 deletions.
64 changes: 64 additions & 0 deletions docs/src/examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,67 @@ ax.set_xlabel("Time")
ax.set_ylabel("Stress")
ax.grid("on")
#!nb fig #hide

# ## Example 4

# This example demonstrates generating a timeline and stress data, fitting multiple models to the data,
# calling the `extractfitdata` function, listing the errors, and determining which model fits the best.

using RHEOS

# Generate Timeline
datat = timeline(t_start = 0, t_end = 20.0, step = 0.02) # Create a timeline from 0 to 20 seconds with a step size of 0.02 seconds

# Generate Stress Data (Ramp & hold)
dramp_stress = stressfunction(datat, ramp(offset = 4.0, gradient = 0.8)) # Generate a ramp stress function with offset 4.0 and gradient 0.8
dhold_stress = dramp_stress - stressfunction(datat, ramp(offset = 5.0, gradient = 0.8)) # Generate a hold stress function by subtracting a shifted ramp

# Define the rheological model and predict
model = RheoModel(SLS_Zener, (η = 1, kᵦ = 1, kᵧ = 1))
SLS_predict = modelpredict(dhold_stress, model)
data = SLS_predict

# Fit three models to the data
SLS_Zener_model = modelfit(data, SLS_Zener, strain_imposed)
Maxwell_model = modelfit(data, Maxwell, strain_imposed)
BurgersLiquid_model = modelfit(data, BurgersLiquid, strain_imposed)

# Call the extractfitdata function to extract fitting data
extracted_data = extractfitdata(data)

# Determine which model fits best by comparing errors
best_model = ""
min_error = Inf

for (model_name, entries) in extracted_data
total_error = sum(entry.info.error for entry in entries)

println("Model: $model_name, Total Error: $total_error")

if total_error < min_error
min_error = total_error
best_model = model_name
end
end

println("Best fitting model: $best_model with total error: $min_error")

# Create strain-only data for model predictions
strain_only_data = onlystrain(data)

# Get model predictions for plotting
SLS_Zener_predict = modelpredict(strain_only_data, SLS_Zener_model)
Maxwell_predict = modelpredict(strain_only_data, Maxwell_model)
BurgersLiquid_predict = modelpredict(strain_only_data, BurgersLiquid_model)

# Plot data and fitted models
fig, ax = subplots(1, 1, figsize = (7, 5))
ax.plot(data.t, data.σ, ".", color = "green", label = "Original Data")
ax.plot(SLS_Zener_predict.t, SLS_Zener_predict.σ, "-", color = "red", label = "SLS_Zener Model")
ax.plot(Maxwell_predict.t, Maxwell_predict.σ, "--", color = "blue", label = "Maxwell Model")
ax.plot(BurgersLiquid_predict.t, BurgersLiquid_predict.σ, ":", color = "purple", label = "BurgersLiquid Model")
ax.set_xlabel("Time")
ax.set_ylabel("Stress")
ax.legend()
ax.grid("on")
#!nb fig #hide
1 change: 1 addition & 0 deletions src/RHEOS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const RheovecOrNone = Union{Vector{RheoFloat}, Nothing}
######################################################################
export namedtuple, dict
export RheoFloat
export extractfitdata

# definitions.jl
export RheoLogItem, RheoLog, rheologrun, showlog
Expand Down
36 changes: 36 additions & 0 deletions src/rheodata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,42 @@ end



"""
extractfitdata(data)
Parse a log of actions from a `RheoTimeData` instance and extract model fitting entries into a dictionary.
# Arguments
- `data`: A `RheoTimeData` instance containing the log of actions.
# Returns
A dictionary where keys are model names and values are lists of named tuples,
each containing model parameters, error, info, and index from the log.
# Example
```julia
data.log = [...] # Define your log data
models_data = extractfitdata(data)
println(models_data)
"""
function extractfitdata(data)
models_dict = Dict{String, Vector{NamedTuple{(:params, :info, :index), Tuple{Any, Any, Int}}}}()

for (idx, entry) in enumerate(log)
if entry.action.funct == :modelfit
model_name = entry.info.model_name
model_params = entry.info.model_params
error = entry.info.error
info = entry.info

model_entry = (params = model_params, info = info, index = idx)

if haskey(models_dict, model_name)
push!(models_dict[model_name], model_entry)
else
models_dict[model_name] = [model_entry]
end
end
end

return models_dict
end



Expand Down
67 changes: 67 additions & 0 deletions test/rheodata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,70 @@ function __mapdata!()
end
@test __mapdata!()


function _test_extractfitdata()
# Step 1: Generate Timeline
datat = timeline(t_start = 0, t_end = 20.0, step = 0.02) # Create a timeline from 0 to 20 seconds with a step size of 0.02 seconds

# Step 2: Generate Stress Data (Ramp & hold)
dramp_stress = stressfunction(datat, ramp(offset = 4.0, gradient = 0.8)) # Generate a ramp stress function with offset 4.0 and gradient 0.8
dhold_stress = dramp_stress - stressfunction(datat, ramp(offset = 5.0, gradient = 0.8)) # Generate a hold stress function by subtracting a shifted ramp

# Define the rheological model
model = RheoModel(SLS_Zener, (η = 1, kᵦ = 1, kᵧ = 1))
SLS_predict = modelpredict(dhold_stress, model)
data = SLS_predict

# Fit the model to the data
SLS_Zener_model = modelfit(data, SLS_Zener, strain_imposed)

# Call the extractfitdata function with data
extracted_data = extractfitdata(data)
all_tests_passed = true

# Iterate through data.log to dynamically verify modelfit entries
for (index, log_entry) in enumerate(data.log)
if log_entry.action.funct == :modelfit
model_name = log_entry.info.model_name
expected_params = log_entry.info.model_params
expected_info = log_entry.info

if haskey(extracted_data, model_name)
model_entries = extracted_data[model_name]

# Find the corresponding entry in the extracted data
matching_entries = filter(x -> x.index == index, model_entries)

if length(matching_entries) == 1
extracted_entry = matching_entries[1]

if extracted_entry.params == expected_params && extracted_entry.info == expected_info && extracted_entry.index == index
println("Test passed for entry $index: Extracted data matches expected data.")
else
println("Test failed for entry $index: Extracted data does not match expected data.")
println("Extracted params: $(extracted_entry.params)")
println("Expected params: $expected_params")
println("Extracted info: $(extracted_entry.info)")
println("Expected info: $expected_info")
println("Extracted index: $(extracted_entry.index)")
println("Expected index: $index")
all_tests_passed = false
end
else
println("Test failed for entry $index: Number of matching entries does not match expected.")
println("Matching entries: $matching_entries")
all_tests_passed = false
end
else
println("Test failed for entry $index: $model_name model not found in extracted data.")
println("Extracted data: $extracted_data")
all_tests_passed = false
end
end
end

return all_tests_passed
end

# Run the test function
_test_extractfitdata()

0 comments on commit cc1690e

Please sign in to comment.