fb-hydra: How to implement 2 nested structured configs?

2.1k views Asked by At

I have 2 sub configs and one master(?) config that having those sub configs. I designed configs like below:

from dataclasses import dataclass, field

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig

from typing import Any, List

@dataclass
class DBConfig:
    host: str = "localhost"
    driver: str = MISSING
    port: int = MISSING


@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


@dataclass
class ConnectionConfig:
    target: str = "app.my_class.MyClass"
    params: DBConfig = MISSING
    defaults: List[Any] = field(
        default_factory=lambda: [
            {
                "params": "mysql",      # I'd like to set mysql as a default
            }
        ]
    )



@dataclass
class AConfig:
    name: str = "foo"


@dataclass
class BConfig(AConfig):
    age: int = 10


@dataclass
class CConfig(AConfig):
    age: int = 20


@dataclass
class SomeOtherConfig:
    target: str = "app.my_class.MyClass2"
    params: AConfig = MISSING
    defaults: List[Any] = field(
        default_factory=lambda: [
            {
                "params": "bconfig",   # I'd like to set bconfig as a default
            }
        ]
    )



@dataclass
class Config:
    db_connection: ConnectionConfig = ConnectionConfig()
    some_other: SomeOtherConfig = SomeOtherConfig()


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())
    # connection = hydra.utils.instantiate(cfg)
    # print(connection)


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(
        name="config",
        node=Config,
    )
    cs.store(group="params", name="mysql", node=MySQLConfig)
    cs.store(group="params", name="postgresql", node=PostGreSQLConfig)

    cs.store(group="params", name="bconfig", node=BConfig)
    cs.store(group="params", name="cconfig", node=CConfig)

    my_app()

What I expected when I run program without any options:

db_connection:
    target: app.my_class.MyClass
    params:   
        host: localhost
        driver: mysql
        port: 3306   

some_other:
    target: app.my_class.MyClass2
    params:
        name: "foo"
        age: 10

But the result:

db_connection:
    target: app.my_class.MyClass
    params: ???
    defaults:
    - params: mysql
some_other:
    target: app.my_class.MyClass2
    params: ???
    defaults:
    - params: bconfig
1

There are 1 answers

5
Omry Yadan On BEST ANSWER

First of all, as of Hydra 1.0 - the defaults list is ONLY supported in the primary config. Below are two versions, the first version changes as little as possible in your example, and the second clean things up a bit.

Example 1:

from dataclasses import dataclass, field

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig

from typing import Any, List


@dataclass
class DBConfig:
    host: str = "localhost"
    driver: str = MISSING
    port: int = MISSING


@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


@dataclass
class ConnectionConfig:
    target: str = "app.my_class.MyClass"
    params: DBConfig = MISSING


@dataclass
class AConfig:
    name: str = "foo"


@dataclass
class BConfig(AConfig):
    age: int = 10


@dataclass
class CConfig(AConfig):
    age: int = 20


@dataclass
class SomeOtherConfig:
    target: str = "app.my_class.MyClass2"
    params: AConfig = MISSING


@dataclass
class Config:
    db_connection: ConnectionConfig = ConnectionConfig()
    some_other: SomeOtherConfig = SomeOtherConfig()
    defaults: List[Any] = field(
        default_factory=lambda: [
            {"db_connection/params": "mysql"},
            {"some_other/params": "bconfig"},
        ]
    )


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(
        name="config", node=Config,
    )
    cs.store(group="db_connection/params", name="mysql", node=MySQLConfig)
    cs.store(group="db_connection/params", name="postgresql", node=PostGreSQLConfig)

    cs.store(group="some_other/params", name="bconfig", node=BConfig)
    cs.store(group="some_other/params", name="cconfig", node=CConfig)

    my_app()

Example 2:

from dataclasses import dataclass, field

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, DictConfig
from hydra.types import ObjectConf
from typing import Any, List


@dataclass
class DBConfig:
    host: str = "localhost"
    driver: str = MISSING
    port: int = MISSING


@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306


@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10


@dataclass
class AConfig:
    name: str = "foo"


@dataclass
class BConfig(AConfig):
    age: int = 10


@dataclass
class CConfig(AConfig):
    age: int = 20


defaults = [{"db_connection": "mysql"}, {"some_other": "bconfig"}]


@dataclass
class Config:
    db_connection: ObjectConf = MISSING
    some_other: ObjectConf = MISSING
    defaults: List[Any] = field(default_factory=lambda: defaults)


cs = ConfigStore.instance()
cs.store(name="config", node=Config)
cs.store(
    group="db_connection",
    name="mysql",
    node=ObjectConf(target="MySQL", params=MySQLConfig),
)
cs.store(
    group="db_connection",
    name="postgresql",
    node=ObjectConf(target="PostgeSQL", params=PostGreSQLConfig),
)
cs.store(
    group="some_other",
    name="bconfig",
    node=ObjectConf(target="ClassB", params=BConfig()),
)
cs.store(
    group="some_other",
    name="cconfig",
    node=ObjectConf(target="ClassC", params=AConfig()),
)


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(cfg.pretty())


if __name__ == "__main__":
    my_app()