codegen.py
author Simon MacMullen <simon@rabbitmq.com>
Fri Feb 03 15:59:12 2012 +0000 (3 months ago)
changeset 8922 4f87837a40be
parent 8430 86988974d7fa
permissions -rw-r--r--
Merge bug24702
     1 ##  The contents of this file are subject to the Mozilla Public License
     2 ##  Version 1.1 (the "License"); you may not use this file except in
     3 ##  compliance with the License. You may obtain a copy of the License
     4 ##  at http://www.mozilla.org/MPL/
     5 ##
     6 ##  Software distributed under the License is distributed on an "AS IS"
     7 ##  basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
     8 ##  the License for the specific language governing rights and
     9 ##  limitations under the License.
    10 ##
    11 ##  The Original Code is RabbitMQ.
    12 ##
    13 ##  The Initial Developer of the Original Code is VMware, Inc.
    14 ##  Copyright (c) 2007-2012 VMware, Inc.  All rights reserved.
    15 ##
    16 
    17 from __future__ import nested_scopes
    18 
    19 import sys
    20 sys.path.append("../rabbitmq-codegen")  # in case we're next to an experimental revision
    21 sys.path.append("codegen")              # in case we're building from a distribution package
    22 
    23 from amqp_codegen import *
    24 import string
    25 import re
    26 
    27 erlangTypeMap = {
    28     'octet': 'octet',
    29     'shortstr': 'shortstr',
    30     'longstr': 'longstr',
    31     'short': 'shortint',
    32     'long': 'longint',
    33     'longlong': 'longlongint',
    34     'bit': 'bit',
    35     'table': 'table',
    36     'timestamp': 'timestamp',
    37 }
    38 
    39 # Coming up with a proper encoding of AMQP tables in JSON is too much
    40 # hassle at this stage. Given that the only default value we are
    41 # interested in is for the empty table, we only support that.
    42 def convertTable(d):
    43     if len(d) == 0:
    44         return "[]"
    45     else:
    46         raise Exception('Non-empty table defaults not supported ' + d)
    47 
    48 erlangDefaultValueTypeConvMap = {
    49     bool : lambda x: str(x).lower(),
    50     str : lambda x: "<<\"" + x + "\">>",
    51     int : lambda x: str(x),
    52     float : lambda x: str(x),
    53     dict: convertTable,
    54     unicode: lambda x: "<<\"" + x.encode("utf-8") + "\">>"
    55 }
    56 
    57 def erlangize(s):
    58     s = s.replace('-', '_')
    59     s = s.replace(' ', '_')
    60     return s
    61 
    62 AmqpMethod.erlangName = lambda m: "'" + erlangize(m.klass.name) + '.' + erlangize(m.name) + "'"
    63 
    64 AmqpClass.erlangName = lambda c: "'" + erlangize(c.name) + "'"
    65 
    66 def erlangConstantName(s):
    67     return '_'.join(re.split('[- ]', s.upper()))
    68 
    69 class PackedMethodBitField:
    70     def __init__(self, index):
    71         self.index = index
    72         self.domain = 'bit'
    73         self.contents = []
    74 
    75     def extend(self, f):
    76         self.contents.append(f)
    77 
    78     def count(self):
    79         return len(self.contents)
    80 
    81     def full(self):
    82         return self.count() == 8
    83 
    84 def multiLineFormat(things, prologue, separator, lineSeparator, epilogue, thingsPerLine = 4):
    85     r = [prologue]
    86     i = 0
    87     for t in things:
    88         if i != 0:
    89             if i % thingsPerLine == 0:
    90                 r += [lineSeparator]
    91             else:
    92                 r += [separator]
    93         r += [t]
    94         i += 1
    95     r += [epilogue]
    96     return "".join(r)
    97 
    98 def prettyType(typeName, subTypes, typesPerLine = 4):
    99     """Pretty print a type signature made up of many alternative subtypes"""
   100     sTs = multiLineFormat(subTypes,
   101                           "( ", " | ", "\n       | ", " )",
   102                           thingsPerLine = typesPerLine)
   103     return "-type(%s ::\n       %s)." % (typeName, sTs)
   104 
   105 def printFileHeader():
   106     print """%%   Autogenerated code. Do not edit.
   107 %%
   108 %%  The contents of this file are subject to the Mozilla Public License
   109 %%  Version 1.1 (the "License"); you may not use this file except in
   110 %%  compliance with the License. You may obtain a copy of the License
   111 %%  at http://www.mozilla.org/MPL/
   112 %%
   113 %%  Software distributed under the License is distributed on an "AS IS"
   114 %%  basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
   115 %%  the License for the specific language governing rights and
   116 %%  limitations under the License.
   117 %%
   118 %%  The Original Code is RabbitMQ.
   119 %%
   120 %%  The Initial Developer of the Original Code is VMware, Inc.
   121 %%  Copyright (c) 2007-2012 VMware, Inc.  All rights reserved.
   122 %%"""
   123 
   124 def genErl(spec):
   125     def erlType(domain):
   126         return erlangTypeMap[spec.resolveDomain(domain)]
   127 
   128     def fieldTypeList(fields):
   129         return '[' + ', '.join([erlType(f.domain) for f in fields]) + ']'
   130 
   131     def fieldNameList(fields):
   132         return '[' + ', '.join([erlangize(f.name) for f in fields]) + ']'
   133 
   134     def fieldTempList(fields):
   135         return '[' + ', '.join(['F' + str(f.index) for f in fields]) + ']'
   136 
   137     def fieldMapList(fields):
   138         return ', '.join([erlangize(f.name) + " = F" + str(f.index) for f in fields])
   139 
   140     def genLookupMethodName(m):
   141         print "lookup_method_name({%d, %d}) -> %s;" % (m.klass.index, m.index, m.erlangName())
   142 
   143     def genLookupClassName(c):
   144         print "lookup_class_name(%d) -> %s;" % (c.index, c.erlangName())
   145 
   146     def genMethodId(m):
   147         print "method_id(%s) -> {%d, %d};" % (m.erlangName(), m.klass.index, m.index)
   148 
   149     def genMethodHasContent(m):
   150         print "method_has_content(%s) -> %s;" % (m.erlangName(), str(m.hasContent).lower())
   151 
   152     def genMethodIsSynchronous(m):
   153         hasNoWait = "nowait" in fieldNameList(m.arguments)
   154         if m.isSynchronous and hasNoWait:
   155           print "is_method_synchronous(#%s{nowait = NoWait}) -> not(NoWait);" % (m.erlangName())
   156         else:
   157           print "is_method_synchronous(#%s{}) -> %s;" % (m.erlangName(), str(m.isSynchronous).lower())
   158 
   159     def genMethodFieldTypes(m):
   160         """Not currently used - may be useful in future?"""
   161         print "method_fieldtypes(%s) -> %s;" % (m.erlangName(), fieldTypeList(m.arguments))
   162 
   163     def genMethodFieldNames(m):
   164         print "method_fieldnames(%s) -> %s;" % (m.erlangName(), fieldNameList(m.arguments))
   165 
   166     def packMethodFields(fields):
   167         packed = []
   168         bitfield = None
   169         for f in fields:
   170             if erlType(f.domain) == 'bit':
   171                 if not(bitfield) or bitfield.full():
   172                     bitfield = PackedMethodBitField(f.index)
   173                     packed.append(bitfield)
   174                 bitfield.extend(f)
   175             else:
   176                 bitfield = None
   177                 packed.append(f)
   178         return packed
   179 
   180     def methodFieldFragment(f):
   181         type = erlType(f.domain)
   182         p = 'F' + str(f.index)
   183         if type == 'shortstr':
   184             return p+'Len:8/unsigned, '+p+':'+p+'Len/binary'
   185         elif type == 'longstr':
   186             return p+'Len:32/unsigned, '+p+':'+p+'Len/binary'
   187         elif type == 'octet':
   188             return p+':8/unsigned'
   189         elif type == 'shortint':
   190             return p+':16/unsigned'
   191         elif type == 'longint':
   192             return p+':32/unsigned'
   193         elif type == 'longlongint':
   194             return p+':64/unsigned'
   195         elif type == 'timestamp':
   196             return p+':64/unsigned'
   197         elif type == 'bit':
   198             return p+'Bits:8'
   199         elif type == 'table':
   200             return p+'Len:32/unsigned, '+p+'Tab:'+p+'Len/binary'
   201 
   202     def genFieldPostprocessing(packed):
   203         for f in packed:
   204             type = erlType(f.domain)
   205             if type == 'bit':
   206                 for index in range(f.count()):
   207                     print "  F%d = ((F%dBits band %d) /= 0)," % \
   208                           (f.index + index,
   209                            f.index,
   210                            1 << index)
   211             elif type == 'table':
   212                 print "  F%d = rabbit_binary_parser:parse_table(F%dTab)," % \
   213                       (f.index, f.index)
   214             else:
   215                 pass
   216 
   217     def genMethodRecord(m):
   218         print "method_record(%s) -> #%s{};" % (m.erlangName(), m.erlangName())
   219 
   220     def genDecodeMethodFields(m):
   221         packedFields = packMethodFields(m.arguments)
   222         binaryPattern = ', '.join([methodFieldFragment(f) for f in packedFields])
   223         if binaryPattern:
   224             restSeparator = ', '
   225         else:
   226             restSeparator = ''
   227         recordConstructorExpr = '#%s{%s}' % (m.erlangName(), fieldMapList(m.arguments))
   228         print "decode_method_fields(%s, <<%s>>) ->" % (m.erlangName(), binaryPattern)
   229         genFieldPostprocessing(packedFields)
   230         print "  %s;" % (recordConstructorExpr,)
   231 
   232     def genDecodeProperties(c):
   233         def presentBin(fields):
   234             ps = ', '.join(['P' + str(f.index) + ':1' for f in fields])
   235             return '<<' + ps + ', _:%d, R0/binary>>' % (16 - len(fields),)
   236         def mkMacroName(field):
   237             return '?' + field.domain.upper() + '_PROP'
   238         def writePropFieldLine(field, bin_next = None):
   239             i = str(field.index)
   240             if not bin_next:
   241                 bin_next = 'R' + str(field.index + 1)
   242             if field.domain in ['octet', 'timestamp']:
   243                 print ("  {%s, %s} = %s(%s, %s, %s, %s)," %
   244                        ('F' + i, bin_next, mkMacroName(field), 'P' + i,
   245                         'R' + i, 'I' + i, 'X' + i))
   246             else:
   247                 print ("  {%s, %s} = %s(%s, %s, %s, %s, %s)," %
   248                        ('F' + i, bin_next, mkMacroName(field), 'P' + i,
   249                         'R' + i, 'L' + i, 'S' + i, 'X' + i))
   250 
   251         if len(c.fields) == 0:
   252             print "decode_properties(%d, _) ->" % (c.index,)
   253         else:
   254             print ("decode_properties(%d, %s) ->" %
   255                    (c.index, presentBin(c.fields)))
   256             for field in c.fields[:-1]:
   257                 writePropFieldLine(field)
   258             writePropFieldLine(c.fields[-1], "<<>>")
   259         print "  #'P_%s'{%s};" % (erlangize(c.name), fieldMapList(c.fields))
   260 
   261     def genFieldPreprocessing(packed):
   262         for f in packed:
   263             type = erlType(f.domain)
   264             if type == 'bit':
   265                 print "  F%dBits = (%s)," % \
   266                       (f.index,
   267                        ' bor '.join(['(bitvalue(F%d) bsl %d)' % (x.index, x.index - f.index)
   268                                      for x in f.contents]))
   269             elif type == 'table':
   270                 print "  F%dTab = rabbit_binary_generator:generate_table(F%d)," % (f.index, f.index)
   271                 print "  F%dLen = size(F%dTab)," % (f.index, f.index)
   272             elif type == 'shortstr':
   273                 print "  F%dLen = shortstr_size(F%d)," % (f.index, f.index)
   274             elif type == 'longstr':
   275                 print "  F%dLen = size(F%d)," % (f.index, f.index)
   276             else:
   277                 pass
   278 
   279     def genEncodeMethodFields(m):
   280         packedFields = packMethodFields(m.arguments)
   281         print "encode_method_fields(#%s{%s}) ->" % (m.erlangName(), fieldMapList(m.arguments))
   282         genFieldPreprocessing(packedFields)
   283         print "  <<%s>>;" % (', '.join([methodFieldFragment(f) for f in packedFields]))
   284 
   285     def genEncodeProperties(c):
   286         print "encode_properties(#'P_%s'{%s}) ->" % (erlangize(c.name), fieldMapList(c.fields))
   287         print "  rabbit_binary_generator:encode_properties(%s, %s);" % \
   288               (fieldTypeList(c.fields), fieldTempList(c.fields))
   289 
   290     def messageConstantClass(cls):
   291         # We do this because 0.8 uses "soft error" and 8.1 uses "soft-error".
   292         return erlangConstantName(cls)
   293 
   294     def genLookupException(c,v,cls):
   295         mCls = messageConstantClass(cls)
   296         if mCls == 'SOFT_ERROR': genLookupException1(c,'false')
   297         elif mCls == 'HARD_ERROR': genLookupException1(c, 'true')
   298         elif mCls == '': pass
   299         else: raise Exception('Unknown constant class' + cls)
   300 
   301     def genLookupException1(c,hardErrorBoolStr):
   302         n = erlangConstantName(c)
   303         print 'lookup_amqp_exception(%s) -> {%s, ?%s, <<"%s">>};' % \
   304               (n.lower(), hardErrorBoolStr, n, n)
   305 
   306     def genAmqpException(c,v,cls):
   307         n = erlangConstantName(c)
   308         print 'amqp_exception(?%s) -> %s;' % \
   309             (n, n.lower())
   310 
   311     methods = spec.allMethods()
   312 
   313     printFileHeader()
   314     module = "rabbit_framing_amqp_%d_%d" % (spec.major, spec.minor)
   315     if spec.revision != 0:
   316         module = "%s_%d" % (module, spec.revision)
   317     if module == "rabbit_framing_amqp_8_0":
   318         module = "rabbit_framing_amqp_0_8"
   319     print "-module(%s)." % module
   320     print """-include("rabbit_framing.hrl").
   321 
   322 -export([version/0]).
   323 -export([lookup_method_name/1]).
   324 -export([lookup_class_name/1]).
   325 
   326 -export([method_id/1]).
   327 -export([method_has_content/1]).
   328 -export([is_method_synchronous/1]).
   329 -export([method_record/1]).
   330 -export([method_fieldnames/1]).
   331 -export([decode_method_fields/2]).
   332 -export([decode_properties/2]).
   333 -export([encode_method_fields/1]).
   334 -export([encode_properties/1]).
   335 -export([lookup_amqp_exception/1]).
   336 -export([amqp_exception/1]).
   337 
   338 """
   339     print "%% Various types"
   340     print "-ifdef(use_specs)."
   341 
   342     print """-export_type([amqp_field_type/0, amqp_property_type/0,
   343               amqp_table/0, amqp_array/0, amqp_value/0,
   344               amqp_method_name/0, amqp_method/0, amqp_method_record/0,
   345               amqp_method_field_name/0, amqp_property_record/0,
   346               amqp_exception/0, amqp_exception_code/0, amqp_class_id/0]).
   347 
   348 -type(amqp_field_type() ::
   349       'longstr' | 'signedint' | 'decimal' | 'timestamp' |
   350       'table' | 'byte' | 'double' | 'float' | 'long' |
   351       'short' | 'bool' | 'binary' | 'void' | 'array').
   352 -type(amqp_property_type() ::
   353       'shortstr' | 'longstr' | 'octet' | 'shortint' | 'longint' |
   354       'longlongint' | 'timestamp' | 'bit' | 'table').
   355 
   356 -type(amqp_table() :: [{binary(), amqp_field_type(), amqp_value()}]).
   357 -type(amqp_array() :: [{amqp_field_type(), amqp_value()}]).
   358 -type(amqp_value() :: binary() |    % longstr
   359                       integer() |   % signedint
   360                       {non_neg_integer(), non_neg_integer()} | % decimal
   361                       amqp_table() |
   362                       amqp_array() |
   363                       byte() |      % byte
   364                       float() |     % double
   365                       integer() |   % long
   366                       integer() |   % short
   367                       boolean() |   % bool
   368                       binary() |    % binary
   369                       'undefined' | % void
   370                       non_neg_integer() % timestamp
   371      ).
   372 """
   373 
   374     print prettyType("amqp_method_name()",
   375                      [m.erlangName() for m in methods])
   376     print prettyType("amqp_method()",
   377                      ["{%s, %s}" % (m.klass.index, m.index) for m in methods],
   378                      6)
   379     print prettyType("amqp_method_record()",
   380                      ["#%s{}" % (m.erlangName()) for m in methods])
   381     fieldNames = set()
   382     for m in methods:
   383         fieldNames.update(m.arguments)
   384     fieldNames = [erlangize(f.name) for f in fieldNames]
   385     print prettyType("amqp_method_field_name()",
   386                      fieldNames)
   387     print prettyType("amqp_property_record()",
   388                      ["#'P_%s'{}" % erlangize(c.name) for c in spec.allClasses()])
   389     print prettyType("amqp_exception()",
   390                      ["'%s'" % erlangConstantName(c).lower() for (c, v, cls) in spec.constants])
   391     print prettyType("amqp_exception_code()",
   392                      ["%i" % v for (c, v, cls) in spec.constants])
   393     classIds = set()
   394     for m in spec.allMethods():
   395         classIds.add(m.klass.index)
   396     print prettyType("amqp_class_id()",
   397                      ["%i" % ci for ci in classIds])
   398     print prettyType("amqp_class_name()",
   399                      ["%s" % c.erlangName() for c in spec.allClasses()])
   400     print "-endif. % use_specs"
   401 
   402     print """
   403 %% Method signatures
   404 -ifdef(use_specs).
   405 -spec(version/0 :: () -> {non_neg_integer(), non_neg_integer(), non_neg_integer()}).
   406 -spec(lookup_method_name/1 :: (amqp_method()) -> amqp_method_name()).
   407 -spec(lookup_class_name/1 :: (amqp_class_id()) -> amqp_class_name()).
   408 -spec(method_id/1 :: (amqp_method_name()) -> amqp_method()).
   409 -spec(method_has_content/1 :: (amqp_method_name()) -> boolean()).
   410 -spec(is_method_synchronous/1 :: (amqp_method_record()) -> boolean()).
   411 -spec(method_record/1 :: (amqp_method_name()) -> amqp_method_record()).
   412 -spec(method_fieldnames/1 :: (amqp_method_name()) -> [amqp_method_field_name()]).
   413 -spec(decode_method_fields/2 ::
   414         (amqp_method_name(), binary()) -> amqp_method_record() | rabbit_types:connection_exit()).
   415 -spec(decode_properties/2 :: (non_neg_integer(), binary()) -> amqp_property_record()).
   416 -spec(encode_method_fields/1 :: (amqp_method_record()) -> binary()).
   417 -spec(encode_properties/1 :: (amqp_property_record()) -> binary()).
   418 -spec(lookup_amqp_exception/1 :: (amqp_exception()) -> {boolean(), amqp_exception_code(), binary()}).
   419 -spec(amqp_exception/1 :: (amqp_exception_code()) -> amqp_exception()).
   420 -endif. % use_specs
   421 
   422 bitvalue(true) -> 1;
   423 bitvalue(false) -> 0;
   424 bitvalue(undefined) -> 0.
   425 
   426 shortstr_size(S) ->
   427     case size(S) of
   428         Len when Len =< 255 -> Len;
   429         _                   -> exit(method_field_shortstr_overflow)
   430     end.
   431 
   432 -define(SHORTSTR_PROP(P, R, L, S, X),
   433         if P =:= 0 -> {undefined, R};
   434            true    -> <<L:8/unsigned, S:L/binary, X/binary>> = R,
   435                       {S, X}
   436         end).
   437 -define(TABLE_PROP(P, R, L, T, X),
   438         if P =:= 0 -> {undefined, R};
   439            true    -> <<L:32/unsigned, T:L/binary, X/binary>> = R,
   440                       {rabbit_binary_parser:parse_table(T), X}
   441         end).
   442 -define(OCTET_PROP(P, R, I, X),
   443         if P =:= 0 -> {undefined, R};
   444            true    -> <<I:8/unsigned, X/binary>> = R,
   445                       {I, X}
   446         end).
   447 -define(TIMESTAMP_PROP(P, R, I, X),
   448         if P =:= 0 -> {undefined, R};
   449            true    -> <<I:64/unsigned, X/binary>> = R,
   450                       {I, X}
   451         end).
   452 """
   453     version = "{%d, %d, %d}" % (spec.major, spec.minor, spec.revision)
   454     if version == '{8, 0, 0}': version = '{0, 8, 0}'
   455     print "version() -> %s." % (version)
   456 
   457     for m in methods: genLookupMethodName(m)
   458     print "lookup_method_name({_ClassId, _MethodId} = Id) -> exit({unknown_method_id, Id})."
   459 
   460     for c in spec.allClasses(): genLookupClassName(c)
   461     print "lookup_class_name(ClassId) -> exit({unknown_class_id, ClassId})."
   462 
   463     for m in methods: genMethodId(m)
   464     print "method_id(Name) -> exit({unknown_method_name, Name})."
   465 
   466     for m in methods: genMethodHasContent(m)
   467     print "method_has_content(Name) -> exit({unknown_method_name, Name})."
   468 
   469     for m in methods: genMethodIsSynchronous(m)
   470     print "is_method_synchronous(Name) -> exit({unknown_method_name, Name})."
   471 
   472     for m in methods: genMethodRecord(m)
   473     print "method_record(Name) -> exit({unknown_method_name, Name})."
   474 
   475     for m in methods: genMethodFieldNames(m)
   476     print "method_fieldnames(Name) -> exit({unknown_method_name, Name})."
   477 
   478     for m in methods: genDecodeMethodFields(m)
   479     print "decode_method_fields(Name, BinaryFields) ->"
   480     print "  rabbit_misc:frame_error(Name, BinaryFields)."
   481 
   482     for c in spec.allClasses(): genDecodeProperties(c)
   483     print "decode_properties(ClassId, _BinaryFields) -> exit({unknown_class_id, ClassId})."
   484 
   485     for m in methods: genEncodeMethodFields(m)
   486     print "encode_method_fields(Record) -> exit({unknown_method_name, element(1, Record)})."
   487 
   488     for c in spec.allClasses(): genEncodeProperties(c)
   489     print "encode_properties(Record) -> exit({unknown_properties_record, Record})."
   490 
   491     for (c,v,cls) in spec.constants: genLookupException(c,v,cls)
   492     print "lookup_amqp_exception(Code) ->"
   493     print "  rabbit_log:warning(\"Unknown AMQP error code '~p'~n\", [Code]),"
   494     print "  {true, ?INTERNAL_ERROR, <<\"INTERNAL_ERROR\">>}."
   495 
   496     for(c,v,cls) in spec.constants: genAmqpException(c,v,cls)
   497     print "amqp_exception(_Code) -> undefined."
   498 
   499 def genHrl(spec):
   500     def erlType(domain):
   501         return erlangTypeMap[spec.resolveDomain(domain)]
   502 
   503     def fieldNameList(fields):
   504         return ', '.join([erlangize(f.name) for f in fields])
   505 
   506     def fieldNameListDefaults(fields):
   507         def fillField(field):
   508             result = erlangize(f.name)
   509             if field.defaultvalue != None:
   510                 conv_fn = erlangDefaultValueTypeConvMap[type(field.defaultvalue)]
   511                 result += ' = ' + conv_fn(field.defaultvalue)
   512             return result
   513         return ', '.join([fillField(f) for f in fields])
   514 
   515     methods = spec.allMethods()
   516 
   517     printFileHeader()
   518     print "-define(PROTOCOL_PORT, %d)." % (spec.port)
   519 
   520     for (c,v,cls) in spec.constants:
   521         print "-define(%s, %s)." % (erlangConstantName(c), v)
   522 
   523     print "%% Method field records."
   524     for m in methods:
   525         print "-record(%s, {%s})." % (m.erlangName(), fieldNameListDefaults(m.arguments))
   526 
   527     print "%% Class property records."
   528     for c in spec.allClasses():
   529         print "-record('P_%s', {%s})." % (erlangize(c.name), fieldNameList(c.fields))
   530 
   531 
   532 def generateErl(specPath):
   533     genErl(AmqpSpec(specPath))
   534 
   535 def generateHrl(specPath):
   536     genHrl(AmqpSpec(specPath))
   537 
   538 if __name__ == "__main__":
   539     do_main_dict({"header": generateHrl,
   540                   "body": generateErl})
   541