Skip to content

Commit b1f0fd2

Browse files
committed
fix input output tests
1 parent f43fa1c commit b1f0fd2

File tree

7 files changed

+36
-24
lines changed

7 files changed

+36
-24
lines changed

src/inputoutput.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
197197
simplify = false,
198198
eval_expression = false,
199199
eval_module = @__MODULE__,
200+
check_simplified = true,
200201
kwargs...)
201-
202202
# Remove this when the ControlFunction gets merged.
203-
if !iscomplete(sys)
203+
if check_simplified && !iscomplete(sys)
204204
error("A completed `ODESystem` is required. Call `complete` or `structural_simplify` on the system before creating the control function.")
205205
end
206206
isempty(inputs) && @warn("No unbound inputs were found in system.")
@@ -259,7 +259,7 @@ end
259259
"""
260260
Turn input variables into parameters of the system.
261261
"""
262-
function inputs_to_parameters!(state::TransformationState, inputsyms)
262+
function inputs_to_parameters!(state::TransformationState, inputsyms; is_disturbance = false)
263263
check_bound = inputsyms === nothing
264264
@unpack structure, fullvars, sys = state
265265
@unpack var_to_diff, graph, solvable_graph = structure
@@ -414,7 +414,7 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwar
414414
@variables u(t)=0 [input = true] # New system input
415415
dsys = get_disturbance_system(dist)
416416

417-
if inputs === nothing
417+
if isempty(inputs)
418418
all_inputs = [u]
419419
else
420420
i = findfirst(isequal(dist.input), inputs)
@@ -429,8 +429,9 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwar
429429
dist.input ~ u + dsys.output.u[1]]
430430
augmented_sys = ODESystem(eqs, t, systems = [dsys], name = gensym(:outer))
431431
augmented_sys = extend(augmented_sys, sys)
432+
ssys = structural_simplify(augmented_sys, inputs = all_inputs, disturbance_inputs = [d])
432433

433-
(f_oop, f_ip), dvs, p, io_sys = generate_control_function(augmented_sys, all_inputs,
434-
[d]; kwargs...)
434+
(f_oop, f_ip), dvs, p, io_sys = generate_control_function(ssys, all_inputs,
435+
[d]; check_simplified = false, kwargs...)
435436
(f_oop, f_ip), augmented_sys, dvs, p, io_sys
436437
end

src/linearization.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,11 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
556556
(; A, B, C, D, f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u), sys
557557
end
558558

559-
function markio!(state, orig_inputs, inputs, outputs; check = true)
559+
function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true)
560560
fullvars = get_fullvars(state)
561561
inputset = Dict{Any, Bool}(i => false for i in inputs)
562562
outputset = Dict{Any, Bool}(o => false for o in outputs)
563+
disturbanceset = Dict{Any, Bool}(d => false for d in disturbances)
563564
for (i, v) in enumerate(fullvars)
564565
if v in keys(inputset)
565566
if v in keys(outputset)
@@ -581,6 +582,12 @@ function markio!(state, orig_inputs, inputs, outputs; check = true)
581582
v = setio(v, false, false)
582583
fullvars[i] = v
583584
end
585+
586+
if v in keys(disturbanceset)
587+
v = setio(v, true, false)
588+
v = setdisturbance(v, true)
589+
fullvars[i] = v
590+
end
584591
end
585592
if check
586593
ikeys = keys(filter(!last, inputset))

src/systems/systems.jl

-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ function __structural_simplify(sys::AbstractSystem; simplify = false,
123123
for (i, v) in enumerate(fullvars)
124124
if !iszero(new_idxs[i]) &&
125125
invview(var_to_diff)[i] === nothing]
126-
# TODO: IO is not handled.
127126
ode_sys = structural_simplify(sys; simplify, inputs, outputs, disturbance_inputs, kwargs...)
128127
eqs = equations(ode_sys)
129128
sorted_g_rows = zeros(Num, length(eqs), size(g, 2))

src/systems/systemstructure.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,8 @@ function _structural_simplify!(state::TearingState; simplify = false,
691691
has_io = inputs !== nothing || outputs !== nothing
692692
orig_inputs = Set()
693693
if has_io
694-
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs)
695-
state = ModelingToolkit.inputs_to_parameters!(state, inputs)
694+
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
695+
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
696696
end
697697
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
698698
if check_consistency

src/variables.jl

+2
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ function isdisturbance(x)
349349
Symbolics.getmetadata(x, VariableDisturbance, false)
350350
end
351351

352+
setdisturbance(x, v) = setmetadata(x, VariableDisturbance, v)
353+
352354
function disturbances(sys)
353355
[filter(isdisturbance, unknowns(sys)); filter(isdisturbance, parameters(sys))]
354356
end

test/input_output_handling.jl

+16-13
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
77
@variables xx(t) some_input(t) [input = true]
88
eqs = [D(xx) ~ some_input]
99
@named model = ODESystem(eqs, t)
10-
@test_throws ExtraVariablesSystemException structural_simplify(model, ((), ()))
10+
@test_throws ExtraVariablesSystemException structural_simplify(model)
1111
if VERSION >= v"1.8"
1212
err = "In particular, the unset input(s) are:\n some_input(t)"
13-
@test_throws err structural_simplify(model, ((), ()))
13+
@test_throws err structural_simplify(model)
1414
end
1515

1616
# Test input handling
@@ -88,7 +88,7 @@ fsys4 = flatten(sys4)
8888
@variables x(t) y(t) [output = true]
8989
@test isoutput(y)
9090
@named sys = ODESystem([D(x) ~ -x, y ~ x], t) # both y and x are unbound
91-
syss = structural_simplify(sys) # This makes y an observed variable
91+
syss = structural_simplify(sys, outputs = [y]) # This makes y an observed variable
9292

9393
@named sys2 = ODESystem([D(x) ~ -sys.x, y ~ sys.y], t, systems = [sys])
9494

@@ -106,7 +106,7 @@ syss = structural_simplify(sys) # This makes y an observed variable
106106
@test isequal(unbound_outputs(sys2), [y])
107107
@test isequal(bound_outputs(sys2), [sys.y])
108108

109-
syss = structural_simplify(sys2)
109+
syss = structural_simplify(sys2, outputs = [sys.y])
110110

111111
@test !is_bound(syss, y)
112112
@test !is_bound(syss, x)
@@ -165,6 +165,7 @@ end
165165
]
166166

167167
@named sys = ODESystem(eqs, t)
168+
sys = structural_simplify(sys, inputs = [u])
168169
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split)
169170

170171
@test isequal(dvs[], x)
@@ -182,8 +183,8 @@ end
182183
]
183184

184185
@named sys = ODESystem(eqs, t)
185-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
186-
sys, [u], [d]; simplify, split)
186+
sys = structural_simplify(sys, inputs = [u], disturbance_inputs = [d])
187+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split)
187188

188189
@test isequal(dvs[], x)
189190
@test isempty(ps)
@@ -200,8 +201,8 @@ end
200201
]
201202

202203
@named sys = ODESystem(eqs, t)
203-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
204-
sys, [u], [d]; simplify, split, disturbance_argument = true)
204+
sys = structural_simplify(sys, inputs = [u], disturbance_inputs = [d])
205+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split, disturbance_argument = true)
205206

206207
@test isequal(dvs[], x)
207208
@test isempty(ps)
@@ -265,9 +266,9 @@ eqs = [connect_sd(sd, mass1, mass2)
265266
@named _model = ODESystem(eqs, t)
266267
@named model = compose(_model, mass1, mass2, sd);
267268

269+
model = structural_simplify(model, inputs = [u])
268270
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(model, simplify = true)
269271
@test length(dvs) == 4
270-
@test length(ps) == length(parameters(model))
271272
p = MTKParameters(io_sys, [io_sys.u => NaN])
272273
x = ModelingToolkit.varmap_to_vars(
273274
merge(ModelingToolkit.defaults(model),
@@ -389,7 +390,7 @@ sys = structural_simplify(model)
389390

390391
## Disturbance models when plant has multiple inputs
391392
using ModelingToolkit, LinearAlgebra
392-
using ModelingToolkit: DisturbanceModel, io_preprocessing, get_iv, get_disturbance_system
393+
using ModelingToolkit: DisturbanceModel, get_iv, get_disturbance_system
393394
using ModelingToolkitStandardLibrary.Blocks
394395
A, C = [randn(2, 2) for i in 1:2]
395396
B = [1.0 0; 0 1.0]
@@ -431,6 +432,7 @@ matrices, ssys = linearize(augmented_sys,
431432
]
432433

433434
@named sys = ODESystem(eqs, t)
435+
sys = structural_simplify(sys, inputs = [u])
434436
(; io_sys,) = ModelingToolkit.generate_control_function(sys, simplify = true)
435437
obsfn = ModelingToolkit.build_explicit_observed_function(
436438
io_sys, [x + u * t]; inputs = [u])
@@ -442,9 +444,9 @@ end
442444
@constants c = 2.0
443445
@variables x(t)
444446
eqs = [D(x) ~ c * x]
445-
@named sys = ODESystem(eqs, t, [x], [])
447+
@mtkbuild sys = ODESystem(eqs, t, [x], [])
446448

447-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
449+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys)
448450
@test f[1]([0.5], nothing, MTKParameters(io_sys, []), 0.0) [1.0]
449451
end
450452

@@ -453,7 +455,8 @@ end
453455
@parameters p(::Real) = (x -> 2x)
454456
eqs = [D(x) ~ -x + p(u)]
455457
@named sys = ODESystem(eqs, t)
456-
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
458+
sys = structural_simplify(sys, inputs = [u])
459+
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys)
457460
p = MTKParameters(io_sys, [])
458461
u = [1.0]
459462
x = [1.0]

test/reduction.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ eqs = [D(x) ~ σ * (y - x)
233233
u ~ z + a]
234234

235235
lorenz1 = ODESystem(eqs, t, name = :lorenz1)
236-
lorenz1_reduced, _ = structural_simplify(lorenz1, inputs = [z], outputs = [])
236+
lorenz1_reduced = structural_simplify(lorenz1, inputs = [z], outputs = [])
237237
@test z in Set(parameters(lorenz1_reduced))
238238

239239
# #2064

0 commit comments

Comments
 (0)